G. Rikka with Intersections of Paths (树上差分)

73 篇文章 0 订阅

题目

2018-2019 ACM-ICPC, Asia Xuzhou Regional Contest G. Rikka with Intersections of Paths
https://codeforces.com/gym/102012/problem/H

大意就是给一棵 n 个点的树,和其中 m 条简单路径,以及一个数 k 。要你在这 m 个路径中选 k 个,使得这 k 个路径至少包含一个公共点。问有多少种选法。
n,m: 1e5
k: m
多组数据

思路

  • 考虑对每个点,要是知道通过它的路径数为 b, 那么以它为公共点的选法有 C(b, k) 种。
  • 但单纯这样对每个点的这个累加,会有重复。
  • 思考有没有方法能让一个选法与唯一一个节点对应。于是想到这样一个对应方法:对每个结点,只计算 “至少有一条路径,其以这个点为顶点” 的选法。(这里顶点指的是路径端点的最近公共祖先)。
  • 这样一来,每个点得出贡献的那些选法都是独一无二的了。(简单说明一下,因为对于一个点 X,和一条经过它的路径 L,必有 L 的顶点不是 X 就是 X 的祖先,不可能是 X 的后代。也就是说 X 是这个选法中深度最深的顶点。而一个选法不可能有两个不同的深度最深的顶点(否则这两个顶点对应的路径就没有公共点了),所以这个对应方法是唯一的了)
  • 每个结点的贡献就是 C(b,k) - C(b-a,k)。其中 b 是通过它的路径数,a 是以它为顶点的路径数。
  • a 好求,每个路径求个 LCA 然后弄个数组累加一下就行。
  • b 一开始想的是用树链剖分把路径上的点全都 ++,结果多个 lg 就 TLE 了,时间卡得很紧啊。然后改成用树上差分,每个路径,给端点++,LCA–,fa[LCA]–,全部路径弄完后,再从叶子节点累加上去,就是需要的 b 数组了。可能有些细节,代码见。(其实应该先想差分的,差分做不了才用得上剖分……)
  • 时间复杂度 O(mlgn + n) (大概)

代码

#include <bits/stdc++.h>
using namespace std;

#define MAXN 300005
#define Ha 1000000007

#define For(x) for (int h=head[x],o=to[h]; h; o=to[h=nxt[h]])


int head[MAXN*3];
int nxt[MAXN*3], to[MAXN*3];
int num;

//vector<int> graph[MAXN];
int n,m,k;

long long jc[MAXN];
long long ans;

int fa[MAXN];
int ST[MAXN][30];
int dep[MAXN];

int A[MAXN];
int B[MAXN];


void dfs1(int X, int F)
{
	dep[X]=dep[F]+1;
	fa[X]=F;
	ST[X][0]=F;
	For(X) if (o!=F)
		dfs1(o, X);
}

void get_ST()
{
	for (int j=1; j<=20; j++)
		for (int i=1; i<=n; i++)
			ST[i][j]=ST[ST[i][j-1]][j-1];
}


int get_LCA(int X, int Y)
{	
	if (dep[X]<dep[Y])
		swap(X, Y);
	for (int tmp=dep[X]-dep[Y], i=0; tmp>0; tmp>>=1, i++)
		if (tmp&1)
			X=ST[X][i];

	if (X==Y) return X;

	for (int i=20; fa[X]!=fa[Y] && i>=0; i--)
		if (ST[X][i]!=ST[Y][i])
			X=ST[X][i], Y=ST[Y][i];
	
	return fa[X];
}


void dfs3(int X, int F)
{
	For(X) if (o!=F) {
		dfs3(o, X);
		B[X]+=B[o];
	}
}




void fun(int X, int Y)
{
	int LCA=get_LCA(X, Y);
    
    A[LCA]++;
    
    B[X]++;
    B[Y]++;
    B[LCA]--;
    if (LCA!=1)
    	B[fa[LCA]]--;
}


long long ksm(long long x, int y)
{
    long long ret=1;
    for (; y; x=x*x%Ha, y>>=1)
        if (y&1) ret=ret*x%Ha;
    return ret;
}

long long C(int x, int y)
{
    if (x<y) return 0;
    if (x==y) return 1;
    long long ret= jc[x-y]*jc[y] %Ha;
    ret=jc[x]*ksm(ret, Ha-2) %Ha;
    return ret;
}

void dfs_solve(int X, int F)
{
    int a, b;
    a=A[X];
    b=B[X];

    ans+= C(b, k)-C(b-a, k);
    ans%=Ha;
    
    For(X) if (o!=F)
    	dfs_solve(o, X);
}




void solve()
{
    scanf("%d%d%d",&n,&m,&k);


    num=0;
	for (int i=1; i<=n; i++) {
		head[i]=0;
		fa[i]=dep[i]=0;
		A[i]=0;
		B[i]=0;
	}

    for (int i=1, uu, vv; i<n; i++) {
        scanf("%d%d", &uu, &vv);
        to[++num]=vv, nxt[num]=head[uu], head[uu]=num;
        to[++num]=uu, nxt[num]=head[vv], head[vv]=num;
    }

	dfs1(1, 1);
	get_ST();

    for (int i=1, xx, yy; i<=m; i++) {
        scanf("%d%d", &xx, &yy);
        fun(xx,yy);
    }
	dfs3(1, 1);

    ans=0;
    dfs_solve(1, 1);
	ans=(ans+Ha)%Ha;
    printf("%lld\n",ans);
}


int main()
{
    jc[0]=jc[1]=1;
    for (int i=2; i<MAXN; i++)
        jc[i]=jc[i-1]*i%Ha;

    int ttt;
    scanf("%d",&ttt);

    while (ttt--) {
        solve();
    }
}


/*
10 10 6
2 1
3 2
4 1
5 3
6 1
7 1
8 1
9 5
10 1
6 9
2 1
1 1
1 2
1 1
9 3
1 1
9 1
1 1
1 1



1
5 5 3
2 1
3 1
4 1
5 3
4 1
1 4
1 5
1 5
2 1





1
5 10 4
2 1
3 3
4 3
5 1
1 2
1 1
1 3
1 1
1 1
1 5
4 1
5 3
4 1
1 5

*/
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值