Description:
你在一个有n 个点的环上,环上点按逆时针顺序标号为0 到n - 1。你一
开始在0 号点。你在每一回合可以使用k 种传送中的一种,第i 种传送会将你
按逆时针方向移动ai 个点。有m 个限制条件,对于每个限制条件(xi; yi),要
求不能在第xi 步之后在yi 号点上。你要求出经过l 步之后在0 号点的方案数
模998244353。
题解:
直接NTT?
O(mnlog2n) 成功拿到60分。
考虑优化一下.
现在把转移数组看作c.
要求 ck 。
正常做法:快速幂NTT
一次复杂度: O(nlog2n)
在这道题中,n是二的整次幂,所以模n刚好回到原位,其实有:
DFT(c∗c)=DFT(c)∗DFT(c)
因此可以先对c进行DFT点值运算,搞个k次幂,再插值回来。
这为什么是对的?
还记得为什么FFT要开两倍。
因为它实际上是一个循环卷积。
c=a∗b
c(i+j) mod n=∑n−1i=0∑n−1j=0a[i]∗b[j]
随便你用点积自我乘个无数遍,它都会刚好溢出,溢出就是mo个次数界,这里的次数界=n*2,刚好符合我们的需求。
因此复杂度降为 O(mnlogn)
Code:
#include<cstdio>
#include<algorithm>
#define ll long long
#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define ff(i, x, y) for(int i = x; i < y; i ++)
using namespace std;
const int N = 2e5 + 5;
const ll mo = 998244353;
int n, l, m, k, x;
ll s[N], b[N], c[N];
struct node {
int x, y;
} a[N];
ll w[N], tx;
ll ksm(ll x, ll y) {
ll s = 1;
for(; y; y /= 2, x = x * x % mo)
if(y & 1) s = s * x % mo;
return s;
}
void dft(ll *a, int n) {
ff(i, 0, n) {
int p = i, q = 0;
fo(j, 1, tx) q = q * 2 + p % 2, p /= 2;
if(q > i) swap(a[q], a[i]);
}
for(int m = 2; m <= n; m *= 2) {
int h = m / 2;
ff(i, 0, h) {
ll W = w[i * (n / m)];
for(int j = i; j < n; j += m) {
int k = j + h;
ll u = a[j], v = a[k] * W % mo;
a[j] = (u + v) % mo; a[k] = (u - v + mo) % mo;
}
}
}
}
ll ni;
void fft(ll *a, ll *b, int n) {
dft(a, n); dft(b, n); ff(i, 0, n) a[i] = a[i] * b[i] % mo;
fo(i, 0, n / 2) swap(w[i], w[n - i]);
dft(a, n); ff(i, 0, n) a[i] = a[i] * ni % mo;
fo(i, 0, n / 2) swap(w[i], w[n - i]);
}
int cmp(node a, node b) {
return a.x < b.x;
}
int main() {
scanf("%d %d", &n, &l);
scanf("%d", &m);
fo(i, 1, m) scanf("%d %d", &a[i].x, &a[i].y);
scanf("%d", &k);
fo(i, 1, k) {
scanf("%d", &x);
b[x] ++;
}
while(1 << tx ++ < n) tx ++;
sort(a + 1, a + m + 1, cmp);
s[0] = 1; a[0].x = 0; a[m + 1].x = l;
int n0 = n;
n = 1 << tx; ll v = ksm(3, (mo - 1) / n);
w[0] = 1; fo(i, 1, n) w[i] = w[i - 1] * v % mo;
dft(b, n);
ni = ksm(n, mo - 2);
fo(i, 1, m + 1) if(i == 1 || a[i].x != a[i - 1].x) {
ff(j, 0, n) c[j] = ksm(b[j], a[i].x - a[i - 1].x);
dft(s, n);
ff(j, 0, n) s[j] = s[j] * c[j] % mo;
fo(j, 0, n / 2) swap(w[j], w[n - j]);
dft(s, n);
ff(j, 0, n) s[j] = s[j] * ni % mo;
fo(j, 0, n / 2) swap(w[j], w[n - j]);
ff(j, n0, n) s[j % n0] = (s[j % n0] + s[j]) % mo, s[j] = 0;
int l = i;
while(l <= m && a[l].x == a[i].x) {
s[a[l].y] = 0;
l ++;
}
}
printf("%lld", s[0]);
}