题目链接:点击打开链接
一开始没看数据范围傻傻的用O(n^2)的算法怼了一波于是就妥妥的吃了一发TLE。。so sad。。
按照这道题的规模需要使用O(n)的方法,那么很明显就是用树形dp了。
首先随便取个点拉成颗有根树,这里取的是点1。
首先,我们将要去的目的地称为【指定点】,设其所有权值和为P。
要求的值可以看做是Σ【所有指定点】(当前点到一个指定点的路径权值和*指定点权值*2),然后将所有点求一遍再找最小值。
那么,将这个问题放在已经拉好的有根树里,当前点要求的Σ就可以分为两个部分:
1. 指定点为自己的后代们的时候的Σ
2. 指定点不是自己的后代们的时候的Σ
第一个Σ记为dp1[i],且设根为j的子树的指定点权值和为p[i],则有
dp1[i]=Σ(dp1[son]+p[son]*edge[i][son]),其中edge[i][son]为点i到这个儿子的边的权值。
接着是求第二个Σ。记为dp2[i],则有
dp2[i]=(dp1[prt]+dp2[prt]-dp1[i])+(P-p[i])*edge[i][prt],其中prt为i的父亲。
可以这么理解,dp2[i]的来源就是从父亲开始的【外面】的结点往下遍历分配给点i的。
那么每个点的解就是dp1[i]+dp2[i]了,取最小值并遍历一遍取出所有最小值点即可。
代码如下:
#include<bits/stdc++.h>
using namespace std;
struct edge
{
int u,v;
int w;
edge(int uu=0,int vv=0,int ww=0):u(uu),v(vv),w(ww){}
};
vector<edge> e[50005];
bool vis[50005];
long long ans[50005];
long long Dans[50005],Dp[50005];
long long a[50005];
long long Solve(int pos)
{
vis[pos]=true;
vector<long long> v,ret;
long long Ret=0;
for(int i=0;i<e[pos].size();i++)
{
int &p=e[pos][i].v;
if(vis[p])continue;
v.push_back(Solve(p));
ret.push_back(ans[p]);
ans[pos]+=(v[v.size()-1]*e[pos][i].w+ret[ret.size()-1]);
Ret+=v[v.size()-1];
}
Ret+=a[pos];
int cnt=0;
for(int i=0;i<e[pos].size();i++)
{
int &p=e[pos][i].v;
if(vis[p])continue;
Dp[p]=Ret-v[cnt];
Dans[p]=ans[pos]-(v[cnt]*e[pos][i].w+ret[cnt])+Dp[p]*e[pos][i].w;
cnt++;
}
vis[pos]=false;
return Ret;
}
void dfs(int pos)
{
vis[pos]=true;
ans[pos]+=Dans[pos];
for(int i=0;i<e[pos].size();i++)
{
int &p=e[pos][i].v;
if(vis[p])continue;
Dans[p]+=Dans[pos];
Dans[p]+=Dp[pos]*e[pos][i].w;
Dp[p]+=Dp[pos];
dfs(p);
}
vis[pos]=false;
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++)
e[i].clear();
for(int i=1;i<n;i++)
{
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
e[u].push_back(edge(u,v,w));
e[v].push_back(edge(v,u,w));
}
int m;
scanf("%d",&m);
memset(ans,0,sizeof(ans));
memset(a,0,sizeof(a));
memset(Dp,0,sizeof(Dp));
memset(Dans,0,sizeof(Dans));
for(int i=0;i<m;i++)
{
int p,w;
scanf("%d%d",&p,&w);
a[p]=w;
}
Solve(1);
dfs(1);
// for(int i=1;i<=n;i++)
// printf("%lld %lld %lld\n",ans[i],Dans[i],Dp[i]);
int mx=1;
for(int i=2;i<=n;i++)
if(ans[mx]>ans[i])mx=i;
printf("%lld\n",ans[mx]*2);
bool flag=false;
for(int i=1;i<=n;i++)
{
if(ans[i]>ans[mx])continue;
if(flag)printf(" ");
else flag=true;
printf("%d",i);
}
printf("\n");
}
return 0;
}
【这次有return 0了!】