解题思路
设 g [ u ] g[u] g[u]表示以u的子树内所有联通块个数(必定选取u), f [ u ] f[u] f[u]为u的子树内不包含最大点权的联通块个数(必定选取u)
则 g [ u ] = ∏ v ∈ s o n u ( f [ v ] + 1 ) g[u]=∏_{v∈son_u}(f[v]+1) g[u]=∏v∈sonu(f[v]+1)
而,当u不是最大点权时, f [ u ] = ∏ v ∈ s o n u ( f [ v ] + 1 ) f[u]=∏_{v∈son_u}(f[v]+1) f[u]=∏v∈sonu(f[v]+1),否则, f [ u ] = 0 f[u]=0 f[u]=0
最终答案为 ∑ ( f [ u ] − g [ u ] ) . ∑(f[u]−g[u]). ∑(f[u]−g[u]).
代码
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const ll mod=998244353;
ll k,n,u,v,sum,ans,maxn,w[100010],head[200010],f[100010],g[100010];
struct c {
int x,next;
} a[200010];
void add(int x,int y) {
a[++k]=(c) {
y,head[x]
};
head[x]=k;
}
void dfs(int x,int fa) {
if(w[x]!=maxn)
f[x]=1;
else f[x]=0;
g[x]=1;
for(int i=head[x]; i; i=a[i].next) {
int y=a[i].x;
if(y==fa)continue;
dfs(y,x);
f[x]=f[x]*(f[y]+1)%mod;
g[x]=g[x]*(g[y]+1)%mod;
}
sum=(sum+g[x])%mod;
ans=(ans+f[x])%mod;
}
int main() {
scanf("%lld",&n);
maxn=-1e17;
for(int i=1; i<=n; i++) {
scanf("%lld",&w[i]);
maxn=max(maxn,w[i]);
}
for(int i=1; i<n; i++) {
scanf("%lld%lld",&u,&v);
add(u,v);
add(v,u);
}
dfs(1,0);
printf("%lld",(sum+mod-ans)%mod);
}