GP of China H Inner Product: 边分治 + 虚树dp

本文详细解析了一种结合边分治与01动态规划算法解决特定树上距离求和问题的方法。通过在两个树TTT和T′T′上计算所有节点对之间的距离乘积之和,采用边分治策略处理复杂度,并在虚树上应用01DP进行高效求解。文章深入分析了算法实现细节,包括数据结构设计、状态转移方程和复杂度分析。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

题意

给出两棵树TTTT′T'T,求
∑i,j∈[1,n]dis(i,j)∗dis′(i,j) \sum_{i,j \in [1,n]}{dis(i,j) * dis'(i,j)} i,j[1,n]dis(i,j)dis(i,j)

题解

TTT进行边分治,当前分治的边为<u,v><u,v><u,v>,边权为www时,设uuu一侧的点集为LLLvvv一侧的点集是RRR,则经过当前边的dis(i,j)dis(i,j)dis(i,j)对答案的贡献是:
ANS=∑i∈L,j∈R(dis(i,u)+dis(j,v)+w)⋅dis′(i,j)=∑i∈L,j∈R(dis[i]+dis[j])⋅dis′(i,j)+w⋅∑i∈L,j∈Rdis′(i,j)=∑i∈L,j∈R(dis[i]+dis[j])⋅(dep′[i]+dep′[j]−2⋅dep′[lca′(i,j)])+w⋅∑i∈L,j∈Rdis′(i,j)=∑i∈L,j∈R(dis[i]+dis[j])⋅(dep′[i]+dep′[j])−2⋅∑i∈L,j∈R(dis[i]+dis[j])⋅dep′[lca′(i,j)]+w⋅∑i∈L,j∈Rdis′(i,j)=∑i∈L,j∈R(dis[i]⋅dep′[i]+dis[j]⋅dep′[j])+∑i∈L,j∈R(dis[i]⋅dep′[j]+dis[j]⋅dep′[i])      −2⋅∑i∈L,j∈R(dis[i]+dis[j])⋅dep′[lca′(i,j)]+w⋅∑i∈L,j∈Rdis′(i,j)=∣R∣⋅∑i∈Ldis[i]⋅dep′[i]+∣L∣⋅∑j∈Rdis[j]⋅dep′[j]+∑i∈Ldis[i]⋅∑j∈Rdep′[j]+∑i∈Ldep′[i]⋅∑j∈Rdis[j]      −2⋅∑i∈L,j∈R(dis[i]+dis[j])⋅dep′[lca′(i,j)]+w⋅∑i∈L,j∈Rdis′(i,j) \begin{aligned} ANS &= \sum_{i \in L,j \in R}{(dis(i,u) + dis(j,v) + w)\cdot dis'(i,j)}\\ &= \sum_{i \in L,j\in R}{(dis[i] + dis[j])\cdot dis'(i,j)} + w\cdot \sum_{i \in L,j\in R}{dis'(i,j)}\\ &= \sum_{i \in L,j\in R}{(dis[i] + dis[j])\cdot (dep'[i] + dep'[j] - 2 \cdot dep'[lca'(i,j)])} + w\cdot \sum_{i \in L,j\in R}{dis'(i,j)}\\ &= \sum_{i \in L,j\in R}{(dis[i] + dis[j])\cdot (dep'[i] + dep'[j])} - 2 \cdot\sum_{i \in L,j \in R}{(dis[i] + dis[j])\cdot dep'[lca'(i,j)]} + w\cdot \sum_{i \in L,j\in R}{dis'(i,j)}\\ &= \sum_{i \in L,j\in R}{(dis[i] \cdot dep'[i] + dis[j] \cdot dep'[j])} + \sum_{i \in L,j \in R}{(dis[i] \cdot dep'[j] + dis[j] \cdot dep'[i])}\\ &\ \ \ \ \ \ - 2 \cdot\sum_{i \in L,j \in R}{(dis[i] + dis[j])\cdot dep'[lca'(i,j)]} + w\cdot \sum_{i \in L,j\in R}{dis'(i,j)}\\ &= |R| \cdot\sum_{i \in L}{dis[i] \cdot dep'[i]} + |L| \cdot \sum_{j \in R}{dis[j] \cdot dep'[j]} + \sum_{i \in L}{dis[i]} \cdot \sum_{j \in R}{dep'[j]}+ \sum_{i \in L}{dep'[i]} \cdot \sum_{j \in R}{dis[j]}\\ &\ \ \ \ \ \ - 2 \cdot\sum_{i \in L,j \in R}{(dis[i] + dis[j])\cdot dep'[lca'(i,j)]} + w\cdot \sum_{i \in L,j\in R}{dis'(i,j)} \end{aligned} ANS=iL,jR(dis(i,u)+dis(j,v)+w)dis(i,j)=iL,jR(dis[i]+dis[j])dis(i,j)+wiL,jRdis(i,j)=iL,jR(dis[i]+dis[j])(dep[i]+dep[j]2dep[lca(i,j)])+wiL,jRdis(i,j)=iL,jR(dis[i]+dis[j])(dep[i]+dep[j])2iL,jR(dis[i]+dis[j])dep[lca(i,j)]+wiL,jRdis(i,j)=iL,jR(dis[i]dep[i]+dis[j]dep[j])+iL,jR(dis[i]dep[j]+dis[j]dep[i])      2iL,jR(dis[i]+dis[j])dep[lca(i,j)]+wiL,jRdis(i,j)=RiLdis[i]dep[i]+LjRdis[j]dep[j]+iLdis[i]jRdep[j]+iLdep[i]jRdis[j]      2iL,jR(dis[i]+dis[j])dep[lca(i,j)]+wiL,jRdis(i,j)
每个求和式都是可以直接DP的。。。所以用点集L∪RL \cup RLRT′T'T上建虚树做01DP即可。
复杂度大概。。O(nlog2n)O(nlog^2n)O(nlog2n)。。那么问题来了。。为什么不用点分呢

#pragma GCC optimize(3)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 1e9 + 7;
const int maxn = 1e5 + 100;
int n;
vector<pair<int,int> > E2[maxn];
namespace HLD{
    int wson[maxn],sz[maxn],depth[maxn];
    ll dep[maxn];
    int fa[maxn],top[maxn];
    int dfs_clock,l[maxn],r[maxn];
    void dfs1(int u,int Fa){
        l[u] = ++dfs_clock;
        depth[u] = depth[Fa] + 1;
        sz[u] = 1;
        wson[u] = 0;
        fa[u] = Fa;
        for (auto e : E2[u]){
            int v,len;
            tie(v,len) = e;
            if (v == Fa)continue;
            dep[v] = dep[u] + len;
            if (dep[v] >= mod)dep[v] -= mod;
            dfs1(v,u);
            sz[u] += sz[v];
            if (sz[v] > sz[wson[u]])wson[u] = v;
        }
        r[u] = dfs_clock;
    }
    void dfs2(int u,int Fa,int chain){
        top[u] = chain;
        if (wson[u])dfs2(wson[u],u,chain);
        for (auto e : E2[u]){
            int v,len;
            tie(v,len) = e;
            if (v == Fa || v == wson[u])continue;
            dfs2(v,u,v);
        }
    }
    void init(int root){
        dfs_clock = 0;
        dep[root] = 0;
        dfs1(root,0);
        dfs2(root,0,root);
    }
    int lca(int x,int y){
        while (top[x] != top[y]){
            if (depth[top[x]] < depth[top[y]])swap(x,y);
            x = fa[top[x]];
        }
        if (depth[x] < depth[y])swap(x,y);
        return y;
    }
}
vector<pair<int,int> > EE1[maxn];
int first[maxn*3],des[maxn*6],llen[maxn*6],edgeid[maxn*6],nxt[maxn*6],tot;
inline void add_edge_(int u,int v,int w,int id){
    tot ++;
    des[tot] = v;
    llen[tot] = w;
    edgeid[tot] = id;
    nxt[tot] = first[u];
    first[u] = tot;
}
int cnt,edge_cnt;
bool banned[maxn * 3];
int pos[maxn * 3];
int id[maxn * 3];
ll ans = 0;
void clear(){
    tot = 0;
    for (int i=1;i<=cnt;i++)first[i] = 0;
    cnt = 0;
    for (int i=1;i<=n;i++){
        EE1[i].clear();
        E2[i].clear();
    }
    for (int i=1;i<=edge_cnt;i++)banned[i] = false;
    edge_cnt = 0;
    ans = 0;
}
inline void add_edge(int u,int v,int w){
    edge_cnt ++;
    add_edge_(u,v,w,edge_cnt);
    add_edge_(v,u,w,edge_cnt);
}
int dfs3d(int u, int fa){
    int now = ++cnt;
    pos[u] = now;
    id[now] = u;
    int pre = now;
    for (auto e : EE1[u]){
        int v,len;
        tie(v,len) = e;
        if (v == fa)continue;
        int temp = ++cnt;
        id[temp] = 0;
        add_edge(pre,temp,0);
        int vid = dfs3d(v,u);
        add_edge(temp,vid,len);
        pre = temp;
    }
    return now;
}
int sz[maxn * 3];
void dfs_sz(int u,int fa){
    sz[u] = 1;
    for (int t = first[u];t;t=nxt[t]){
        int v = des[t],e_id = edgeid[t];
        if (v == fa || banned[e_id])continue;
        dfs_sz(v,u);
        sz[u] += sz[v];
    }
}
void dfs_edge(int u,int fa,int tot_node,
              int &uu,int &vv,int &ww,int &edge_id,int &max_sz){
    for (int t = first[u];t;t=nxt[t]){
        int v = des[t],len = llen[t],e_id = edgeid[t];
        if (v == fa || banned[e_id])continue;
        int max_sz_t = max(sz[v],tot_node - sz[v]);
        if (max_sz_t < max_sz){
            max_sz = max_sz_t;
            uu = u;vv = v;
            ww = len;
            edge_id = e_id;
        }
        dfs_edge(v,u,tot_node,uu,vv,ww,edge_id,max_sz);
    }
}
ll dis[maxn * 3];
void dfs_node(int u,int fa,ll length,vector<int> &nodes){
    if (id[u])nodes.push_back(id[u]);
    dis[u] = length;
    for (int t = first[u];t;t=nxt[t]){
        int v = des[t],len = llen[t],e_id = edgeid[t];
        if (v == fa || banned[e_id])continue;
        int le = length + len;
        if (le >= mod) le -= mod;
        dfs_node(v,u,le, nodes);
    }
}
int vis[maxn];
int stk[maxn];
int fa[maxn];
ll dp_sum[maxn][2], dp_cnt[maxn][2],dp[maxn];
int color[maxn];
inline void clear(int x){
    dp[x] = 0;
    for (int c = 0;c < 2;c ++){
        dp_sum[x][c] = dp_cnt[x][c] = 0;
    }
}
void calc(int u, int v, int w){
    vector<int> L(0),R(0),nodes(0);
    dfs_node(u,0,0,L);dfs_node(v,0,0,R);
    if (L.size() == 0 || R.size() == 0)return;
    for (int x : L){
        color[x] = 1;
        vis[x] = 1;
        nodes.push_back(x);
    }
    for (int y : R){
        color[y] = 2;
        vis[y] = 1;
        nodes.push_back(y);
    }
    sort(nodes.begin(),nodes.end(),[](int x,int y){
        return HLD::l[x] < HLD::l[y];
    });
    int SZ = nodes.size();
    for (int i=1;i<SZ;i++){
        int temp = HLD::lca(nodes[i-1],nodes[i]);
        if (!vis[temp]){
            nodes.push_back(temp);
            vis[temp] = 2;
        }
    }
    if (!vis[1]){
        nodes.push_back(1);
        vis[1] = 2;
    }
    sort(nodes.begin(),nodes.end(),[](int x,int y){
        return HLD::l[x] < HLD::l[y];
    });
    int top = 1;
    stk[0] = nodes.front();
    for (int i=1;i<nodes.size(); i ++){
        while (HLD::l[nodes[i]] > HLD::r[stk[top-1]]) top--;
        fa[nodes[i]] = stk[top-1];
        stk[top ++] = nodes[i];
    }
    for (int x : nodes)clear(x);
    ll temp_ans = 0;
    ll sum = 0;
    for (int x : L) sum += 1ll * (dis[pos[x]] + w) * HLD::dep[x] % mod;
    temp_ans += sum % mod * R.size();
    sum = 0;
    for (int y : R) sum += 1ll * (dis[pos[y]] + w) * HLD::dep[y] % mod;
    temp_ans += sum % mod * L.size();
    ll sum1 = 0,sum2 = 0;
    for (int x : L)sum1 += dis[pos[x]];
    for (int y : R)sum2 += HLD::dep[y];
    sum1 %= mod;sum2 %= mod;
    temp_ans += sum1 * sum2 % mod;
    sum1 = sum2 = 0;
    for (int x : L)sum1 += HLD::dep[x];
    for (int y : R)sum2 += dis[pos[y]];
    sum1 %= mod;sum2 %= mod;
    temp_ans += sum1 * sum2 % mod;
    for (int i = nodes.size() - 1;i >=0; i--){
        int u = nodes[i], c = vis[u] == 1?color[u] - 1 : -1;
        if (c != -1){
            ll A = dis[pos[u]];
            dp[u] += A * dp_cnt[u][!c] % mod;
            dp[u] += dp_sum[u][!c];
            dp[u] += dp_cnt[u][!c] * w % mod;
            dp[u] %= mod;
            dp_sum[u][c] += A;
            if (dp_sum[u][c] >= mod)dp_sum[u][c] -= mod;
            dp_cnt[u][c] ++;
        }
        temp_ans -= 2ll * dp[u] * HLD::dep[u] % mod;
        dp[fa[u]] += dp_sum[u][0] * dp_cnt[fa[u]][1]% mod + dp_sum[fa[u]][1] * dp_cnt[u][0]% mod;
        dp[fa[u]] += dp_sum[u][1] * dp_cnt[fa[u]][0]% mod + dp_sum[fa[u]][0] * dp_cnt[u][1]% mod;
        dp[fa[u]] += (dp_cnt[fa[u]][1] * dp_cnt[u][0]% mod + dp_cnt[fa[u]][0] * dp_cnt[u][1] % mod) * w % mod;
        dp[fa[u]] %= mod;
        for (int c = 0;c < 2;c ++){
            dp_cnt[fa[u]][c] += dp_cnt[u][c];
            dp_sum[fa[u]][c] += dp_sum[u][c];
            if(dp_cnt[fa[u]][c] >= mod)dp_cnt[fa[u]][c] -= mod;
            if(dp_sum[fa[u]][c] >= mod)dp_sum[fa[u]][c] -= mod;
        }
    }
    ans += (temp_ans % mod) + mod;
    for (int x : nodes){
        vis[x] = 0;
    }
}
void dfs(int root){
    dfs_sz(root,0);
    int uu,vv,ww,e_id,max_sz = mod;
    int node_cnt = sz[root];
    if (node_cnt == 1)return;
    dfs_edge(root,0,node_cnt,uu,vv,ww,e_id,max_sz);
    banned[e_id] = true;
    calc(uu, vv, ww);
    dfs(uu);dfs(vv);
}
inline void read(int &x){
    x = 0;
    char ch = getchar();
    while (ch < '0' || ch > '9')ch = getchar();
    while (ch >= '0' && ch <= '9')x = x * 10 + ch - '0',ch = getchar();
}
void work(){
    read(n);
    for (int i=1;i<n;i++){
        int u,v,len;
        read(u);read(v);read(len);
        EE1[u].push_back(make_pair(v,len));
        EE1[v].push_back(make_pair(u,len));
    }
    for (int i=1;i<n;i++){
        int u,v,len;
        read(u);read(v);read(len);
        E2[u].push_back(make_pair(v,len));
        E2[v].push_back(make_pair(u,len));
    }
    HLD::init(1);
    int root = dfs3d(1,0);
    int max_d = -1;
    for (int i=1;i<=cnt;i++){
        int d = 0;
        for (int t = first[i];t;t=nxt[t]){
            d ++;
        }
        max_d = max(max_d,d);
    }
    dfs(root);
    ans %= mod;
    ans = ans * 2 % mod;
    printf("%lld\n",ans);
}
int main(){
    int T;
    read(T);
    while (T--){
        work();
        clear();
    }
    return 0;
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值