更好的阅读体验 Press Here
Problem
题目大意:
有
n
n
件物品,每件物品有一个权值,可以用
1,2,3
1
,
2
,
3
个价值不同的物品组合出一个总价值,问每种总价值有多少种组成方案
Solution
既然每种价值的物品只能选一个,那么不用管每种价值有多少个,只用关心有没有就好了。作为一个组合问题,使用普通型生成函数
考虑到直接算答案比较麻烦,利用容斥进行计算
A(i) A ( i ) 表示选择一件物品的生成函数
B(i) B ( i ) 表示选择两件相同物品的生成函数
C(i) C ( i ) 表示选择三件相同物品的生成函数
由容斥原理可得
选择一件物品的贡献: A(i) A ( i )
选择两件物品的贡献: A2(i)−B(i)2 A 2 ( i ) − B ( i ) 2
选择三件物品的贡献: A3(i)−3∗A∗B(i)+2∗C[i]6 A 3 ( i ) − 3 ∗ A ∗ B ( i ) + 2 ∗ C [ i ] 6
输出三者系数之和即可
这次的实现用了NTT和快速乘,感觉跑的好慢的说…
代码如下
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = (35ll << 31) + 1 , G = 3 , N = 525000;
ll rev[N] , a[N] , b[N] , c[N] , ANS[N];
ll n , m;
inline ll mul(ll a , ll b) {
ll d = (ll) double(a * (double)b / mod + 0.5);
ll ret = a * b - d * mod;
if(ret < 0) ret += mod;
return ret;
}
ll read() {
ll ans = 0 , flag = 1;
char ch = getchar();
while(ch > '9' || ch < '0') {if(ch == '-') flag = -1; ch = getchar();}
while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
return ans * flag;
}
ll qpow(ll a , ll b) {
ll ans = 1;
while(b) {
if(b & 1) ans = mul(ans , a);
a = mul(a , a);
b >>= 1;
}
return ans;
}
void dft(ll *now , ll n , ll f) {
for(ll i = 0 ; i < n ; ++ i)
if(i < rev[i]) swap(now[i] , now[rev[i]]);
for(ll i = 1 ; i < n ; i <<= 1) {
ll gn = qpow(G , (mod - 1) / (i << 1));
if(f != 1) gn = qpow(gn , mod - 2);
for(int j = 0 ; j < n ; j += (i << 1)) {
ll x , y , g = 1;
for(int k = 0 ; k < i ; ++ k , g = mul(g , gn)) {
x = now[j + k] , y = mul(now[i + j + k] , g);
now[j + k] = (x + y) % mod;
now[i + j + k] = (x - y + mod) % mod;
}
}
}
if(f != 1) {
ll ny = qpow(n , mod - 2);
for(int i = 0 ; i < n ; ++ i) now[i] = mul(now[i] , ny);
}
}
int main() {
n = read();
for(ll i = 0 ; i < n ; ++ i) {
ll w = read();
m = max(m , w);
a[w] = b[w * 2] = c[w * 3] = 1;
}
m *= 3 + 1;
ll nn , l =0;
for(nn = 1 ; nn < m ; nn <<= 1) ++ l;
for(int i = 0 ; i < nn ; ++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
dft(a , nn , 1); dft(b , nn , 1); dft(c , nn , 1);
ll inv2 = qpow(2 , mod - 2) , inv6 = qpow(6 , mod - 2);
for(int i = 0 ; i < nn ; ++ i) {
ANS[i] += a[i];
if(ANS[i] > mod) ANS[i] -= mod;
ANS[i] += mul((mul(a[i] , a[i]) - b[i]) , inv2);
if(ANS[i] > mod) ANS[i] -= mod;
ANS[i] += mul(((mul(mul(a[i] , a[i]) , a[i]) - mul(mul(3 , a[i]) , b[i]) + mul(2 , c[i])) % mod + mod) % mod , inv6);
if(ANS[i] > mod) ANS[i] -= mod;
}
dft(ANS , nn , -1);
for(int i = 0 ; i < m ; ++ i)
if(ANS[i]) printf("%d %lld\n" , i ,ANS[i]);
return 0;
}