题意
T T T组数据,每组数据给你一个正整数 n n n,然后每个点的权值 a i a_i ai,再给你 n − 1 n-1 n−1条无向边 ( u i , v i ) (u_i,v_i) (ui,vi),保证构成一棵树。求有多少条合法的路径(相当于多少个点对),使得路径上经过的所有点的权值可以构成一个简单多边形。
数据范围:
1
⩽
n
⩽
2
×
1
0
5
,
1
⩽
a
i
⩽
1
0
9
1\leqslant n\leqslant 2\times10^5,1\leqslant a_i\leqslant 10^9
1⩽n⩽2×105,1⩽ai⩽109
1
⩽
u
i
,
v
i
⩽
n
,
∑
n
⩽
4
×
1
0
5
1\leqslant u_i,v_i\leqslant n,\sum_{}n\leqslant 4\times 10^5
1⩽ui,vi⩽n,∑n⩽4×105
思路
考察点分治的运用。
①一个结论:长度为
a
i
(
i
=
1
,
2
,
.
.
.
,
x
)
a_i(i=1,2,...,x)
ai(i=1,2,...,x)的边能够构成多边形的充要条件是:
∑
i
=
1
x
a
i
>
2
×
m
a
x
{
a
i
}
(
i
=
1
,
2
,
.
.
.
,
x
)
\sum_{i=1}^xa_i>2 \times max \{a_i\}(i=1,2,...,x)
∑i=1xai>2×max{ai}(i=1,2,...,x)
②设
s
u
m
[
i
]
sum[i]
sum[i]为点
i
i
i到根节点
r
t
rt
rt的路径上的所有点的权值和,
m
x
[
i
]
mx[i]
mx[i]为点
i
i
i到根节点
r
t
rt
rt的路径上的所有点中点权最大值,若点对
(
u
,
v
)
(u,v)
(u,v)之间的路径满足条件,则有
s
u
m
[
u
]
+
s
u
m
[
v
]
−
a
[
r
t
]
>
2
×
m
a
x
(
m
x
[
u
]
,
m
x
[
v
]
)
sum[u]+sum[v]-a[rt]>2\times max(mx[u],mx[v])
sum[u]+sum[v]−a[rt]>2×max(mx[u],mx[v])
移项得
2
×
m
a
x
(
m
x
[
u
]
,
m
x
[
v
]
)
−
s
u
m
[
v
]
+
a
[
r
t
]
<
s
u
m
[
u
]
2\times max(mx[u],mx[v])-sum[v]+a[rt]<sum[u]
2×max(mx[u],mx[v])−sum[v]+a[rt]<sum[u]
因此可以对
m
x
[
i
]
mx[i]
mx[i]从小到大排序,枚举
v
v
v,用树状数组维护合法的答案数。
③注意初始化。
#include<bits/stdc++.h>
#define inf 0x3f3f3f3f
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define dep(i,a,b) for(int i=(a);i>=(b);i--)
#define VV vector<int>
#define PP pair<int,int>
#define ls (rt<<1)
#define rs (rt<<1|1)
#define fi first
#define se second
#define pb push_back
using namespace std;
typedef long long ll;
template <typename T>
inline void read(T &X){
X=0;int w=0; char ch=0;
while(!isdigit(ch)) {w|=ch=='-';ch=getchar();}
while(isdigit(ch)) X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
if(w) X=-X;
}
const int maxn=2e5+5;
//const double pi=acos(-1.0);
//const double eps=1e-9;
//const ll mo=1e9+7;
int n,m,k,rt;
int he[maxn], tot;
int dp[maxn];
struct edge{
int v;
int nxt;
} e[maxn << 1];
void add(int u,int v){
e[tot].v = v;
e[tot].nxt = he[u];
he[u] = tot++;
}
int a[maxn],vis[maxn];
int sz[maxn],sum;
ll ans;
void init(){
rep(i, 0, n) he[i] = -1, vis[i] = 0;
tot = 0;
ans = 0;
}
void getroot(int u,int fa){
sz[u] = 1;
dp[u] = 0;//
for (int i = he[u]; ~i;i=e[i].nxt){
int v = e[i].v;
if(vis[v]||v==fa)
continue;
getroot(v, u);
sz[u] += sz[v];
dp[u] = max(dp[u], sz[v]);//
}
dp[u] = max(dp[u], sum - sz[u]);//
if(dp[u]<dp[rt])
rt = u;
}
ll cal(int u, int w);
void go(int u){
vis[u] = 1;
ans += cal(u, 0);
for (int i = he[u]; ~i;i=e[i].nxt){
int v = e[i].v;
if(vis[v])
continue;
ans -= cal(v, a[u]);
rt = 0;
dp[0] = sum = sz[v];//
getroot(v, 0);//
go(rt);
}
}
int cnt;
pair<ll, int> p[maxn], val[maxn];
void add_dfs(int u,int mx,ll w,int fa){
cnt++;
p[cnt] = {w, cnt};
val[cnt].fi = mx;
for (int i = he[u]; ~i;i=e[i].nxt){
int v = e[i].v;//
if(vis[v]||v==fa)
continue;
add_dfs(v, max(mx, a[v]), w + a[v], u);
}
}
int C[maxn];
int lb(int x) { return x & (-x); }
void addv(int x,int v){
while(x<=cnt) {
C[x] += v;
x += lb(x);
}
}
int query(int x){
int sum = 0;
while(x){
sum += C[x];
x -= lb(x);
}
return sum;
}
ll cal(int u,int w){
cnt = 0;
add_dfs(u, max(w, a[u]), w + a[u], 0);
int hd = (w ? w : val[1].fi);
sort(p + 1, p + 1 + cnt);
rep(i, 1, cnt) val[p[i].se].se = i;
rep(i, 0, cnt) C[i] = 0;
ll res = 0;
sort(val + 1, val + 1 + cnt);
rep(i,1,cnt){
ll sw = 2LL * val[i].fi + hd - p[val[i].se].fi;
int l = 1, r = cnt, as = 0;
while(l<=r){
int mid = (l + r) >> 1;
if(p[mid].fi<=sw){
l = mid + 1;
as = mid;
}
else
r = mid - 1;
}
res += i - 1 - query(as);
addv(val[i].se, 1);
}
return res;
}
void solve(){
read(n);
init();
rep(i,1,n){
read(a[i]);
}
rep(i,1,n-1) {
int x, y;
read(x);
read(y);
add(x, y);
add(y, x);
}
rt = 0;
dp[0] = sum = n;//
getroot(1, 0);
go(rt);
printf("%lld\n", ans);
}
int main(){
// freopen("e://duipai//amyout.txt","w",stdout);
int T=1,cas=1;
read(T);
while(T--){
solve();
}
//system("pause");
return 0;
}