树上前缀和+LCA
暴力做法:
我们先把不删的sum维护出来,然后遍历跳过的点,假如a1,a2,a3,跳过2,那么答案就是sum-cost(a1,a2)-cost(a2,a3)+cost(a1,a3).
DFS暴力,下面是代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int k,n;
typedef pair<int,int> pii;
int a[100010];
vector<pii> edge[100010];
map<pii,ll> st;
bool dfs(int s,int u,int fa,int v,ll sum)
{
if(u==v)
{
st[{s,v}]=sum;
st[{v,s}]=sum;
return 1;
}
for(int i=0;i<edge[u].size();i++)
{
int son=edge[u][i].first;
if(fa==son) continue;
if(dfs(s,son,u,v,sum+edge[u][i].second)) return 1;
}
return 0;
}
int main()
{
cin>>n>>k;
for(int i=1;i<=n-1;i++)
{
int u,v,t;
cin>>u>>v>>t;
edge[u].push_back({v,t});
edge[v].push_back({u,t});
}
for(int i=1;i<=k;i++) cin>>a[i];
ll ans=0;
for(int i=1;i<=k;i++)
{
dfs(a[i],a[i],-1,a[i+1],0);
ans+=st[{a[i],a[i+1]}];
}
for(int i=1;i<=k;i++)
{
ll tp=ans;
if(i==1) tp-=st[{a[i],a[i+1]}];
if(i==k) tp-=st[{a[i-1],a[i]}];
if(i>1&&i<k)
{
tp-=st[{a[i-1],a[i]}]+st[{a[i],a[i+1]}];
dfs(a[i-1],a[i-1],-1,a[i+1],0);
tp+=st[{a[i-1],a[i+1]}];
}
cout<<tp<<" ";
}
}
正确做法:
我们先预处理出各个点到根节点的距离就是树上前缀和,答案就是sum[a1]+sum[a2]-2*sum[fa],fa为a1,a2的最近公共祖先,下面是LCA用倍增实现的板子:
https://www.luogu.com.cn/problem/P3379
#include<bits/stdc++.h>
using namespace std;
int n,m,s;
vector<int> edge[500010];
int dep[500010];
int fa[500010][22];
int maxd=21;
void dfs(int x,int fath)
{
if(x!=s)
{
fa[x][0]=fath;
dep[x]=dep[fath]+1;
for(int i=1;(1<<i)<=n;i++) fa[x][i]=fa[fa[x][i-1]][i-1];
}
for(int i=0;i<edge[x].size();i++)
{
int ck=edge[x][i];
if(ck==fath) continue;
dfs(ck,x);
}
}
int up(int x,int d){
int ret=x;
for(int i=0;(1<<i)<=n;i++)
{
if(((1<<i)&d)!=0) ret=fa[ret][i];
}
return ret;
}
int lca(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
x=up(x,dep[x]-dep[y]);
if(x==y) return x;
for(int i=maxd;i>=0;i--)
{
if(fa[x][i]!=fa[y][i])
{
x=fa[x][i],y=fa[y][i];
}
}
return fa[x][0];
}
int main()
{
cin>>n>>m>>s;
for(int i=1;i<=n-1;i++)
{
int x,y;
cin>>x>>y;
edge[x].push_back(y);
edge[y].push_back(x);
}
dep[s]=1;
dfs(s,-1);
while(m--)
{
int a1,b1;
cin>>a1>>b1;
cout<<lca(a1,b1)<<endl;
}
}
本题的AC代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int k,n;
typedef pair<int,ll> pii;
int a[100010];
vector<pii> edge[100010];
map<pii,ll> st;
ll sum[100010];
int dep[100010];
int fa[100010][22];
int maxd=21;
void dfss(int x,int fath)
{
if(x!=1)
{
fa[x][0]=fath;
dep[x]=dep[fath]+1;
for(int i=1;(1<<i)<=n;i++) fa[x][i]=fa[fa[x][i-1]][i-1];
}
for(int i=0;i<edge[x].size();i++)
{
int ck=edge[x][i].first;
if(ck==fath) continue;
dfss(ck,x);
}
}
int up(int x,int d){
int ret=x;
for(int i=0;(1<<i)<=n;i++)
{
if(((1<<i)&d)!=0) ret=fa[ret][i];
}
return ret;
}
int lca(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
x=up(x,dep[x]-dep[y]);
if(x==y) return x;
for(int i=maxd;i>=0;i--)
{
if(fa[x][i]!=fa[y][i])
{
x=fa[x][i],y=fa[y][i];
}
}
return fa[x][0];
}
void dfs(int x,int fa)
{
for(int i=0;i<edge[x].size();i++)
{
int ck=edge[x][i].first;
if(ck==fa) continue;
ll num=edge[x][i].second;
sum[ck]=sum[x]+num;
dfs(ck,x);
}
}
ll dis(int x,int y)
{
ll zhi=sum[x]+sum[y];
zhi-=2*sum[lca(x,y)];
return zhi;
}
int main()
{
cin>>n>>k;
for(int i=1;i<=n-1;i++)
{
int u,v,t;
cin>>u>>v>>t;
edge[u].push_back({v,t});
edge[v].push_back({u,t});
}
for(int i=1;i<=k;i++) cin>>a[i];
dep[1]=1;
sum[1]=0;
dfs(1,-1);
dfss(1,-1);
ll summ=0;
for(int i=1;i<=k-1;i++) summ+=dis(a[i],a[i+1]);
for(int i=1;i<=k;i++)
{
ll ans=summ;
if(i>1&&i<k)
{
ans-=dis(a[i-1],a[i])+dis(a[i],a[i+1]);
ans+=dis(a[i-1],a[i+1]);
}
if(i==1) ans-=dis(a[1],a[2]);
if(i==k) ans-=dis(a[k-1],a[k]);
cout<<ans<<" ";
}
}