Description
给定一棵n个点的树,求有多少种选择m个点的方法,使得存在一个点x,使得所有m个点到x的距离不超过k。
n,m<=1e5
Solution
感觉自己最近智商下线的厉害啊
考虑如何使一种方案被唯一计算。
就是所有点到x的距离都<=k,并且存在至少一个点到fa[x]的距离>k
如果求出了这个东西我们就可以直接组合数计算答案。
直接求可能比较麻烦,我们不妨求出到x和fa[x]都<=k的点的数量。
这个东西可以在每一层分治树上考虑,判断x和fa[x]哪个的深度更深,那么就是更深的这个点到其他子树的点的答案。
注意判断当x为分治中心的情况。
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define rep(i,a) for(int i=lst[a];i;i=nxt[i])
using namespace std;
typedef long long ll;
int read() {
char ch;
for(ch=getchar();ch<'0'||ch>'9';ch=getchar());
int x=ch-'0';
for(ch=getchar();ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x;
}
const int N=1e5+5,Mo=998244353;
int pwr(int x,int y) {
int z=1;
for(;y;y>>=1,x=(ll)x*x%Mo)
if (y&1) z=(ll)z*x%Mo;
return z;
}
int t[N<<1],v[N<<1],nxt[N<<1],lst[N],l;
void add(int x,int y,int z) {
t[++l]=y;v[l]=z;nxt[l]=lst[x];lst[x]=l;
}
int n,m,k,x,y,z,rt,all,tot,size[N],fa[N],dep[N],father[N][17];
bool vis[N];
void get_size(int x,int y) {
size[x]=1;
rep(i,x)
if (t[i]!=y&&!vis[t[i]]) {
get_size(t[i],x);
size[x]+=size[t[i]];
}
}
void get_root(int x,int y) {
bool ok=1;
rep(i,x)
if (t[i]!=y&&!vis[t[i]]) {
get_root(t[i],x);
if (size[t[i]]>all/2) ok=0;
}
if (all-size[x]>all/2) ok=0;
if (ok) rt=x;
}
int cnt[N][17];
int f[N],g[N],inv[N],fact[N];
struct node{ll v;int id,w;}a[N];
bool cmp1(node x,node y) {return x.v<y.v;}
bool cmp2(node x,node y) {return x.id<y.id||x.id==y.id&&x.v<y.v;}
void dfs(int x,int y,ll z,int rt,int dep,int u) {
if (z<=k) g[u]++;
a[++tot].w=x;a[tot].v=z;a[tot].id=rt;father[x][dep]=y;
rep(i,x) if (t[i]!=y&&!vis[t[i]]) dfs(t[i],x,z+v[i],rt,dep,u);
}
void calc(int L,int R,int opt,int dep) {
int r=R;
fo(i,L,R) {
while (r>=L&&a[r].v+a[i].v>k) r--;
cnt[a[i].w][dep]+=opt*(r-L+1);
}
}
void solve(int x,int d) {
get_size(x,0);all=size[x];
get_root(x,0);int z=rt;
vis[z]=1;dep[x]=d;tot=0;cnt[z][d]++;
rep(i,z) if (!vis[t[i]]) dfs(t[i],z,v[i],t[i],d,fa[t[i]]==z?t[i]:z);
fo(i,1,tot) if (a[i].v<=k) cnt[z][d]++,cnt[a[i].w][d]++;
sort(a+1,a+tot+1,cmp1);
calc(1,tot,1,d);
sort(a+1,a+tot+1,cmp2);
for(int l=1,r=0;l<=tot;l=r+1) {
while (r<tot&&a[r+1].id==a[l].id) r++;
calc(l,r,-1,d);
}
rep(i,z) if (!vis[t[i]]) solve(t[i],d+1);
}
void get_father(int x,int y) {
fa[x]=y;
rep(i,x) if (t[i]!=y) get_father(t[i],x);
}
int C(int m,int n) {
if (m<n) return 0;
return (ll)fact[m]*inv[n]%Mo*inv[m-n]%Mo;
}
int main() {
n=read();m=read();k=read();
fo(i,1,n-1) {
x=read();y=read();z=read();
add(x,y,z);add(y,x,z);
}
get_father(1,0);solve(1,0);
fo(i,1,n) {
fo(j,0,16) {
f[i]+=cnt[i][j];
if (!father[i][j]) break;
}
if (fa[i]) {
fo(j,0,16) {
g[i]+=(father[i][j]==fa[i])?cnt[i][j]:cnt[fa[i]][j];
if (!father[i][j]||!father[fa[i]][j]) break;
}
}
}
int ans=0;
fact[0]=1;fo(i,1,n) fact[i]=(ll)fact[i-1]*i%Mo;
inv[n]=pwr(fact[n],Mo-2);fd(i,n-1,0) inv[i]=(ll)inv[i+1]*(i+1)%Mo;
fo(i,1,n) {
(ans+=C(f[i],m))%=Mo;
if (fa[i]) (ans+=Mo-C(g[i],m))%=Mo;
}
ans=(ll)ans*fact[m]%Mo;
printf("%d\n",ans);
return 0;
}