题意即是给出一棵树,统计点到点的距离,按照%3的方式记录
树形dp:cnt1[i][j]表示在子树中(包括自己)%3距离为j的点数,方便更新
dp1[i][j]表示在子树中(包括自己)距离%3为j的总距离
第一次dfs记录子树答案。
设我们计算的u->v的边w,我们还需要记录边对u的兄弟结点的贡献和父亲那边子树的贡献
cnt2[i][j]表示非子树中(不包括自己)%3距离为j的点数,即为父亲节点的cnt2值加上兄弟节点的点数(通过cnt1[u]-cnt1[v])
dp2[i][j]表示非子树中(不包括自己)%3距离为j的总距离,该边对父亲的那边子树的贡献+对兄弟子树的贡献
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e4 + 5;
const ll mod = 1e9 + 7;
vector < pair<int, int> >E[maxn];
int cnt1[maxn][5], cnt2[maxn][5];
ll dp1[maxn][5], dp2[maxn][5];
void dfs1(int u, int fa) {
cnt1[u][0] = 1;
for (int i = 0; i < E[u].size(); ++i) {
int v = E[u][i].first, w = E[u][i].second; if (v == fa) continue;
dfs1(v, u);
for (int j = 0; j < 3; ++j) {
cnt1[u][(j + w) % 3] += cnt1[v][j];
dp1[u][(j + w) % 3] += dp1[v][j] + w * cnt1[v][j] % mod;
dp1[u][(j + w) % 3] %= mod;
}
}
}
void dfs2(int u, int fa) {
for (int i = 0; i < E[u].size(); ++i) {
int v = E[u][i].first, w = E[u][i].second; if (v == fa) continue;
for (int j = 0; j < 3; ++j) {
cnt2[v][(j + w) % 3] = cnt1[u][j] - cnt1[v][((j - w) % 3 + 3) % 3] + cnt2[u][j];
dp2[v][(j + w) % 3] = (dp1[u][j] - dp1[v][((j - w) % 3 + 3) % 3] - w * cnt1[v][((j - w) % 3 + 3) % 3] % mod + w * cnt2[v][(j + w) % 3] % mod + dp2[u][j]) % mod;
}
dfs2(v, u);
}
}
int main()
{
int n; while (~scanf("%d", &n)) {
for (int i = 1; i <= n; ++i) {
E[i].clear();
for (int j = 0; j < 3; ++j)
cnt1[i][j] = cnt2[i][j] = dp1[i][j] = dp2[i][j] = 0;
}
for (int i = 1; i < n; ++i) {
int u, v, w; scanf("%d%d%d", &u, &v, &w);
u++, v++;
E[u].push_back(make_pair(v, w));
E[v].push_back(make_pair(u, w));
}
dfs1(1, 0); dfs2(1, 0);
ll ans[3] = {0};
for (int i = 1; i <= n; ++i) {
for (int j = 0; j < 3; ++j) {
ans[j] = (ans[j] + dp1[i][j] + dp2[i][j]) % mod;
}
}
printf("%lld %lld %lld\n", ans[0], ans[1], ans[2]);
}
}
点分治写法
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e4 + 5;
const ll mod = 1e9 + 7;
vector < pair<int, int> >E[maxn];
int n, root, sum;
int vis[maxn], f[maxn], sz[maxn];
ll ans[5], o[5], num[5];
void getroot(int u, int fa) {
sz[u] = 1; f[u] = 0;
for (int i = 0; i < E[u].size(); ++i) {
int v = E[u][i].first; if (v == fa || vis[v]) continue;
getroot(v, u);
sz[u] += sz[v];
f[u] = max(f[u], sz[v]);
}
f[u] = max(f[u], sum - sz[u]);
if (f[u] < f[root]) root = u;
}
void getdis(int u, int dis, int fa) {
o[dis % 3]++; num[dis % 3] += dis;
for (int i = 0; i < E[u].size(); ++i) {
int v = E[u][i].first, w = E[u][i].second;
if (v == fa || vis[v]) continue;
getdis(v, (dis + w) % mod, u);
}
}
void cal(int u, int dis, int add) {
getdis(u, dis, 0);
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
ans[(i + j) % 3] = (ans[(i + j) % 3] + o[i] * num[j] * add % mod + mod) % mod;
ans[(i + j) % 3] = (ans[(i + j) % 3] + o[j] * num[i] * add % mod + mod) % mod;
}
}
for (int i = 0; i < 3; ++i) o[i] = num[i] = 0;
}
void solve(int u) {
cal(u, 0, 1); vis[u] = 1;
for (int i = 0; i < E[u].size(); ++i) {
int v = E[u][i].first, w = E[u][i].second;
if (vis[v]) continue;
cal(v, w, -1);
sum = sz[v]; root = 0;
getroot(v, 0);
solve(root);
}
}
int main()
{
while (~scanf("%d", &n)) {
for (int i = 1; i <= n; ++i) E[i].clear(), vis[i] = 0;
for (int i = 1; i < n; ++i) {
int u, v, w; scanf("%d%d%d", &u, &v, &w);
u++, v++;
E[u].push_back(make_pair(v, w));
E[v].push_back(make_pair(u, w));
}
for (int i = 0; i < 3; ++i) ans[i] = 0;
root = 0; f[0] = INT_MAX - 1; sum = n;
getroot(1, 0); solve(root);
printf("%lld %lld %lld\n", ans[0], ans[1], ans[2]);
}
}