题意
分析
因为树的形态固定,我们可以把一个连通块记录在树上深度最浅的节点上,且一个连通块内这样的点有且只有一个,所以这种统计方式不重不漏
那么我们可以利用点分治,每次计算以分治中心作为连通块最高点的符合条件的连通块个数
通过树形dp来解决,按照dfs序dp,记
f
i
,
a
,
b
,
c
f_{i,a,b,c}
fi,a,b,c表示dfs为 i 的有a个A,b个B,c个C颜色的连通块个数,那么对于第 i+1 个位置,可以选或者不选。
选直接转移即可,不选意味着 i+1 这个子树都不选,转移到 i+siz[u] 即可
为了节省空间,把后三维编个号记录
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
int id[105][105][105],n,A,B,C;
int col[205],cnt;
vector <int> G[205];
ll f[205][40000];
int totsiz,gsiz,root;
int siz[205],vis[205];
void findrt(int u,int fa)
{
siz[u]=1;
int maxi=0;
for(auto to:G[u])
{
if(vis[to] || to==fa) continue;
findrt(to,u);
siz[u]+=siz[to];
maxi=max(maxi,siz[to]);
}
maxi=max(maxi,totsiz-siz[u]);
if(maxi<gsiz) gsiz=maxi,root=u;
}
int dfn[205],st[205],top,mx[205];
void dfs(int u,int fa)
{
st[++top]=u; dfn[u]=top;
for(auto to:G[u])
{
if(vis[to] || to==fa) continue;
dfs(to,u);
}
mx[u]=top;
}
ll ans=0;
void solve(int u)
{
findrt(u,0);
vis[u]=1; top=0;
dfs(root,0);
for(int i=0;i<=top;i++)
for(int j=0;j<=cnt;j++)
f[i][j]=0;
f[0][id[0][0][0]]=1;
for(int i=0;i<top;i++)
for(int a=0;a<=A;a++)
for(int b=0;b<=B;b++)
for(int c=0;c<=C;c++)
{
int w=f[i][id[a][b][c]];
if(!w) continue;
//i+1选上
int color=col[st[i+1]];
if(color==0)
if(a+1<=A)
f[i+1][id[a+1][b][c]]=(f[i+1][id[a+1][b][c]]+w)%mod;
if(color==1)
if(b+1<=B)
f[i+1][id[a][b+1][c]]=(f[i+1][id[a][b+1][c]]+w)%mod;
if(color==2)
if(c+1<=C)
f[i+1][id[a][b][c+1]]=(f[i+1][id[a][b][c+1]]+w)%mod;
//i+1不选
f[i+siz[st[i+1]]][id[a][b][c]]=(f[i+siz[st[i+1]]][id[a][b][c]]+w)%mod;
}
for(int a=0;a<=A;a++)
for(int b=0;b<=B;b++)
for(int c=0;c<=C;c++)
if(a+b+c) ans=(ans+f[top][id[a][b][c]])%mod;
// if(a+b+c) printf("%d %d %d: %d\n",a,b,c,f[top][id[a][b][c]]);
// printf("[%d %lld]\n",u,ans);
for(auto to:G[u])
{
if(vis[to]) continue;
totsiz=gsiz=siz[to];
findrt(to,u);
solve(root);
}
}
int main()
{
freopen("b.in","r",stdin);
freopen("b.out","w",stdout);
scanf("%d%d%d%d",&n,&A,&B,&C);
for(int i=1;i<=n;i++) scanf("%d",&col[i]);
int x,y;
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
G[x].push_back(y); G[y].push_back(x);
}
for(int i=0;i<=A;i++)
for(int j=0;j<=B;j++)
for(int k=0;k<=C;k++)
id[i][j][k]=++cnt;
totsiz=gsiz=n; findrt(1,0);
solve(root);
printf("%lld",ans);
return 0;
}