2019杭电多校1012Sequence(生成函数&卷积NTT)
题目大意
给出一个长度为n的序列 { a n } \{a_n\} {an},现在定义一种‘k操作’其意为有 b i = ∑ j = i − k ∗ x a j ( 0 ≤ x , 1 ≤ j ≤ i ) b_i=\sum_{j=i-k*x}a_j(0\le x,1\le j \le i) bi=∑j=i−k∗xaj(0≤x,1≤j≤i)接下来再将所有的 a i a_i ai用相应位置的 b i b_i bi代替,问经过m次k ( 1 ≤ k ≤ 3 ) (1\le k\le 3) (1≤k≤3)变换后的序列是怎么样的。
解题思路
设一生成函数 f ( x ) = ∑ i = 1 n a i x i f(x)=\sum_{i=1}^na_ix^i f(x)=∑i=1naixi则作一次k变化后的函数可以表示为 ( ∑ i = 1 n a i x i ) ( ∑ i = 0 x k i ) (\sum_{i=1}^na_ix^i)(\sum_{i=0}x^{ki}) (∑i=1naixi)(∑i=0xki)的前n项,那么新的 { b n } \{b_n\} {bn}也就是这个新的多项式的前n项的系数。同时可以发现变换的结果与变换的先后顺序无关,故可以所有变换的次数统计并加上相应次数的多项式即可
则要求的目标多项式就是 ( ∑ i = 1 n a i x i ) ( ∑ i = 0 x i ) c n t [ 1 ] ( ∑ i = 0 x 2 i ) c n t [ 2 ] ( ∑ i = 0 x 3 i ) c n t [ 3 ] (\sum_{i=1}^na_ix^i)(\sum_{i=0}x^{i})^{cnt[1]}(\sum_{i=0}x^{2i})^{cnt[2]}(\sum_{i=0}x^{3i})^{cnt[3]} (∑i=1naixi)(∑i=0xi)cnt[1](∑i=0x2i)cnt[2](∑i=0x3i)cnt[3]
其中根据二项式展开系数的推导原理可以推导出有
(
∑
i
=
1
x
i
)
n
=
∑
i
=
1
C
n
+
i
−
1
i
x
i
(\sum_{i=1}x^i)^n=\sum_{i=1}C_{n+i-1}^{i}x^i
(i=1∑xi)n=i=1∑Cn+i−1ixi
据此进行卷积即可
AC代码
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int mod=998244353;
const int g=3;
const int size=5e6+5;
const int lens=5e6+5;
int fac[size],invfac[size];
int quick_pow(int a,int b){int ans=1;while(b){if(b&1)ans=1LL*ans*a%mod;a=1LL*a*a%mod;b>>=1;}return ans;}
void init()
{
fac[0]=1;
for(int i=1;i<size;i++) fac[i]=1LL*fac[i-1]*i%mod;
invfac[size-1]=quick_pow(fac[size-1],mod-2);
for(int i=size-2;i>=1;i--) invfac[i]=1LL*invfac[i+1]*(i+1)%mod;
}
int combi(int a,int b)
{
if(a==0) return 1;
if(a<0) return 0;
return 1LL*fac[b]*invfac[a]%mod*invfac[b-a]%mod;
}
int rev[lens];
void ntt(int a[],int n,int inv)
{
int bit=0;
while((1<<bit)<n) bit++;
for(int i=0;i<n;i++)
{
rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<n;mid*=2)
{
int e=quick_pow(g,(mod-1)/(2*mid));
for(int i=0;i<n;i+=mid*2)
{
int omega=1;
for(int j=0;j<mid;j++,omega=1LL*omega*e%mod)
{
int x=a[i+j],y=1LL*a[i+j+mid]*omega%mod;
a[i+j]=(x+y)%mod;a[i+j+mid]=(x-y+mod)%mod;
}
}
}
if(inv==1) return ;
int nv=quick_pow(n,mod-2);reverse(a+1,a+n);
for(int i=0;i<n;i++)a[i]=1LL*a[i]*nv%mod;
}
int a[lens];
int b[5][lens];
int cnt[5];
int aa[lens],bb[lens];
void conv(int a[],int b[],int lena,int lenb)
{
int lens=0;
while((1<<lens)<lena+lenb-1)lens++;
lens=(1<<lens);
memset(aa,0,sizeof(aa));
memset(bb,0,sizeof(bb));
for(int i=0;i<lena;i++) aa[i]=a[i];
for(int i=0;i<lenb;i++) bb[i]=b[i];
ntt(aa,lens,1),ntt(bb,lens,1);
for(int i=0;i<lens;i++)
{
aa[i]=1LL*aa[i]*bb[i]%mod;
}
ntt(aa,lens,-1);
for(int i=0;i<lena;i++)
{
a[i]=aa[i];
}
}
int main()
{
init();
int t;
scanf("%d",&t);
int n,m;
while(t--)
{
scanf("%d%d",&n,&m);
for(int i=0;i<n;i++) scanf("%d",&a[i]);
int c;
for(int i=1;i<=3;i++) cnt[i]=0;
for(int i=1;i<=m;i++)
{
scanf("%d",&c);
cnt[c]++;
}
for(int i=1;i<=3;i++)
{
if(!cnt[i]) continue;
for(int j=0;j<n;j+=i)
b[i][j]=combi(j/i,cnt[i]-1+j/i);
conv(a,b[i],n,n);
}
LL ans=0;
for(int i=0;i<n;i++)
{
ans=ans^(1LL*(i+1)*a[i]);
}
printf("%lld\n",ans);
}
}
/*
2
5 2
2 4 2 1 1
1 1
5 2
3 2 2 4 1
2 2
*/