线段树log^2的复杂度很显然,怎么优化?其实我们知道一个有序的数组从根往下传是依然有序,只是位置变化了而已,那么我们从根开始维护每个有序数组往下传的时候每个位置在子树中的位置,所以只需要开始的时候二分一下即可。
#include<cstdio>
#include<cstring>
#include<iostream>
#include <algorithm>
#include<vector>
using namespace std;
const int maxn=100005;
const int mod=1000000007;
int A,B;
int a,b,C,M;
void init(){
a = A;
b = B;
C = ~(1<<31);
M = (1<<16)-1;
}
int rnd(int last) {
a = (36969 + (last >> 3)) * (a & M) + (a >> 16);
b = (18000 + (last >> 3)) * (b & M) + (b >> 16);
return (C & ((a << 16) + b)) % 1000000000;
}
typedef long long LL;
int w[maxn];
int x[maxn];
struct pi{
int all,lazy,no;
vector<int>l1;
vector<int>r1;
}pp[maxn<<2];
struct p1{
int a;
int id;
int pos;
}pp1[maxn];
int cmp(p1 a,p1 b){
return a.a<b.a;
}
void build(int tot,int l,int r){
pp[tot].lazy=0;
pp[tot].l1.clear();
pp[tot].r1.clear();
pp[tot].all=0;
if(l==r){
pp[tot].lazy=x[l];
if(x[l]>=w[l]){
pp[tot].all=1;
}
return;
}
pp[tot].l1.push_back(0);
pp[tot].r1.push_back(0);
int cnt=0;
int mid=(l+r)/2;
build(2*tot,l,mid);
build(2*tot+1,mid+1,r);
for(int i=1;i<=mid-l+1;i++){
pp1[cnt].a=w[l+i-1];
pp1[cnt].id=-1;
pp1[cnt].pos=i;
cnt++;
}
for(int i=1;i<=r-mid;i++){
pp1[cnt].a=w[mid+i];
pp1[cnt].id=1;
pp1[cnt].pos=i;
cnt++;
}
sort(pp1,pp1+cnt,cmp);
for(int i=0;i<cnt;i++){
int w=pp[tot].l1[i];
int q=pp[tot].r1[i];
if(pp1[i].id==-1){
w=max(w,pp1[i].pos);
}
else{
q=max(q,pp1[i].pos);
}
pp[tot].l1.push_back(w);
pp[tot].r1.push_back(q);
}
pp[tot].all=pp[2*tot].all+pp[2*tot+1].all;
sort(w+l,w+r+1);
}
void push(int tot){
if(pp[tot].lazy){
pp[2*tot].lazy=pp[tot].lazy;
pp[2*tot].no=pp[tot].l1[pp[tot].no];
pp[2*tot].all=pp[2*tot].no;
pp[2*tot+1].lazy=pp[tot].lazy;
pp[2*tot+1].no=pp[tot].r1[pp[tot].no];
pp[2*tot+1].all=pp[2*tot+1].no;
pp[tot].lazy=0;
pp[tot].all=pp[2*tot].all+pp[2*tot+1].all;
}
}
void merg(int tot,int l,int r,int p,int pos,int le,int ri){
if(le>=l&&ri<=r){
pp[tot].lazy=p;
pp[tot].all=pos;
pp[tot].no=pos;
return;
}
push(tot);
int mid=(le+ri)/2;
if(l<=mid) merg(2*tot,l,r,p,pp[tot].l1[pos],le,mid);
if(r>mid) merg(2*tot+1,l,r,p,pp[tot].r1[pos],mid+1,ri);
pp[tot].all=pp[2*tot].all+pp[2*tot+1].all;
}
int query(int tot,int l,int r,int le,int ri){
if(le>=l&&ri<=r) return pp[tot].all;
push(tot);
int mid=(le+ri)/2;
int s=0;
if(l<=mid)
s+=query(2*tot,l,r,le,mid);
if(r>mid) s+=query(2*tot+1,l,r,mid+1,ri);
return s;
}
int main()
{
int t,n,m;
cin>>t;
while(t--){
scanf("%d%d%d%d",&n,&m,&A,&B);
init();
for(int i=1;i<=n;i++){
scanf("%d",&x[i]);
}
for(int i=1;i<=n;i++) scanf("%d",&w[i]);
int last=0;
int l,r,x;
int ans=0;
build(1,1,n);
for(int i=1;i<=m;i++){
l = rnd(last) % n + 1;
r = rnd(last) % n + 1;
x = rnd(last) + 1;
if(l>r) swap(l,r);
if((l+r+x)%2==0){
last=query(1,l,r,1,n);
ans+=((LL)i*last)%mod;
ans%=mod;
}
else{
int le=1,ri=n;
while(le<=ri){
int mid=(le+ri)/2;
if(w[mid]<=x) le=mid+1;
else ri=mid-1;
}
merg(1,l,r,x,ri,1,n);
}
}
printf("%d\n",ans);
}
}