题意:n个节点的树,问有多少对(i,j)其最短距离等于K.
n<=5e4,k<=5e2. (i,j),(j,i) 视为一对.
设d[i][j][0/1]从节点i向下或者向上走长度为J的方法数.
dp[i][j][0]+=dp[son][j-1][0].
dp[i][j][1]+=dp[fa][j-1][1]+dp[fa][j-1][0] (i-fa->fa的前i-1个子树中,总是往左走)
n<=5e4,k<=5e2. (i,j),(j,i) 视为一对.
设d[i][j][0/1]从节点i向下或者向上走长度为J的方法数.
dp[i][j][0]+=dp[son][j-1][0].
dp[i][j][1]+=dp[fa][j-1][1]+dp[fa][j-1][0] (i-fa->fa的前i-1个子树中,总是往左走)
ans+=dp[i][k][1] 第i个点作为最深的点时,长度为k的路径数
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=5e4+5,M=5e2+5,inf=0x3f3f3f3f;
int n,k;
vector<int> e[N];
int d[N][M][2];
void dfs(int u,int fa)
{
d[u][0][0]=d[u][0][1]=1;
if(fa)
{
d[u][1][1]++;
for(int x=2;x<=k;x++)
d[u][x][1]+=d[fa][x-1][1]+d[fa][x-1][0];
}
for(int i=0;i<e[u].size();i++)
{
int v=e[u][i];
if(v==fa)
continue;
dfs(v,u);
for(int x=1;x<=k;x++)
d[u][x][0]+=d[v][x-1][0];
}
}
int main()
{
cin>>n>>k;
int u,v;
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&u,&v);
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1,0);
int ans=0;
for(int i=1;i<=n;i++)
{
ans=ans+d[i][k][1];
// for(int x=1;x<=k;x++)
// printf("%d %d %d %d\n",i,x,d[i][x][0],d[i][x][1]);
// cout<<endl;
}
cout<<ans<<endl;
return 0;
}
法2:设d[u][x] 为子树u中,距离u长度为x的个数
当u作为路径的最高点时
u作为终点:ans+=d[u][k]
起点和终点为u子树中的某两个点 并且路径经过u : d[son[u]][x-1] * d[u][k-x] 另外一个点不能在u内 所以还要扣掉 dp[son[u]][x-1] * dp[u][k-x-1]
#include<bits/stdc++.h>
#define ll long long
#define ld long double
#define pb push_back
#define x first
#define y second
#define fastread ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
using namespace std;
const int maxn=5e4+7,maxk=600;
ll dp[maxn][maxk];
vector<int> adjlist[maxn];
int n,k,x,y;
ll ans;
void dfs(int cur,int par){
dp[cur][0]++;
for(auto u:adjlist[cur]){
if(u!=par){
dfs(u,cur);
for(int i=0;i<k;i++)
dp[cur][i+1]+=dp[u][i];
}
}
for(auto u:adjlist[cur]){
if(u==par)
continue;
for(int i=0;i<k;i++){
int down=i,up=(k-1-i);
if(up==0)
ans=ans+dp[u][down]*dp[cur][up];
else{
ans=ans+dp[u][down]*(dp[cur][up]-dp[u][up-1]);
}
}
}
ans+=dp[cur][k];
}
int main()
{
fastread;
cin>>n>>k;
for(int i=1;i<n;i++){
cin>>x>>y;
adjlist[x].pb(y);
adjlist[y].pb(x);
}
dfs(1,0);
ans/=2;
cout<<ans;
return 0;
}