题目链接:点击打开链接
思路:除了树形DP,自己没想到其他思路,直接就写的DP......
开的数组形式为dp[n][3],对于当前父节点,dp[x][0]、dp[x][1]表示已经遍历过的子树中,从该父节点出发到子树节点的路径长度为奇、偶的路径条数;dp[x][2]表示以该父节点为树根的树,满足条件的情况总数。dp[x][2]可以不开,更节省空间,直接用一个变量存储最后的结果即可。具体状态转移可以看代码。
// Treepath 运行/限制:56ms/1000ms
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
using namespace std;
#define LL long long
LL dp[100005][3];
vector<int> v[100005];//存图
void DP(int x,int fa) {
dp[x][0] = 0;
dp[x][1] = 0;
dp[x][2] = 0;
for (int i = 0; i < v[x].size(); i++) {
int y = v[x][i];
if (y != fa) {
DP(y, x);
dp[x][2] += dp[x][0] * (dp[y][1] + 1) + dp[x][1] * dp[y][0] + dp[y][0] + dp[y][2];//奇*奇 偶*偶 子树满足条件总数
//加1是因为从当前节点到子树树根也是奇数的一种情况。虽然子树树根到自己的路径长度为0,为偶数,但是是不合法的,dp[y][1]内不包含这种情况。
dp[x][0] += dp[y][1] + 1;
dp[x][1] += dp[y][0];
}
}
return;
}
int main(){
int n;
int a, b;
while (scanf("%d", &n) != EOF) {
for (int i = 1; i <= n; i++) {
v[i].clear();
}
for (int i = 1; i < n; i++) {
scanf("%d%d", &a, &b);
v[a].push_back(b);
v[b].push_back(a);
}
DP(1, 0);
printf("%lld\n", dp[1][2]);
}
return 0;
}
思路2:稍加思维就可以想到,得到每个节点的高度,统计出偶数高度和奇数高度的数量,就可以直接计算了。因为从一个节点到另一个节点,路径是唯一的,路径长度为两者到最近公共祖先的路径长度之和,两者求之到公共祖先的路径长度,均为节点高度减去树根到最近公共祖先的路径长度,也就相当于两者高度之和减去 树根到最近公共祖先的路径长度的二倍,由偶数 - 偶数 = 偶数,奇数 - 偶数 = 奇数,只要两个节点的高度之和为偶数就是满足条件的情况。又由于偶数 + 偶数 = 偶数,奇数 + 奇数 = 偶数,所以组合的两个节点需要满足高度均为偶数或者奇数。设偶数高度节点的数量为sum1,奇数高度节点的数量为sum2,那么最后结果为【(sum1 - 1)+ (sum1 - 2)+ ... + 2 + 1】 + 【(sum2 - 1)+ (sum2 - 2)+ ... + 2 + 1】,即sum1 * (sum1 - 1)/ 2 + sum2 * (sum2 - 1)/ 2。
// Treepath2 运行/限制:74ms/1000ms
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
using namespace std;
#define LL long long
LL sum1, sum2;
int height[100005];
vector<int> v[100005];
void dfs(int x, int fa) {
height[x] = height[fa] + 1;
if (height[x] % 2 == 1) {
sum1++;
}
else{
sum2++;
}
for (int i = 0; i < v[x].size(); i++) {
int y = v[x][i];
if (y != fa) {
dfs(y, x);
}
}
}
int main(){
int n;
int a, b;
LL re;
while (scanf("%d", &n) != EOF) {
sum1 = 0, sum2 = 0;
memset(height, -1, sizeof(height));
for (int i = 1; i <= n; i++) {
v[i].clear();
}
for (int i = 1; i < n; i++) {
scanf("%d%d", &a, &b);
v[a].push_back(b);
v[b].push_back(a);
}
dfs(1, 0);
re = sum1 * (sum1 - 1) / 2 + sum2 * (sum2 - 1) / 2;
printf("%lld\n", re);
}
return 0;
}