题目大意
一颗点权树,有多少种将树划分成若干条路径的方法,使得每条路径点权和非负?
做法
不妨设f[i]表示i子树里全部成功覆盖方案数,g[i]表示i子树里除了i全部成功覆盖方案数。
g就是各个子树的f乘积。
f可以枚举lca穿过当前点的一条路径,设为j->k,那么j和k都贡献g,其余挂着的子树贡献f,乘起来即可。
这样太慢了。
考虑简单问题。
如果路径是j->i怎么做呢?
可以尝试对每个子树维护一颗线段树,下标为一个点到根路径点权和的排名,值为它的g乘上一直到当前点挂着的那些子树的f的乘积,区间维护和。
可以发现这个线段树只需要支持合并和乘法标记。每次可选决策显然是一段区间。
这个我们会做了,n log n就能做。
现在是j->k,不妨枚举在较小的那边枚举一个端点,另一端在线段树查询即可。
复杂度类似启发式合并,那么是n log^2 n。
#include<cstdio>
#include<algorithm>
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fd(i,a,b) for(i=a;i>=b;i--)
using namespace std;
typedef long long ll;
const int maxn=100000+10,maxtot=10000000+10,mo=1000000007;
int root[maxn],tree[maxtot],siz[maxtot],mu[maxtot],left[maxtot],right[maxtot],size[maxn];
int h[maxn],go[maxn*2],nxt[maxn*2],sum[maxn],a[maxn],id[maxn],rk[maxn],sta[maxn][2];
int s[maxn],pre[maxn],suf[maxn];
int f[maxn],g[maxn];
int i,j,k,l,t,n,m,tot,top,cnt,ans;
int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
void add(int x,int y){
go[++tot]=y;
nxt[tot]=h[x];
h[x]=tot;
}
void dfs(int x,int y){
sum[x]=sum[y]+a[x];
int t=h[x];
while (t){
if (go[t]!=y) dfs(go[t],x);
t=nxt[t];
}
}
bool cmp(int x,int y){
return sum[x]<sum[y];
}
int newnode(){
mu[++tot]=1;
return tot;
}
void mark(int x,int v){
mu[x]=(ll)mu[x]*v%mo;
tree[x]=(ll)tree[x]*v%mo;
}
void down(int x){
if (mu[x]!=1){
if (left[x]) mark(left[x],mu[x]);
if (right[x]) mark(right[x],mu[x]);
mu[x]=1;
}
}
void change(int &x,int l,int r,int a,int b){
if (!x) x=newnode();
(tree[x]+=b)%=mo;
siz[x]++;
if (l==r) return;
down(x);
int mid=(l+r)/2;
if (a<=mid) change(left[x],l,mid,a,b);else change(right[x],mid+1,r,a,b);
}
int merge(int a,int b,int l,int r){
if (!a||!b) return a+b;
down(a);down(b);
int mid=(l+r)/2;
left[a]=merge(left[a],left[b],l,mid);
right[a]=merge(right[a],right[b],mid+1,r);
tree[a]=(tree[left[a]]+tree[right[a]])%mo;
siz[a]=siz[left[a]]+siz[right[a]];
return a;
}
void travel(int x,int l,int r){
if (l==r){
top++;
sta[top][0]=id[l];
sta[top][1]=tree[x];
return;
}
down(x);
int mid=(l+r)/2;
if (left[x]) travel(left[x],l,mid);
if (right[x]) travel(right[x],mid+1,r);
}
int binary(int x){
int l=1,r=n+1,mid;
while (l<r){
mid=(l+r)/2;
if (sum[id[mid]]>=x) r=mid;else l=mid+1;
}
return l;
}
int query(int x,int l,int r,int a,int b){
if (a>b) return 0;
if (l==a&&r==b) return tree[x];
down(x);
int mid=(l+r)/2;
if (b<=mid) return query(left[x],l,mid,a,b);
else if (a>mid) return query(right[x],mid+1,r,a,b);
else return (query(left[x],l,mid,a,mid)+query(right[x],mid+1,r,mid+1,b))%mo;
}
void solve(int id,int x,int v){
int j,l=-sum[id]+a[id];
l=-l;
j=binary(l);
(f[id]+=(ll)query(x,1,n,j,n)*v%mo)%=mo;
}
void work(int id,int x,int y,int v){
if (siz[x]>siz[y]) swap(x,y);
top=0;
travel(x,1,n);
int i,j,k,l,t;
fo(i,1,top){
k=sta[i][0];
l=sum[k]-2*sum[id]+a[id];
l=-l;
j=binary(l);
(f[id]+=(ll)query(y,1,n,j,n)*sta[i][1]%mo*v%mo)%=mo;
}
}
void dg(int x,int y){
int i,j,t=h[x];
while (t){
if (go[t]!=y) dg(go[t],x);
t=nxt[t];
}
cnt=0;
t=h[x];
while (t){
if (go[t]!=y) s[++cnt]=go[t];
t=nxt[t];
}
g[x]=1;
/*while (t){
if (go[t]!=y){
dg(go[t],x);
g[x]=(ll)g[x]*f[go[t]]%mo;
solve(x,root[go[t]]);
work(x,root[x],root[go[t]]);
root[x]=merge(root[x],root[go[t]],1,n);
}
t=nxt[t];
}*/
pre[0]=1;
fo(i,1,cnt) pre[i]=(ll)pre[i-1]*f[s[i]]%mo;
suf[cnt+1]=1;
fd(i,cnt,1) suf[i]=(ll)suf[i+1]*f[s[i]]%mo;
fo(i,1,cnt){
j=s[i];
g[x]=(ll)g[x]*f[j]%mo;
solve(x,root[j],(ll)pre[i-1]*suf[i+1]%mo);
work(x,root[x],root[j],suf[i+1]);
mark(root[x],f[j]);
mark(root[j],pre[i-1]);
root[x]=merge(root[x],root[j],1,n);
}
if (a[x]>=0) (f[x]+=g[x])%=mo;
change(root[x],1,n,rk[x],g[x]);
}
int main(){
freopen("tree.in","r",stdin);freopen("tree.out","w",stdout);
n=read();
fo(i,1,n){
a[i]=read();
id[i]=i;
}
fo(i,1,n-1){
j=read();k=read();
add(j,k);add(k,j);
}
dfs(1,0);
sort(id+1,id+n+1,cmp);
fo(i,1,n) rk[id[i]]=i;
tot=0;
dg(1,0);
ans=f[1];
(ans+=mo)%=mo;
printf("%d\n",ans);
}