题目大意
一棵大小为
n
的树,每个节点可以赋予一个
求方案数对
一个测试点
T
组数据。
题目分析
一个很显然的做法,令
fx,y
表示点
x
值取
fx,y=∏v∈children(x)∑i∈[1,m],|y−i|≥kfv,i
处理前缀和,我们可以在 O(nm) 的时间内计算出来。
然而这个做法,弱、逊、水还有naive,是肯定会T的。但是我们能鄙视它吗?不能。
从这个算法继续深入思考。首先可以发现 fx 肯定是对称的,一定是前面一段不同的数,中间一段相同的数,后面一段是最前面对称过去。
感性理解(sui bian gao gao)就可以想(cai)到前面一段的长度一定是至多 (n−1)k 的(超过这个范围为边界已经失去了影响,应该可以使用数学归纳来归纳一下,读者自行思考吧)。
那么我们直接判断 m 是否大于
时间复杂度 O(n2k) 。
代码实现
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cctype>
using namespace std;
const int P=1000000007;
const int N=105;
const int K=105;
const int S=N*K*2;
const int E=N<<1;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while (!isdigit(ch)) f=ch=='-'?-1:f,ch=getchar();
while (isdigit(ch)) x=x*10+ch-'0',ch=getchar();
return x*f;
}
int T,n,m,k,tot,ans,s;
int f[N][S],g[N][S];
int next[E],tov[E];
int last[N],fa[N];
void clear()
{
memset(f,0,sizeof f);
for (;tot;tot--) next[tot]=tov[tot]=0;
for (int i=1;i<=n;i++) fa[i]=last[i]=0;
}
inline void insert(int x,int y){tov[++tot]=y,next[tot]=last[x],last[x]=tot;}
inline int sum(int x,int l,int r){return ((g[x][r]-g[x][l-1])%P+P)%P;}
void brute(int x)
{
for (int i=1;i<=m;i++) f[x][i]=1;
for (int i=last[x],y;i;i=next[i])
if ((y=tov[i])!=fa[x])
{
fa[y]=x,brute(y);
for (int j=1;j<=m;j++) f[x][j]=1ll*f[x][j]*((sum(y,min(j+k,m+1),m)+sum(y,1,max(j-k,0)))%P)%P;
}
for (int i=1;i<=m;i++) g[x][i]=(g[x][i-1]+f[x][i])%P;
}
void dp(int x)
{
for (int i=1;i<=s+1;i++) f[x][i]=1;
for (int i=last[x],y;i;i=next[i])
if ((y=tov[i])!=fa[x])
{
fa[y]=x,dp(y);
for (int j=1;j<=s;j++) f[x][j]=1ll*f[x][j]*((((j+k<=m-s?(sum(y,1,s)+1ll*min(m-s-j-k+1,m-s*2)*f[y][s+1]%P)%P:sum(y,1,m-j-k+1))+sum(y,1,max(j-k,0)))%P+(j+k<=s?sum(y,j+k,s):0))%P)%P;
if (m-s*2>=1) f[x][s+1]=1ll*f[x][s+1]*(((s+1+k<=m-s?(sum(y,1,s)+1ll*(m-s-s-k)*f[y][s+1]%P)%P:sum(y,1,m-s-k))%P+sum(y,1,max(s+1-k,0)))%P)%P;
}
for (int i=1;i<=s+1;i++) g[x][i]=(g[x][i-1]+f[x][i])%P;
}
int main()
{
freopen("label.in","r",stdin),freopen("label.out","w",stdout);
for (T=read();T--;)
{
clear();
n=read(),m=read(),k=read();
for (int i=1,x,y;i<n;i++) x=read(),y=read(),insert(x,y),insert(y,x);
if (!k)
{
ans=1;
for (int i=1;i<=n;i++) ans=1ll*ans*m%P;
}
else if (m<=(s=(n-1)*k)*2)
{
brute(1);
ans=0;
for (int i=1;i<=m;i++) (ans+=f[1][i])%=P;
}
else
{
dp(1);
ans=0;
for (int i=1;i<=s;i++) (ans+=f[1][i]*2%P)%=P;
(ans+=1ll*(m-s*2)*f[1][s+1]%P)%=P;
}
printf("%d\n",ans);
}
fclose(stdin),fclose(stdout);
return 0;
}