题目:wannfly 11 E 白兔的刁难
题意:
给定
n
n
n,
k
k
k,对于
t
∈
[
0
,
k
)
t\in [0,k)
t∈[0,k)
求
a
n
s
t
ans_t
anst =
∑
k
∣
i
,
0
≤
i
+
t
≤
n
,
,
(
n
i
+
t
)
\sum_{k|i,0 \le i+t \le n,},\binom{n}{i+t}
∑k∣i,0≤i+t≤n,,(i+tn)
sol:
官方题解:
题目要求的实际上就是,对于任意 t,把 ( 1 + x ) n (1 + x)^n (1+x)n展开后,x 的指数模 k 为 t的所有项的系数之和。那么我们可以利用 ω ? ? = 1 ω_?^? = 1 ωnn=1,带入 x = ω ? 1 x=ω_?^1 x=ωn1。那么 x 的指数模k相同的项就会自动合并。
把1 + x进行长度为 k 的 DFT,然后每个数求 n 次幂,再 IDFT 回来就是答案。
复杂度 k l o g m i n ( n , 998244352 ) + k l o g k klogmin(n,998244352)+klogk klogmin(n,998244352)+klogk。既然是求 n 次幂,那么 n 可以直接对
998244352 取模,那个 10{106}是吓唬人的。
对于长度为 k k k( k k k为2的次幂),求循环卷积只需做长度为 k k k的 D F T DFT DFT + I D F T IDFT IDFT。用 D F T DFT DFT转换为点值后直接快速幂后再 I D F T IDFT IDFT后就是答案。用扩展欧拉定理处理幂次即可。
code:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 2e6+10;
const int mod = 998244353;
const int P = mod;
const int _mod = mod -1;
int g = 3;
char s[maxn];
inline void Add(ll& x,ll y,int _P){
x += y;
if(x >= P) x -= _P;
}
inline void Mul(ll& x,ll y,int _P){
x *= y;
if(x >= P) x %= _P;
}
int qpow(ll a, ll b){
ll sum = 1;
while (b){
if (b & 1) Mul(sum,a,mod);
b >>= 1;
Mul(a,a,mod);
}
return sum;
}
int Inv(ll a){
return qpow(a,mod - 2);
}
struct NTT
{
int rev[maxn], dig[maxn];
int N, L;
void init_rev(int n){
for(N=1,L=0;N<n;N<<=1,L++);
L--;
for(int i = 1;i<N;i++) rev[i] = (rev[i>>1]>>1)|((i&1)<<L);
}
void DFT(int a[], int flag)
{
for (int i = 0; i < N; i++)
if (i<rev[i]) swap(a[i], a[rev[i]]);
for (int l = 2; l <= N; l <<= 1){
int wn;
if (flag == 1)
wn = qpow(g, (P - 1) / l);
else
wn = qpow(g, P - 1 - (P - 1) / l);
for (int k = 0; k < N; k += l){
int w = 1;
int x, y;
for (int j = k; j < k + l / 2; j++)
{
x = a[j];
y = (ll)a[j + l / 2] * w % P;
a[j] = (x + y) % P;
a[j + l / 2] = (x - y + P) % P;
w = (ll)w * wn % P;
}
}
}
if (flag == -1)
{
int inv = Inv(N);
for (int i = 0; i < N; i++)
a[i] = 1LL * a[i] * inv % P;
}
}
}ntt;
int a[maxn];
int main(){
scanf("%s",s);
ll n = 0;
int len = strlen(s);
for(int i = 0;i<len;i++){
Mul(n,10,_mod);
Add(n,s[i] - '0',_mod);
}
int k;
scanf("%d",&k);
ntt.init_rev(k);
a[0] = a[1] = 1;
ntt.DFT(a,1);
for(int i = 0;i<k;i++) a[i] = qpow(a[i],n);
ntt.DFT(a,-1);
int ans = 0;
for(int i = 0;i<k;i++) {
// cout<<i<<' '<<a[i]<<endl;
ans ^= a[i];
}
printf("%d\n",ans);
return 0;
}