一、题目
二、解法
设
d
p
[
s
]
dp[s]
dp[s]为选出来的点状压为
s
s
s,所得到的满意度总和,转移:
d
p
[
s
]
=
1
f
[
s
]
∑
i
∈
s
d
p
[
i
]
×
g
[
s
−
i
]
dp[s]=\frac{1}{f[s]}\sum_{i\in s}dp[i]\times g[s-i]
dp[s]=f[s]1i∈s∑dp[i]×g[s−i]其中
f
[
s
]
f[s]
f[s]是
w
w
w总和的
p
p
p次方,
g
[
s
]
g[s]
g[s]是
w
w
w总和的
p
p
p次方 乘上 这个状态是否合法。
显然这个柿子可以用快速子集卷积,然鹅我
T
T
T了,不说了,尽量不要看我大常数的代码 。
UPD 2020-6-12:补充一点卡常的知识,如果这样写的话,就挂了:
for(int i=1;i<=n;i++)
{
for(int j=0;j<lim;j++)
for(int k=0;k<i;k++)
{
dp[i][j]=(dp[i][j]+1ll*dp[k][j]*g[i-k][j])%MOD;
}
如果这样写的话,就快很多
for(int i=1;i<=n;i++)
{
for(int k=0;k<i;k++)
for(int j=0;j<lim;j++)
{
dp[i][j]=(dp[i][j]+1ll*dp[k][j]*g[i-k][j])%MOD;
}
发现了吗,其实是 j j j要在内层枚举,这样调用空间会快很多,现在我过了,开心。
#include <cstdio>
const int N = 405;
const int M = 1<<21;
const int MOD = 998244353;
int read()
{
int num=0,flag=1;char c;
while((c=getchar())<'0'||c>'9')if(c=='-')flag=-1;
while(c>='0'&&c<='9')num=(num<<3)+(num<<1)+(c^48),c=getchar();
return num*flag;
}
int n,m,p,lim,a[N],b[N],w[N],f[M],inv[M];
int deg[N],fa[N],bit[M],g[22][M],dp[22][M];
int chk(int s,int i)
{
return s&(1<<i-1);
}
int qkpow(int a,int b)
{
if(a==0) return 0;
int r=1;
while(b>0)
{
if(b&1) r=1ll*a*r%MOD;
a=1ll*a*a%MOD;
b>>=1;
}
return r;
}
int find(int x)
{
return fa[x]==x?x:fa[x]=find(fa[x]);
}
void fwt(int *a,int n,int op)
{
for(int i=1;i<n;i<<=1)
for(int p=i<<1,j=0;j<n;j+=p)
for(int k=0;k<i;k++)
{
if(op==1) a[i+j+k]=(a[i+j+k]+a[j+k])%MOD;
else a[i+j+k]=(a[i+j+k]-a[j+k]+MOD)%MOD;
}
}
signed main()
{
inv[0]=inv[1]=1;
for(int i=2;i<=2100;i++)
inv[i]=1ll*inv[MOD%i]*(MOD-MOD/i)%MOD;
n=read();m=read();p=read();
lim=1<<n;
for(int i=1;i<=m;i++)
a[i]=read(),b[i]=read();
for(int i=1;i<=n;i++)
w[i]=read();
for(int i=0;i<lim;i++)
{
int x=0,y=0,z=0;
for(int j=1;j<=n;j++)
{
deg[j]=0;fa[j]=j;
if(chk(i,j))
{
x++;y++;f[i]+=w[j];
}
}
for(int j=1;j<=m;j++)
if(chk(i,a[j]) && chk(i,b[j]))
{
deg[a[j]]++;deg[b[j]]++;
int u=find(a[j]),v=find(b[j]);
if(u!=v) fa[u]=v,y--;
}
for(int j=1;j<=n;j++) z+=(deg[j]&1);
if(z || y!=1) g[x][i]=f[i];
g[x][i]=qkpow(g[x][i],p);
f[i]=qkpow(inv[f[i]],p);
bit[i]=x;
}
for(int i=0;i<=n;i++) fwt(g[i],lim,1);
dp[0][0]=1;fwt(dp[0],lim,1);
int ans=0;
for(int i=1;i<=n;i++)
{
for(int k=0;k<i;k++)
for(int j=0;j<lim;j++)
{
dp[i][j]=(dp[i][j]+1ll*dp[k][j]*g[i-k][j])%MOD;
}
fwt(dp[i],lim,-1);
for(int j=0;j<lim;j++)
dp[i][j]=bit[j]==i?1ll*dp[i][j]*f[j]%MOD:0;
if(i!=n) fwt(dp[i],lim,1);
}
printf("%d\n",dp[n][lim-1]);
}