花花的森林

题目描述
花花有一棵带n 个顶点的树T,每个节点有一个点权ai。
有一天,他认为拥有两棵树更好一些。所以,他从T 中删去了一条边。
第二天,他认为三棵树或许又更好一些。因此,他又从他拥有的某一棵树中去除了一条边。
如此往复。每一天,花花都会删去一条尚未被删去的边,直到他得到了一个包含了n 棵只有一个点的树的森林。
定义一条简单路径的权值为路径上点权之和,一棵树的直径为树上权值最大的简单路径。
花花认为树最重要的特征就是它的直径。所以他想请你算出任一时刻他拥有的所有树的直径的乘积。因为这个数可能很大,他要求你输出乘积对 109+7 取模之后的结果。

输入
输入的第一行包含一个整数n,表示树T 上顶点的数量。
下一行包含n 个空格分隔的整数ai,表示顶点的权值。
之后的n-1 行中,每一行包含两个用空格分隔的整数xi 和yi,表示节点xi 和yi 之间连有一条边,编号为i。
再之后n-1 行中,每一行包含一个整数kj,表示在第j 天里会被删除的边的编号
输出
输出n 行。
在第i 行,输出删除i-1 条边之后,所有树直径的乘积对10^9 + 7 取模的结果。

样例输入
3
1 2 3
1 2
1 3
2
1
样例输出
6
9
6
提示
初始,树的直径为6(由节点2、1 和3 构成的路径)。在第一天之后,得到了两棵直径都为3 的树。第二天之后,得到了三棵直径分别为1,2,3 的树,乘积为6。
• 对于40% 的数据: N<=100
• 另有20% 的数据: N<=1000
• 另有20% 的数据: N<=104
• 对于100% 的数据: N<=105;ai<=104

Solution

根据一贯的套路,这种删除边的问题很多都是倒着添边做的。
那我们就来考虑倒着做
最后树成了n个点的森林
每次合并两棵树,除去原来两棵树的直径,乘上合并后的新直径
除直径用乘逆元
那么现在的问题就是:如何快速得出两棵树合并后产生的新直径。
考试的时候,我这个渣渣当然不会了,于是我就暴力更新一条链,极限接近 n2 ,但快得出奇,100000的随机数据都只需要0.8s
其实,关于直径又有一个套路,当两棵子树合并的时候,x树的直径两端点是a b,y树的直径两端点是c d,则新直径必定是abcd的某个点对。
求路径用LCA
结果我发现正解比暴力慢了1倍

暴力

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#define ll long long
using namespace std;
const ll mod=1e9+7;
int n,x,y,u,tot;
int st[100005],ed[100005],a[100005],cut[100005],f[100005];
int s[100005],fa[100005],d[100005];
int head[100005],Next[200005],to[200005];
ll now,ans[100005];
void dfs(int k,int pre)
{
    fa[k]=pre;
    for(int i=head[k];i!=-1;i=Next[i]) 
    if(to[i]!=pre) dfs(to[i],k);
}
void add(int x,int y)
{
    tot++;
    Next[tot]=head[x];
    to[tot]=y;
    head[x]=tot;
}
int get(int x)
{
    if(f[x]==x) return x; else return f[x]=get(f[x]);
}
void update(int k)
{
    int s1=0,s2=0;
    for(int i=head[k];i!=-1;i=Next[i]) 
    if(fa[to[i]]==k) 
    {
        if(s[to[i]]>s1) 
        {
            s2=s1;
            s1=s[to[i]];
        }
        else
        if(s[to[i]]>s2) s2=s[to[i]];
        d[k]=max(d[k],d[to[i]]); 
    }
    s[k]=s1+a[k];
    d[k]=max(d[k],s1+s2+a[k]);
    if(k!=u) update(fa[k]); else return;
}
ll ny(ll x,ll y)
{
    ll p=1;
    while(y>0) 
    {
        if(y%2==1) p=(p*x)%mod;
        y=y/2;
        x=(x*x)%mod;
    }
    return p;
}
void prepare()
{
    cin>>n;
    for(int i=1;i<=n;i++) 
    {
        head[i]=-1;
        scanf("%d",&a[i]); 
    }
    for(int i=1;i<n;i++) 
    {
        scanf("%d%d",&st[i],&ed[i]);
        add(st[i],ed[i]);
        add(ed[i],st[i]);
    }
    dfs(1,0);
    for(int i=1;i<n;i++) scanf("%d",&cut[i]);
    tot=0;
    for(int i=1;i<=n;i++) 
    {
        f[i]=i;
        d[i]=(ll)(a[i]);
        s[i]=a[i];
        head[i]=-1;
    }
    ans[n]=1;
    for(int i=1;i<=n;i++) ans[n]=(ans[n]*a[i])%mod;
    now=ans[n]; //现在的直径乘积 
}   
int main()
{
    prepare();
    for(int i=n-1;i>=1;i--) 
    {
        x=st[cut[i]],y=ed[cut[i]];
        add(x,y);
        add(y,x);
        if(fa[x]==y) 
        {
            u=get(y);
            f[x]=u;
            now=(now*ny((ll)(d[x]),mod-2))%mod;
            now=(now*ny((ll)(d[u]),mod-2))%mod; 
            update(y);
            now=(now*d[u])%mod;
        }
        else
        {
            u=get(x);
            f[y]=u;
            now=(now*ny((ll)(d[y]),mod-2))%mod;
            now=(now*ny((ll)(d[u]),mod-2))%mod; 
            update(x); 
            now=(now*d[u])%mod;
        }
        ans[i]=now;
    }
    for(int i=1;i<=n;i++) printf("%lld\n",ans[i]); 
    return 0;
}

正解

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#define ll long long
using namespace std;
const ll mod=1e9+7;
int n,x,y,ty,i1,i2,dis,dad,tot;
int u[100005],v[100005],b[5];
int st[100005],ed[100005],a[100005],cut[100005];
int fa[100005][20],d[100005],f[100005],s[100005];
int head[100005],Next[200005],to[200005],deep[100005];
ll now,ans[100005];
void dfs(int k,int pre)
{
    fa[k][0]=pre;
    deep[k]=deep[pre]+1;
    s[k]=s[pre]+a[k];
    for(int i=head[k];i!=-1;i=Next[i]) 
    if(to[i]!=pre) dfs(to[i],k);
}
void add(int x,int y)
{
    tot++;
    Next[tot]=head[x];
    to[tot]=y;
    head[x]=tot;
}
int get(int x)
{
    if(f[x]==x) return x; else return f[x]=get(f[x]);
}
ll ny(ll x,ll y)
{
    ll p=1;
    while(y>0) 
    {
        if(y%2==1) p=(p*x)%mod;
        y=y/2;
        x=(x*x)%mod;
    }
    return p;
}
void prepare()
{
    cin>>n;
    for(int i=1;i<=n;i++) 
    {
        head[i]=-1;
        scanf("%d",&a[i]); 
    }
    for(int i=1;i<n;i++) 
    {
        scanf("%d%d",&st[i],&ed[i]);
        add(st[i],ed[i]);
        add(ed[i],st[i]);
    }
    dfs(1,0);
    for(int i=1;(1<<i)<=n;i++) 
    for(int j=1;j<=n;j++) fa[j][i]=fa[fa[j][i-1]][i-1];
    for(int i=1;i<n;i++) scanf("%d",&cut[i]);
    tot=0;
    for(int i=1;i<=n;i++) 
    {
        f[i]=i;
        u[i]=v[i]=i;
        d[i]=(ll)(a[i]);
        head[i]=-1;
    }
    ans[n]=1;
    for(int i=1;i<=n;i++) ans[n]=(ans[n]*a[i])%mod;
    now=ans[n]; //现在的直径乘积 
}
int LCA(int x,int y)
{
    if(deep[x]<deep[y]) swap(x,y);
    for(int i=18;i>=0;i--) 
    if(deep[x]-(1<<i)>=deep[y]) x=fa[x][i];
    if(x==y) return x;
    for(int i=18;i>=0;i--) 
    if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}
int main()
{
    prepare();
    for(int i=n-1;i>=1;i--) 
    {
        x=st[cut[i]],y=ed[cut[i]];
        add(x,y);
        add(y,x);
        if(fa[x][0]==y) 
        {
            ty=get(y);
            f[x]=ty;
            b[1]=u[x],b[2]=v[x],b[3]=u[ty],b[4]=v[ty];
            i1=i2=dis=0;
            for(int j=1;j<=4;j++) 
            for(int k=1;k<=4;k++) 
            {
                dad=LCA(b[j],b[k]);
                if(s[b[j]]+s[b[k]]-2*s[dad]+a[dad]>dis) 
                {
                    dis=s[b[j]]+s[b[k]]-2*s[dad]+a[dad];
                    i1=b[j];
                    i2=b[k];
                }
            }
            now=(now*ny((ll)(d[x]),mod-2))%mod;
            now=(now*ny((ll)(d[ty]),mod-2))%mod; 
            d[ty]=dis,u[ty]=i1,v[ty]=i2;
            now=(now*d[ty])%mod;
        }
        else
        {
            ty=get(x);
            f[y]=ty;
            b[1]=u[y],b[2]=v[y],b[3]=u[ty],b[4]=v[ty];
            i1=i2=dis=0;
            for(int j=1;j<=4;j++) 
            for(int k=1;k<=4;k++) 
            {
                dad=LCA(b[j],b[k]);
                if(s[b[j]]+s[b[k]]-2*s[dad]+a[dad]>dis) 
                {
                    dis=s[b[j]]+s[b[k]]-2*s[dad]+a[dad];
                    i1=b[j];
                    i2=b[k];
                }
            }
            now=(now*ny((ll)(d[y]),mod-2))%mod;
            now=(now*ny((ll)(d[ty]),mod-2))%mod;
            d[ty]=dis,u[ty]=i1,v[ty]=i2; 
            now=(now*d[ty])%mod;
        }
        ans[i]=now;
    }
    for(int i=1;i<=n;i++) printf("%lld\n",ans[i]); 
    return 0;
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值