题目链接
听说是当年徐州现场赛的银牌的(从银牌到打铁全是2题队,太可怕了)。
很容易想到的就是树上差分,但是显然是算重了。怎么去重是重点。。这里有一个理论。
一个树上任意两条路径如果有交点的话,那么这些交点中肯定有一个为两条路径中的一条路径两端点的lca
这就很奇妙了?不太清楚是怎么推出来的。
如果借此理论,题目就好做了。对于每个点:
ans = C(cnt, k) - C(cnt - lcacnt, k), 其中cnt是经过该点的路线总数,lcacnt是以该点位为lca的路线总数。
然后把ans累加就行了。
下面是ac代码:
#include <iostream>
#include <cmath>
#include <queue>
#include <cstring>
#include <cstdlib>
#include <string>
#include <vector>
#include <algorithm>
#include <map>
#define ll long long
using namespace std;
const int N = 3e5+5;
const int mod = 1e9 + 7;
int n, m, k;
int sum, cnt, tot;
int f[N][20], d[N], ans[N], dif[N];
int ne[N<<1], he[N], ver[N<<1];
ll q[N], inv[N];
int sumlca[N];
int t;
ll _pow(ll a, ll b)
{
ll res = 1;
while(b)
{
if (b&1) res = res*a%mod;
a=a*a%mod;
b>>=1;
}
return res%mod;
}
void initp()
{
q[1]=1;
for (int i =2; i < N; i++) q[i] = q[i-1]*i%mod;
inv[N - 1] = _pow(q[N-1], mod-2);
for (int i = N-2; i >= 1; i--) inv[i] = (inv[i+1]*(i+1))%mod;
}
void init(int n)
{
memset(d, 0, sizeof(d));
memset(f, 0, sizeof(f));
t =(log(n)/log(2)) + 1;
memset(he, 0, sizeof(he));
memset(dif, 0, sizeof(dif));
memset(sumlca, 0, sizeof(sumlca));
sum = cnt = tot = 0;
}
void add(int x, int y)
{
ver[++tot] = y;
ne[tot] = he[x];
he[x] = tot;
}
void bfs()
{
d[1] = 1;
queue<int> q;
while(q.size()) q.pop();
q.push(1);
while(q.size())
{
int te = q.front();
q.pop();
for (int i = he[te]; i; i = ne[i])
{
int v = ver[i];
if (d[v]) continue;
d[v] = d[te] + 1;
f[v][0] = te;
for (int j = 1; j <= t; j++)
f[v][j] = f[f[v][j-1]][j-1];
q.push(v);
}
}
}
int lca(int x, int y)
{
if (d[x] > d[y]) swap(x, y);
for (int i = t; i >= 0; i--)
{
if (d[f[y][i]] < d[x]) continue;
y = f[y][i];
}
if (x == y) return x;
for (int i = t; i >= 0; i--)
if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
void dfs(int cur)
{
int sum = 0;
for (int i = he[cur]; i; i = ne[i])
{
int v = ver[i];
if (d[v] < d[cur]) continue;
dfs(v);
sum += ans[v];
}
sum += dif[cur];
ans[cur] = sum;
}
ll C(int n, int m)
{
if (n<0||m<0||m>n) return 0;
if (m==0||m==n) return 1;
return q[n]*inv[n-m]%mod*inv[m]%mod;
}
int main()
{
int t0;
cin >> t0;
initp();
while(t0--)
{
scanf("%d%d%d", &n,&m, &k);
init(n);
for (int i = 1; i < n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
add(x, y); add(y, x);
}
bfs();
for (int i = 0; i < m; i++)
{
int x, y;
scanf("%d%d", &x, &y);
int lp = lca(x, y);
sumlca[lp]++;
dif[x]++; dif[y]++;
dif[lp]--;
if (lp != 1) dif[f[lp][0]]--;
}
dfs(1);
ll ans0 = 0;
for (int i = 1; i <= n; i++)
ans0 = ((ans0 + ((C(ans[i], k) - C(ans[i]- sumlca[i], k)) + mod)%mod)%mod + mod)%mod;
printf("%lld\n", ans0);
}
return 0;
}