SPOJ Free Tour 2
链接:http://www.spoj.com/problems/FTOUR2/
树上分治的经典题目。
每次找到这棵树的重心。
递归的处理子树。后合并处理整棵树。
对于合并的过程。记重心为
z
tmp[c][i] 表示不包括
c
为根的子树的节点。从z 出发。不超过
i
个节点的最远距离。
显然有了tmp[][] 数组后。合并是非常快的。只需要查表即可。
我们不需要直接求出 tmp 数组。而是通过不断的合并。
合并 tmp[c1][],tmp[c2][] ,复杂度线性。
等于节点数最多的那个子树。没有顺序的合并。复杂度 O(n2)
根据上面情况。我们只需要按子树结点个数之和从小到大的顺序合并子树。显然复杂度等于子树节点数量之和。总复杂度线性。
时间复杂度可以记为
T(n)=O(n)+∑i=1kT(ni)
影响 T(n) 大小的。显然于其递归层数有关。
这是因为。所有同层子问题合并到上一层合并总耗费都是 O(n)
令
h
是其最坏递归深度。
那么。T(n)=O(hn)
因为重心子树节点数不超过 n2
所以:
h=O(log n)
所以:
T(n)=O(nlog n)
#include <algorithm>
#include <string.h>
#include <stdio.h>
#include <vector>
#include <time.h>
#include <stdlib.h>
#include <queue>
#define MAXN 200005
#define MMM 500000
using namespace std;
const int INF=0x3f3f3f3f;
struct IO
{
char A[MMM],*L,*R;
IO()
{
L=R=A;
}
void IO_fread()
{
L=A;
R=A+fread(A,sizeof(char),MMM,stdin);
}
int read()
{
int a=0,on=1;
if(L==R)
{
IO_fread();
if(L==R)return a;
}
while(*L<'0'||*L>'9')
{
if(*L=='-')on=-1;
L++;
if(L==R)
{
IO_fread();
if(L==R)return a;
}
}
while(*L>='0'&&*L<='9')
{
a=a*10+(*L)-'0';
L++;
if(L==R)
{
IO_fread();
if(L==R)return a*on;
}
}
return a*on;
}
}Io;
struct edge
{
int to;
int next;
int w;
edge(int to=0,int next=0,int w=0):to(to),next(next),w(w){};
}E[MAXN*2];
int deep=1;
int inof[MAXN],ans=0;
void add(int a,int b,int w)
{
E[deep]=edge(b,inof[a],w);
inof[a]=deep++;
}
int cl[MAXN];
int size[MAXN];
int ph[MAXN];
int D[MAXN];
int C[MAXN];
int tmp[MAXN];
int Q[MAXN];
int ma[MAXN];
bool vis[MAXN];
struct node
{
int h;//
int w;
int next;
node()
{
h=w=next=-1;
}
}V[MAXN*2];
struct link
{
int head;
int sz;
link()
{
head=-1;
sz=0;
}
}Li[MAXN],Ti[MAXN];
struct QU
{
int l,r;
int A[MAXN*2+10];
int sz;
QU()
{
sz=MAXN*2+8;
l=0;
r=0;
}
int pop()
{
int ans=A[l++];
if(l>sz)l=0;
return ans;
}
void push(int a)
{
A[r++]=a;
if(r>sz)r=0;
}
}qu;
void addL(int a,int h,int w,int c)
{
int v=qu.pop();
V[v].h=h;
V[v].w=w;
V[v].next=Li[a].head;
Li[a].head=v;
if(Li[a].sz<h)Li[a].sz=h;
}
void clear(int a)
{
for(int i=Li[a].head ;i>-1;i=V[i].next) qu.push(i);
Li[a].head=-1;
Li[a].sz=0;
}
int BFS1(int root)
{
int z=root,l=0,r=1,n,m=INF;
Q[0]=z;
ph[root]=0;
while(l<r)
{
int v=Q[l++];
for(int i=inof[v];i;i=E[i].next)
{
edge &e=E[i];
if(e.to==ph[v]||vis[e.to])continue;
ph[e.to]=v;
Q[r++]=e.to;
}
}
if(r==1)return-1;
n=r;
while(r--)
{
int v=Q[r];
size[v]++;
if(ma[v]<n-size[v])ma[v]=n-size[v];
if(m>ma[v])
{
m=ma[v];
z=v;
}
if(ma[ph[v]]<size[v])ma[ph[v]]=size[v];
size[ph[v]]+=size[v];
size[v]=0;
ma[v]=0;
}
ma[0]=0;
size[0]=0;
return z;
}
void BFS2(int root,int w,int a,int k)
{
D[0]=w;
int l=0,r=1;
Q[0]=root;
C[0]=cl[root];
addL(a,C[0],D[0],cl[root]);
ph[root]=0;
while(l<r)
{
int v=Q[l];
for(int i=inof[v];i;i=E[i].next)
{
edge &e=E[i];
if(vis[e.to]||e.to==ph[v]||C[l]+cl[e.to]>k)continue;
ph[e.to]=v;
C[r]=C[l]+cl[e.to];
D[r]=D[l]+e.w;
Q[r]=e.to;
addL(a,C[r],D[r],cl[e.to]);
r++;
}
l++;
}
}
void DFS(int root,int k)
{
int z=BFS1(root);
if(z==-1)return;
vis[z]=true;
int cnt=0;
for(int i=inof[z];i;i=E[i].next)
{
edge &e=E[i];
if(vis[e.to])continue;
DFS(e.to,k);
}
int dep=0;
for(int i=inof[z];i;i=E[i].next)
{
edge &e=E[i];
if(vis[e.to]||k-cl[z]<cl[e.to])continue;
BFS2(e.to,e.w,dep,k-cl[z]);
if(cnt<Li[dep].sz)cnt=Li[dep].sz;
dep++;
}
memset(tmp,0,(cnt+1)*sizeof(int));
for(int i=0;i<dep;i++) tmp[Li[i].sz]++;
for(int i=1;i<=cnt;i++) tmp[i]+=tmp[i-1];
for(int i=dep-1;i>-1;i--) Ti[--tmp[Li[i].sz]]=Li[i];
for(int i=0;i<dep;i++) Li[i]=Ti[i];
for(int i=0;i<=cnt;i++) tmp[i]=-INF;
int sz=0;
for(int j=0;j<dep;j++)
{
for(int i=Li[j].head;i>-1;i=V[i].next)
{
int u=k-cl[z]-V[i].h;
if(u>sz)u=sz;
if(ans<tmp[u]+V[i].w)ans=tmp[u]+V[i].w;
if(ans<V[i].w)ans=V[i].w;
}
for(int i=Li[j].head;i>-1;i=V[i].next)
if(tmp[V[i].h]<V[i].w)tmp[V[i].h]=V[i].w;
for(int i=1;i<=Li[j].sz;i++)
if(tmp[i]<tmp[i-1]) tmp[i]=tmp[i-1];
sz=Li[j].sz;
clear(j);
}
vis[z]=false;
return;
}
int da[MAXN];
int W[MAXN];
int P[MAXN];
int main ()
{
for(int i=0;i<MAXN*2;i++)qu.push(i);
int n,m,k,a,b,c;
n=Io.read();
k=Io.read();
m=Io.read();
for(int i=0;i<m;i++)
{
a=Io.read();
cl[a]=1;
}
for(int i=1;i<n;i++)
{
a=Io.read();
b=Io.read();
c=Io.read();
add(a,b,c);
add(b,a,c);
}
DFS(1,k);
printf("%d\n",ans);
}