题目描述:
无向连通图G 有n 个点,n - 1 条边。点从1 到n 依次编号,编号为 i 的点的权值为W i ,每条边的长度均为1 。图上两点( u , v ) 的距离定义为u 点到v 点的最短距离。对于图G 上的点对( u, v) ,若它们的距离为2 ,则它们之间会产生Wu*Wv的联合权值。
请问图G 上所有可产生联合权值的有序点对中,联合权值最大的是多少?所有联合权
值之和是多少?
【输入】
输入文件名为link .in。
第一行包含1 个整数n 。
接下来n - 1 行,每行包含 2 个用空格隔开的正整数u 、v ,表示编号为 u 和编号为v 的点之间有边相连。
最后1 行,包含 n 个正整数,每两个正整数之间用一个空格隔开,其中第 i 个整数表示图G 上编号为i 的点的权值为W i 。
【输出】
输出文件名为link .out 。
输出共1 行,包含2 个整数,之间用一个空格隔开,依次为图G 上联合权值的最大值和所有联合权值之和。由于所有联合权值之和可能很大,输出它时要对10007 取余。
解题思路:
说实话我不想说其他分的做法。。。(其实是我不会。。。。。)
这道题目有多重算法。。。。蒟蒻的我只会树形dp,。。。。。。。。
首先看一看数据,n<=200000,卧槽这TM坑爹啊,居然不能打邻接矩阵,所以果断邻接表搞起啊。。。。。。。。说实话链表这东西我无力吐槽了。。。。。。。。(邻接表不会写的大爷请自行查阅。。。。由于本人水平问题无法描述。。。。。)
解决了边的问题后,就开始进入正题。
事实上,虽说是树形dp,但是我们只需要进行一次dfs就行了,不需要真正的构树。我们只需要随便抓一个点作为根在进行树的遍历就行了。那么接下来我就默认以1为根节点。
因为是距离为2的有序点对,所以对于一个节点,它只有可能和它的爷爷、孙子、还有兄弟构成联合权值。
考虑好这些之后,我们就有一个大体的思路。那么怎么实现呢?
首先在遍历是我们需要在dfs时维护一些值:sum[i](i节点的所有儿子的和),max1[i](i节点的所有儿子的最大值),max2[i](i节点的所有儿子的次大值)我相信大家都会维护!
那么如何用这些值来求解ans_sum和ans_max呢?
对于ans_sum:
{
很简单,对于节点k与孙子构成的联合权值,我们只需要将k的儿子的sum值全部加起来存在q里,然后将ans_sum加上q*W[k]*2,这样我们就可以得到k与孙子构成的所有有序联合权值的和,那么和他的爷爷呢?显然我们会在对其爷爷进行操作时考虑到。
然后对于与兄弟构成的联合点对,我们需要借助其父节点来操作,对于父节点k,我们已经用sum[k]存下了其所有子节点的和,那么根据乘法结合律,父节点所有的儿子构成的有序联合权值总和就等于Σ(sum[k]-W[i])*W[i](i为k的儿子),为什么请自己脑补(其实是我不会证明),这是后将ans_sum加上求得的总和就行了
}
对于ans_max:
{
也很简单,对于节点k,我们只需要考虑他和儿子联合权值的最大值,他的儿子中的联合权值的最大值,所以ans_max=MAX(ans_max,W[k]*max1[i] (i为k的儿子),max1[k]*max2[k])
}
吐槽:
话说我考试的之后自己作死,然后在windows下开了一个一条链的极限数据,然后递归暴栈了,我就果断人工栈搞起,但是事后发现我同学写的dfs,他在考场上写人工栈没写出来,然后直接交的暴力,但是TMD后来发现linux下不会暴栈,艹,壮哉我大linux。。。。。。下面附上我的代码和同学的代码
代码1(本人的人工栈的丑代码):
//本人是淳朴的C党
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#define MAX(a,b) (a>b?a:b)
#define MIN(a,b) (a>b?b:a)
struct tree
{
int num;
struct tree *p;
}*h[200010]={NULL},*t[200010]={NULL};
int ans=0;
int w[200010]={0};
int max=-2e9;
int n;
int f[200010]={0};
int maxn[200010][4]={0};
int hash[200010]={0};
int zhan[200010]={0};
struct tree * zhanp[200010]={0};
int top=0;
void dfs()
{
struct tree *o;
struct tree *l;
zhan[++top]=1;
zhanp[top]=h[1];
for(;top>0;)
{
hash[zhan[top]]=1;
for(;zhanp[top]!=NULL;zhanp[top]=zhanp[top]->p)
{
if(hash[zhanp[top]->num]==1)
continue;
zhan[top+1]=zhanp[top]->num;
top++;
zhanp[top]=h[zhan[top]];
break;
}
if(zhanp[top]==NULL && top>1)
{
f[zhan[top-1]]=(long long)(f[zhan[top-1]]+w[zhan[top]])%10007;
if(maxn[zhan[top-1]][1]==0)
maxn[zhan[top-1]][1]=w[zhan[top]];
else
{
if(w[zhan[top]]>maxn[zhan[top-1]][1])
{
maxn[zhan[top-1]][2]=maxn[zhan[top-1]][1];
maxn[zhan[top-1]][1]=w[zhan[top]];
}
else if(w[zhan[top]]>maxn[zhan[top-1]][2])
maxn[zhan[top-1]][2]=w[zhan[top]];
}
max=MAX(max,w[zhan[top-1]]*maxn[zhan[top]][1]);
ans=(long long)(ans+(long long)f[zhan[top]]*w[zhan[top-1]]%10007*2%10007)%10007;
}
if(zhanp[top]==NULL)
{
max=MAX(max,maxn[zhan[top]][1]*maxn[zhan[top]][2]);
for(l=h[zhan[top]];l!=NULL;l=l->p)
if(l->num!=zhan[top-1])
ans=(long long)(ans+(long long)(f[zhan[top]]+10007-w[l->num])%10007*w[l->num]%10007)%10007;
zhan[top]=0;
zhanp[top]=NULL;
top--;
}
}
return;
}
int main()
{
int i,j,p,q;
freopen("link.in","r",stdin);
freopen("link.out","w",stdout);
scanf("%d",&n);
for(i=1;i<n;i++)
{
scanf("%d%d",&j,&q);
if(h[j]==NULL)
{
h[j]=(struct tree*)malloc(sizeof(struct tree));
t[j]=h[j];
}
else
{
t[j]->p=(struct tree*)malloc(sizeof(struct tree));
t[j]=t[j]->p;
}
t[j]->p=NULL;
t[j]->num=q;
if(h[q]==NULL)
{
h[q]=(struct tree*)malloc(sizeof(struct tree));
t[q]=h[q];
}
else
{
t[q]->p=(struct tree*)malloc(sizeof(struct tree));
t[q]=t[q]->p;
}
t[q]->p=NULL;
t[q]->num=j;
}
for(i=1;i<=n;i++)
scanf("%d",&w[i]);
dfs();
printf("%d %d",max,ans%10007);
fclose(stdin);
fclose(stdout);
return 0;
}
代码2(我同学的深搜代码,来自"期待变成大神的owaski"):
//果然是大神,写的C++,蒟蒻的我。。。。
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#define MAX(a, b) (a>b?a:b)
using namespace std;
const int MAXN = 200005, MOD = 10007;
int n = 0;
int w[MAXN] = {0};
long long sumd[MAXN] = {0}, maxd[MAXN] = {0};
long long ansmax[MAXN] = {0}, anssum[MAXN] = {0};
int hash[MAXN] = {0};
struct biao
{
int node;
biao *nxt;
}*edge[MAXN] = {NULL}, *t[MAXN] = {NULL};
void add(int u, int v)
{
biao *tmp;
tmp = (biao *)malloc(sizeof(biao));
tmp->node = v;
tmp->nxt = NULL;
if(edge[u] == NULL)
edge[u] = t[u] = tmp;
else
{
t[u]->nxt = tmp;
t[u] = t[u]->nxt;
}
}
int top = 0;
void find(int cur)
{
long long maxd2 = 0;
long long maxdd = 0, sumdd = 0, flag = 0;
maxd2 = maxdd = sumdd = flag = 0;
hash[cur] = ++top;
for(biao *i = edge[cur]; i != NULL; i = i->nxt)
if(!hash[i->node])
{
find(i->node);
sumd[cur] += (long long)w[i->node];
sumd[cur] %= MOD;
if((long long)w[i->node] > maxd[cur])
{
maxd[cur] = (long long)w[i->node];
flag = i->node;
}
maxdd = MAX(maxd[i->node], maxdd);
sumdd += sumd[i->node];
sumdd %= MOD;
ansmax[cur] = MAX(ansmax[i->node], ansmax[cur]);
anssum[cur] += anssum[i->node];
anssum[cur] %= MOD;
}
for(biao *i = edge[cur]; i != NULL; i = i->nxt)
if(hash[i->node] > hash[cur])
{
if(i->node != flag && maxd2 < (long long)w[i->node])
maxd2 = (long long)w[i->node];
anssum[cur] += (long long)w[i->node]*(sumd[cur]-(long long)w[i->node]);
anssum[cur] %= MOD;
}
ansmax[cur] = MAX(ansmax[cur], MAX(maxdd*(long long)w[cur], maxd2*maxd[cur]));
anssum[cur] += (sumdd*(long long)w[cur])<<1;
anssum[cur] %= MOD;
}
int main()
{
freopen("link.in", "r", stdin);
freopen("link.out", "w", stdout);
scanf("%d", &n);
for(int i = 1; i < n; ++i)
{
int u = 0, v = 0;
scanf("%d%d", &u, &v);
add(u, v);add(v, u);
}
for(int i = 1; i <= n; ++i) scanf("%d", w+i);
find(1);
printf("%d %d", (int)ansmax[1], (int)(anssum[1]%MOD));
fclose(stdin);
fclose(stdout);
return 0;
}