题
OvO http://acm.hdu.edu.cn/showproblem.php?pid=6091
( 2017 Multi-University Training Contest - Team 6 - 1007)
解
记 f[i][j]表示,以i为根的子树的所有子图(包含子树所有节点,删掉一些边得到的子图)中,符合下列条件的子图的个数
1. 记子图最大匹配为h1,子图去掉与i节点相连的边后的最大匹配为h2,满足 h1-h2=0
2. h1%m=j
记g[i][j]表示,以i为根的子树的所有子图(包含子树所有节点,删掉一些边得到的子图)中,符合下列条件的子图的个数
1. 记子图最大匹配为h1,子图去掉与i节点相连的边后的最大匹配为h2,满足 h1-h2=1
2. h1%m=j
则对于每个根节点 p 遍历其后继,回溯计算答案,当遍历到后继 q 时,有下列递推式成立
这个可以由一个经典树形DP(求最大匹配)联想到(反正我是没联想到),
复杂度的话,因为如果要使一个节点k的size变成m的话,那么这个子树的大小也就有了m的大小,所以如果一个节点他的后继的size全为m的话,那么这个节点的后继值最多只有n/m个。
所以复杂度大概似乎好像貌似就O(mn)吧,
(思路来自题解)
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int M=5e4+44;
const int N=244;
const int mod=998244353;
struct node{
int u,v,d;
int next;
} edge[2*M];
int num;
int head[M];
int n,m;
int f[M][N],g[M][N];
int sz[M];
int F[N<<1],G[N<<1];
void addedge(int u,int v,int d)
{
edge[num].u=u;
edge[num].v=v;
edge[num].d=d;
edge[num].next=head[u];
head[u]=num++;
}
void init()
{
num=0;
memset (head,-1,sizeof(head));
}
int inc(int x,int y)
{
x+=y;
if(x>m) return m;
return x;
}
void deal(int rt,int v)
{
int i,j,siz=inc(sz[rt],sz[v]);
memset(F,0,sizeof(F));
memset(G,0,sizeof(G));
for(i=0;i<=sz[rt];i++)
for(j=0;j<=sz[v];j++)
{
// printf("i: %d j: %d\n",i,j);
F[i+j]=(0ll+F[i+j]+1ll*f[rt][i]*f[v][j]+2ll*f[rt][i]*g[v][j])%mod;
G[i+j]=(0ll+G[i+j]+2ll*g[rt][i]*g[v][j]+2ll*g[rt][i]*f[v][j])%mod;
G[i+j+1]=(0ll+G[i+j+1]+1ll*f[rt][i]*f[v][j])%mod;
}
// printf("rt: %d v: %d\n",rt,v);
// printf(" F:\n");
// for(i=0;i<2*m+3;i++)
// printf("%d ",F[i]);
// printf("\n G:\n");
// for(i=0;i<2*m+3;i++)
// printf("%d ",G[i]);
// printf("\n");
sz[rt]=siz;
for(i=0;i<m;i++)
{
f[rt][i]=(0ll+F[i]+F[i+m])%mod;
g[rt][i]=(0ll+G[i]+G[i+m])%mod;
}
}
void dfs(int rt,int pa)
{
int i,j,v;
sz[rt]=0;
memset(f[rt],0,sizeof(f[rt]));
memset(g[rt],0,sizeof(g[rt]));
f[rt][0]=1;
for(i=head[rt];i!=-1;i=edge[i].next)
{
v=edge[i].v;
if(v==pa) continue;
dfs(v,rt);
deal(rt,v);
}
sz[rt]=inc(sz[rt],1);
}
void solve()
{
dfs(1,-1);
}
int main()
{
int i,j,cas,u,v;
scanf("%d",&cas);
while(cas--)
{
init();
scanf("%d%d",&n,&m);
for(i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
addedge(u,v,1);
addedge(v,u,1);
}
solve();
printf("%d\n",(0ll+f[1][0]+g[1][0])%mod);
}
return 0;
}