转化为 2mx<x1+x2-sumrt
我们按照mx排序
那么我们就是求解 2mx-x1+sumrt的x2的数量
树状数组计数
典题 有点难调很多细节
#include <bits/stdc++.h>
using namespace std;
#define int long long
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
typedef vector<int> vi;
#define fi first
#define se second
#define pb push_back
#define inf 1ll<<62
#define endl "\n"
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define de_bug(x) cerr << #x << "=" << x << endl
#define all(a) a.begin(),a.end()
#define IOS std::ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define fer(i,a,b) for(int i=a;i<=b;i++)
#define der(i,a,b) for(int i=a;i>=b;i--)
const int mod = 1e9 + 7;
const int N = 2e5 + 10;
int n, m , k;
vector<pii> g[N];
int vis[N];
int cnt, idx2;
int t[N];
struct node {
int dis;
int ma;
/*bool operator<(const node &a)const {
return dis<a.dis;
}*/
} q[N],p[N];
bool cmp(const node &a,const node &b) {
return a.ma<b.ma;
}
struct BIT {
int c[N];
int nn;
void init(int num) {
nn=num;
}
void change(int x, int val) {
for(; x <= nn; x += x & -x)c[x] += val;
}
int query(int x) {
int res = 0;
for(; x; x -= x & -x)res += c[x];
return res;
}
} B;
int get_sz(int u, int fa) {
if(vis[u])return 0;
int ans = 1;
for(auto a : g[u]) {
int v = a.fi;
if(v != fa)
ans += get_sz(v, u);
}
return ans;
}
int get_wc(int u, int fa, int tot, int&wc) {
if(vis[u]) return 0;
int sum = 1;
int ma = -1;
for(auto a : g[u]) {
int v = a.fi;
if(v == fa)continue;
int t = get_wc(v, u, tot, wc);
ma = max(ma, t);
sum += t;
}
ma = max(ma, tot - sum);
if(ma <= tot / 2)wc = u;
return sum;
}
void get_d(int u, int fa, int dis,int dis1, int &cnt) {
if(vis[u]) return ;
q[cnt].dis=dis;
q[cnt].ma=dis1;
++cnt;
for(auto a : g[u]) {
int v = a.fi;
if(v == fa)continue;
get_d(v, u, dis + t[v],max(dis1,t[v]), cnt);
}
}
int d2[N];
int get_pre(node p[], int id,int val) {
int ans=0;
vector<int>d;
//de_bug(id);
for(int i=0; i<id; i++)d.push_back(p[i].dis);
sort(all(d));
d.erase(unique(all(d)),d.end());
sort(p,p+id,cmp);
//de_bug(d.size());
B.init((int)d.size());
/*for(int i=0; i<id; i++) {
cout<<p[i].ma<<" "<<p[i].dis<<endl;
}*/
for(int i=0; i<id; i++) {
d2[i]=lower_bound(all(d),p[i].dis)-d.begin()+1;
//cout<<d2[i]<<"***"<<endl;
}
for(int i=0; i<id; i++) {
// de_bug(-p[i].dis+2*p[i].ma+val);
int pos=upper_bound(all(d),-p[i].dis+2*p[i].ma+val)-d.begin()+1;
if(pos<=d.size()) {
// de_bug(pos);
ans+=B.query(d.size())-B.query((pos-1));
// de_bug(B.query(d.size()));
// de_bug(B.query(pos-1));
}
//de_bug(d2[i]);
B.change(d2[i],1);
}
for(int i=0; i<id; i++) {
B.change(d2[i],-1);
}
// de_bug(ans);
for(int i=0; i<id; i++) {
d2[i]=0;
}
return ans;
}
int get_ans(int u) {
if(vis[u])return 0;
int ans = 0;
get_wc(u, -1, get_sz(u, -1), u);
vis[u] = 1;
int idx1 = 0;
for(auto a : g[u]) {
int v = a.fi;
int w = a.se;
idx2 = 0;
get_d(v, -1, t[u]+t[v], max(t[u],t[v]),idx2);
// de_bug(idx2);
if(!idx2)continue;
ans -= get_pre(q, idx2,t[u]);
for(int i = 0; i < idx2; i++) {
//if(q[i].dis<=k&&q[i].dep<=m) ans++;
// cout<<q[i].dis<<" "<<q[i].ma<<endl;
/*if(q[i].dis>2*q[i].ma) {
// cout<<"****"<<endl;
ans++;
}*/
p[idx1++]=q[i];
}
}
// de_bug(idx1);
//de_bug(t[u]);
p[idx1++]= {t[u],t[u]};
if(idx1>1)ans += get_pre(p, idx1,t[u]);
for(auto a : g[u]) {
int v = a.fi;
ans += get_ans(v);
}
return ans;
}
void solve() {
cin >> n;
fer(i,1,n) {
vis[i]=0;
g[i].clear();
}
cnt=0;
fer(i,1,n) {
cin>>t[i];
}
for(int i = 2; i <=n; i++) {
int a, b;
cin >> a >> b;
g[a].push_back({b, 1});
g[b].push_back({a, 1});
}
cout << get_ans(1) << endl;
}
signed main() {
IOS;
int _ ;
cin>>_;
while( _-- )
solve();
}