感谢b站up大佬:不分解的AgOH。
点分治用于树上的大规模路径操作、统计。它的灵活性高,适用范围广。很多树形dp也可以用点分治搞。其实,树分治中,还有一个边分治,不过没有点分治常用。
基本步骤:
第一步
找树的重心。我们重心为根节点,然后按题目的不同情况统计树各个点到根节点的路径信息。
第二步
递归至每一个子树,重复第一步,直到分割至叶子节点。
非常简洁。
下面看一看具体的实现:
找重心函数:
找到原以v为根的子树的重心,并切换根。
void getrt(int u, int f)
{
siz[u] = 1; mp[u] = 0;//siz数组数组树子树大小
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v];
if (siz[v] > mp[u]) mp[u] = siz[v];//mp数组是去掉u节点后,剩余部分的最大一部分。
}
mp[u] = max(mp[u], sum-siz[u]);//不要忘记上子树
if (mp[u] < mp[rt]) rt = u;//换根
}
分割函数:
对原树进行划分
//------main----
getrt(1, 0);//找整树的重心
getrt(rt, 0);//为啥是两遍?因为我们需要让siz数组正确(换根后siz数组就不正确了)
divide(rt);
//--------------
void divide(int u)
{
vis[u] = 1;//vis[i],表示u节点已被选中
solve(u);//已u为根节点,统计路径信息
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if(vis[v]) continue;
mp[rt=0] = sum = siz[v];
getrt(v, 0);//找v子树的根节点
getrt(rt, 0);
divide(rt);//递归划分子树
}
}
统计信息的函数,随题目而异,这里引出5道例题。
例1:CF161D Distance in Tree
最基本的树分治,求树上距离为k的点的对数。
直接上代码,注意看solve函数:
下面是ac代码:
#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5+5;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, ans, k;
int tmp[N], siz[N], dis[N], mp[N], jd[N*10];
bool vis[N];
void add(int x, int y, int w)
{
ver[++tot] = y;
ne[tot] = he[x];
e[tot] = w;
he[x] = tot;
}
void getrt(int u, int f)
{
siz[u] = 1; mp[u] = 0;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v];
if (siz[v] > mp[u]) mp[u] = siz[v];
}
mp[u] = max(mp[u], sum-siz[u]);
if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
tmp[cnt++] = dis[u];
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
dis[v] = dis[u] + e[i];
getdis(v, u);
}
}
void solve(int u)
{
queue<int> que;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (vis[v]) continue;
cnt = 0;
dis[v] = e[i];
getdis(v, u);//统计v子树的所有节点到v的距离
for (int j = 0; j < cnt; j++)
if (k >= tmp[j])
ans += jd[k-tmp[j]];//jd数组是一个桶数组,jd[i]为路径i的数量
for (int j = 0; j < cnt; j++)
{
que.push(tmp[j]);
jd[tmp[j]]++;//每找一个子树就压进去一批
}
}
while(que.size())
{
jd[que.front()]--;
que.pop();//以u为根的子树统计完毕,清空。
}
}
void divide(int u)
{
vis[u] = jd[0] = 1;
solve(u);
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if(vis[v]) continue;
mp[rt=0] = sum = siz[v];
getrt(v, 0);
getrt(rt, 0);
divide(rt);
}
}
int main()
{
scanf("%d%d", &n, &k);
for (int i = 1; i < n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
add(x, y, 1); add(y, x, 1);
}
mp[0] = sum = n;
getrt(1, 0);
getrt(rt, 0);
divide(rt);
printf("%d\n", ans);
return 0;
}
例2:洛谷 P3806 【模板】点分治1
统计距离k的点是否存在,离线后,和上一个题基本一样。挨个统计,挨个比较每一个询问。
下面是ac代码:
#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5+5;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt;
int tmp[N], siz[N], dis[N], mp[N], q[105];
bool jd[N*10], ans[105], vis[N];
void add(int x, int y, int w)
{
ver[++tot] = y;
ne[tot] = he[x];
e[tot] = w;
he[x] = tot;
}
void getrt(int u, int f)
{
siz[u] = 1; mp[u] = 0;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v];
if (siz[v] > mp[u]) mp[u] = siz[v];
}
mp[u] = max(mp[u], sum-siz[u]);
if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
tmp[cnt++] = dis[u];
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
dis[v] = dis[u] + e[i];
getdis(v, u);
}
}
void solve(int u)
{
queue<int> que;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (vis[v]) continue;
cnt = 0;
dis[v] = e[i];
getdis(v, u);
for (int j = 0; j < cnt; j++)
for (int k = 0; k <m; k++)
if (q[k] >= tmp[j])
ans[k] |= jd[q[k]-tmp[j]];
for (int j = 0; j < cnt; j++)
{
que.push(tmp[j]);
jd[tmp[j]] = 1;
}
}
while(que.size())
{
jd[que.front()] = 0;
que.pop();
}
}
void divide(int u)
{
vis[u] = jd[0] = 1;
solve(u);
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if(vis[v]) continue;
mp[rt=0] = sum = siz[v];
getrt(v, 0);
getrt(rt, 0);
divide(rt);
}
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i < n; i++)
{
int x, y, w;
scanf("%d%d%d", &x, &y, &w);
add(x, y, w); add(y, x, w);
}
for (int i = 0; i < m; i++)
scanf("%d", &q[i]);
mp[0] = sum = n;
getrt(1, 0);
getrt(rt, 0);
divide(rt);
for (int i = 0; i < m; i++)
{
if (ans[i]) puts("AYE");
else puts("NAY");
}
return 0;
}
例3:洛谷 P4178 Tree
这个题要求我们统计有多少对点之间的距离小于等于k的。这时我们发现,solve函数中更新ans的复杂度不太理想(因为要对于一个距离g需要累计从0至k-g),不过好在我们可以用树状数组维护桶数组jd达到我们的目的。
下面是ac代码:
#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5+5;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, k, ans;
int tmp[N], siz[N], dis[N], mp[N];
int jd[N*100];
bool vis[N];
void change(int x, int y)
{
for (;x <= 100*N-2; x += x & -x) jd[x] += y;
}
ll ask(int x)
{
ll ans = 0;
for (; x; x -= x &-x) ans += jd[x];
return ans;
}
void add(int x, int y, int w)
{
ver[++tot] = y;
ne[tot] = he[x];
e[tot] = w;
he[x] = tot;
}
void getrt(int u, int f)
{
siz[u] = 1; mp[u] = 0;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v];
if (siz[v] > mp[u]) mp[u] = siz[v];
}
mp[u] = max(mp[u], sum-siz[u]);
if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
tmp[++cnt] = dis[u];
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
dis[v] = dis[u] + e[i];
getdis(v, u);
}
}
void solve(int u)
{
// cout << u << " ::::" << endl;
queue<int> que;
que.push(0);
change(1, 1);
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (vis[v]) continue;
cnt = 0;
dis[v] = e[i];
// cout << v <<":";
getdis(v, u);
// for (int i = 1; i <= cnt; i++)
// cout << tmp[i] << " ";
// cout << endl;
for (int j = 1; j <= cnt; j++)
{
if (tmp[j] > k) continue;
ans += ask(k-tmp[j]+1);
}
for (int j = 1; j <= cnt; j++)
{
que.push(tmp[j]);
change(tmp[j]+1, 1);
}
}
while(que.size())
{
change(que.front()+1, -1);
que.pop();
}
}
void divide(int u)
{
vis[u] = jd[0] = 1;
solve(u);
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if(vis[v]) continue;
mp[rt=0] = sum = siz[v];
getrt(v, 0);
getrt(rt, 0);
divide(rt);
}
}
int main()
{
scanf("%d", &n);
for (int i = 1; i < n; i++)
{
int x, y, w;
scanf("%d%d%d", &x, &y, &w);
add(x, y, w); add(y, x, w);
}
scanf("%d", &k);
mp[0] = sum = n;
getrt(1, 0);
getrt(rt, 0);
divide(rt);
printf("%d\n", ans);
return 0;
}
例4:洛谷 P2634 [国家集训队]聪聪可可
也是比较基础的问题,就是统计所有距离是3的倍数的对数。然后,值得注意的是,点分治没有考虑两个点重合的情况,看题目样例,我们是需要讨论的。好在两个点重合(dis==0)都符合题目要求,我们之间吧ans加上一个n就ok了。
下面是ac代码:
#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5+5;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, k, ans;
int tmp[N], siz[N], dis[N], mp[N];
int jd[5];
bool vis[N];
int gcd(int a, int b)
{
return b?gcd(b,a%b):a;
}
void add(int x, int y, int w)
{
ver[++tot] = y;
ne[tot] = he[x];
e[tot] = w;
he[x] = tot;
}
void getrt(int u, int f)
{
siz[u] = 1; mp[u] = 0;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v];
if (siz[v] > mp[u]) mp[u] = siz[v];
}
mp[u] = max(mp[u], sum-siz[u]);
if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
tmp[++cnt] = dis[u];
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
dis[v] = dis[u] + e[i];
getdis(v, u);
}
}
void solve(int u)
{
// cout << u << " ::::" << endl;
queue<int> que;
que.push(0);
jd[0] = 1;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (vis[v]) continue;
cnt = 0;
dis[v] = e[i];
// cout << v <<":";
getdis(v, u);
// for (int i = 1; i <= cnt; i++)
// cout << tmp[i] << " ";
// cout << endl;
for (int j = 1; j <= cnt; j++)
{
if(tmp[j]%3 == 0) ans += jd[0];
else if (tmp[j]%3 == 1) ans += jd[2];
else ans += jd[1];
}
for (int j = 1; j <= cnt; j++)
{
que.push(tmp[j]);
jd[tmp[j]%3]++;
}
}
while(que.size())
{
jd[que.front()%3]--;
que.pop();
}
}
void divide(int u)
{
vis[u] = jd[0] = 1;
solve(u);
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if(vis[v]) continue;
mp[rt=0] = sum = siz[v];
getrt(v, 0);
getrt(rt, 0);
divide(rt);
}
}
int main()
{
scanf("%d", &n);
for (int i = 1; i < n; i++)
{
int x, y, w;
scanf("%d%d%d", &x, &y, &w);
add(x, y, w); add(y, x, w);
}
mp[0] = sum = n;
getrt(1, 0);
getrt(rt, 0);
divide(rt);
ans *= 2;
ans += n;
int gg = n *n;
int d = gcd(ans, gg);
printf("%d/%d\n", ans/d, gg/d);
return 0;
}
例5:洛谷 P4149 [IOI2011]Race
比较有意思,求出路径为k的点对中,边数最小的数量。不难,我们在getdis的函数里同时跑出深度dep。然后在路径为k的情况下最小化ans。但是,,但是!这狗题没说权值范围。。。。没法看了1e6的jd数组交了一份re了3个点。随改成2e7还是re了一个点,在开就mlt了。。。。用了unormap,,t了3个点我擦。。。然后改回jd数组强行hash了一波。。可算是过了。不过这也不是正经解法啊。。于是用unormap吸着氧2.6s险些超时。。。两份代码都放一下吧。。可能有正经解法,本憨批不知道。
下面是ac代码(数组强行hash):
#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define max(x, y) ((x)>(y)?(x):(y))
#define ll long long
using namespace std;
const int N = 3e5+5;
const int mod = N *200;
const int inf = 0x3f3f3f3f;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, ans = inf, k;
int siz[N], dis[N], mp[N], jd[N*200], dep[N];
pair<int, int> tmp[N];
bool vis[N];
void add(int x, int y, int w)
{
ver[++tot] = y;
ne[tot] = he[x];
e[tot] = w;
he[x] = tot;
}
void getrt(int u, int f)
{
siz[u] = 1; mp[u] = 0;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v];
if (siz[v] > mp[u]) mp[u] = siz[v];
}
mp[u] = max(mp[u], sum-siz[u]);
if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
tmp[cnt].first = dep[u];
tmp[cnt++].second = dis[u];
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
dis[v] = dis[u] + e[i];
dep[v] = dep[u] + 1;
getdis(v, u);
}
}
void solve(int u)
{
queue<pair<int, int> > que;
que.push(make_pair(0, 0));
jd[0] = 0;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (vis[v]) continue;
cnt = 0;
dis[v] = e[i];
dep[v] = 1;
getdis(v, u);
for (int j = 0; j < cnt; j++)
if (k >= tmp[j].second)
ans = min(ans, jd[k-tmp[j].second%mod] + tmp[j].first);
for (int j = 0; j < cnt; j++)
{
que.push(tmp[j]);
jd[tmp[j].second%mod] = min(jd[tmp[j].second%mod], tmp[j].first);
}
}
while(que.size())
{
jd[que.front().second%mod] = inf;
que.pop();
}
}
void divide(int u)
{
vis[u] = jd[0] = 1;
solve(u);
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if(vis[v]) continue;
mp[rt=0] = sum = siz[v];
getrt(v, 0);
getrt(rt, 0);
divide(rt);
}
}
int main()
{
scanf("%d%d", &n, &k);
for (int i = 1; i < n; i++)
{
int x, y, w;
scanf("%d%d%d", &x, &y, &w);
add(x+1, y+1, w); add(y+1, x+1, w);
}
memset(jd, inf, sizeof(jd));
mp[0] = sum = n;
getrt(1, 0);
getrt(rt, 0);
divide(rt);
if (ans != inf)
printf("%d\n", ans);
else
puts("-1");
return 0;
}
(吸氧map):
#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <unordered_map>
#include <queue>
#define max(x, y) ((x)>(y)?(x):(y))
#define ll long long
using namespace std;
const int N = 3e5+5;
const int inf = 0x3f3f3f3f;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, ans = inf, k;
int siz[N], dis[N], mp[N], dep[N];
pair<int, int> tmp[N];
unordered_map<int, int> jd;
bool vis[N];
void add(int x, int y, int w)
{
ver[++tot] = y;
ne[tot] = he[x];
e[tot] = w;
he[x] = tot;
}
void getrt(int u, int f)
{
siz[u] = 1; mp[u] = 0;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v];
if (siz[v] > mp[u]) mp[u] = siz[v];
}
mp[u] = max(mp[u], sum-siz[u]);
if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
tmp[cnt].first = dep[u];
tmp[cnt++].second = dis[u];
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
dis[v] = dis[u] + e[i];
dep[v] = dep[u] + 1;
getdis(v, u);
}
}
void solve(int u)
{
queue<pair<int, int> > que;
que.push(make_pair(0, 0));
jd[0] = 0;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (vis[v]) continue;
cnt = 0;
dis[v] = e[i];
dep[v] = 1;
getdis(v, u);
for (int j = 0; j < cnt; j++)
if (k >= tmp[j].second)
{
if (jd.find(k-tmp[j].second) == jd.end()) jd[k-tmp[j].second] = inf;
ans = min(ans, jd[k-tmp[j].second] + tmp[j].first);
}
for (int j = 0; j < cnt; j++)
{
que.push(tmp[j]);
if (jd.find(tmp[j].second) == jd.end()) jd[tmp[j].second] = inf;
jd[tmp[j].second] = min(jd[tmp[j].second], tmp[j].first);
}
}
while(que.size())
{
jd[que.front().second] = inf;
que.pop();
}
}
void divide(int u)
{
vis[u] = jd[0] = 1;
solve(u);
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if(vis[v]) continue;
mp[rt=0] = sum = siz[v];
getrt(v, 0);
getrt(rt, 0);
divide(rt);
}
}
int main()
{
scanf("%d%d", &n, &k);
for (int i = 1; i < n; i++)
{
int x, y, w;
scanf("%d%d%d", &x, &y, &w);
add(x+1, y+1, w); add(y+1, x+1, w);
}
mp[0] = sum = n;
getrt(1, 0);
getrt(rt, 0);
divide(rt);
if (ans != inf)
printf("%d\n", ans);
else
puts("-1");
return 0;
}