题目
2018-2019 ACM-ICPC, Asia Xuzhou Regional Contest G. Rikka with Intersections of Paths
https://codeforces.com/gym/102012/problem/H
大意就是给一棵 n 个点的树,和其中 m 条简单路径,以及一个数 k 。要你在这 m 个路径中选 k 个,使得这 k 个路径至少包含一个公共点。问有多少种选法。
n,m: 1e5
k: m
多组数据
思路
- 考虑对每个点,要是知道通过它的路径数为 b, 那么以它为公共点的选法有 C(b, k) 种。
- 但单纯这样对每个点的这个累加,会有重复。
- 思考有没有方法能让一个选法与唯一一个节点对应。于是想到这样一个对应方法:对每个结点,只计算 “至少有一条路径,其以这个点为顶点” 的选法。(这里顶点指的是路径端点的最近公共祖先)。
- 这样一来,每个点得出贡献的那些选法都是独一无二的了。(简单说明一下,因为对于一个点 X,和一条经过它的路径 L,必有 L 的顶点不是 X 就是 X 的祖先,不可能是 X 的后代。也就是说 X 是这个选法中深度最深的顶点。而一个选法不可能有两个不同的深度最深的顶点(否则这两个顶点对应的路径就没有公共点了),所以这个对应方法是唯一的了)
- 每个结点的贡献就是 C(b,k) - C(b-a,k)。其中 b 是通过它的路径数,a 是以它为顶点的路径数。
- a 好求,每个路径求个 LCA 然后弄个数组累加一下就行。
- b 一开始想的是用树链剖分把路径上的点全都 ++,结果多个 lg 就 TLE 了,时间卡得很紧啊。然后改成用树上差分,每个路径,给端点++,LCA–,fa[LCA]–,全部路径弄完后,再从叶子节点累加上去,就是需要的 b 数组了。可能有些细节,代码见。(其实应该先想差分的,差分做不了才用得上剖分……)
- 时间复杂度 O(mlgn + n) (大概)
代码
#include <bits/stdc++.h>
using namespace std;
#define MAXN 300005
#define Ha 1000000007
#define For(x) for (int h=head[x],o=to[h]; h; o=to[h=nxt[h]])
int head[MAXN*3];
int nxt[MAXN*3], to[MAXN*3];
int num;
//vector<int> graph[MAXN];
int n,m,k;
long long jc[MAXN];
long long ans;
int fa[MAXN];
int ST[MAXN][30];
int dep[MAXN];
int A[MAXN];
int B[MAXN];
void dfs1(int X, int F)
{
dep[X]=dep[F]+1;
fa[X]=F;
ST[X][0]=F;
For(X) if (o!=F)
dfs1(o, X);
}
void get_ST()
{
for (int j=1; j<=20; j++)
for (int i=1; i<=n; i++)
ST[i][j]=ST[ST[i][j-1]][j-1];
}
int get_LCA(int X, int Y)
{
if (dep[X]<dep[Y])
swap(X, Y);
for (int tmp=dep[X]-dep[Y], i=0; tmp>0; tmp>>=1, i++)
if (tmp&1)
X=ST[X][i];
if (X==Y) return X;
for (int i=20; fa[X]!=fa[Y] && i>=0; i--)
if (ST[X][i]!=ST[Y][i])
X=ST[X][i], Y=ST[Y][i];
return fa[X];
}
void dfs3(int X, int F)
{
For(X) if (o!=F) {
dfs3(o, X);
B[X]+=B[o];
}
}
void fun(int X, int Y)
{
int LCA=get_LCA(X, Y);
A[LCA]++;
B[X]++;
B[Y]++;
B[LCA]--;
if (LCA!=1)
B[fa[LCA]]--;
}
long long ksm(long long x, int y)
{
long long ret=1;
for (; y; x=x*x%Ha, y>>=1)
if (y&1) ret=ret*x%Ha;
return ret;
}
long long C(int x, int y)
{
if (x<y) return 0;
if (x==y) return 1;
long long ret= jc[x-y]*jc[y] %Ha;
ret=jc[x]*ksm(ret, Ha-2) %Ha;
return ret;
}
void dfs_solve(int X, int F)
{
int a, b;
a=A[X];
b=B[X];
ans+= C(b, k)-C(b-a, k);
ans%=Ha;
For(X) if (o!=F)
dfs_solve(o, X);
}
void solve()
{
scanf("%d%d%d",&n,&m,&k);
num=0;
for (int i=1; i<=n; i++) {
head[i]=0;
fa[i]=dep[i]=0;
A[i]=0;
B[i]=0;
}
for (int i=1, uu, vv; i<n; i++) {
scanf("%d%d", &uu, &vv);
to[++num]=vv, nxt[num]=head[uu], head[uu]=num;
to[++num]=uu, nxt[num]=head[vv], head[vv]=num;
}
dfs1(1, 1);
get_ST();
for (int i=1, xx, yy; i<=m; i++) {
scanf("%d%d", &xx, &yy);
fun(xx,yy);
}
dfs3(1, 1);
ans=0;
dfs_solve(1, 1);
ans=(ans+Ha)%Ha;
printf("%lld\n",ans);
}
int main()
{
jc[0]=jc[1]=1;
for (int i=2; i<MAXN; i++)
jc[i]=jc[i-1]*i%Ha;
int ttt;
scanf("%d",&ttt);
while (ttt--) {
solve();
}
}
/*
10 10 6
2 1
3 2
4 1
5 3
6 1
7 1
8 1
9 5
10 1
6 9
2 1
1 1
1 2
1 1
9 3
1 1
9 1
1 1
1 1
1
5 5 3
2 1
3 1
4 1
5 3
4 1
1 4
1 5
1 5
2 1
1
5 10 4
2 1
3 3
4 3
5 1
1 2
1 1
1 3
1 1
1 1
1 5
4 1
5 3
4 1
1 5
*/