显然需要用线段树维护一个数组p[c]表示选c个数相乘的方案总和。合并的时候只要枚举c,然后枚举左边选几个数然后和右边的乘起来累加进去就好了。取反实际上就是把所有c为奇数的取反,关键是区间加。
对于[l,r],区间+x,那么枚举c。注意到形式是这样的:
newp[c]=Σ(a1+x)(a2+x)...(ac+x),然后按照x的次数合并同类项,可以发现这个东西可以通过x^k,组合数C(r-l+1,k)和oldp[c-k]得到答案。k为x的次数。
其余就是普通的线段树了。
这样时间复杂度O(NlogNc^2),注意到c^2实际上就是求卷积的过程,可以加速到clogc。不过考虑到FFT的大常数估计这样会变慢。。。。
AC代码如下:
#include<iostream>
#include<cstdio>
#include<cstring>
#define mod 19940417
#define M 140005
#define ll long long
using namespace std;
int n,m,sz[M],icr[M],cbn[50005][21]; struct node{ int p[21]; }sum[M]; bool rev[M];
int read(){
int x=0,fu=1; char ch=getchar();
while (ch<'0' || ch>'9'){ if (ch=='-') fu=-1; ch=getchar(); }
while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); }
return x*fu;
}
void ad(int &x,int y){ x+=y; if (x>=mod) x-=mod; }
void maintain(int k){
int i,j;
for (i=1; i<=20; i++){
sum[k].p[i]=0;
for (j=1; j<i; j++) ad(sum[k].p[i],(ll)sum[k<<1].p[j]*sum[k<<1|1].p[i-j]%mod);
ad(sum[k].p[i],sum[k<<1].p[i]); ad(sum[k].p[i],sum[k<<1|1].p[i]);
}
}
void ins(int k,int x){
int i,j,y; ad(icr[k],x);
for (i=20; i; i--){
y=x;
for (j=i-1; j; j--,y=(ll)y*x%mod)
ad(sum[k].p[i],(ll)y*sum[k].p[j]%mod*cbn[sz[k]-j][i-j]%mod);
ad(sum[k].p[i],(ll)y*cbn[sz[k]][i]%mod);
}
}
void turn(int k){
int i; rev[k]^=1; if (icr[k]) icr[k]=mod-icr[k];
for (i=19; i>0; i-=2) if (sum[k].p[i]) sum[k].p[i]=mod-sum[k].p[i];
}
void pushdown(int k){
if (rev[k]){
turn(k<<1); turn(k<<1|1); rev[k]=0;
}
if (icr[k]){
ins(k<<1,icr[k]); ins(k<<1|1,icr[k]); icr[k]=0;
}
}
void build(int k,int l,int r){
sz[k]=r-l+1;
if (l==r){ sum[k].p[1]=read()%mod; return; }
int mid=(l+r)>>1;
build(k<<1,l,mid); build(k<<1|1,mid+1,r); maintain(k);
}
void mdy(int k,int l,int r,int x,int y,int v){
if (l==x && r==y){ ins(k,v); return; }
int mid=(l+r)>>1; pushdown(k);
if (y<=mid) mdy(k<<1,l,mid,x,y,v); else
if (x>mid) mdy(k<<1|1,mid+1,r,x,y,v); else{
mdy(k<<1,l,mid,x,mid,v); mdy(k<<1|1,mid+1,r,mid+1,y,v);
}
maintain(k);
}
void ovr(int k,int l,int r,int x,int y){
if (l==x && r==y){ turn(k); return; }
int mid=(l+r)>>1; pushdown(k);
if (y<=mid) ovr(k<<1,l,mid,x,y); else
if (x>mid) ovr(k<<1|1,mid+1,r,x,y); else{
ovr(k<<1,l,mid,x,mid); ovr(k<<1|1,mid+1,r,mid+1,y);
}
maintain(k);
}
node qry(int k,int l,int r,int x,int y,int z){
if (l==x && r==y) return sum[k];
int mid=(l+r)>>1; pushdown(k);
if (y<=mid) return qry(k<<1,l,mid,x,y,z); else
if (x>mid) return qry(k<<1|1,mid+1,r,x,y,z); else{
node t1=qry(k<<1,l,mid,x,mid,z),t2=qry(k<<1|1,mid+1,r,mid+1,y,z),t;
int i,j;
for (i=1; i<=z; i++){
t.p[i]=(t1.p[i]+t2.p[i])%mod;
for (j=1; j<i; j++) ad(t.p[i],(ll)t1.p[j]*t2.p[i-j]%mod);
}
return t;
}
}
int main(){
n=read(); m=read(); int i,j,x,y,z;
cbn[0][0]=1;
for (i=1; i<=n; i++){
cbn[i][0]=1;
for (j=1; j<=i && j<=20; j++) cbn[i][j]=(cbn[i-1][j-1]+cbn[i-1][j])%mod;
}
build(1,1,n); char ch;
while (m--){
ch=getchar(); while (ch<'A' || ch>'Z') ch=getchar();
if (ch=='I'){
x=read(); y=read(); z=read()%mod;
if (z<0) z+=mod; mdy(1,1,n,x,y,z);
} else if (ch=='R'){
x=read(); y=read(); ovr(1,1,n,x,y);
} else{
x=read(); y=read(); z=read();
printf("%d\n",qry(1,1,n,x,y,z).p[z]);
}
}
return 0;
}
by lych
2016.4.6