题目
Time Limits: 2000 ms Memory Limits: 262144 KB
Description
你想举行一场派对,有m个朋友会来参加。
你有n个房间,由n-1条道路连接,形成一个树结构。你需要给每个朋友安排一个房间,满足以下条件:
每个朋友住在一个单独的房间;
存在一个房间(不一定要有人),使得每个朋友到它的距离不超过k。求方案数对998244353取模的结果。
Input
第一行三个整数n,m,k,接下来n-1行每行三个整数ui,vi,wi,表示存在一条连接ui和vi,长度为wi的道路。
Output
输出一行一个整数表示答案。
Sample Input
5 2 7
1 2 4
3 2 8
2 4 2
4 5 6
Sample Output
12
Data Constraint
Subtask 1 (8pts):n<=20。
Subtask 2 (31pts):n<=5000。
Subtask 3 (21pts):m=2且wi=1。
Subtask 4 (40pts):无特殊限制。
对于全部数据,1<=m<=n<=10^5,1<=k,wi<=10^9。
题解
考虑把这一颗树变成一颗以1位根的树
那么对于每一种合法的方案,一定有一个最高的点(也就是深度最浅的点)使得它可以成为这一种方案中的和每一个小朋友的距离都不超过k的点,并且对于任意深度更浅的点都有一个选取的点与其距离>k,我们考虑在这个最高的点统计这一种方案
那么对于一个点x,我们可以选取的点就是与它<=k的点,并且我们至少要选取一个和fa[x]的距离>k的点,使得这种方案对应的最高点是x
设f[i]表示与点i距离<=k的点的个数,这个东西可以用点剖来求
注意到与x距离<=k,与fa[x]距离>k的点一定在x的子树中,那么我们直接使用线段树合并来找到这样的所有点,然后问题就变成了现在有f[x]个点,要从中选取m个点,并且有z个点是必须至少选一个的
正难则反,直接全部方案减去全部不选就好了
贴代码
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<cmath>
#define fo(i,a,b) for(i=a;i<=b;i++)
#define ll long long
using namespace std;
const int maxn=1e5+5,md=998244353;
struct P{
ll x,y;
}w[maxn];
int tree[maxn*120][3],root[maxn],size[maxn],f[maxn];
ll zo[maxn],c[maxn],mx,ni[maxn],wc[maxn],dis[maxn],ans;
int fi[maxn],ne[maxn*2],dui[maxn*2],dui1[maxn*2],qc[maxn];
int i,j,k,l,m,n,x,y,now,p,z,zz,rc,rot;
bool bz[maxn];
int cmp(P x,P y){
return x.x<y.x;
}
void add(int x,int y,int z){
if (fi[x]==0) fi[x]=++now; else ne[qc[x]]=++now;
dui[now]=y; dui1[now]=z; qc[x]=now;
}
ll C(int x,int y){
return (((ni[x]*ni[y-x])%md)*wc[y])%md;
}
ll quickmi(ll x,int y){
ll t1=1;
while (y){
if ((y & 1)==1) t1=(x*t1)%md;
x=(x*x)%md;
y=y/2;
}
return t1;
}
void ge_size(int x,int y){
size[x]=1;
for(int i=fi[x];i;i=ne[i]){
if (dui[i]==y || bz[dui[i]]==true) continue;
ge_size(dui[i],x);
size[x]=size[x]+size[dui[i]];
}
}
void ge_root(int x,int y){
int s=0;
for(int i=fi[x];i;i=ne[i]){
if (dui[i]==y || bz[dui[i]]==true) continue;
s=max(s,size[dui[i]]);
ge_root(dui[i],x);
}
s=max(s,rc-size[x]+1);
if (s<z){
z=s; rot=x;
}
}
void ge_dis(int x,int y){
for(int i=fi[x];i;i=ne[i]){
if (dui[i]==y || bz[dui[i]]==true) continue;
dis[dui[i]]=dis[x]+dui1[i]; w[++p].x=dis[dui[i]]; w[p].y=dui[i];
ge_dis(dui[i],x);
}
}
void calc(int x,int y,int z){
dis[x]=y;
p=1; w[1].x=y; w[1].y=x;
ge_dis(x,0);
sort(w+1,w+1+p,cmp);
int i; l=p;
fo(i,1,p){
while (w[l].x+w[i].x>k && l>=1) l--;
f[w[i].y]=f[w[i].y]+z*l;
if (l<=0) break;
}
}
void dfs(int x){
ge_size(x,0);
z=n; rc=size[x];
ge_root(x,0);
calc(rot,0,1);
int rt=rot;
bz[rt]=true;
for(int i=fi[rt];i;i=ne[i]){
if (bz[dui[i]]==true) continue;
calc(dui[i],dui1[i],-1);
dfs(dui[i]);
}
}
void change(int v,ll l,ll r,ll x){
tree[v][2]++;
if (l==r) return;
ll mid=(l+r)/2;
if (x<=mid){
if (! tree[v][0]) tree[v][0]=++p;
change(tree[v][0],l,mid,x);
} else{
if (! tree[v][1]) tree[v][1]=++p;
change(tree[v][1],mid+1,r,x);
}
}
void find(int v,ll l,ll r,ll x,ll y){
if (!v) return;
if (l==x && r==y) z=z+tree[v][2]; else{
ll mid=(l+r)/2;
if (y<=mid) find(tree[v][0],l,mid,x,y); else
if (x>mid) find(tree[v][1],mid+1,r,x,y); else{
find(tree[v][0],l,mid,x,mid);
find(tree[v][1],mid+1,r,mid+1,y);
}
}
}
void mira_adin(int rt,int x,int y,ll go){
change(root[rt],1,mx,zo[rt]+go);
for(int i=fi[x];i;i=ne[i]){
if (dui[i]==y) continue;
mira_adin(rt,dui[i],x,go+dui1[i]);
}
}
void dfs(int x,int y){
int big=0,ju,la;
size[x]=1;
for(int i=fi[x];i;i=ne[i]){
if (dui[i]==y){
la=dui1[i];
continue;
}
dfs(dui[i],x);
size[x]=size[x]+size[dui[i]];
if (size[dui[i]]>size[big]){
big=dui[i]; ju=dui1[i];
}
}
if (big==0) change(root[x],1,mx,zo[x]); else{
root[x]=root[big]; zo[x]=zo[big]-ju;
change(root[x],1,mx,zo[x]);
for(int i=fi[x];i;i=ne[i]){
if (dui[i]==y || dui[i]==big) continue;
mira_adin(x,dui[i],x,dui1[i]);
}
}
z=0;
find(root[x],1,mx,zo[x],zo[x]+k); zz=z;
z=0;
if (y) find(root[x],1,mx,zo[x]-la,zo[x]-la+k);
z=zz-z;
if (f[x]>=m && z>0) ans=(ans+C(m,f[x]))%md;
if (f[x]-z>=m && z>0) ans=(ans-C(m,f[x]-z)+md)%md;
}
int main(){
// freopen("t2.in","r",stdin);
scanf("%d%d%d",&n,&m,&k);
mx=2e14;
fo(i,1,n){
zo[i]=1e14;
root[i]=i;
}
wc[0]=1;
fo(i,1,n) wc[i]=(wc[i-1]*i)%md;
fo(i,0,n)
ni[i]=quickmi(wc[i],md-2);
fo(i,m,n) c[i]=C(m,i);
fo(i,1,n-1){
scanf("%d%d%d",&x,&y,&z);
add(x,y,z); add(y,x,z);
}
dfs(1);
memset(size,0,sizeof(size));
p=n;
dfs(1,0);
ans=(ans*wc[m])%md;
printf("%d\n",ans);
return 0;
}