先给出题目链接
大致题意如下:给定k个关键点,然后要求你在树上找到一个点作为出发点rt,最小化
首先是分子部分,我们可以通过换根DP或者是差分dfn维护来获得以x为根节点时,所有的关键节点到x节点的距离的和。
现在考虑分母部分,我们需要实现所有关键节点到根节点x的距离的GCD。维护一个数组中所有元素的GCD我们有一个很常见的操作,那就是维护这个数组的差分数组中所有元素的GCD。差分数组相较于原数组,便于实现区间修改,并且正确性是一致的。
那么我们先获得每个关键节点到根节点1处的距离,然后对其差分数组建立一棵维护区间GCD的线段树。然后,我们采取和差分dfn序类似的操作,考虑每条边的边权在根节点发生变化的时候会对关键节点到根节点的距离产生什么变化。显然,当根节点从u变化到v的时候,以v节点为根节点的子树内的关键节点到根节点的距离会减少w(u,v),而剩余的关键节点到根节点的距离都会增加w(u,v)。我们将树上的操作转移到平面区间线段树上,维护一个距离的差分数组即可。就是一个朴素的带修改的区间GCD问题。
注:笔者不会换根DP,且没有搞懂官方给的题解..但是在网上找不到合适的文章,自己瞎搞了几天后写下了这篇文章。
下面是代码部分,附有部分注释。
#include <bits/stdc++.h>
using namespace std;
#define ios \
ios::sync_with_stdio(0); \
cin.tie(0);
#define i64 long long
#define pii pair<int, int>
#define ls(x) (x << 1)
#define rs(x) (x << 1 | 1)
#define lowbit(x) (x & -x)
#define int i64
const int N = 5e5 + 10;
struct node
{
int x;
} tr[N << 2];
int gcd(int a, int b)
{
return b ? gcd(b, a % b) : a;
}
int b[N];
void pushup(int x)
{
tr[x].x = gcd(tr[ls(x)].x, tr[rs(x)].x);
}
void add(int x, int l, int r, int pos, int c)
{
if (pos < l || pos > r)
return;
if (l == r)
{
tr[x].x += c;
return;
}
int m = l + r >> 1;
if (pos <= m)
add(ls(x), l, m, pos, c);
else
add(rs(x), m + 1, r, pos, c);
pushup(x);
}
void build(int x, int l, int r)
{
if (l == r)
{
tr[x].x = b[l];
return;
}
int m = l + r >> 1;
build(ls(x), l, m);
build(rs(x), m + 1, r);
pushup(x);
}
struct BIT // 使用树状数组来统计当前子树内的关键点编号合集
{
int c[N];
void add(int x)
{
for (int i = x; i < N; i += lowbit(i))
c[i]++;
}
int query(int x)
{
int res = 0;
for (int i = x; i > 0; i -= lowbit(i))
res += c[i];
return res;
}
} c;
i64 ans[N], dis[N];
int sum[N], k;
int siz[N], dfn[N], tim; // 经典时间戳和dfn序
int d[N], idx;
vector<pii> e[N];
void dfs(int u, int f)
{
dfn[u] = ++tim, siz[u] = 1;
if (sum[u])
{
c.add(tim);
b[++idx] = dis[u]; // 记录关键点到起点1的距离
}
for (pii t : e[u])
{
int v = t.first, w = t.second;
if (v == f)
continue;
dis[v] = dis[u] + w;
dfs(v, u);
sum[u] += sum[v];
siz[u] += siz[v];
ans[1] += 1ll * sum[v] * w; // 经典差分解决换根DP
ans[dfn[v]] += 1ll * (k - sum[v] - sum[v]) * w;
ans[dfn[v] + siz[v]] -= 1ll * (k - sum[v] - sum[v]) * w;
}
}
void go(int u, int f, int val) // 遍历子树 下传u<->f的边权
{
// lef和rig维护的是以u为根节点的子树内的关键节点点集在差分数组中的起始下标和结束下标
// lef>rig说明以u为根节点的子树内没有关键节点 不会影响差分数组的修改维护
int lef = c.query(dfn[u] - 1) + 1, rig = c.query(dfn[u] + siz[u] - 1);
add(1, 1, k, 1, val); // 经典gcd区间修改转化为差分 进行简单的单点修改
add(1, 1, k, lef, -2ll * val);
add(1, 1, k, rig + 1, 2ll * val);
// d[dfn[u]] = gcd(gcd(query(1, 1, k, 1, lef - 1), query(1, 1, k, lef, rig)), query(1, 1, k, rig + 1, k));
d[dfn[u]] = tr[1].x; // 不需要写多一个query函数 直接取全局gcd即可
for (pii t : e[u])
{
int v = t.first, w = t.second;
if (v == f)
continue;
go(v, u, w);
}
add(1, 1, k, 1, -val);
add(1, 1, k, lef, 2ll * val);
add(1, 1, k, rig + 1, -2ll * val);
}
void solve()
{
int n;
cin >> n >> k;
for (int i = 1, x; i <= k; i++)
{
cin >> x;
sum[x] = 1;
}
for (int i = 1; i < n; i++)
{
int x, y, z;
cin >> x >> y >> z;
e[x].push_back({y, z});
e[y].push_back({x, z});
}
if (k == 1) // 特判
{
cout << "0\n";
return;
}
dfs(1, 0);
for (int i = k; i >= 1; i--) // 维护差分数组 b[i]-b[i-1] 倒序遍历
b[i] -= b[i - 1];
build(1, 1, k); // 建树
go(1, 0, 0);
i64 inf = 1e18; // 1e18就够了 答案至多也就5e17左右的规模且不会有爆long long情况出现
for (int i = 1; i <= n; i++)
{
ans[i] += ans[i - 1]; // ans是差分数组
inf = min(inf, ans[i] / abs(d[i])); // k==1的时候有可能会出现d[i]==0 会re
}
inf <<= 1ll; // *2是因为来回都要算
cout << inf << '\n';
}
signed main()
{
ios;
solve();
return 0;
}