http://acm.hdu.edu.cn/showproblem.php?pid=6540
这道捧杯爷湘潭现场写出来的题,做了两天还瞟了一眼题解的状态设计才做出来。。。菜不成声.jpg
这题的关键我觉得是想到设 f [ u ] [ j ] 为以u为根的子树离u最远的点距离为 j 的方案数是多少。
那么我们考虑如何得到f [ u ] [ j ]。
f[u][j]=f[u][j]+f[u][j]*sum[v][j-1] 表示v之前的子树的距离为j的方案得到 j 这个长度,然后当前最远距离也要是j,所以乘以sum[v][j-1]
f[u][j]=f[u][j]+sum[u][j-1]*f[v][j-1],表示由v子树中距离为 j-1 的方案和v之前的子树中<=j-1的方案组成,因为去重所以只到sum[u][j-1]
f[u][j]=(f[u][j]+f[v][j-1]); 表示由v子树距离为j-1的方案单独组成,不和v之前子树中的选择方案合并组成。
那么如果 u 是关键点,那么u这个点是可选可不选的,所有f[u][j]*2,且f[u][0]=1;
upd:这题最离谱的地方是我直接输出sum[1][k]竟然过了,是不是数据只有样例啊。。。
这种问题我们一般考虑放到lca上计数,也就是说假设所选的节点的lca就是u点
那么为了避免重复,我们每次枚举一个f[v][j],就把之前已经统计的sum[u][k-j-1]乘上,加到tmpans中,表示离v最大距离恰好为j的方案数*u之前的儿子得到离u最大距离为k-j-1的方案数,这样乘起来的方案一定保证经过u-v这条边,说明他们的lca就在u
具体实现的时候一开始都是不考虑u这个点选不选的,如果u是可选点的话,那么tmp就乘2,而且每一个儿子f[v][j]可以单独和u组成一个lca在u的方案,那么扫儿子v的时候得到一个vsum,如果u可选就加上,还有就是u单独作1个点
所以ans=(ans+tmp*2+vsum+1)否则就直接+tmp
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod=1e9+7;
const int maxl=5010;
int n,m,k,cnt;ll ans;
int ehead[maxl];
ll f[maxl][maxl],sum[maxl][maxl];
struct ed
{
int to,nxt;
}e[maxl<<1];
bool in[maxl];
inline void add(int u,int v)
{
e[++cnt].to=v;e[cnt].nxt=ehead[u];ehead[u]=cnt;
}
inline void prework()
{
scanf("%d%d%d",&n,&m,&k);
int u,v;
for(int i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
for(int i=1;i<=m;i++)
{
scanf("%d",&u);
in[u]=true;
}
}
inline void dfs(int u,int fa)
{
int v;ll tmp=0,vsum=0;
for(int i=ehead[u];i;i=e[i].nxt)
{
v=e[i].to;
if(v==fa)
continue;
dfs(v,u);
for(int j=0;j<=k-1;j++)
tmp=(tmp+f[v][j]*sum[u][k-j-1])%mod;
vsum=(vsum+sum[v][k-1])%mod;
for(int j=1;j<=k;j++)
{
f[u][j]=(f[u][j]+f[u][j]*sum[v][j-1]%mod)%mod;
f[u][j]=(f[u][j]+sum[u][j-1]*f[v][j-1]%mod)%mod;
f[u][j]=(f[u][j]+f[v][j-1])%mod;
}
for(int j=1;j<=k;j++)
sum[u][j]=(sum[u][j-1]+f[u][j])%mod;
}
if(in[u])
{
for(int j=1;j<=k;j++)
f[u][j]=f[u][j]*2%mod;
f[u][0]=1;
ans=(ans+tmp*2+vsum+1)%mod;
}else
ans=(ans+tmp)%mod;
sum[u][0]=f[u][0];
for(int j=1;j<=k;j++)
sum[u][j]=(sum[u][j-1]+f[u][j])%mod;
}
inline void mainwork()
{
ans=0;
dfs(1,0);
}
inline void print()
{
printf("%lld\n",ans);
}
int main()
{
prework();
mainwork();
print();
return 0;
}