题解:
这个DP显然可以用李超线段树优化……然后由于每个子树都要求一遍,所以可以直接线段树合并。具体怎么合并应该自己yy一下就可以了。
代码:
#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pa pair<int,int>
const int Maxn=100010;
const LL inf=(1LL<<60);
const int K=100001;
int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return x*f;
}
int n,A[Maxn],B[Maxn];LL f[Maxn];
struct Edge{int y,next;}e[Maxn<<1];
int last[Maxn],len=0;
void ins(int x,int y)
{
int t=++len;
e[t].y=y;e[t].next=last[x];last[x]=t;
}
struct Line
{
LL k,b;
Line(LL _k=0,LL _b=0){k=_k,b=_b;}
}l[Maxn];int cnt=0;
LL Y(LL x,int a){return l[a].k*x+l[a].b;}
LL X(int a,int b){return (l[b].b-l[a].b)/(l[a].k-l[b].k);}
int root[Maxn],lc[Maxn*20],rc[Maxn*20],tag[Maxn*20],tot=0;
void insert(int &u,int l,int r,int x)
{
if(!u)u=++tot;
if(!tag[u]){tag[u]=x;return;}
if(Y(l-K,tag[u])<=Y(l-K,x)&&Y(r-K,tag[u])<=Y(r-K,x))return;
if(Y(l-K,tag[u])>=Y(l-K,x)&&Y(r-K,tag[u])>=Y(r-K,x)){tag[u]=x;return;}
int mid=l+r>>1;
if(X(tag[u],x)<=mid-K)
{
if(Y(r-K,tag[u])>=Y(r-K,x))
{
insert(lc[u],l,mid,tag[u]);
tag[u]=x;
}
else insert(lc[u],l,mid,x);
}
else
{
if(Y(l-K,tag[u])>=Y(l-K,x))
{
insert(rc[u],mid+1,r,tag[u]);
tag[u]=x;
}
else insert(rc[u],mid+1,r,x);
}
}
void merge(int &u1,int u2,int l,int r)
{
if(!u1){u1=u2;return;}
if(!u2)return;
int mid=l+r>>1;
merge(lc[u1],lc[u2],l,mid);
merge(rc[u1],rc[u2],mid+1,r);
if(!tag[u2])return;
insert(u1,l,r,tag[u2]);
}
LL query(int u,int l,int r,LL x)
{
if(!u||!tag[u])return inf;
int mid=l+r>>1;
LL re=Y(x,tag[u]);
if(x<=mid-K)return min(re,query(lc[u],l,mid,x));
return min(re,query(rc[u],mid+1,r,x));
}
void dfs(int x,int fa)
{
int son=0;
for(int i=last[x];i;i=e[i].next)
{
int y=e[i].y;
if(y==fa)continue;
son++;
dfs(y,x);
merge(root[x],root[y],1,200001);
}
if(!son)f[x]=0;
else f[x]=query(root[x],1,200001,A[x]);
l[++cnt]=Line(B[x],f[x]);
insert(root[x],1,200001,cnt);
}
int main()
{
n=read();
for(int i=1;i<=n;i++)A[i]=read();
for(int i=1;i<=n;i++)B[i]=read();
for(int i=1;i<n;i++)
{
int x=read(),y=read();
ins(x,y),ins(y,x);
}
dfs(1,0);
for(int i=1;i<=n;i++)printf("%lld ",f[i]);
}