YbtOJ「动态规划」第4章 树形DP

YbtOJ 大全

【例题1】树上求和

f [ u ] [ 0 ] f[u][0] f[u][0] 表示这个点不选, f [ u ] [ 1 ] f[u][1] f[u][1] 表示这个点选。

那么转移方程 f [ u ] [ 0 ] + = max ⁡ ( f [ v ] [ 0 ] , f [ v ] [ 1 ] ) f[u][0] += \max(f[v][0],f[v][1]) f[u][0]+=max(f[v][0],f[v][1]) f [ u ] [ 1 ] = f [ v ] [ 0 ] + a [ u ] f[u][1] = f[v][0] + a[u] f[u][1]=f[v][0]+a[u]。最后的答案就是 a n s = max ⁡ ( a n s , max ⁡ ( f [ u ] [ 0 ] , f [ u ] [ 1 ] ) ) ans = \max(ans,\max(f[u][0],f[u][1])) ans=max(ans,max(f[u][0],f[u][1]))

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#define re register
#define int long long 
#define drep(a,b,c) for(re int a(b) ; a>=(c) ; --a)
#define rep(a,b,c) 	for(re int a(b) ; a<=(c) ; ++a)
using namespace std;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch == '-') f=-1 ; ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
inline void print(int x){
	if(x < 0) putchar('-'),x = -x;
	if(x >= 10) print(x / 10);
	putchar(x % 10 + '0');
}
const int M = 2e4+10;
int head[M],f[M][2],a[M];
int n,cnt,ans = 1;
struct edge{
	int to,nxt;
}e[M];
inline void add(int u,int v){
	e[++cnt].to = v;
	e[cnt].nxt = head[u];
	head[u] = cnt;
}
inline void dfs(int u,int fa){
	f[u][1] = a[u],f[u][0] = 0;
	for(re int i(head[u]) ; i ; i=e[i].nxt){
		int v = e[i].to;
		if(v == fa) continue;
		dfs(v,u);
		f[u][0] += max(f[v][0],f[v][1]);
		f[u][1] = f[v][0] + a[u];
		ans = max(ans,max(f[u][0],f[u][1]));
	}
}
signed main(){
	n = read();
	rep(i,1,n) a[i] = read();
	rep(i,1,n-1){
		int u = read(),v = read();
		add(v,u),add(u,v);
	}
	dfs(1,0);
	printf("%d\n",ans);
	return 0;
}

【例题2】结点覆盖

f [ u ] [ 0 ] f[u][0] f[u][0] u u u 的父亲结点, f [ u ] [ 1 ] f[u][1] f[u][1] 表示选 u u u 自己, f [ u ] [ 2 ] f[u][2] f[u][2] 表示选 u u u 的儿子。

  • u u u 被其父结点覆盖,那么子结点只能被其子结点或其本身覆盖。 f [ u ] [ 0 ] + = min ⁡ ( f [ v ] [ 1 ] , f [ v ] [ 2 ] ) f[u][0] += \min(f[v][1],f[v][2]) f[u][0]+=min(f[v][1],f[v][2])
  • u u u 被其本身覆盖,那它的子结点可以被其父亲或自己或孩子覆盖。 f [ u ] [ 1 ] + = min ⁡ ( f [ v ] [ 0 ] , min ⁡ ( f [ v ] [ 1 ] , f [ v ] [ 2 ] ) ) f[u][1] += \min(f[v][0],\min(f[v][1],f[v][2])) f[u][1]+=min(f[v][0],min(f[v][1],f[v][2]))
  • u u u 被其子结点覆盖。由于至少要选一个子结点,其余的子结点可以被其子结点或本身覆盖。我们用一个变量记录是否选了 u u u 的一个子结点。如果都没有选,只能选最小的 v v v 来覆盖 u u u。具体转移看代码。
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#define re register
#define drep(a,b,c) for(re int a(b) ; a>=(c) ; --a)
#define rep(a,b,c) 	for(re int a(b) ; a<=(c) ; ++a)
using namespace std;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch == '-') f=-1 ; ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
inline void print(int x){
	if(x < 0) putchar('-'),x = -x;
	if(x >= 10) print(x / 10);
	putchar(x % 10 + '0');
}
const int M = 3010;
int f[M][3],a[M],head[M];
int n,cnt;
struct edge{
	int to,nxt;
}e[M];
inline void add(int u,int v){
	e[++cnt].to = v;
	e[cnt].nxt = head[u];
	head[u] = cnt;
}
inline void dfs(int u,int fa){
	int fl = 0,mi = 1e9;
	f[u][0] = 0,f[u][1] = a[u],f[u][2] = 0;
	for(re int i(head[u]) ; i ; i=e[i].nxt){
		int v = e[i].to;
		if(v == fa) continue;
		dfs(v,u);
		f[u][0] += min(f[v][1],f[v][2]);
		f[u][1] += min(f[v][0],min(f[v][1],f[v][2]));
		if(f[v][1] < f[v][2]) fl = 1,f[u][2] += f[v][1];
		else f[u][2] += f[v][2],mi = min(mi,f[v][1]-f[v][2]);
	}
	if(fl == 0) f[u][2] += mi;
}
signed main(){
	n = read();
	rep(i,1,n){
		int u = read();
		a[u] = read();
		int m = read();
		rep(i,1,m){
			int v = read();
			add(u,v),add(v,u);
		}
	}
	dfs(1,0);
	printf("%d\n",min(f[1][1],f[1][2]));
	return 0;
}

【例题3】最长距离

求出树的直径,然后从直径的两端分别进行一次 d f s dfs dfs 算出离每一个点的距离,最后的答案就是二者的 max ⁡ \max max

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#define re register
#define int long long
#define drep(a,b,c) for(re int a(b) ; a>=(c) ; --a)
#define rep(a,b,c) 	for(re int a(b) ; a<=(c) ; ++a)
using namespace std;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch == '-') f=-1 ; ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
inline void print(int x){
	if(x < 0) putchar('-'),x = -x;
	if(x >= 10) print(x / 10);
	putchar(x % 10 + '0');
}
const int M = 2e4+10;
int head[M],dis1[M],dis2[M];
int n,cnt,s,maxlen;
struct edge{
	int to,nxt,w;
}e[M];
inline void add(int u,int v,int w){
	e[++cnt].to = v;
	e[cnt].w = w;
	e[cnt].nxt = head[u];
	head[u] = cnt;
}
inline void dfs1(int u,int fa){
	if(dis1[u] > maxlen){
		maxlen = dis1[u];
		s = u;
	}
	for(re int i(head[u]) ; i ; i=e[i].nxt){
		int v = e[i].to;
		if(v == fa) continue;
		dis1[v] = dis1[u] + e[i].w;
		dfs1(v,u);
	}
}
inline void dfs2(int u,int fa){
	if(dis1[u] > maxlen){
		maxlen = dis1[u];
		s = u;
	}
	for(re int i(head[u]) ; i ; i=e[i].nxt){
		int v = e[i].to;
		if(v == fa) continue;
		dis1[v] = dis1[u] + e[i].w;
		dfs2(v,u);
	}
}
inline void dfs3(int u,int fa){
	if(dis2[u] > maxlen){
		maxlen = dis2[u];
		s = u;
	}
	for(re int i(head[u]) ; i ; i=e[i].nxt){
		int v = e[i].to;
		if(v == fa) continue;
		dis2[v] = dis2[u] + e[i].w;
		dfs3(v,u);
	}
}
signed main(){
	while(cin >> n){
		memset(dis1,0,sizeof(dis1));
		memset(dis2,0,sizeof(dis2));
		memset(head,0,sizeof(head));
		memset(e,0,sizeof(e));
		cnt = 0;
		maxlen = 0;
		s = 0;
		rep(i,2,n){
			int u = read(),w = read();
			add(u,i,w),add(i,u,w);
		}
		dfs1(1,0);
		memset(dis1,0,sizeof(dis1));
		maxlen = 0;
		dfs2(s,0);
		maxlen = 0;
		dfs3(s,0);
		rep(i,1,n) printf("%lld\n",max(dis1[i],dis2[i]));
	}
	return 0;
}

【例题4】选课方案

树上背包。设 f [ u ] [ t ] f[u][t] f[u][t] 表示在以 u u u 为根的子树中考虑前 t t t 个结点的最大学分。

由于这道题会形成森林,所以我们将每一门没有先修课的课程向 0 0 0 连边,我们强制选 0 0 0 号结点,也就是最后的答案就是 f [ 0 ] [ m + 1 ] f[0][m+1] f[0][m+1]

在考虑到 v v v 子树时,我们枚举一个 k k k,表示在以 v v v 为根的子树中选 k k k 门课程,那么就需要在 u u u 为根的其它子树中选 t − k t-k tk 门课程。

转移方程 f [ u ] [ j ] = max ⁡ ( f [ u ] [ j ] , f [ v ] [ k ] + f [ u ] [ j − k ] ) f[u][j] = \max(f[u][j],f[v][k]+f[u][j-k]) f[u][j]=max(f[u][j],f[v][k]+f[u][jk])。注意这里需要 k ≤ j k \leq j kj 而且我们枚举 j j j 的时候需要倒着枚举,类似于 01 01 01 背包。

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#define re register
#define int long long
#define drep(a,b,c) for(re int a(b) ; a>=(c) ; --a)
#define rep(a,b,c) 	for(re int a(b) ; a<=(c) ; ++a)
using namespace std;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch == '-') f=-1 ; ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
inline void print(int x){
	if(x < 0) putchar('-'),x = -x;
	if(x >= 10) print(x / 10);
	putchar(x % 10 + '0');
}
const int M = 610;
int n,m;
int head[M],f[M][M];
int cnt;
struct edge{
	int to,nxt;
}e[M];
inline void add(int u,int v){
	e[++cnt].to = v;
	e[cnt].nxt = head[u];
	head[u] = cnt;
}
inline void dfs(int u){
	for(re int i(head[u]) ; i ; i=e[i].nxt){
		int v = e[i].to;
		dfs(v);
		for(re int j(m+1) ; j>=1 ; --j){
			for(re int k(0) ; k<j ; ++k){
				f[u][j] = max(f[u][j],f[v][k]+f[u][j-k]);
			}
		}
	}
}
signed main(){
	n = read(),m = read();
	rep(i,1,n){	
		int u = read(),x = read();
		add(u,i);
		f[i][1] = x;
	}
	dfs(0);
	printf("%d\n",f[0][m+1]);
	return 0;
}

1. 1. 1. 路径求和

我们设 s i z [ u ] siz[u] siz[u] 表示以 u u u 为根的子树大小, s u m [ u ] sum[u] sum[u] 表示以 u u u 为根的子树的叶子结点数量。根据乘法原理,我们在枚举边的时候,设边的两个端点为 u u u v v v,它对答案的贡献就是 w × s i z [ v ] × ( s u m [ 1 ] − s u m [ v ] ) + w × ( s i z [ 1 ] − s i z [ v ] ) × s u m [ v ] w \times siz[v] \times (sum[1]-sum[v]) + w \times (siz[1]-siz[v]) \times sum[v] w×siz[v]×(sum[1]sum[v])+w×(siz[1]siz[v])×sum[v]

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#define re register
#define int long long
#define drep(a,b,c) for(re int a(b) ; a>=(c) ; --a)
#define rep(a,b,c) 	for(re int a(b) ; a<=(c) ; ++a)
using namespace std;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch == '-') f=-1 ; ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
inline void print(int x){
	if(x < 0) putchar('-'),x = -x;
	if(x >= 10) print(x / 10);
	putchar(x % 10 + '0');
}
const int M = 2e5+10;
int head[M],siz[M],du[M],sum[M];
int n,m,cnt,ans;
struct edge{
	int to,nxt,w;
}e[M];
inline void add(int u,int v,int w){
	e[++cnt].to = v;
	e[cnt].w = w;
	e[cnt].nxt = head[u];
	head[u] = cnt;
}
inline void dfs1(int u,int fa){
	siz[u] = 1;
	if(du[u] == 1) sum[u] = 1;
	for(re int i(head[u]) ; i ; i=e[i].nxt){
		int v = e[i].to;
		if(v == fa) continue;
		dfs1(v,u);
		siz[u] += siz[v];
		sum[u] += sum[v];
	}
}
inline void dfs2(int u,int fa){
	for(re int i(head[u]) ; i ; i=e[i].nxt){
		int v = e[i].to,w = e[i].w;
		if(v == fa) continue;
		dfs2(v,u);
		ans += w * (siz[v] * (sum[1]-sum[v]));
		ans += w * ((siz[1]-siz[v]) * sum[v]);
	}
}
signed main(){
	n = read(),m = read();
	rep(i,1,m){
		int w = read(),u = read(),v = read();
		add(u,v,w),add(v,u,w);
		du[u]++,du[v]++;
	}
	dfs1(1,0);
	dfs2(1,0);
	printf("%lld\n",ans);
	return 0;
}

2. 2. 2. 树上移动

只有一个人走的话,我们需要找到离 S S S 的最长路,这些边走一遍,其余的边走两遍。

两个人走的话,两条走到底的路径拼起来正好是最长链的情况最优。我们在 d f s dfs dfs 的过程中记录最长路以及次长路即可。

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#define re register
#define drep(a,b,c) for(re int a(b) ; a>=(c) ; --a)
#define rep(a,b,c) 	for(re int a(b) ; a<=(c) ; ++a)
using namespace std;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch == '-') f=-1 ; ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
inline void print(int x){
	if(x < 0) putchar('-'),x = -x;
	if(x >= 10) print(x / 10);
	putchar(x % 10 + '0');
}
const int M = 2e5+10;
int n,s,sum,cnt,ans1,ans2;
int head[M],dis1[M],dis2[M];
struct edge{
	int to,nxt,w;
}e[M];
inline void add(int u,int v,int w){
	e[++cnt].to = v;
	e[cnt].w = w;
	e[cnt].nxt = head[u];
	head[u] = cnt;
}
inline void dfs(int u,int fa){
	for(re int i(head[u]) ; i ; i=e[i].nxt){
		int v = e[i].to;
		if(v == fa) continue;
		dfs(v,u);
		if(dis1[v] + e[i].w > dis1[u]){
			dis2[u] = dis1[u];
			dis1[u] = dis1[v] + e[i].w;
		}
		else if(dis1[v] + e[i].w > dis2[u]) dis2[u] = dis1[v] + e[i].w;
	}
}
signed main(){
	n = read(),s = read();
	rep(i,1,n-1){
		int u = read(),v = read(),w = read();
		add(u,v,w),add(v,u,w);
		sum += 2*w;
	}
	dfs(s,0);
	int maxn = 0;
	rep(i,1,n) maxn = max(maxn,dis1[i] + dis2[i]);
	ans1 = sum - dis1[s];
	ans2 = sum - maxn;
	printf("%d\n%d\n",ans1,ans2);
	return 0;
}

3. 3. 3. 块的计数

有一种解题的思维是正难则反。

直接算包含的不好求,那么我们可以算出一共的减去不包含的。

f [ u ] f[u] f[u] 表示以结点 u u u 为根的联通块总数 ( ( (包含 u u u ) ) ) g [ u ] g[u] g[u] 表示以 u u u 为根的不包含最大值的联通块总数 ( ( (同样包含 u u u ) ) )。那么以 u u u 为根的包含最大值的联通块总数就是 f [ u ] − g [ u ] f[u] - g[u] f[u]g[u]

对于一个结点 v v v,可以选,可以不选,所以 f [ u ] = ∏ v ⊆ s o n [ u ] ( f [ v ] + 1 ) f[u] = \prod\limits_{v\subseteq son[u]}(f[v]+1) f[u]=vson[u](f[v]+1)

对于 g [ u ] g[u] g[u] 来说,如果 u u u 本身是最大值,那么 g [ u ] = 0 g[u] = 0 g[u]=0,否则跟 f [ u ] f[u] f[u] 的转移类似, g [ u ] = ∏ v ⊆ s o n [ u ] ( g [ v ] + 1 ) g[u] = \prod\limits_{v\subseteq son[u]}(g[v]+1) g[u]=vson[u](g[v]+1)

最后的答案 a n s = ∑ i = 1 n ( f [ i ] − g [ i ] ) ans = \sum_{i=1}^n(f[i]-g[i]) ans=i=1n(f[i]g[i])

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#define re register
#define int long long 
#define drep(a,b,c) for(re int a(b) ; a>=(c) ; --a)
#define rep(a,b,c) 	for(re int a(b) ; a<=(c) ; ++a)
using namespace std;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch == '-') f=-1 ; ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
inline void print(int x){
	if(x < 0) putchar('-'),x = -x;
	if(x >= 10) print(x / 10);
	putchar(x % 10 + '0');
}
const int M = 2e5+10;
const int mod = 998244353;
int head[M],f[M],g[M],a[M];
int n,cnt,ans1,ans2,maxn = -1e17;
struct edge{
	int to,nxt;
}e[M];
inline void add(int u,int v){
	e[++cnt].to = v;
	e[cnt].nxt = head[u];
	head[u] = cnt;
}
inline void dfs(int u,int fa){
	f[u] = 1;
	g[u] = (a[u] != maxn);
	for(re int i(head[u]) ; i ; i=e[i].nxt){
		int v = e[i].to;
		if(v == fa) continue;
		dfs(v,u);
		f[u] = f[u]*(f[v]+1)%mod;
		g[u] = g[u]*(g[v]+1)%mod;
	}
	ans1 = (ans1+f[u])%mod;
	ans2 = (ans2+g[u])%mod;
}
signed main(){
	n = read();
	rep(i,1,n) a[i] = read(),maxn = max(maxn,a[i]);
	rep(i,1,n-1){
		int u = read(),v = read();
		add(u,v),add(v,u);
	}
	dfs(1,0);
	printf("%lld\n",(ans1-ans2+mod)%mod);
	return 0;
}

4. 4. 4. 树的合并

先把两棵树的直径分别求出来,去最大的直径设为 m a x l e n maxlen maxlen。我们发现,将 u u u v v v 连一条边,直径只有两种情况,要么是 m a x l e n maxlen maxlen,要么是 u u u v v v 在其子树中的最大长度 + 1 +1 +1

于是我们可以预处理两棵树每一个点在其树上的最远距离,然后对于数组从大到小排序,进行二分。我们要找到 f 1 [ u ] + f 2 [ v ] + 1 ≥ m a x l e n f1[u] + f2[v] + 1 \geq maxlen f1[u]+f2[v]+1maxlen 的贡献就是 f 1 [ u ] + f 2 [ v ] + 1 f1[u] + f2[v] + 1 f1[u]+f2[v]+1,否则贡献就是 m a x l e n maxlen maxlen

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#define re register
#define int long long 
#define drep(a,b,c) for(re int a(b) ; a>=(c) ; --a)
#define rep(a,b,c) 	for(re int a(b) ; a<=(c) ; ++a)
using namespace std;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch == '-') f=-1 ; ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
inline void print(int x){
	if(x < 0) putchar('-'),x = -x;
	if(x >= 10) print(x / 10);
	putchar(x % 10 + '0');
}
const int M = 4e5+10;
int n,m,cnt,maxlen,s,mx;
int head[M],dis1[M],dis2[M],f1[M],f2[M],sum[M];
struct edge{
	int to,nxt;
}e[M];
inline bool cmp(int x,int y) { return x > y; }
inline void add(int u,int v){
	e[++cnt].to = v;
	e[cnt].nxt = head[u];
	head[u] = cnt;
}
inline void dfs1(int u,int fa){
	if(dis1[u] > maxlen){
		maxlen = dis1[u];
		s = u;
	}
	for(re int i(head[u]) ; i ; i=e[i].nxt){
		int v = e[i].to;
		if(v == fa) continue;
		dis1[v] = dis1[u] + 1;
		dfs1(v,u);
	}
}
inline void dfs2(int u,int fa){
	if(dis1[u] > maxlen){
		maxlen = dis1[u];
		s = u;
	}
	for(re int i(head[u]) ; i ; i=e[i].nxt){
		int v = e[i].to;
		if(v == fa) continue;
		dis1[v] = dis1[u] + 1;
		dfs2(v,u);
	}
}
inline void dfs3(int u,int fa){
	if(dis2[u] > maxlen){
		maxlen = dis2[u];
		s = u;
	}
	for(re int i(head[u]) ; i ; i=e[i].nxt){
		int v = e[i].to;
		if(v == fa) continue;
		dis2[v] = dis2[u] + 1;
		dfs3(v,u);
	}
}
signed main(){
	n = read(),m = read();
	rep(i,1,n-1){
		int u = read(),v = read();
		add(u,v),add(v,u);
	}
	dfs1(1,0);
	memset(dis1,0,sizeof(dis1));
	maxlen = 0;
	dfs2(s,0);
	maxlen = 0;
	dfs3(s,0);
	mx = max(mx,maxlen);
	rep(i,1,n) f1[i] = max(dis1[i],dis2[i]);
	rep(i,1,m-1){
		int u = read(),v = read();
		add(u+n,v+n),add(v+n,u+n);
	}
	memset(dis1,0,sizeof(dis1));
	memset(dis2,0,sizeof(dis2));
	maxlen = 0;
	s = 0;
	dfs1(n+1,0);
	memset(dis1,0,sizeof(dis1));
	maxlen = 0;
	dfs2(s,0);
	maxlen = 0;
	dfs3(s,0);
	mx = max(mx,maxlen);
	rep(i,n+1,n+m) f2[i] = max(dis1[i],dis2[i]);
	sort(f1+1,f1+n+1,cmp);
	rep(i,1,n) sum[i] = sum[i-1] + f1[i];
	int ans = 0;
	rep(i,n+1,n+m){
		int l = 0,r = n,p = 0;
		while(l <= r){
			int mid = (l+r)>>1;
			if(f1[mid]+f2[i]+1 >= mx) l = mid+1,p = mid;
			else r = mid-1;
		}
		ans += sum[p] + p + f2[i]*p + (n-p)*mx;
	}
	printf("%lld\n",ans);
	return 0;
}

5. 5. 5. 权值统计

f [ u ] f[u] f[u] 表示 u u u 子树中的答案。分两种情况考虑 u u u。若 u u u 是路径的端点,那么答案就是 f [ u ] f[u] f[u]。如果 u u u 是路径上的点,答案就是 l s o n × r s o n × a [ u ] lson \times rson \times a[u] lson×rson×a[u]

这里我们记录 ( ∑ v ⊆ s o n [ u ] f [ v ] ) 2 (\sum_{v\subseteq son[u]} f[v])^2 (vson[u]f[v])2 ∑ v ⊆ s o n [ u ] f [ v ] 2 \sum_{v\subseteq son[u]} f[v]^2 vson[u]f[v]2,相减除以 2 2 2 就是答案。

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#define re register
#define int long long 
#define drep(a,b,c) for(re int a(b) ; a>=(c) ; --a)
#define rep(a,b,c) 	for(re int a(b) ; a<=(c) ; ++a)
using namespace std;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch == '-') f=-1 ; ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
inline void print(int x){
	if(x < 0) putchar('-'),x = -x;
	if(x >= 10) print(x / 10);
	putchar(x % 10 + '0');
}
const int M = 2e5+10;
const int mod = 10086;
int head[M],a[M],f[M];
int n,cnt,ans;
struct edge{
	int to,nxt;
}e[M];
inline void add(int u,int v){
	e[++cnt].to = v;
	e[cnt].nxt = head[u];
	head[u] = cnt;
}
inline void dfs(int u,int fa){
	int s1 = 0,s2 = 0,s3 = 0;
	for(re int i(head[u]) ; i ; i=e[i].nxt){
		int v = e[i].to;
		if(v == fa) continue;
		dfs(v,u);
		s1 += f[v];
		s2 += f[v]*f[v];
	}
	f[u] = (s1+1) * a[u] % mod;
	s3 = ((s1*s1-s2)>>1) % mod;
	ans = (ans + f[u] + s3 * a[u]) % mod;
}
signed main(){
	n = read();
	rep(i,1,n) a[i] = read();
	rep(i,1,n-1){
		int u = read(),v = read();
		add(u,v),add(v,u);
	}
	dfs(1,0);
	printf("%lld\n",ans);
	return 0;
}
  • 4
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值