4015
4015
4015: 永琳的竹林迷径
题目描述
竹林可以看作是一个
n
n
n 个点的树,每个边有一个边长
w
i
w_{i}
wi,其中有
k
k
k 个关键点,永琳需要破坏这些关键点才能走出竹林迷径。
然而永琳打算将这 k k k 个点编号记录下来,然后随机排列,按这个随机的顺序走过 k k k 个点,但是两点之间她只走最短路线。初始时永琳会施展一次魔法,将自己传送到选定的k 个点中随机后的第一个点。
现在永琳想知道,她走过路程的期望是多少,答案对 998244353 998244353 998244353 取模。
输入
第一行一个数
C
a
s
e
Case
Case,表示测试点编号。(样例的编号表示其满足第
C
a
s
e
Case
Case 个测试点的性质)
下一行一个 n n n,表示树的点数。
下面 n − 1 n-1 n−1 行,每行三个数 u i , v i , w i u_{i},v_{i},w_{i} ui,vi,wi,表示一条边连接 u i u_{i} ui和 v i v_{i} vi,长度为 w i w_{i} wi。
下面一行一个数 k k k,表示关键点数。
下面一行 k k k 个数,表示 k k k 个关键点的编号。
输出
一行一个数,表示答案(对
998244353
998244353
998244353 取模)。
样例输入
1
3
1 2 1
1 3 2
3
1 2 3
样例输出
4
提示
数据范围
对于 100 % 100\% 100%的数据,保证
1 ≤ w i ≤ 1 0 4 , n ≤ 1 0 6 , k ≤ 1 0 6 1≤w_{i}≤10^4,n\leq 10^6,k\leq 10^6 1≤wi≤104,n≤106,k≤106。
题解:
首先我们考虑相邻两个关键点会产生的贡献。
考虑某两个点,确定一个位置后,另一个点会有
k
−
1
k-1
k−1种取值。
而对于这个点,就会有
k
k
k种位置,所以总共对于某两个点就是共有
k
×
(
k
−
1
)
k\times (k-1)
k×(k−1) 种方案。
而对于除了开头结尾,产生贡献的只有两个位置,即
2
×
(
k
−
2
)
2 \times(k-2)
2×(k−2),加上开头结尾各一个,所以概率即为
2
×
k
−
2
2 \times k-2
2×k−2
所以对于总的概率,就是
2
×
(
k
−
1
)
k
×
(
k
−
1
)
\frac{2\times (k-1)}{k \times (k-1)}
k×(k−1)2×(k−1)
=
=
=
2
k
\frac{2}{k}
k2
再乘上这两个点之间的路径长,就成了期望。
由于期望的和会等于和的期望,提公因式即可得答案:
a
n
s
=
2
k
×
∑
i
=
1
k
∑
j
=
i
+
1
k
d
i
s
i
,
j
ans=\frac{2}{k}\times \sum_{i=1}^{k} \sum_{j=i+1}^{k} dis_{i,j}
ans=k2×∑i=1k∑j=i+1kdisi,j
对于求两两关键点之间的路径,一般会想到使用点分治来暴力求出。
但可能过不去,于是就想到考虑每条边的贡献:
每条边会被使用的次数就是这条边两边关键点数的乘积:
对于第
i
i
i条边:
a
n
s
=
a
n
s
+
s
z
t
o
i
×
(
k
−
s
z
t
o
i
)
ans=ans+sz_{to_{i}} \times (k-sz_{to_{i}})
ans=ans+sztoi×(k−sztoi)
直接
d
f
s
dfs
dfs一遍求解、
#include<bits/stdc++.h>
using namespace std;
#define in inline
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define repd(i,a,b) for(int i=a;i>=b;i--)
#define For(i,a,b) for(int i=a;i<b;i++)
#define _(d) while(d(isdigit(ch=getchar())))
template<class T>in void g(T&t){T x,f=1;char ch;_(!)ch=='-'?f=-1:f;x=ch-48;_()x=x*10+ch-48;t=f*x;}
typedef long long ll;
const ll mod=998244353;
const int N=1e6+3;
struct E{int to,nxt;ll w;}e[N<<1];
ll head[N],tot,n,sz[N],dis[N];ll ans,k;int vis[N];
in void ins(int x,int y,ll z){
e[++tot]={y,head[x],z};head[x]=tot;
}
in void dfs(int x,int fa){
for(int i=head[x];i;i=e[i].nxt){
if(e[i].to==fa) continue;
dfs(e[i].to,x);
sz[x]+=sz[e[i].to];
ans+=e[i].w*sz[e[i].to]%mod*(k-sz[e[i].to])%mod;
ans%=mod;
}
}
in ll qp(ll x,ll y){
ll res=1;
while(y){
if(y&1) res=res*x%mod;
x=x*x%mod;y>>=1;
}return res%mod;
}
ll f[N];
int main(){
// freopen("path.in","r",stdin);freopen("path.out","w",stdout);
g(n);g(n);
For(i,1,n){
int x,y;ll z;g(x),g(y),g(z);
ins(x,y,z);ins(y,x,z);
}
g(k);
if(k==1){printf("0");return 0;}
rep(i,1,k){
int x;g(x);
sz[x]=1;
}
dfs(1,0);
ans=ans*qp(k,mod-2)%mod*2%mod;
printf("%lld\n",(ans%mod+mod)%mod);
return 0;
}