枚举出现
s
s
s 次的颜色种数,则剩余颜色出现次数不能为
s
s
s 次,由容斥原理得:
a
n
s
=
∑
i
=
0
min
{
⌊
n
s
⌋
,
m
}
C
m
i
n
!
W
i
(
s
!
)
i
(
n
−
i
s
)
!
∑
j
=
0
min
{
⌊
n
s
⌋
,
m
}
−
i
(
−
1
)
j
C
m
−
i
j
(
n
−
i
s
)
!
(
m
−
i
−
j
)
n
−
i
s
−
j
s
(
s
!
)
j
(
n
−
i
s
−
j
s
)
!
ans = \sum \limits_{i = 0}^{\min\{\lfloor \frac{n}{s} \rfloor, m\}} C_{m}^{i} \frac{n!W_i}{(s!)^i (n - is)!} \sum \limits_{j = 0}^{\min\{\lfloor \frac{n}{s} \rfloor, m\} - i} (-1)^j C_{m - i}^j \frac{(n - is)!(m - i - j)^{n - is - js}}{(s!)^j(n - is - js)!}
ans=i=0∑min{⌊sn⌋,m}Cmi(s!)i(n−is)!n!Wij=0∑min{⌊sn⌋,m}−i(−1)jCm−ij(s!)j(n−is−js)!(n−is)!(m−i−j)n−is−js
展开式子,把无关项前移:
a
n
s
=
n
!
m
!
∑
i
=
0
min
{
⌊
n
s
⌋
,
m
}
W
i
i
!
(
s
!
)
i
∑
j
=
0
min
{
⌊
n
s
⌋
,
m
}
−
i
(
−
1
)
j
j
!
(
s
!
)
j
(
m
−
i
−
j
)
n
−
i
s
−
j
s
(
m
−
i
−
j
)
!
(
n
−
i
s
−
j
s
)
!
ans = n!m!\sum \limits_{i = 0}^{\min\{\lfloor \frac{n}{s} \rfloor, m\}} \frac{W_i}{i!(s!)^i} \sum \limits_{j = 0}^{\min\{\lfloor \frac{n}{s} \rfloor, m\} - i} \frac{(-1)^j}{j!(s!)^j} \frac{(m - i - j)^{n - is - js}}{(m - i - j)!(n - is - js)!}
ans=n!m!i=0∑min{⌊sn⌋,m}i!(s!)iWij=0∑min{⌊sn⌋,m}−ij!(s!)j(−1)j(m−i−j)!(n−is−js)!(m−i−j)n−is−js
好像没什么头绪,尝试把枚举
j
j
j 改为枚举
i
+
j
i + j
i+j:
a
n
s
=
n
!
m
!
∑
i
=
0
min
{
⌊
n
s
⌋
,
m
}
W
i
i
!
(
s
!
)
i
∑
j
=
i
min
{
⌊
n
s
⌋
,
m
}
(
−
1
)
j
−
i
(
j
−
i
)
!
(
s
!
)
j
−
i
(
m
−
j
)
n
−
j
s
(
m
−
j
)
!
(
n
−
j
s
)
!
ans = n!m!\sum \limits_{i = 0}^{\min\{\lfloor \frac{n}{s} \rfloor, m\}} \frac{W_i}{i!(s!)^i} \sum \limits_{j = i}^{\min\{\lfloor \frac{n}{s} \rfloor, m\}} \frac{(-1)^{j - i}}{(j - i)!(s!)^{j - i}} \frac{(m - j)^{n - js}}{(m - j)!(n - js)!}
ans=n!m!i=0∑min{⌊sn⌋,m}i!(s!)iWij=i∑min{⌊sn⌋,m}(j−i)!(s!)j−i(−1)j−i(m−j)!(n−js)!(m−j)n−js
已经简单很多了,再改变下
i
,
j
i, j
i,j 的枚举顺序:
a
n
s
=
n
!
m
!
∑
j
=
0
min
{
⌊
n
s
⌋
,
m
}
(
m
−
j
)
n
−
j
s
(
m
−
j
)
!
(
n
−
j
s
)
!
(
s
!
)
j
∑
i
=
0
j
W
i
i
!
(
−
1
)
j
−
i
(
j
−
i
)
!
ans = n!m!\sum \limits_{j = 0}^{\min\{\lfloor \frac{n}{s} \rfloor, m\}} \frac{(m - j)^{n - js}}{(m - j)!(n - js)!(s!)^j} \sum \limits_{i = 0}^{j} \frac{W_i}{i!} \frac{(-1)^{j - i}}{(j - i)!}
ans=n!m!j=0∑min{⌊sn⌋,m}(m−j)!(n−js)!(s!)j(m−j)n−jsi=0∑ji!Wi(j−i)!(−1)j−i
后面是个卷积的形式,NTT 优化即可。
时间复杂度
O
(
m
log
m
+
n
)
O(m \log m + n)
O(mlogm+n)。
Code
注意阶乘要预处理到
max
{
n
,
m
}
\max\{n, m\}
max{n,m}。
#include<algorithm>#include<iostream>#include<cstring>#include<cstdlib>#include<cctype>#include<cstdio>#include<cmath>#include<ctime>template<classT>inlinevoidread(T &res){char ch;while(ch =getchar(),!isdigit(ch));
res = ch ^48;while(ch =getchar(),isdigit(ch))
res = res *10+ ch -48;}constint N =1e7+5;constint M =1e5+5;constint M4 =4e5+5;constint mod =1004535809;constint inv3 =(mod +1)/3;int fra[N], inv[N], w[M], f[M4], g[M4], rev[M4];int n, m, s, tm, ans;inlinevoidadd(int&x,int y){
x += y;
x >= mod ? x -= mod :0;}inlineintquick_pow(int x,int k){int res =1;while(k){if(k &1) res =1ll* res * x % mod;
x =1ll* x * x % mod; k >>=1;}return res;}inlinevoidNTT(int*f,int fm,int opt){int g = opt ==1?3: inv3;for(int i =0; i < fm;++i)if(i < rev[i])
std::swap(f[i], f[rev[i]]);for(int k =1; k < fm; k <<=1){int w =quick_pow(g,(mod -1)/(k <<1));for(int i =0; i < fm; i += k <<1){int res =1;for(int j =0; j < k;++j){int u = f[i + j],
v =1ll* res * f[i + j + k]% mod;
f[i + j]= f[i + j + k]= u;add(f[i + j], v);add(f[i + j + k], mod - v);
res =1ll* res * w % mod;}}}if(opt ==-1){for(int i =0, inv =quick_pow(fm, mod -2); i < fm;++i)
f[i]=1ll* f[i]* inv % mod;}}template<classT>inline T Min(T x, T y){return x < y ? x : y;}template<classT>inline T Max(T x, T y){return x > y ? x : y;}intmain(){read(n);read(m);read(s);for(int i =0; i <= m;++i)read(w[i]);
tm =Max(n, m);
fra[0]=1;for(int i =1; i <= tm;++i)
fra[i]=1ll* fra[i -1]* i % mod;
inv[tm]=quick_pow(fra[tm], mod -2);for(int i = tm; i >=1;--i)
inv[i -1]=1ll* inv[i]* i % mod;
tm =Min(n / s, m);for(int i =0; i <= tm;++i){
g[i]=1ll* w[i]* inv[i]% mod;
f[i]= inv[i];if(i &1) f[i]= mod - f[i];}int fm, k =0; tm <<=1;for(fm =1; fm <= tm; fm <<=1)++k;--k;for(int i =1; i < fm;++i)
rev[i]=(rev[i >>1]>>1)|((i &1)<< k);NTT(f, fm,1);NTT(g, fm,1);for(int i =0; i < fm;++i)
f[i]=1ll* f[i]* g[i]% mod;NTT(f, fm,-1);int mul_s =1; tm >>=1;for(int i =0; i <= tm;++i){int tx = m - i, ty = n - i * s;
ans =(1ll*quick_pow(tx, ty)* inv[tx]% mod * inv[ty]% mod * mul_s % mod * f[i]+ ans)% mod;
mul_s =1ll* mul_s * inv[s]% mod;}
std::cout <<1ll* fra[n]* fra[m]% mod * ans % mod << std::endl;}