初见安~这里是传送门:P7423 「PMOI-2」简单构造题
题解
好像确实是第一步难想…其他的都比较套路了……
显然我们要脱离枚举区间左右端点这种策略,尝试找到某个共性可以将一些区间收起来。考虑枚举区间长度。假设表示区间长度为i的所有区间对总答案的贡献。则答案就是:
含义为,这个区间首先有n-i+1个位置,其次另外的n-i个位置可以随便放,这就是这个区间出现的所有情况。因为题意就是枚举了所有情况的所有区间,所以对于所求不会有重复计算的情况。
接下来我们就是算g了。因为现在要往里面放各个颜色的球,总数固定,显然可以想到生成函数;又因为题意不是单纯的划分集合,而是前后有序,所以我们搞一个对于某一个颜色(不管标号),我们选择的EGF:
那么关于g的生成函数就是:
这就是经典套路了。我们可以先On求出F,然后得到lnF(x)。因为对于每一个i,F的第k项都个相当于乘上了一个i^k,所以F的第k项要乘一个:
这就是一个经典的自然数幂和了。但是i是连续的,所以没必要分治NTT。我们可以再次写成EGF:
e^x和e^{(m+1)x}分别是两个多项式,展开n项后下面求个inv就是T了。
但是有个问题是,-1后常数项为0了,无法求inv。这里用到一个技巧是:上下两个多项式同时除以x,也就是每一项都往前移一步。这样操作后多项式T的值不会变。
于是这个题就解决掉啦。又是一个多项式全家桶:)
上代码——
#include<algorithm>
#include<iostream>
#include<cstring>
#include<vector>
#include<cstdio>
#include<cmath>
#include<queue>
#define maxn 1000006
using namespace std;
typedef long long ll;
const int mod = 998244353, mx = 3e5;
int read() {
int x = 0, f = 1, ch = getchar();
while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
while(isdigit(ch)) x = (x << 1) + (x << 3) + ch - '0', ch = getchar();
return x * f;
}
int l, len, r[maxn];
void init(int n) {
len = 1, l = 0;
while(len <= n) l++, len <<= 1;
for(int i = 1; i <= len; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << l - 1);
}
int pw(int a, int b) {int res = 1; while(b) {if(b & 1) res = 1ll * res * a % mod; a = 1ll * a * a % mod, b >>= 1;} return res;}
void NTT(int *c, int flag) {
for(int i = 1; i < len; i++) if(i < r[i]) swap(c[i], c[r[i]]);
for(int mid = 1; mid < len; mid <<= 1) {
register int gn = pw(3, (mod - 1) / (mid << 1));
if(flag == -1) gn = pw(gn, mod - 2);
for(int ls = 0, L = mid << 1; ls < len; ls += L)
for(int k = 0, g = 1; k < mid; k++, g = 1ll * g * gn % mod) {
register int x = c[ls + k], y = 1ll * c[ls + mid + k] * g % mod;
c[ls + k] = (x + y) % mod, c[ls + mid + k] = (x - y + mod) % mod;
}
}
if(flag == -1) {
register int rev = pw(len, mod - 2);
for(int i = 0; i < len; i++) c[i] = 1ll * c[i] * rev % mod;
}
}
int n, m, F[maxn], f[maxn], A[maxn], B[maxn], tmp[maxn], fac[maxn], inv[maxn];
void get_inv(int *a, int *b, int n) {
if(n == 1) {b[0] = pw(a[0], mod - 2); return;}
get_inv(a, b, n + 1 >> 1); init(n + n);
for(int i = 0; i < n; i++) tmp[i] = a[i];
for(int i = n; i <= len; i++) tmp[i] = 0;
NTT(tmp, 1); NTT(b, 1);
for(int i = 0; i < len; i++) b[i] = 1ll * b[i] * (2ll - 1ll * b[i] * tmp[i] % mod + mod) % mod;
NTT(b, -1); for(int i = n; i <= len; i++) b[i] = 0;
}
void deriv(int *a, int n) {for(int i = 1; i < n; i++) a[i - 1] = 1ll * a[i] * i % mod; a[n - 1] = 0;}
void integ(int *a, int n) {for(int i = n - 1; i; i--) a[i] = 1ll * a[i - 1] * pw(i, mod - 2) % mod; a[0] = 0;}
void get_ln(int *a, int *b, int n) {
get_inv(a, b, n);
for(int i = 0; i < n; i++) tmp[i] = a[i];
for(int i = n; i <= len; i++) tmp[i] = 0;
deriv(tmp, n); init(n + n);
NTT(b, 1), NTT(tmp, 1);
for(int i = 0; i < len; i++) b[i] = 1ll * tmp[i] * b[i] % mod;
NTT(b, -1); integ(b, n);
}
void get_exp(int *a, int *b, int n) {
if(n == 1) {b[0] = 1; return;}
get_exp(a, b, n + 1 >> 1);
int c[maxn]; get_ln(b, c, n);
c[0] = (1ll - c[0] + a[0] + mod) % mod;
for(int i = 1; i < n; i++) c[i] = (a[i] - c[i] + mod) % mod;
NTT(c, 1), NTT(b, 1);
for(int i = 0; i < len; i++) b[i] = 1ll * b[i] * c[i] % mod;
NTT(b, -1);
for(int i = n; i <= len; i++) b[i] = 0;
}
signed main() {
n = read(), m = read();
fac[0] = inv[0] = 1;
for(int i = 1; i <= mx; i++) fac[i] = 1ll * fac[i - 1] * i % mod;
inv[mx] = pw(fac[mx], mod - 2);
for(int i = mx - 1; i; i--) inv[i] = 1ll * inv[i + 1] * (i + 1) % mod;
for(int i = 1; i <= n; i++) f[i] = inv[i - 1]; f[0] = 1;
get_ln(f, F, n + 1);//F: lnF(x)
register int M = m + 1;
for(int i = 0; i <= n; i++) A[i] = 1ll * inv[i + 1] * M % mod, M = 1ll * M * (m + 1) % mod;
for(int i = 0; i <= n; i++) B[i] = inv[i + 1];//A上B下
memset(f, 0, sizeof f); get_inv(B, f, n + 1);
NTT(f, 1), NTT(A, 1);
for(int i = 0; i < len; i++) A[i] = 1ll * A[i] * f[i] % mod;
NTT(A, -1);//A是T(x)
for(int i = 0; i <= n; i++) F[i] = 1ll * F[i] * A[i] % mod * fac[i] % mod;//EGF记得乘一个i!回去。
memset(B, 0, sizeof B);
get_exp(F, B, n + 1);
int ans = 0;
for(int i = 1; i <= n; i++) ans = (ans + 1ll * B[i] * (n - i + 1) % mod * pw(m, n - i) % mod * fac[i] % mod) % mod;
printf("%d\n", 1ll * ans * pw(pw(m, n), mod - 2) % mod);//最后因为是期望,记得除去总方案数。
return 0;
}//623902740
迎评:)
——End——