点分治核心模板
const int maxn = 1e6+5;
vector<int> edge[maxn];
int rt,size[maxn],son[maxn];
int all;
int vis[maxn];
int ans;
void getroot(int u, int fa)
{
size[u] = 1;
son[u] = 0;
int tmp = 0;
for(auto v : edge[u]){
if(v == fa) continue;
getroot(v, u);
size[u] += size[v];
son[u] = max(son[u], size[v]); //son[u]:u结点最大的儿子块
}
son[u] = max(son[u], all - size[u]);
if(son[u] < son[rt]) rt = u; //最大儿子最小的为重心
}
void dfs(int u)
{
vis[u] = 1;
ans += cal(u);
for(auto v : edge[u]){
if(vis[v]) continue;
ans -= cal(v);
rt = 0; //son[0]为inf,保证第一次rt能被置换
all = size[v];
getroot(v, 0);
dfs(rt);
}
}
上述的遍历保证了logn的复杂度
问题
给定一棵树和一个整数 k ,求树上等于 k 的路径有多少条?
点分治解决形似如上的问题,求树上所有的路径中满足某种条件的路径数
#include<cstdio>
#include<algorithm>
#include<cstring>
#define F(i,a,b) for(int i=a;i<=(b);++i)
#define eF(i,u) for(int i=h[u];i;i=nxt[i])
using namespace std;
const int INF=0x3f3f3f3f;
int n,k,Ans;
int h[10001],nxt[20001],to[20001],w[20001],tot;
inline void ins(int x,int y,int z){nxt[++tot]=h[x];to[tot]=y;w[tot]=z;h[x]=tot;}
bool vis[10001];
int Root,Tsiz,siz[10001],wt[10001];
int arr[10001],cnt;
void GetRoot(int u,int f){
siz[u]=1; wt[u]=0;
eF(i,u) if(to[i]!=f&&!vis[to[i]])
GetRoot(to[i],u), siz[u]+=siz[to[i]], wt[u]=max(wt[u],siz[to[i]]);
wt[u]=max(wt[u],Tsiz-siz[u]);
if(wt[Root]>wt[u]) Root=u;
}
void Dfs(int u,int D,int f){
arr[++cnt]=D;
eF(i,u) if(to[i]!=f&&!vis[to[i]]) Dfs(to[i],D+w[i],u);
}
int calc(int u,int D){
cnt=0; Dfs(u,D,0); int l=1,r=cnt,sum=0;
sort(arr+1,arr+cnt+1);
for(;;++l){
while(r&&arr[l]+arr[r]>k) --r;
if(r<l) break;
sum+=r-l+1;
}
return sum;
}
void DFS(int u){
Ans+=calc(u,0); vis[u]=1;
eF(i,u) if(!vis[to[i]]){
Ans-=calc(to[i],w[i]);
Root=0, Tsiz=siz[to[i]], GetRoot(to[i],0);
DFS(Root);
}
}
int main(){
int x,y,z;
while(~scanf("%d%d",&n,&k)&&n&&k){
tot=Ans=0; memset(vis,0,sizeof vis); memset(h,0,sizeof h);
F(i,2,n) scanf("%d%d%d",&x,&y,&z), ins(x,y,z), ins(y,x,z);
wt[Root = 0]=INF; Tsiz=n; GetRoot(1,0);
DFS(Root);
printf("%d\n",Ans-n);
}
return 0;
}
思想
点分治思想如下
用点分治模板搜索一颗树,保证时间复杂度为logn,再调用cal函数暴力搜索计算所有的 d i s ( r t , v ) dis(rt,v) dis(rt,v),将满足条件的dis相加,对ans产生贡献。但注意,我们枚举的是所有的从rt出发的链,将所有的链保存到一个数组里再进行计算,计算将任意两个链合并产生新的链,两个链可能位于同一子树,此时情况非法,需用容斥原理将每一个子树的cal减去。
练习
牛客练习赛81D
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 1e6+5;
const int mod = 998244353;
const int inf = 1e18;
vector<int> edge[maxn];
int n,a[maxn],ans;
int rt,size[maxn],son[maxn],all,vis[maxn];
int cnt;
struct node{
int dis, val;
friend bool operator < (struct node a, struct node b){
return a.val > b.val;
}
}arr[maxn];
void getroot(int u, int fa)
{
size[u] = 1;
son[u] = 0;
for(auto v : edge[u]){
if(vis[v] || v == fa) continue;
getroot(v, u);
size[u] += size[v];
son[u] = max(son[u], size[v]); //son[u]:u结点最大的儿子块
}
son[u] = max(son[u], all - size[u]);
if(son[u] < son[rt]) rt = u; //最大儿子最小的为重心
}
void dfs2(int u, int dis, int fa)
{
arr[cnt].dis = dis;
arr[cnt++].val = a[u];
for(auto v : edge[u]){
if(vis[v] || v == fa) continue;
dfs2(v, dis+1, u);
}
}
int cal(int u, int dis)
{
cnt = 0;
dfs2(u, dis, 0);
sort(arr, arr+cnt);
int res = 0;
for(int i = 0, sum=0; i < cnt; i++){
res += sum * arr[i].val + (i * arr[i].dis % mod) * arr[i].val;
sum += arr[i].dis;
sum %= mod;
res %= mod;
}
return res;
}
void dfs(int u)
{
vis[u] = 1;
ans += cal(u, 0);
ans %= mod;
for(auto v : edge[u]){
if(vis[v]) continue;
ans -= cal(v, 1);
ans = (ans + mod) % mod;
rt = 0; //son[0]为inf,保证第一次rt能被置换
all = size[v];
getroot(v, 0);
dfs(rt);
}
}
signed main()
{
scanf("%lld", &n);
for(int i = 1; i <= n; i++) scanf("%lld", &a[i]);
for(int i = 1; i < n; i++){
int u, v;
scanf("%lld%lld", &u, &v);
edge[u].push_back(v);
edge[v].push_back(u);
}
son[rt=0] = inf;
all = n;
getroot(1, 0);
dfs(rt);
ans = (ans * 2 + mod) % mod;
printf("%lld", ans);
return 0;
}