虚树,选择k个点,构成一个树,这个树只包含所选点和他们两两之间的lca,特点为这个虚树上的点不超过2*k个,若有多组k,则复杂度为2 * ∑ k。
注意事项:1.dis数组一般是距离,而dep数组是树的节点的深度,如果每条边的长度都为1则两个数组一样,如果不为1则dep数组是把每条边都看作长度为1.
2.求LCA要注意!!!dep[0]=-1。
3.得到k个关键点之后记得排序之后再建虚树!!!
4.记得清空!!!因为每个虚树最多2*k个节点,所以直接dfs删除时间复杂度是常数的。
5.虚树只保存了sv,没有保存边权,可以在用到边权的时候再求,以为x是y的父节点,所以很好搞,也不需要保存。
模板:
int tp,st[MAX_N];
vector<int>sv[MAX_N];//存放虚树,用完后clear,复杂度为2*k
int dep[MAX_N];//为树的深度
void build(int s,int count){//建虚树;count为所选点的个数,s为根
st[++tp]=s;
for(int i=1;i<=count;i++){
int u=a[i],y=st[tp],lca=LCA(u,y);
while(lca!=y){
tp--;
if(dfn[st[tp]]<dfn[lca])st[++tp]=lca;
sv[st[tp]].push_back(y);
sv[y].push_back(st[tp]);
y=st[tp];
}
if(u!=st[tp])st[++tp]=u;
}
while(tp){
int y=st[tp--];
if(!tp)break;
sv[st[tp]].push_back(y);
sv[y].push_back(st[tp]);
}
}
#include<iostream>
#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
const int MAX_N=201000;
vector<int>v[MAX_N],sv[MAX_N],dep[MAX_N];
long long a[MAX_N];
void add(int x,int y){
sv[x].push_back(y);
sv[y].push_back(x);
}
int dfn[MAX_N],cnt;
int dp[MAX_N][20],dis[MAX_N],n,log_2[MAX_N];
int max_deep=0;
void dfs(int now,int fa){
int i;
dfn[now]=++cnt;
dep[dis[now]].push_back(now);
for(i=0;i<v[now].size();i++){
int to=v[now][i];
if(to==fa)
continue;
dp[to][0]=now;
dis[to]=dis[now]+1;
max_deep=max(max_deep,dis[to]);
dfs(to,now);
}
}
int LCA(int a,int b){
int k=log_2[n],i;
if(dis[a]<dis[b]){
int c=a;
a=b;
b=c;
}
while(dis[a]!=dis[b]){
for(i=k;i>=0;i--){
if(dis[dp[a][i]]>=dis[b])
a=dp[a][i];
}
}
if(a==b){
return b;
}
for(i=k;i>=0;i--){
if(dp[a][i]!=dp[b][i]){
a=dp[a][i];
b=dp[b][i];
}
}
return dp[a][0];
}
int st[MAX_N],s;
int dist(int x,int y){
int lca=LCA(x,y);
return dis[x]+dis[y]-2*dis[lca];
}
long long dp1[MAX_N];
int now_deep;
void dfs1(int x,int fa){
int i;
if(dis[x]==now_deep)
dp1[x]=a[x];
else
dp1[x]=0;
for(i=0;i<sv[x].size();i++){
int y=sv[x][i];
if(y==fa)
continue;
long long w=dis[y]-dis[x];
//cout<<w<<"\n";
dfs1(y,x);
if(dp1[y]!=0)
dp1[x]+=max(dp1[y]-(w-1),(long long)1);
//cout<<dp1[x]<<" "<<x<<"\n";
}
if(dp1[x]>1)
dp1[x]-=1;
sv[x].clear();
}
bool cmp(int x,int y){
return dfn[x]<dfn[y];
}
int tp;
void build(int deep){//建虚树;
st[++tp]=s;
int count=dep[deep].size();
for(int i=0;i<count;i++){
int u=dep[deep][i],y=st[tp],lca=LCA(u,y);
while(lca!=y){
tp--;
if(dfn[st[tp]]<dfn[lca])st[++tp]=lca;
sv[st[tp]].push_back(y);
sv[y].push_back(st[tp]);
y=st[tp];
}
if(u!=st[tp])st[++tp]=u;
}
while(tp){
int y=st[tp--];
if(!tp)break;
sv[st[tp]].push_back(y);
sv[y].push_back(st[tp]);
}
}
int main(void){
int i,j,x,y;
scanf("%d%d",&n,&s);
for(i=1;i<=n;i++){
log_2[i]=log_2[i-1]+(1<<log_2[i-1]==i);
}
for(i=1;i<=n;i++)
scanf("%lld",&a[i]);
for(i=1;i<n;i++){
scanf("%d%d",&x,&y);
v[x].push_back(y);
v[y].push_back(x);
}
long long ans;
if(a[s]>1)
ans=a[s]-1;
else
ans=a[s];
dfs(s,0);
dis[0]=-1;
for(i=1;i<=log_2[n];i++){
for(j=1;j<=n;j++){
dp[j][i]=dp[dp[j][i-1]][i-1];
}
}
for(i=1;i<=max_deep;i++){
now_deep=i;
build(i);
dfs1(s,0);
ans+=dp1[s];
}
printf("%lld\n",ans);
return 0;
}
P2495 [SDOI2011]消耗战
o2优化过的,懒得再优化了。
#include<iostream>
#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
const int MAX_N=250100;
const long long INF=0x3f3f3f3f3f3f3f3f;
vector<int>v[MAX_N],sv[MAX_N];
vector<long long>w[MAX_N];
int dfn[MAX_N],cnt;
int dp[MAX_N][20],n,log_2[MAX_N];
long long dp_min[MAX_N][20];
long long dis[MAX_N];
int dep[MAX_N];//为树的深度
void dfs(int x,int fa){
int i;
dfn[x]=++cnt;
for(i=0;i<v[x].size();i++){
int y=v[x][i];
if(y==fa)
continue;
dis[y]=dis[x]+w[x][i];
dep[y]=dep[x]+1;
dp[y][0]=x;
dp_min[y][0]=w[x][i];
dfs(y,x);
}
}
int LCA(int a,int b){
int k=log_2[n],i;
if(dep[a]<dep[b]){
int c=a;
a=b;
b=c;
}
while(dep[a]!=dep[b]){
for(i=k;i>=0;i--){
if(dep[dp[a][i]]>=dep[b])
a=dp[a][i];
}
}
if(a==b){
return b;
}
for(i=k;i>=0;i--){
if(dp[a][i]!=dp[b][i]){
a=dp[a][i];
b=dp[b][i];
}
}
return dp[a][0];
}
int a[MAX_N];
int tp,st[MAX_N];
void build(int s,int count){//建虚树;count为所选点的个数,s为根
st[++tp]=s;
for(int i=1;i<=count;i++){
//cout<<a[i]<<" "<<i<<" a[i]\n";
int u=a[i],y=st[tp],lca=LCA(u,y);
//cout<<u<<" "<<y<<" "<<lca<<" #\n";
while(lca!=y){
tp--;
if(dfn[st[tp]]<dfn[lca])st[++tp]=lca;
sv[st[tp]].push_back(y);
sv[y].push_back(st[tp]);
//cout<<st[tp]<<" "<<y<<" !\n";
y=st[tp];
}
if(u!=st[tp])st[++tp]=u;
}
//cout<<"hi\n";
while(tp){
int y=st[tp--];
if(!tp)break;
sv[st[tp]].push_back(y);
sv[y].push_back(st[tp]);
//cout<<st[tp]<<" "<<y<<" !\n";
}
}
bool cmp(int x,int y){
return dfn[x]<dfn[y];
}
long long dp1[MAX_N];
long long dist(int x,int y){
int deep=dep[y]-dep[x],i;
long long ans=INF;
for(i=log_2[n];i>=0;i--){
if(dep[dp[y][i]]>=dep[x]){
ans=min(ans,dp_min[y][i]);
y=dp[y][i];
}
}
return ans;
}
void dfs1(int x,int fa){
int i;
for(i=0;i<sv[x].size();i++){
int y=sv[x][i];
if(y==fa)
continue;
long long cost=dist(x,y);
dfs1(y,x);
//cout<<cost<<" cost\n";
dp1[x]+=min(dp1[y],cost);
}
//cout<<x<<" "<<dp1[x]<<" %!\n";
}
void dfs2(int x,int fa){
int i;
for(i=0;i<sv[x].size();i++){
int y=sv[x][i];
if(y==fa)
continue;
dfs2(y,x);
}
sv[x].clear();
dp1[x]=0;
}
int main(void){
int i,j,x,y,m;
long long z;
scanf("%d",&n);
for(i=1;i<=n;i++){
log_2[i]=log_2[i-1]+(1<<log_2[i-1]==i);
}
for(i=1;i<n;i++){
scanf("%d%d%lld",&x,&y,&z);
v[x].push_back(y);
w[x].push_back(z);
v[y].push_back(x);
w[y].push_back(z);
}
dfs(1,0);
dep[0]=-1;
for(i=1;i<=log_2[n];i++){
for(j=1;j<=n;j++){
dp[j][i]=dp[dp[j][i-1]][i-1];
dp_min[j][i]=min(dp_min[j][i-1],dp_min[dp[j][i-1]][i-1]);
}
}
scanf("%d",&m);
int k;
for(i=1;i<=m;i++){
scanf("%d",&k);
for(j=1;j<=k;j++){
scanf("%d",&a[j]);
dp1[a[j]]=INF;
}
sort(a+1,a+k+1,cmp);
build(1,k);
dfs1(1,0);
printf("%lld\n",dp1[1]);
dfs2(1,0);
}
return 0;
}