对A而言这里给出官方题解的实现方式,不需要exgcd和int128
//
// Created by 13730 on 2023/11/15.
//
/*
* 2022杭州区域赛A
* 考点数论
* 我们令x = (n + 1) * n / 2,也就是d的系数
* 观察可得,题目希望求出(ns + xd + sum) % m的最小值
* 不难发现,当n为奇数时, x必然为n的倍数,也就是说ns可以与xd合并,d可以为0
* 当n为偶数时,我们可以利用等差数列的性质,构造出公差为1的序列,例如1,4,7,10可以写成4,5,6,7
* 换句话来说,相当于直接让d为1,然后xd等价与kn + x,kn又可以与ns合并
* 对奇数而言,有ns - km = ans - sum, ns - km = gcd(n, m) * k2
* ans = k2 * gcd(n, m) + sum, 所以ans的最小值就是sum % gcd(n, m)
* 因此可以枚举k,因为k必然不超过n,如果超过则有ns - (k1 + n) * m可以与前一项合并
* 对偶数而言则需要多枚举ans = (sum + x) % gcd(n, m)的情况
* 该解法不需要exgcd
*/
#include <iostream>
#include <cmath>
#include <vector>
#include <algorithm>
#define endl '\n'
#define int long long
using namespace std;
const int N = 2e5 + 10;
int a[N];
int n, m;
int gcd(int a, int b)
{
return b ? gcd(b, a % b) : a;
}
signed main()
{
#ifdef DEBUG
freopen("in.in", "r", stdin);
freopen("out.out", "w", stdout);
#endif
ios::sync_with_stdio(false); cin.tie(nullptr);
cin >> n >> m;
int sum = 0;
int x = (n + 1) * n / 2;
int ans, s, d;
for (int i = 1; i <= n; i ++) cin >> a[i], sum += a[i];
int g = gcd(n, m);
if (sum % m == 0) ans = s = d = 0;
else if (n % 2)
{
d = 0;
ans = sum % g;
for (int i = 1; i <= n; i ++)
{
int y = i * m;
int z = y + ans - sum;
if (z % n == 0)
{
s = z / n;
s = (s % m + m) % m;
break;
}
}
}
else
{
d = 0;
ans = sum % g;
for (int i = 1; i <= n; i ++)
{
int y = i * m;
int z = y + ans - sum;
if (z % n == 0)
{
s = z / n;
s = (s % m + m) % m;
break;
}
}
if ((sum + x) % g < ans)
{
d = 1;
ans = (sum + x) % g;
for (int i = 1; i <= n; i ++)
{
int y = i * m;
int z = y + ans - sum - x;
if (z % n == 0)
{
s = z / n;
s = (s % m + m) % m;
break;
}
}
}
}
cout << ans << endl;
cout << s << " " << d << endl;
return 0;
}
M题
* 思路:差分数组,线段树,换根dp * 首先明确该题任务,选择一个点i,维护出k个病毒地点到它的距离和,同时需要维护每个地点到它路径的gcd * 思路: * 首先可以暴力求出1到k个地点的距离和,然后换根dp,思路非常简单 * 具体实现:len = now - cnt[v] * w + (tot - cnt[v]) * w * now表示上次算出的长度, cnt[v]表示包括v在内,v的子树有多少病毒地点,tot等于m * 麻烦的点在第二个操作,如何快速维护出变化后所有距离的gcd * 考虑差分数组 * gcd(a[1], a[2], a[3], a[4], ... a[n]) = gcd(a[1], a[2] - a[1], a[3] - a[2], ..., a[n] - a[n - 1]) * 为什么往差分数组上想?原因在于a[x], a[y]如果同时加上或减去一个数,那么它们的差不变,该性质可以使得换根的每次修改在3次以内 * 为什么在3次以内?我们利用树链剖分的想法,把一个子树里所有的病毒地点都打上连续的序号,我们将开头设成mn[v], 结尾设成mx[v] * 对于mn[v] <= i < mx[i],不难发现a[i + 1] - a[i]是不会发生改变的 * 因此需要修改的地方只有mn[v](表示a[mn[v]] - a[mn[v] - 1]), mx[v + 1]和1处 * 对于mn[v] = 1没有第一种修改,对于mx[v] = m来说没有第二种修改,所以说一定在3次以内
//
// Created by 13730 on 2023/11/14.
//
/*
* 2022杭州区域赛M
* 思路:差分数组,线段树,换根dp
* 首先明确该题任务,选择一个点i,维护出k个病毒地点到它的距离和,同时需要维护每个地点到它路径的gcd
* 思路:
* 首先可以暴力求出1到k个地点的距离和,然后换根dp,思路非常简单
* 具体实现:len = now - cnt[v] * w + (tot - cnt[v]) * w
* now表示上次算出的长度, cnt[v]表示包括v在内,v的子树有多少病毒地点,tot等于m
* 麻烦的点在第二个操作,如何快速维护出变化后所有距离的gcd
* 考虑差分数组
* gcd(a[1], a[2], a[3], a[4], ... a[n]) = gcd(a[1], a[2] - a[1], a[3] - a[2], ..., a[n] - a[n - 1])
* 为什么往差分数组上想?原因在于a[x], a[y]如果同时加上或减去一个数,那么它们的差不变,该性质可以使得换根的每次修改在3次以内
* 为什么在3次以内?我们利用树链剖分的想法,把一个子树里所有的病毒地点都打上连续的序号,我们将开头设成mn[v], 结尾设成mx[v]
* 对于mn[v] <= i < mx[i],不难发现a[i + 1] - a[i]是不会发生改变的
* 因此需要修改的地方只有mn[v](表示a[mn[v]] - a[mn[v] - 1]), mx[v + 1]和1处
* 对于mn[v] = 1没有第一种修改,对于mx[v] = m来说没有第二种修改,所以说一定在3次以内
*/
#include <iostream>
#include <vector>
#include <algorithm>
#define endl '\n'
#define int long long
using namespace std;
typedef pair<int, int> PII;
const int N = 5e5 + 10;
int n, m, sum;
int mark[N], dist[N], mn[N], mx[N];
/*
* mark表示当前点及它的子树中病毒地点的数量数量,
* dist表示点1到所有点的距离,
* mn,mx分别表示当前点及它的子树中最小最大病毒地点的编号
*/
int d1[N];//表示所有点到第一个源头的距离
int diff[N];//距离的差分数组
int dp[N];//表示第i个点的答案
vector<PII> e[N];
vector<int> lsh;//依次保存所有病毒地点
int id, sta;
void dfs1(int u, int fa, int d)
{
if (mark[u])
{
dist[u] = d;
++ id;
lsh.push_back(u);
mx[u] = max(mx[u], id);
mn[u] = min(mn[u], id);
}
for (auto [v, w] : e[u])
{
if (v == fa) continue;
dfs1(v, u, d + w);
mx[u] = max(mx[u], mx[v]);
mn[u] = min(mn[u], mn[v]);
}
}
//第一次dfs,处理病毒地点的编号,同时处理所有点的mn, mx
void dfs2(int u, int fa)
{
for (auto [v, w] : e[u])
{
if (v == fa) continue;
dfs2(v, u);
mark[u] += mark[v];
}
}
//第二次dfs,计算每个点及它的子树中病毒地点的数量
struct Node
{
int l, r, g;
void init(int _l, int _r, int _v)
{
l = _l, r = _r;
g = _v;
}
}tr[N << 2];
void pushup(int id)
{
tr[id].g = __gcd(tr[id << 1].g, tr[id << 1 | 1].g);
}
void build(int id, int l, int r)
{
if (l == r)
{
tr[id].init(l, r, diff[l]);
return;
}
tr[id].init(l, r, 0);
int mid = l + r >> 1;
build(id << 1, l, mid), build(id << 1 | 1, mid + 1, r);
pushup(id);
}
void modify(int id, int pos, int v)
{
if (tr[id].l == tr[id].r)
{
tr[id].g += v;
return;
}
int mid = tr[id].l + tr[id].r >> 1;
if (pos <= mid) modify(id << 1, pos, v);
else modify(id << 1 | 1, pos, v);
pushup(id);
}
int query_gcd(int id, int ql, int qr)
{
if (tr[id].l == ql && tr[id].r == qr) return tr[id].g;
int mid = tr[id].l + tr[id].r >> 1;
if (qr <= mid) return query_gcd(id << 1, ql, qr);
else if (ql > mid) return query_gcd(id << 1 | 1, ql, qr);
return __gcd(query_gcd(id << 1, ql, mid), query_gcd(id << 1 | 1, mid + 1, qr));
}
//线段树板子老生常谈
void dfs3(int u, int fa, int now)
{
d1[u] = now;
for (auto [v, w] : e[u])
{
if (v == fa) continue;
dfs3(v, u, now + w);
}
}
//第三次dfs计算lsh中的第一个点到其他点的距离,也就是gcd(a[1] ...)中的a[1],情况比较特殊
void dfs4(int u, int fa, int now)
{
for (auto [v, w] : e[u])
{
if (v == fa) continue;
int len = now - mark[v] * w + (m - mark[v]) * w;
modify(1, 1, d1[v] - d1[u]);
if (mx[v])
{
if (mn[v] != 1) {
modify(1, mn[v], -2 * w);
}
if (mx[v] != m) {
modify(1, mx[v] + 1, 2 * w);
}
}
int x = abs(query_gcd(1, 1, m));
dp[v] = len / x;
dfs4(v, u, len);
modify(1, 1, d1[u] - d1[v]);//记得回溯
if (mx[v])
{
if (mn[v] != 1) {
modify(1, mn[v], 2 * w);
}
if (mx[v] != m) {
modify(1, mx[v] + 1, -2 * w);
}
}
}
}
//第四次dfs进行换根dp
void sv1()
{
cout << 0 << endl;
exit(0);
}
//记得特判m等于1的情况,不然会RE
signed main()
{
#ifdef DEBUG
freopen("in.in", "r", stdin);
freopen("out.out", "w", stdout);
#endif
ios::sync_with_stdio(false); cin.tie(nullptr);
cin >> n >> m;
for (int i = 1; i <= n; i ++) mn[i] = 1e9;
for (int i = 1; i <= m; i ++)
{
int x; cin >> x;
mark[x] = 1;
}
if (m == 1) sv1();
for (int i = 0; i < n - 1; i ++)
{
int u, v, w; cin >> u >> v >> w;
e[u].push_back({v, w});
e[v].push_back({u, w});
}
dfs1(1, 0, 0);
int num = 0;
for (auto i : lsh) diff[++ num] = dist[i];
dfs2(1, 0);
for (int i = m; i; i --) sum += diff[i], diff[i] -= diff[i - 1];
build(1, 1, m);
dfs3(lsh[0], 0, 0);
dp[1] = sum / abs(query_gcd(1, 1, m));
dfs4(1, 0, sum);
int ans = 1e18;
for (int i = 1; i <= n; i ++) ans = min(ans, 2 * dp[i]);
cout << ans << endl;
return 0;
}