题目链接:http://codeforces.com/gym/102012/problem/G
对每个点u定义贡献为:路径交中最高的(最靠近根节点)公共节点为u时的方案数
则通过树上差分得出经过u的路径数x,和lca恰好为u的路径数y
其贡献为c(x,k)-c(x-y,k)(因为如果全部k条都选择lca!=u的路径的话,路径交中最高的(最靠近根节点)公共节点至少是u的父节点,不应计算在内)
代码像补丁一样好丑
#include<bits/stdc++.h>
#define mod 1000000007
using namespace std;
const int N = 3e5 + 5;
vector<int> edge[N];
long long num[N],yes[N],vis[N];
int n,m,k;
long long ans;
long long F[300010];
void init(long long p)
{
F[0] = 1;
for(int i = 1;i <= p;i++)
F[i] = F[i-1]*i % mod;
}
long long inv(long long a,long long m)
{
if(a == 1)return 1;
return inv(m%a,m)*(m-m/a)%m;
}
long long Lucas(long long n,long long m,long long p)
{
if(n<=0)
return 0;
long long ans = 1;
while(n&&m)
{
long long a = n%p;
long long b = m%p;
if(a < b)return 0;
ans = ans*F[a]%p*inv(F[b]*F[a-b]%p,p)%p;
n /= p;
m /= p;
}
return ans;
}
/// LCA
int pre[32][N]; /// Ancestor nodes
int dep[N]; /// depth of nodes
//int dis[N]; /// distance to root
void dfs (int now, int pa) {
pre[0][now] = pa;
for (int i = 0; i < edge[now].size(); ++i) {
int son = edge[now][i];
if (son == pa) continue;
// for (int son : edge[now]) if (son != pa) {
dep[son] = dep[now] + 1;
dfs(son, now);
}
}
void init_lca (int n) {
dfs(1, -1); /// st, -1 根据起点标号dfs
for (int k = 0; k < 20; k ++) {
for (int v = 1; v <= n; v ++) {
if (pre[k][v])
pre[k + 1][v] = pre[k][ pre[k][v] ];
}
}
}
int lca (int u, int v) {
if (dep[u] > dep[v]) swap(u, v);
for (int k = 0; k < 20; k ++)
if ((dep[v] - dep[u]) & (1 << k))
v = pre[k][v];
if (u == v) return u;
for (int k = 19; k >= 0; k --) if (pre[k][u] != pre[k][v]) {
u = pre[k][u];
v = pre[k][v];
}
return pre[0][u];
}
int dfs1(int x)
{
int sum=num[x];
vis[x]=1;
for(int i=0;i<edge[x].size();++i)
{
if(vis[edge[x][i]]==0)
sum+=dfs1(edge[x][i]);
}
// printf("num%d=%d yes%d=%d\n",x,sum+yes[x],x,yes[x]);
if(sum+yes[x]<k)
return sum;
ans=(ans+Lucas(sum+yes[x],k,mod)-Lucas(sum,k,mod)+mod)%mod;
// printf("ans=%d\n",ans);
// if(yes[x]<k)
// ans+=Lucas(k,yes[x],mod);
// else
// ans-=Lucas(yes[x],k,mod);
// printf("sum=%d\n",sum);
return sum;
}
int main()
{
int t;
init(300000);
scanf("%d",&t);
while(t--)
{
ans=0;
memset(num,0,sizeof num);
memset(yes,0,sizeof yes);
memset(vis,0,sizeof vis);
memset(pre,0,sizeof pre);
memset(dep,0,sizeof dep);
scanf("%d%d%d",&n,&m,&k);
for(int i=1;i<=n;i++)
edge[i].clear();
for(int i=1;i<=n-1;i++)
{
int a,b;
scanf("%d%d",&a,&b);
edge[a].push_back(b);
edge[b].push_back(a);
}
init_lca(n);
for(int i=1;i<=m;i++)
{
int a,b;
scanf("%d%d",&a,&b);
int w=lca(a,b);
// printf("w=%d\n",w);
num[a]++;
num[b]++;
// if(w!=a&&w!=b)
num[w]-=2;
yes[w]++;
}
dfs1(1);
printf("%lld\n",ans%mod);
}
return 0;
}
//1
//3 6 2
//1 2
//1 3
//1 1
//2 2
//3 3
//1 2
//1 3
//2 3