点分治的作用
在一棵树上,我们会经常遇到各类路径问题:例如路径计数问题,路径最值问题等。
此时,我们的主要思想就是选择一个根作为这棵树的根节点(一般题目中的树都是无根树),计算和根相关的答案,再对每一个子树分治。此时,我们需要将树分的尽可能均匀来保证时间复杂度得稳定,我们一般选择计算树的重心。此时,如果操作是线性级别的,那么时间复杂度就是: O ( n l o g n ) O(n\ log\ n) O(n log n)
方法1:双指针法[POJ1741]
这道题以根为计数,可以将答案分为经过根的路径和不经过根的路径。
不经过根节点的可以递归处理,我们来考虑如何计算经过根节点的路径树。
我们预处理每一个节点到根的距离 d [ x ] d[x] d[x],特别地,另 d [ r o o t ] = 1 d[root]=1 d[root]=1.我们现在考虑如何用 O ( n ) O(n) O(n)的时间计算答案。
现在我们将题目转化为,有多少点对满足 d [ x ] + d [ y ] ≤ k d[x]+d[y]\leq k d[x]+d[y]≤k.我们可以使用双指针法。
- 将每一个 d d d值排序。设置左指针 l = 1 l=1 l=1,右指针 r = n r=n r=n.
- 当 d [ l ] + d [ r ] ≤ k d[l]+d[r]\leq k d[l]+d[r]≤k时,l可以和 [ l + 1 , r ] [l+1,r] [l+1,r]的点进行配对, a n s ans ans加上 r − l r-l r−l.否则 r r r减去 1 1 1.
- 如果 r ≤ l r\leq l r≤l,则停止继续,让 l l l加上 1 1 1.
使用这个方法,就能够巧妙的计算有多少点对小于等于k了。
但是如果存在同一子树的点对产生的答案应该如何处理?
- 对每一个子树单独按照原来的 d d d数组进行配对,让答案每一个子树中配对后的值即可。
方法2:DP思想([IOI2011]race)
给一棵树,每条边有权.求一条简单路径,权值和等于 K K K,且边的数量最小
设 f [ v ] f[v] f[v]表示权值和为 v v v的最小边。设 d i s [ x ] dis[x] dis[x]表示 x x x到 r o o t root root的权值和, M i n [ x ] Min[x] Min[x]表示 x x x到 r o o t root root的边的数量。
我们每一次用 f [ k − M i n [ x ] ] + d i s [ x ] f[k-Min[x]]+dis[x] f[k−Min[x]]+dis[x]来更新答案即可。
再用 f [ x ] = m i n ( f [ x ] , M i n [ x ] ) f[x]=min(f[x],Min[x]) f[x]=min(f[x],Min[x])来进行状态转移即可。
方法3:树状数组([SPOJ1825]免费旅行)
题目描述
在两周年纪念日的旅行之后,在第三年,旅行社SPOJ又一次踏上的打折旅行的道路。
这次旅行是ICPC岛屿上进行的,一个位于太平洋上,不可思议的小岛。我们列出了N个地点(编号从 1 1 1到 N N N)供旅客游览。这 N N N个点由 N − 1 N-1 N−1条边连成一个树,每条边都有一个权值,这个权值可能为负。我们可以选择两个地点作为旅行的起点和终点。
由于当地正在庆祝节日,所以某些地方会特别的拥挤(我们称这些地方为拥挤点)。旅行的组织者希望这次旅行最多访问 K K K个拥挤点。同时,我们希望我们经过的道路的权值和最大。
题解
这道题是一个最优性树上问题,我们依然通过固定一个根来考虑这个问题。
我们设
s
[
x
]
s[x]
s[x]表示节点
x
x
x到
r
o
o
t
root
root的路径中,黑点的个数。
d
i
s
[
x
]
dis[x]
dis[x]表示节点
x
x
x到
r
o
o
t
root
root的权值和。
f
a
[
x
]
fa[x]
fa[x]表示
x
x
x子树的根节点。则我们需要求解:
m
a
x
(
d
i
s
[
x
]
+
d
i
s
[
y
]
)
,
s
[
x
]
+
s
[
y
]
≤
k
,
f
a
[
x
]
=
̸
f
a
[
y
]
max(dis[x]+dis[y]),s[x]+s[y]\leq k,fa[x]=\not fa[y]
max(dis[x]+dis[y]),s[x]+s[y]≤k,fa[x]≠fa[y]
此时我们可以分别通过枚举每一个子树来解决不同祖先的问题。
我们设 f [ i ] f[i] f[i]表示节点到根节点的黑点数为 x x x的最大权值,则用 f [ 0... k − s [ x ] ] + d i s [ x ] f[0...k-s[x]]+dis[x] f[0...k−s[x]]+dis[x]来跟新答案即可。
此时我们需要求解 f [ 0... k − s [ x ] ] f[0...k-s[x]] f[0...k−s[x]]的最大值,我们可以用树状数组来维护前缀的最大值。
一些细节:因为有某一个点到根节点的路径,需要一开始将根节点加入树状数组;如果根节点是黑点,要计算两边,所以 k k k需要加上 1 1 1.
4.较复杂的DP计数问题(采药人的路径)
题目描述
采药人的药田是一个树状结构,每条路径上都种植着同种药材。
采药人以自己对药材独到的见解,对每种药材进行了分类。大致分为两类,一种是阴性的,一种是阳性的。
采药人每天都要进行采药活动。他选择的路径是很有讲究的,他认为阴阳平衡是很重要的,所以他走的一定是两种药材数目相等的路径。采药工作是很辛苦的,所以他希望他选出的路径中有一个可以作为休息站的节点(不包括起点和终点),满足起点到休息站和休息站到终点的路径也是阴阳平衡的。他想知道他一共可以选择多少种不同的路径。
题解
将一种路径的权值定为
1
1
1,一种路径的权值定为
−
1
-1
−1.
题目需要求解一条路径,使得权值和为 0 0 0,且能分成两半权值和也分别为 0 0 0.
我们设 f [ x ] [ 0 ] f[x][0] f[x][0]表示权值和为 x x x,不存在起点到根节点的路径中,某一个点的路径和为0的方案数; f [ x ] [ 1 ] f[x][1] f[x][1]表示存在。 g [ x ] [ 0 / 1 ] g[x][0/1] g[x][0/1]表示已经枚举过的子树中所对应的相同状态。
初始化: g [ 0 ] [ 0 ] = 1 g[0][0]=1 g[0][0]=1,表示起点的方案数。
然后对答案的贡献是:
f
(
x
,
0
)
∗
g
(
x
,
1
)
+
f
(
x
,
1
)
+
g
(
x
,
0
)
+
f
(
x
,
1
)
∗
g
(
x
,
1
)
f(x,0)*g(x,1)+f(x,1)+g(x,0)+f(x,1)*g(x,1)
f(x,0)∗g(x,1)+f(x,1)+g(x,0)+f(x,1)∗g(x,1).
当
j
=
0
j=0
j=0时:
f
(
0
,
0
)
+
g
(
0
,
0
)
f(0,0)+g(0,0)
f(0,0)+g(0,0)
每一次f数组求解时,在枚举每一个点直接累加即可;在每一次累加答案以后,让
g
(
x
,
1
)
g(x,1)
g(x,1)加上
f
(
x
,
1
)
f(x,1)
f(x,1),
g
(
x
,
0
)
g(x,0)
g(x,0)加上
f
(
x
,
0
)
f(x,0)
f(x,0)即可。
注意路径负数,加上偏移量。
代码1:
#include <cstdio>
#include <vector>
#include <cstring>
#include <limits.h>
#include <iostream>
#include <algorithm>
#define mp make_pair
using namespace std;
int n,k,ans,cnt,S,root;
int c[20000];
int vis[20000];
int dis[20000];
int Max[20000];
int size[20000];
vector < pair<int,int> > a[20000];
void clear(void)
{
n = k = ans = cnt = S = root = 0;
memset(c,0,sizeof c);
memset(vis,0,sizeof vis);
memset(dis,0,sizeof dis);
memset(Max,0,sizeof Max);
memset(size,0,sizeof size);
for (int i=0;i<20000;++i) a[i].clear();
return;
}
inline int read(void)
{
int s = 0, w = 1;char c = getchar();
while (c<'0' || c>'9') {if (c == '-') w = -1; c = getchar();}
while (c>='0' && c<='9') s = s*10+c-48,c = getchar();
return s*w;
}
void Find_root(int x,int fa)
{
size[x] = 1, Max[x] = 0;
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (y == fa || vis[y] == 1) continue;
Find_root(y,x);
size[x] += size[y];
Max[x] = max(Max[x], size[y]);
}
Max[x] = max(Max[x], S-size[x]);
if (Max[x] < Max[root]) root = x;
return;
}
void dfs(int x,int fa)
{
size[x] = 1;
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (y == fa) continue;
dfs(y,x);
size[x] += size[y];
}
return;
}
void find_dis(int x,int fa)
{
if (dis[x] > k) return;
c[++cnt] = dis[x];
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
int v = a[x][i].second;
if (y == fa || vis[y] == 1) continue;
dis[y] = dis[x]+v;
find_dis(y,x);
}
return;
}
int add(int x,int t)
{
dis[x] = t, cnt = 0;
find_dis(x,0);
sort(c+1, c+cnt+1);
int l = 1, r = cnt, sum = 0;
while (l < r)
{
if (c[l]+c[r] <= k) sum += r-l, l ++;
else r --;
}
return sum;
}
void solve(int x)
{
ans += add(x,0);
vis[x] = 1;
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
int v = a[x][i].second;
if (vis[y] == 1) continue;
ans -= add(y,v);
S = size[y], root = 0;
Find_root(y,0);
solve(root);
}
return;
}
void work(void)
{
clear();
n = read(), k = read();
if (n == 0 && k == 0) exit(0);
for (int i=1;i<n;++i)
{
int x = read(), y = read(), v = read();
a[x].push_back(mp(y,v));
a[y].push_back(mp(x,v));
}
Max[0] = INT_MAX, S = n, root = 0;
Find_root(1,0);
dfs(root,0);
solve(root);
printf("%d\n", ans);
}
int main(void)
{
freopen("poj1741_tree.in","r",stdin);
freopen("poj1741_tree.out","w",stdout);
while (1) work();
return 0;
}
代码2:
#include <cstdio>
#include <vector>
#include <cstring>
#include <limits.h>
#include <iostream>
#include <algorithm>
#define mp make_pair
using namespace std;
int n,k,ans,cnt,S,root,m,K;
int v[3000000];
int s[3000000];
int f[3000000];
int vis[3000000];
int dis[3000000];
int Max[3000000];
int size[3000000];
vector < pair<int,int> > a[3000000];
struct BIT
{
int S[500000];
#define lowbit(i) (i&-i)
void init(void)
{
for (int i=0;i<500000;++i)
S[i] = -1e15;
return;
}
int ask(int x)
{
int Max = -1e15;
for (int i=x;i>=1;i-=lowbit(i))
Max = max(Max,S[i]);
return Max;
}
void add(int pos,int v)
{
for (int i=pos;i<=n+10;i+=lowbit(i))
S[i] = max(S[i],v);
return;
}
void del(int x)
{
for (int i=x;i<=n+10;i+=lowbit(i))
S[i] = -1e15;
}
} tree;
inline int read(void)
{
int s = 0, w = 1;char c = getchar();
while (c<'0' || c>'9') {if (c == '-') w = -1; c = getchar();}
while (c>='0' && c<='9') s = s*10+c-48,c = getchar();
return s*w;
}
void get_root(int x,int fa)
{
size[x] = 1, Max[x] = 0;
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (y == fa || vis[y] == 1) continue;
get_root(y,x);
size[x] += size[y];
Max[x] = max(Max[x],size[y]);
}
Max[x] = max(Max[x],S-size[x]);
if (Max[x] < Max[root]) root = x;
return;
}
void get_size(int x,int fa)
{
size[x] = 1;
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (y == fa || vis[y] == 1) continue;
get_size(y,x);
size[x] += size[y];
}
return;
}
void get_dis(int x,int fa)
{
if (s[x] > K) return;
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
int V = a[x][i].second;
if (y == fa || vis[y] == 1) continue;
s[y] = s[x] + v[y];
dis[y] = dis[x] + V;
get_dis(y,x);
}
return;
}
void updata_ans(int x,int fa)
{
if (s[x] > K) return;
ans = max(ans, dis[x]+tree.ask(K-s[x]+1));
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (fa == y || vis[y] == 1) continue;
updata_ans(y,x);
}
return;
}
void updata(int x,int fa)
{
if (s[x] > K) return;
tree.add(s[x]+1,dis[x]);
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (fa == y || vis[y] == 1) continue;
updata(y,x);
}
return;
}
void dfs(int x,int fa)
{
if (s[x] > K) return;
tree.del(s[x]+1);
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (fa == y || vis[y] == 1) continue;
dfs(y,x);
}
return;
}
void solve(int x)
{
vis[x] = 1;
dis[x] = 0, s[x] = v[x];
K = v[x] ? k+1 : k;//如果根节点是黑点结果会被重复算两遍
get_dis(x,0);
tree.add(s[x]+1,0);
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (vis[y] == 1) continue;
updata_ans(y,0);
updata(y,0);
}
tree.del(s[x]+1);
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (vis[y] == 1) continue;
dfs(y,0);
}
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (vis[y] == 1) continue;
root = 0, S = size[y];
get_root(y,0);
solve(root);
}
return;
}
signed main(void)
{
freopen("freetourII.in","r",stdin);
freopen("freetourII.out","w",stdout);
n = read(), k = read(), m = read();
while (m --) v[read()] = 1;
for (int i=1;i<n;++i)
{
int x = read(), y = read(), v = read();
a[x].push_back(mp(y,v));
a[y].push_back(mp(x,v));
}
Max[0] = INT_MAX, root = 0, S = n, tree.init(), ans = 0;
get_root(1,0);
get_size(root,0);
solve(root);
cout<<ans<<endl;
return 0;
}
代码3:
#include <cstdio>
#include <vector>
#include <cstring>
#include <limits.h>
#include <iostream>
#include <algorithm>
#define mp make_pair
using namespace std;
const int N = 1200000;
int n,k,ans,cnt,S,root;
int s[N];
int f[N];
int vis[N];
int dis[N];
int Max[N];
int size[N];
vector < pair<int,int> > a[N];
inline int read(void)
{
int s = 0, w = 1;char c = getchar();
while (c<'0' || c>'9') {if (c == '-') w = -1; c = getchar();}
while (c>='0' && c<='9') s = s*10+c-48,c = getchar();
return s*w;
}
void get_root(int x,int fa)
{
size[x] = 1, Max[x] = 0;
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (y == fa || vis[y] == 1) continue;
get_root(y,x);
size[x] += size[y];
Max[x] = max(Max[x],size[y]);
}
Max[x] = max(S-size[x], Max[x]);
if (Max[x] < Max[root]) root = x;
return;
}
void get_size(int x,int fa)
{
size[x] = 1;
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (y == fa) continue;
get_size(y,x);
size[x] += size[y];
}
return;
}
void get_dis(int x,int fa)
{
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
int v = a[x][i].second;
if (y == fa || vis[y] == 1) continue;
dis[y] = dis[x]+v;
s[y] = s[x]+1;
get_dis(y,x);
}
return;
}
void updata_ans(int x,int fa)
{
if (dis[x] > k) return;
ans = min(ans,f[k-dis[x]]+s[x]);
if (dis[x] == k) ans = min(ans,s[x]);
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (y == fa || vis[y] == 1) continue;
updata_ans(y,x);
}
return;
}
void updata(int x,int fa)
{
if (dis[x] > k) return;
f[dis[x]] = min(f[dis[x]], s[x]);
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (y == fa || vis[y] == 1) continue;
updata(y,x);
}
return;
}
void dfs(int x,int fa)
{
if (dis[x] > k) return;
f[dis[x]] = 1e9;
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (y == fa || vis[y] == 1) continue;
dfs(y,x);
}
return;
}
void solve(int x)
{
vis[x] = 1, dis[x] = s[x] = 0, f[0] = 0;
get_dis(x,0);
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (vis[y] == 1) continue;
updata_ans(y,0);
updata(y,0);
}
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (vis[y] == 1) continue;
dfs(y,0);
}
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (vis[y] == 1) continue;
root = 0, S = size[y];
get_root(y,0);
solve(root);
}
return;
}
int main(void)
{
freopen("ioi2011-race.in","r",stdin);
freopen("ioi2011-race.out","w",stdout);
n = read(), k = read();
for (int i=1;i<n;++i)
{
int x = read()+1, y = read()+1, v = read();
a[x].push_back(mp(y,v));
a[y].push_back(mp(x,v));
}
ans = INT_MAX;
for (int i=0;i<=k;++i) f[i] = 1e9;
Max[0] = INT_MAX, root = 0, S = n;
get_root(1,0);
get_size(root,0);
solve(root);
cout<<(ans > n ? -1 : ans)<<endl;
return 0;
}
#include <cstdio>
#include <vector>
#include <cstring>
#include <limits.h>
#include <iostream>
#include <algorithm>
#define int long long
using namespace std;
int n,k,ans,S,root,m,K,Maxv;
int cnt[300000];
int vis[300000];
int dis[300000];
int Max[300000];
int deep[300000];
int f[300000][2];
int g[300000][2];
int size[300000];
vector < pair<int,int> > a[200000];
inline int read(void)
{
int s = 0, w = 1;char c = getchar();
while (c<'0' || c>'9') {if (c == '-') w = -1; c = getchar();}
while (c>='0' && c<='9') s = s*10+c-48,c = getchar();
return s*w;
}
void get_root(int x,int fa)
{
size[x] = 1, Max[x] = 0;
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (y == fa || vis[y] == 1) continue;
get_root(y,x);
size[x] += size[y];
Max[x] = max(Max[x], size[y]);
}
Max[x] = max(Max[x], S-size[x]);
if (Max[x] < Max[root]) root = x;
return;
}
void get_size(int x,int fa)
{
size[x] = 1;
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (y == fa || vis[y] == 1) continue;
get_size(y,x);
size[x] += size[y];
}
return;
}
#define f(i,j) (f[i+100000][j])
#define g(i,j) (g[i+100000][j])
#define cnt(i) (cnt[i+100000])
void updata(int x,int fa)
{
Maxv = max(Maxv,abs(deep[x]));
if (cnt(dis[x])) f(dis[x],1) ++;//路径上已经存在休息站
else f(dis[x],0) ++;
cnt(dis[x]) ++;
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
int v = a[x][i].second;
if (vis[y] == 1 || y == fa) continue;
dis[y] = dis[x] + v;
deep[y] = deep[x] + 1;
updata(y,x);
}
cnt(dis[x]) --;
return;
}
void solve(int x)
{
dis[x] = 0, vis[x] = 1, g(0,0) = 1, Maxv = 0;
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
int v = a[x][i].second;
if (vis[y] == 1) continue;
dis[y] = dis[x] + v, deep[y] = 1;
updata(y,0);
for (int j=-Maxv;j<=Maxv;++j)
{
if (j == 0) ans += f(j,0)*(g(-j,0)-1);
//以当前点为休息站 -1因为起点产生了重复
ans += f(j,0) * g(-j,1);
ans += f(j,1) * g(-j,0);
ans += f(j,1) * g(-j,1);
}
for (int j=-Maxv;j<=Maxv;++j)
{
g(j,0) += f(j,0);
g(j,1) += f(j,1);
f(j,0) = f(j,1) = 0;
}
}
for (int i=-Maxv;i<=Maxv;++i)
g(i,0) = g(i,1) = 0;
for (int i=0;i<a[x].size();++i)
{
int y = a[x][i].first;
if (vis[y] == 1) continue;
S = size[y], root = 0;
get_root(y,0);
solve(root);
}
return;
}
signed main(void)
{
freopen("yinyang.in","r",stdin);
freopen("yinyang.out","w",stdout);
n = read();
for (int i=1;i<n;++i)
{
int x = read(), y = read(), v = read();
a[x].push_back(make_pair(y,v?1:-1));
a[y].push_back(make_pair(x,v?1:-1));
}
Max[0] = INT_MAX;
root = 0, ans = 0, S = n;
get_root(1,0);
get_size(root,0);
solve(root);
cout<<ans<<endl;
return 0;
}