>Link
牛客1022普及T3
>Description
给出一棵大小为
n
n
n,边权为1的树
其中有
m
m
m个节点拥有“超能力”
在剩下无超能力的点中,离所有具有超能力的节点的最短距离在区间
[
l
,
r
]
[l,r]
[l,r]中的点,被赋予“旪超能力”
每一个“超能力”点
i
i
i,给其他所有点
j
j
j添加
i
i
i到
j
j
j的最短路径的平方 的磁场强度
每一个“旪超能力”点
i
i
i,给其他所有点
j
j
j添加
i
i
i到
j
j
j的最短路径 的磁场强度
现有 k k k个询问,询问 x x x点得到的磁场强度
>解题思路
比赛时,打了一个暴力拿了5分@__@
首先我们得求出具有旪超能力的点,这时我们可以用bfs,一开始将所有超能力点都放入队列,求最短路过程中记录每个非超能力点的最短距离,然后最后给符合条件的点赋予旪超能力。
(因为旪超能力更容易求,所以先考虑旪超能力)
这是一个树形dp,我们定义在以1为根的这棵树中,每个状态
i
i
i仅表示以i为根的子树
c
n
t
i
cnt_i
cnti→旪超能力的个数,
d
i
s
i
dis_i
disi→所有旪超能力点到i的距离和,
X
c
n
t
i
Xcnt_i
Xcnti→超能力的个数,
X
d
i
s
1
i
Xdis1_i
Xdis1i→所有超能力点到i的距离和,
X
d
i
s
2
i
Xdis2_i
Xdis2i→所有超能力点到i的距离平方和
旪超能力的转移特别好求,
d
i
s
i
=
∑
d
i
s
s
o
n
+
c
n
t
s
o
n
dis_i=∑dis_{son}+cnt_{son}
disi=∑disson+cntson
(我们从儿子
s
o
n
son
son转移到父亲
i
i
i,所以到父亲要多加
c
n
t
s
o
n
cnt_{son}
cntson条边权为1的边)
超能力求的是平方和就比较麻烦,
∑
x
2
→
∑
(
x
+
1
)
2
∑x^2→∑(x+1)^2
∑x2→∑(x+1)2,
这里要用到完全平方公式:
(
x
+
1
)
2
=
x
2
+
2
x
+
1
(x + 1)^2=x^2+2x+1
(x+1)2=x2+2x+1,因此
∑
(
x
+
1
)
2
=
∑
x
2
+
2
x
+
1
∑(x+1)^2=∑x^2+2x+1
∑(x+1)2=∑x2+2x+1,就可以直接得出方程
X
d
i
s
2
i
=
∑
X
d
i
s
2
s
o
n
+
2
X
d
i
s
1
s
o
n
+
X
c
n
t
s
o
n
Xdis2_i=∑Xdis2_{son}+2Xdis1_{son}+Xcnt_{son}
Xdis2i=∑Xdis2son+2Xdis1son+Xcntson 非常丑
但是我们询问的是所有点给的磁场强度,就是要求以每个点为根的不同的树对应的
d
i
s
i
+
X
d
i
s
2
i
dis_i+Xdis2_i
disi+Xdis2i
我们进行换根操作,定义
s
u
m
i
sum_i
sumi→以i为根所有旪超能力点到i的距离和,
X
s
u
m
1
i
Xsum1_i
Xsum1i→以i为根所有超能力点到i的距离和,
X
s
u
m
2
i
Xsum2_i
Xsum2i→以i为根所有超能力点到i的距离平方和
以上面求 d i s dis dis为例,先删去 i i i子树中所有贡献答案的点从 i i i到 f a t h e r i father_i fatheri的那条边,再加上其他所有贡献答案的点从 f a t h e r i father_i fatheri到 i i i的那条边,得出状态转移方程
数据范围需要用longlong
>代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#define N 500005
#define inf 1 << 23
#define int long long
using namespace std;
struct edge
{
int to, next;
} a[2 * N];
int n, m, k, L, R, t, h[N], c[N], minn[N];
int dis[N], cnt[N], Xdis1[N], Xdis2[N], Xcnt[N], sum[N], Xsum1[N], Xsum2[N];
bool mark[N], sp[N], p[N];
queue<int> q;
void add (int u, int v)
{
a[++t] = (edge){v, h[u]}; h[u] = t;
a[++t] = (edge){u, h[v]}; h[v] = t;
}
void bfs_find ()
{
for (int i = 1; i <= n; i++)
{
if (sp[i]) mark[i] = 1, c[i] = 0, q.push(i);
else mark[i] = 0, c[i] = inf, minn[i] = inf;
}
while (!q.empty())
{
int u = q.front();
q.pop(); mark[u] = 0;
for (int i = h[u]; i; i = a[i].next)
if (c[a[i].to] > c[u] + 1)
{
c[a[i].to] = c[u] + 1;
if (!sp[a[i].to]) minn[a[i].to] = min (minn[a[i].to], c[a[i].to]);
if (!mark[a[i].to])
{
mark[a[i].to] = 1;
q.push(a[i].to);
}
}
}
for (int i = 1; i <= n; i++)
if (!sp[i])
if (L <= minn[i] && minn[i] <= R) p[i] = 1;
}
void dfs (int now, int fath)
{
cnt[now] = p[now];
Xcnt[now] = sp[now];
for (int i = h[now]; i; i = a[i].next)
if(a[i].to != fath)
{
dfs (a[i].to, now);
cnt[now] += cnt[a[i].to];
dis[now] += dis[a[i].to] + cnt[a[i].to];
Xcnt[now] += Xcnt[a[i].to];
Xdis1[now] += Xdis1[a[i].to] + Xcnt[a[i].to];
Xdis2[now] += Xdis2[a[i].to] + 2 * Xdis1[a[i].to] + Xcnt[a[i].to];
}
}
void dfs2 (int now, int fath)
{
sum[now] = sum[fath] - cnt[now] + (cnt[1] - cnt[now]);
Xsum1[now] = Xsum1[fath] - Xcnt[now] + (Xcnt[1] - Xcnt[now]);
Xsum2[now] = Xsum2[fath] - 2 * Xdis1[now] - Xcnt[now];
Xsum2[now] = Xsum2[now] + 2 * (Xsum1[fath] - Xdis1[now] - Xcnt[now]) + (Xcnt[1] - Xcnt[now]);
for (int i = h[now]; i; i = a[i].next)
if (a[i].to != fath) dfs2 (a[i].to, now);
}
signed main()
{
int u, v, x;
scanf ("%lld%lld%lld%lld%lld", &n, &m, &k, &L, &R);
for (int i = 1; i < n; i++)
{
scanf ("%lld%lld", &u, &v);
add (u, v);
}
for (int i = 1; i <= m; i++)
{
scanf ("%lld", &x);
sp[x] = 1;
}
bfs_find ();
dfs (1, 0);
sum[1] = dis[1], Xsum1[1] = Xdis1[1], Xsum2[1] = Xdis2[1];
for (int i = h[1]; i; i = a[i].next)
dfs2 (a[i].to, 1);
for (int i = 1; i <= k; i++)
{
scanf ("%lld", &x);
printf ("%lld\n", sum[x] + Xsum2[x]);
}
return 0;
}