http://www.elijahqi.win/archives/3954
首先考虑一条边受到的影响 考虑dp 我们会枚举我的子树里有几个点等待配对 但如果这个点数超过1 那么显然是不优的 所以我们只考虑子树内有1的情况 那么我们再次把子树对应的dll 在序列上标出来 子树内的标记1 其他为0 统计我子树内的0,1序列中有多少个偶数区间和是奇数即可我们接着观察,对于一个确定的ddl区间,一条树边被匹配包含了当且仅当把它删去后树被分成两个包含奇数个ddl的块。若我们令 1 为根做dfs,则一个非根点父向边被包含当且仅当其子树内有奇数个ddl,因为这个信息我们可以o(1)维护所以我们选择用线段树为该信息 那么我们在每个节点的时候把这个节点的信息添加进去 然后线段树合并上去 每次合并之后 最后算合并完成之后我对父边的贡献即可 具体可以在线段树上维护4个信息 v[0][0/1] v[1][0/1]分别表示下标是奇数/偶数情况下 我标记0,1之后前缀和在奇数/偶数的位置 是奇数/偶数的个数 注意每次插入都需要用update维护 这个值
#include<bits/stdc++.h>
#define ll long long
using namespace std;
inline char gc(){
static char now[1<<16],*S,*T;
if (T==S){T=(S=now)+fread(now,1,1<<16,stdin);if (T==S) return EOF;}
return *S++;
}
inline int read(){
int x=0,f=1;char ch=gc();
while(!isdigit(ch)) {if (ch=='-') f=-1;ch=gc();}
while(isdigit(ch)) x=x*10+ch-'0',ch=gc();
return x*f;
}
const int N=1e5+10;
const int mod=998244353;
struct node1{
int y,next,z;
}data[N<<1];
struct node{
bool flag;int v[2][2],left,right;
inline ll get(){return (ll)v[0][1]*(v[0][0]+1)+(ll)v[1][1]*v[1][0];}
inline void init(int l,int r){
flag=v[0][1]=v[1][1]=0;int mid=l+r>>1,len=(r-l+1)>>1;
v[0][0]=v[1][0]=len;(r-l+1)&1?++v[l&1][0]:v[l&1][0];
}
}tree[N*22];
inline void inc(int &x,int v){x+=v;x>=mod?x-=mod:x;}
int h[N],n,m,num,ans,rt[N];vector<int>c[N];
inline void update(int x,int l,int r){
int mid=l+r>>1;static node lc,rc;
if (tree[x].left) lc=tree[tree[x].left];else lc.init(l,mid);
if (tree[x].right) rc=tree[tree[x].right];else rc.init(mid+1,r);
for (int i=0;i<2;++i)
for (int j=0;j<2;++j) tree[x].v[i][j]=lc.v[i][j]+rc.v[i][j^lc.flag];
tree[x].flag=lc.flag^rc.flag;
}
inline int merge(int x,int y,int l,int r){
if (!x||!y) return x+y;int mid=l+r>>1;
tree[x].left=merge(tree[x].left,tree[y].left,l,mid);
tree[x].right=merge(tree[x].right,tree[y].right,mid+1,r);
update(x,l,r);return x;
}
inline void insert1(int &x,int l,int r,int p){
if (!x) x=++num;if (l==r) {++tree[x].v[p&1][1];tree[x].flag^=1;return;}
int mid=l+r>>1;
if (p<=mid) insert1(tree[x].left,l,mid,p);
else insert1(tree[x].right,mid+1,r,p);update(x,l,r);
}
inline void dfs(int x,int fa){
for (int i=0;i<c[x].size();++i) insert1(rt[x],1,m,c[x][i]);
for (int i=h[x];i;i=data[i].next){
int y=data[i].y,z=data[i].z;if (y==fa) continue;
dfs(y,x);inc(ans,tree[rt[y]].get()*z%mod);rt[x]=merge(rt[x],rt[y],1,m);
}
}
int main(){
freopen("uoj388.in","r",stdin);
n=read();m=read();
for (int i=1;i<n;++i){
int x=read(),y=read(),z=read();
data[++num].y=y;data[num].next=h[x];h[x]=num;data[num].z=z;
data[++num].y=x;data[num].next=h[y];h[y]=num;data[num].z=z;
}num=0;
for (int i=1;i<=m;++i) c[read()].push_back(i);dfs(1,1);
printf("%d\n",ans);
return 0;
}