题解
按照题意 将m课树合并为一颗树 题目保证合并后必定为一棵树且端点数量不超过n*m<=1e6 遍历所有节点不会超时
使用树形dp d[i]表示以i为根的子树节点数量 使用DFS递归回溯计算d的值 n*m-d[i]表示以d的父节点为根的子节点数量(把树反转过来)
则每次把ans加上d[i]*(n*m-d[i])可计算i节点与父节点连接的边对答案的贡献
AC代码
#include <stdio.h>
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int INF = 0x3f3f3f3f;
const int MOD = 1e9 + 7;
const int MAXN = 1e6 + 10;
ll n, m, ans;
int d[MAXN];
vector<int> e[MAXN];
void DFS(int x, int f)
{
d[x] = 1; //只有一个节点为1
for (int y : e[x]) if (y != f)
{
DFS(y, x);
d[x] += d[y]; //将子节点的数量加到父节点上
}
ans = (ans + 1LL * d[x] * (n * m - d[x]) % MOD) % MOD;
}
int main()
{
#ifdef LOCAL
freopen("C:/input.txt", "r", stdin);
#endif
cin >> n >> m;
for (int i = 1; i < n; i++)
{
int u, v;
scanf("%d%d", &u, &v);
for (int j = 0; j < m; j++)
e[u + j * n].push_back(v + j * n), e[v + j * n].push_back(u + j * n);
}
for (int i = 1; i < m; i++)
{
int a, b, u, v;
scanf("%d%d%d%d", &a, &b, &u, &v);
a--, b--;
e[u + a * n].push_back(v + b * n), e[v + b * n].push_back(u + a * n);
}
DFS(1, 0);
cout << ans << endl;
return 0;
}