链接:
题解:
首先考虑如果确定了非零数的相对顺序怎么做。
考虑从左到右给每个
wi
贪心地分配离他最近的
wi−1
个
0
,那么我们记
g
的含义就是将
这样就可以用一个状压DP来求方案数了。
注意到最后答案是一堆
代码:
#include <bits/stdc++.h>
#define xx first
#define yy second
#define mp make_pair
#define pb push_back
#define mset(x, y) memset(x, y, sizeof x)
#define mcpy(x, y) memcpy(x, y, sizeof x)
using namespace std;
typedef long long LL;
typedef pair <int, int> pii;
inline int Read()
{
int x = 0, f = 1, c = getchar();
for (; !isdigit(c); c = getchar())
if (c == '-')
f = -1;
for (; isdigit(c); c = getchar())
x = x * 10 + c - '0';
return x * f;
}
const int MAXN = 200005;
const int mod = 998244353;
int n, m, nxt[45], sum[MAXN], fac[MAXN], inv[MAXN], way[MAXN], f[45][MAXN], trans[MAXN][45];
map <vector <int>, int> idx;
vector <int> cur, val[MAXN];
inline int C(int n, int m)
{
return 1LL * fac[n] * inv[m] % mod * inv[n - m] % mod;
}
inline void Dfs(int s, int v, int r)
{
if (!s)
idx[cur] = ++ m, sum[m] = r, val[m] = cur;
else
for (int i = 1; i <= v && i <= s; i ++)
cur.pb(i), Dfs(s - i, i, r), cur.pop_back();
}
int main()
{
#ifdef wxh010910
freopen("data.in", "r", stdin);
#endif
n = Read(), fac[0] = fac[1] = inv[0] = inv[1] = 1;
for (int i = 2; i < MAXN; i ++)
fac[i] = 1LL * fac[i - 1] * i % mod, inv[i] = 1LL * (mod - mod / i) * inv[mod % i] % mod;
for (int i = 2; i < MAXN; i ++)
inv[i] = 1LL * inv[i - 1] * inv[i] % mod;
for (int i = 0; i < n; i ++)
Dfs(i, i, i);
for (int i = 1; i <= m; i ++)
{
trans[i][0] = i;
for (int j = 1; j <= n - 1 - sum[i]; j ++)
{
vector <int> tmp = val[i];
tmp.pb(j);
sort(tmp.begin(), tmp.end(), greater <int> ());
trans[i][j] = idx[tmp];
}
}
f[0][1] = 1;
for (int i = 0; i < n; i ++)
for (int j = 1; j <= m; j ++)
if (f[i][j])
for (int k = 0; k <= i - sum[j]; k ++)
f[i + 1][trans[j][k]] = (f[i + 1][trans[j][k]] + f[i][j]) % mod;
for (int i = 1; i <= m; i ++)
way[i] = f[n][i];
mset(f, 0);
f[0][1] = 1;
int ans = 0, cnt = 0;
for (int i = 0; i < n; i ++)
{
int x = Read() - 1;
cnt += x;
if (x)
for (int j = 0; j < n; j ++)
nxt[j] = C(x + j - 1, j);
else
for (int j = 0; j < n; j ++)
nxt[j] = !j;
for (int j = 1; j <= m; j ++)
if (f[i][j])
for (int k = 0; k <= n - 1 - sum[j]; k ++)
f[i + 1][trans[j][k]] = (1LL * f[i][j] * nxt[k] + f[i + 1][trans[j][k]]) % mod;
}
for (int i = 1; i <= m; i ++)
{
int cur = 1LL * way[i] * f[n][i] % mod;
for (int l = 0, r = 0; l < val[i].size(); cur = 1LL * cur * fac[r - l] % mod, l = r)
while (r < val[i].size() && val[i][r] == val[i][l])
r ++;
cur = 1LL * cur * fac[n - val[i].size()] % mod;
ans = (ans + cur) % mod;
}
while (cnt)
ans = 1LL * ans * cnt % mod, cnt --;
return printf("%d\n", ans), 0;
}