题目描述
花花有一棵带n 个顶点的树T,每个节点有一个点权ai。
有一天,他认为拥有两棵树更好一些。所以,他从T 中删去了一条边。
第二天,他认为三棵树或许又更好一些。因此,他又从他拥有的某一棵树中去除了一条边。
如此往复。每一天,花花都会删去一条尚未被删去的边,直到他得到了一个包含了n 棵只有一个点的树的森林。
定义一条简单路径的权值为路径上点权之和,一棵树的直径为树上权值最大的简单路径。
花花认为树最重要的特征就是它的直径。所以他想请你算出任一时刻他拥有的所有树的直径的乘积。因为这个数可能很大,他要求你输出乘积对
109+7
取模之后的结果。
输入
输入的第一行包含一个整数n,表示树T 上顶点的数量。
下一行包含n 个空格分隔的整数ai,表示顶点的权值。
之后的n-1 行中,每一行包含两个用空格分隔的整数xi 和yi,表示节点xi 和yi 之间连有一条边,编号为i。
再之后n-1 行中,每一行包含一个整数kj,表示在第j 天里会被删除的边的编号
输出
输出n 行。
在第i 行,输出删除i-1 条边之后,所有树直径的乘积对10^9 + 7 取模的结果。
样例输入
3
1 2 3
1 2
1 3
2
1
样例输出
6
9
6
提示
初始,树的直径为6(由节点2、1 和3 构成的路径)。在第一天之后,得到了两棵直径都为3 的树。第二天之后,得到了三棵直径分别为1,2,3 的树,乘积为6。
• 对于40% 的数据:
N<=100
;
• 另有20% 的数据:
N<=1000
;
• 另有20% 的数据:
N<=104
;
• 对于100% 的数据:
N<=105;ai<=104
。
Solution
根据一贯的套路,这种删除边的问题很多都是倒着添边做的。
那我们就来考虑倒着做
最后树成了n个点的森林
每次合并两棵树,除去原来两棵树的直径,乘上合并后的新直径
除直径用乘逆元
那么现在的问题就是:如何快速得出两棵树合并后产生的新直径。
考试的时候,我这个渣渣当然不会了,于是我就暴力更新一条链,极限接近
n2
,但快得出奇,100000的随机数据都只需要0.8s
其实,关于直径又有一个套路,当两棵子树合并的时候,x树的直径两端点是a b,y树的直径两端点是c d,则新直径必定是abcd的某个点对。
求路径用LCA
结果我发现正解比暴力慢了1倍
暴力
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#define ll long long
using namespace std;
const ll mod=1e9+7;
int n,x,y,u,tot;
int st[100005],ed[100005],a[100005],cut[100005],f[100005];
int s[100005],fa[100005],d[100005];
int head[100005],Next[200005],to[200005];
ll now,ans[100005];
void dfs(int k,int pre)
{
fa[k]=pre;
for(int i=head[k];i!=-1;i=Next[i])
if(to[i]!=pre) dfs(to[i],k);
}
void add(int x,int y)
{
tot++;
Next[tot]=head[x];
to[tot]=y;
head[x]=tot;
}
int get(int x)
{
if(f[x]==x) return x; else return f[x]=get(f[x]);
}
void update(int k)
{
int s1=0,s2=0;
for(int i=head[k];i!=-1;i=Next[i])
if(fa[to[i]]==k)
{
if(s[to[i]]>s1)
{
s2=s1;
s1=s[to[i]];
}
else
if(s[to[i]]>s2) s2=s[to[i]];
d[k]=max(d[k],d[to[i]]);
}
s[k]=s1+a[k];
d[k]=max(d[k],s1+s2+a[k]);
if(k!=u) update(fa[k]); else return;
}
ll ny(ll x,ll y)
{
ll p=1;
while(y>0)
{
if(y%2==1) p=(p*x)%mod;
y=y/2;
x=(x*x)%mod;
}
return p;
}
void prepare()
{
cin>>n;
for(int i=1;i<=n;i++)
{
head[i]=-1;
scanf("%d",&a[i]);
}
for(int i=1;i<n;i++)
{
scanf("%d%d",&st[i],&ed[i]);
add(st[i],ed[i]);
add(ed[i],st[i]);
}
dfs(1,0);
for(int i=1;i<n;i++) scanf("%d",&cut[i]);
tot=0;
for(int i=1;i<=n;i++)
{
f[i]=i;
d[i]=(ll)(a[i]);
s[i]=a[i];
head[i]=-1;
}
ans[n]=1;
for(int i=1;i<=n;i++) ans[n]=(ans[n]*a[i])%mod;
now=ans[n]; //现在的直径乘积
}
int main()
{
prepare();
for(int i=n-1;i>=1;i--)
{
x=st[cut[i]],y=ed[cut[i]];
add(x,y);
add(y,x);
if(fa[x]==y)
{
u=get(y);
f[x]=u;
now=(now*ny((ll)(d[x]),mod-2))%mod;
now=(now*ny((ll)(d[u]),mod-2))%mod;
update(y);
now=(now*d[u])%mod;
}
else
{
u=get(x);
f[y]=u;
now=(now*ny((ll)(d[y]),mod-2))%mod;
now=(now*ny((ll)(d[u]),mod-2))%mod;
update(x);
now=(now*d[u])%mod;
}
ans[i]=now;
}
for(int i=1;i<=n;i++) printf("%lld\n",ans[i]);
return 0;
}
正解
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#define ll long long
using namespace std;
const ll mod=1e9+7;
int n,x,y,ty,i1,i2,dis,dad,tot;
int u[100005],v[100005],b[5];
int st[100005],ed[100005],a[100005],cut[100005];
int fa[100005][20],d[100005],f[100005],s[100005];
int head[100005],Next[200005],to[200005],deep[100005];
ll now,ans[100005];
void dfs(int k,int pre)
{
fa[k][0]=pre;
deep[k]=deep[pre]+1;
s[k]=s[pre]+a[k];
for(int i=head[k];i!=-1;i=Next[i])
if(to[i]!=pre) dfs(to[i],k);
}
void add(int x,int y)
{
tot++;
Next[tot]=head[x];
to[tot]=y;
head[x]=tot;
}
int get(int x)
{
if(f[x]==x) return x; else return f[x]=get(f[x]);
}
ll ny(ll x,ll y)
{
ll p=1;
while(y>0)
{
if(y%2==1) p=(p*x)%mod;
y=y/2;
x=(x*x)%mod;
}
return p;
}
void prepare()
{
cin>>n;
for(int i=1;i<=n;i++)
{
head[i]=-1;
scanf("%d",&a[i]);
}
for(int i=1;i<n;i++)
{
scanf("%d%d",&st[i],&ed[i]);
add(st[i],ed[i]);
add(ed[i],st[i]);
}
dfs(1,0);
for(int i=1;(1<<i)<=n;i++)
for(int j=1;j<=n;j++) fa[j][i]=fa[fa[j][i-1]][i-1];
for(int i=1;i<n;i++) scanf("%d",&cut[i]);
tot=0;
for(int i=1;i<=n;i++)
{
f[i]=i;
u[i]=v[i]=i;
d[i]=(ll)(a[i]);
head[i]=-1;
}
ans[n]=1;
for(int i=1;i<=n;i++) ans[n]=(ans[n]*a[i])%mod;
now=ans[n]; //现在的直径乘积
}
int LCA(int x,int y)
{
if(deep[x]<deep[y]) swap(x,y);
for(int i=18;i>=0;i--)
if(deep[x]-(1<<i)>=deep[y]) x=fa[x][i];
if(x==y) return x;
for(int i=18;i>=0;i--)
if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
int main()
{
prepare();
for(int i=n-1;i>=1;i--)
{
x=st[cut[i]],y=ed[cut[i]];
add(x,y);
add(y,x);
if(fa[x][0]==y)
{
ty=get(y);
f[x]=ty;
b[1]=u[x],b[2]=v[x],b[3]=u[ty],b[4]=v[ty];
i1=i2=dis=0;
for(int j=1;j<=4;j++)
for(int k=1;k<=4;k++)
{
dad=LCA(b[j],b[k]);
if(s[b[j]]+s[b[k]]-2*s[dad]+a[dad]>dis)
{
dis=s[b[j]]+s[b[k]]-2*s[dad]+a[dad];
i1=b[j];
i2=b[k];
}
}
now=(now*ny((ll)(d[x]),mod-2))%mod;
now=(now*ny((ll)(d[ty]),mod-2))%mod;
d[ty]=dis,u[ty]=i1,v[ty]=i2;
now=(now*d[ty])%mod;
}
else
{
ty=get(x);
f[y]=ty;
b[1]=u[y],b[2]=v[y],b[3]=u[ty],b[4]=v[ty];
i1=i2=dis=0;
for(int j=1;j<=4;j++)
for(int k=1;k<=4;k++)
{
dad=LCA(b[j],b[k]);
if(s[b[j]]+s[b[k]]-2*s[dad]+a[dad]>dis)
{
dis=s[b[j]]+s[b[k]]-2*s[dad]+a[dad];
i1=b[j];
i2=b[k];
}
}
now=(now*ny((ll)(d[y]),mod-2))%mod;
now=(now*ny((ll)(d[ty]),mod-2))%mod;
d[ty]=dis,u[ty]=i1,v[ty]=i2;
now=(now*d[ty])%mod;
}
ans[i]=now;
}
for(int i=1;i<=n;i++) printf("%lld\n",ans[i]);
return 0;
}