问你一棵树里,删除一条边之后,剩下的两个树里面最长的是多少
然后n−1种情况的都加起来
有两种方法,第一种就是正常的树形dp,dfs两次,分类讨论维护一些东西
对于删除u−>v这条边,有如下情况
1.最长的经过u
2.最长的不经过u
经过u,那么就是要么是u的两条不经过v的最长的往下的,要么就是一条最长的,外加一条从上面到达u的最长的
不经过u,就是u的儿子里最大的直径,或者是u的祖先里面最大的,以及u的兄弟里最大的
维护几条太烦了,multiset水过
代码:
#include <map>
#include <set>
#include <stack>
#include <queue>
#include <cmath>
#include <string>
#include <vector>
#include <cstdio>
#include <cctype>
#include <bitset>
#include <cstring>
#include <sstream>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#pragma comment(linker,"/STACK:102400000,102400000")
using namespace std;
#define MAX 100005
#define MAXN 1000005
#define maxnode 15
#define sigma_size 30
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define lrt rt<<1
#define rrt rt<<1|1
#define middle int m=(r+l)>>1
#define LL long long
#define ull unsigned long long
#define mem(x,v) memset(x,v,sizeof(x))
#define lowbit(x) (x&-x)
#define pii pair<int,int>
#define bits(a) __builtin_popcount(a)
#define mk make_pair
#define limit 10000
//const int prime = 999983;
const int INF = 0x3f3f3f3f;
const LL INFF = 0x3f3f;
const double pi = acos(-1.0);
const double inf = 1e18;
const double eps = 1e-4;
const LL mod = 1e9+7;
const ull mx = 133333331;
/*****************************************************/
inline void RI(int &x) {
char c;
while((c=getchar())<'0' || c>'9');
x=c-'0';
while((c=getchar())>='0' && c<='9') x=(x<<3)+(x<<1)+c-'0';
}
/*****************************************************/
struct Edge{
int v,next,c;
}edge[MAX*2];
int head[MAX];
int dp[MAX][2];
int ans[MAX];
int tot;
LL ret;
void init(){
mem(head,-1);
tot=0;
}
void add_edge(int a,int b,int c){
edge[tot]=(Edge){b,head[a],c};
head[a]=tot++;
}
void dfs1(int u,int fa){
ans[u]=0;
dp[u][0]=dp[u][1]=0;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v==fa) continue;
dfs1(v,u);
ans[u]=max(ans[u],ans[v]);
int temp=dp[v][0]+edge[i].c;
if(temp>dp[u][0]){
swap(dp[u][0],dp[u][1]);
dp[u][0]=temp;
}
else if(temp>dp[u][1]) dp[u][1]=temp;
}
ans[u]=max(ans[u],dp[u][0]+dp[u][1]);
//cout<<ans[u]<<" "<<u<<endl;
}
void dfs2(int u,int fa,int d,int c){
int tmp=0;
multiset<int> s,ss;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v==fa) continue;
s.insert(dp[v][0]+edge[i].c);
tmp=max(tmp,ans[v]);
ss.insert(ans[v]);
}
tmp=max(d,tmp);
tmp=max(c,tmp);
//cout<<tmp<<" "<<u<<endl;
for(int i=head[u];i!=-1;i=edge[i].next){
int kk=tmp;
int v=edge[i].v;
if(v==fa) continue;
multiset<int>::iterator it=s.find(dp[v][0]+edge[i].c);
multiset<int>::iterator iit=ss.find(ans[v]);
s.erase(it);
ss.erase(iit);
it=s.end();
iit=ss.end();
if(s.empty()){
ret+=kk;
//cout<<ret<<endl;
dfs2(v,u,d,c+edge[i].c);
}
else{
iit--;
it--;
int temp=0;
int x=*it;
temp+=*it;
kk=max(kk,temp+c);
if(it==s.begin()){
ret+=kk;
//cout<<v<<endl;
dfs2(v,u,max(max(d,temp+c),*iit),max(c+edge[i].c,*it+edge[i].c));
}
else{
int xx=*it;
int xxx=temp+c;
it--;
temp+=*it;
kk=max(kk,temp);
ret+=kk;
dfs2(v,u,max(max(d,temp),max(*iit,xxx)),max(c+edge[i].c,xx+edge[i].c));
}
}
s.insert(dp[v][0]+edge[i].c);
ss.insert(ans[v]);
}
}
int main(){
//freopen("in.txt","r",stdin);
int t;
cin>>t;
while(t--){
int n;
cin>>n;
init();
for(int i=1;i<n;i++){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
add_edge(a,b,c);
add_edge(b,a,c);
}
dfs1(1,-1);
ret=0;
dfs2(1,-1,0,0);
cout<<ret<<endl;
}
return 0;
}
方法2是题解的方法,感觉很不错诶
先找一个直径a−b,如果删除的不是直径上的边,那么很显然
如果删除的是直径上的边u−v,那么就是u为根的子树里的直径和v为根的子树里的最大的
如何处理这两个子树里的直径呢,只要从a开始dfs一遍,b开始dfs一遍预处理即可
代码:
#include <map>
#include <set>
#include <stack>
#include <queue>
#include <cmath>
#include <string>
#include <vector>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <sstream>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#pragma comment(linker,"/STACK:102400000,102400000")
using namespace std;
#define MAX 100005
#define MAXN 1000005
#define maxnode 15
#define sigma_size 30
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define lrt rt<<1
#define rrt rt<<1|1
#define middle int m=(r+l)>>1
#define LL long long
#define ull unsigned long long
#define mem(x,v) memset(x,v,sizeof(x))
#define lowbit(x) (x&-x)
#define pii pair<int,int>
#define bits(a) __builtin_popcount(a)
#define mk make_pair
#define limit 10000
//const int prime = 999983;
const int INF = 0x3f3f3f3f;
const LL INFF = 0x3f3f;
const double pi = acos(-1.0);
const double inf = 1e18;
const double eps = 1e-4;
const LL mod = 1e9+7;
const ull mx = 133333331;
/*****************************************************/
inline void RI(int &x) {
char c;
while((c=getchar())<'0' || c>'9');
x=c-'0';
while((c=getchar())>='0' && c<='9') x=(x<<3)+(x<<1)+c-'0';
}
/*****************************************************/
struct Edge{
int v,next,c;
}edge[MAX*2];
int head[MAX];
int dp[MAX][2];
int nxt[MAX][2];
int pre[MAX];
int max1[MAX];
int max2[MAX];
int tot;
struct Node{
int u,v;
}tmp[MAX];
void init(){
mem(head,-1);
tot=0;
}
void add_edge(int a,int b,int c){
edge[tot]=(Edge){b,head[a],c};
head[a]=tot++;
}
void dfs(int u,int fa){
pre[u]=fa;
dp[u][0]=dp[u][1]=0;
nxt[u][0]=nxt[u][1]=u;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v==fa) continue;
dfs(v,u);
if(dp[u][0]<dp[v][0]+edge[i].c){
swap(dp[u][0],dp[u][1]);
swap(nxt[u][0],nxt[u][1]);
dp[u][0]=dp[v][0]+edge[i].c;
nxt[u][0]=nxt[v][0];
}
else if(dp[u][1]<dp[v][0]+edge[i].c){
dp[u][1]=dp[v][0]+edge[i].c;
nxt[u][1]=nxt[v][0];
}
}
}
void dfs1(int u,int fa){
max1[u]=0;
dp[u][0]=dp[u][1]=0;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v==fa) continue;
dfs1(v,u);
max1[u]=max(max1[u],max1[v]);
if(dp[u][0]<dp[v][0]+edge[i].c){
swap(dp[u][0],dp[u][1]);
dp[u][0]=dp[v][0]+edge[i].c;
}
else if(dp[u][1]<dp[v][0]+edge[i].c) dp[u][1]=dp[v][0]+edge[i].c;
}
max1[u]=max(max1[u],dp[u][0]+dp[u][1]);
}
void dfs2(int u,int fa){
max2[u]=0;
dp[u][0]=dp[u][1]=0;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v==fa) continue;
dfs2(v,u);
max2[u]=max(max2[u],max2[v]);
if(dp[u][0]<dp[v][0]+edge[i].c){
swap(dp[u][0],dp[u][1]);
dp[u][0]=dp[v][0]+edge[i].c;
}
else if(dp[u][1]<dp[v][0]+edge[i].c) dp[u][1]=dp[v][0]+edge[i].c;
}
max2[u]=max(max2[u],dp[u][0]+dp[u][1]);
}
int main(){
//freopen("in.txt","r",stdin);
int t;
cin>>t;
while(t--){
int n;
cin>>n;
init();
for(int i=1;i<n;i++){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
add_edge(a,b,c);
add_edge(b,a,c);
}
dfs(1,-1);
int ret=0,u,v,w;
for(int i=1;i<=n;i++){
if(ret<dp[i][0]+dp[i][1]){
ret=dp[i][0]+dp[i][1];
u=nxt[i][0];
v=nxt[i][1];
w=i;
}
}
//cout<<u<<" "<<v<<endl;
//cout<<ret<<endl;
dfs1(u,-1);
dfs2(v,-1);
int cnt=0;
while(u!=w){
tmp[cnt++]=(Node){u,pre[u]};
u=pre[u];
}
while(v!=w){
tmp[cnt++]=(Node){pre[v],v};
v=pre[v];
}
LL ans=(LL)(n-1-cnt)*ret;
for(int i=0;i<cnt;i++){
ans+=max(max1[tmp[i].v],max2[tmp[i].u]);
//cout<<max1[tmp[i].v]<<" "<<max2[tmp[i].u]<<endl;
}
cout<<ans<<endl;
}
return 0;
}