http://acm.hdu.edu.cn/showproblem.php?pid=6368
题意: 构造最小方差生成树
首先我们先从方差的定义出发可知 方差必定是一段连续的值 方差最小
最暴力的方法是 我们枚举平均数 然后转化最小生成树check 明显的复杂度太高
我们考虑到每条边都有一个作用区间 因为当加入某一条边形成环时 为了维持稳定的方差 我们需要将这个环里面的最小边踢掉(至于为什么 可以手动模拟一下
这样的话 我们就把问题简化成插入或者加入某一条边 来维护Σwe^2和Σwe 很显然的 这是LCT加速即可
这个题 在场上并没有转化成区间 不然还可以搏一下 赛后补题遇到了三个问题:
1.在加边和删边维护LCT是取最小方差不应该在模的意义下
2.因为我们考虑的是枚举每一条边的最小方差 也就是说每条边都要构造一颗最小方差树 那也就是要离线将后面权值比较大但是又独一无二的边先加入进来 也就是当处理有环时应该同时加入边和最小边(不用分先后)
3.然后就是写戳了...还好抢救成功
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#define ll long long
const int MAXN=5e5+10;
using namespace std;
const int mod=998244353;
const int inf=1e9;
int pre[MAXN],ch[MAXN][2],res[MAXN],pos[MAXN];
ll key[MAXN];
int cnt;
int st[MAXN],tp;
pair<int,int>d[MAXN];
bool rt[MAXN];int n,m;
typedef struct node{
int u,v,vul;
friend bool operator<(node aa,node bb){
return aa.vul<bb.vul;
}
}node;
node que[MAXN];
ll ans1,ans2;
inline int newnode(ll vul){
int x;x=++cnt;pre[x]=ch[x][0]=ch[x][1]=res[x]=0;key[x]=vul;
rt[x]=1;pos[x]=x;
return x;
}
inline void reverse(int x){
if(!x)return ;
swap(ch[x][0],ch[x][1]);
res[x]^=1;
}
inline void push(int x){
if(res[x]){
reverse(ch[x][0]);
reverse(ch[x][1]);
res[x]^=1;
}
}
inline void up(int r){
pos[r]=r;
if(key[pos[ch[r][0]]]<key[pos[r]]) pos[r]=pos[ch[r][0]];
if(key[pos[ch[r][1]]]<key[pos[r]]) pos[r]=pos[ch[r][1]];
}
inline bool pd1(int x){
return ch[pre[x]][0]!=x&&ch[pre[x]][1]!=x;
}
inline void P(int x){
int i;st[++tp]=x;
for(i=x;!pd1(i);i=pre[i]) st[++tp]=pre[i];
for(;tp;tp--) push(st[tp]);
}
/*void P(int r){
if(!rt[r]) P(pre[r]);
push(r);
}*/
inline void rotate(int x,int kind){
int y=pre[x];
pre[ch[x][kind]]=y;ch[y][!kind]=ch[x][kind];
if(!rt[y])ch[pre[y]][ch[pre[y]][1]==y]=x;
else rt[y]=0,rt[x]=1;
pre[x]=pre[y];ch[x][kind]=y;pre[y]=x;
up(y);
}
inline void splay(int x){
P(x);
while(!rt[x]){
if(rt[pre[x]])rotate(x,ch[pre[x]][0]==x);
else{
int y=pre[x];int kind=ch[pre[y]][0]==y;
if(ch[y][kind]==x)rotate(x,!kind),rotate(x,kind);
else rotate(y,kind),rotate(x,kind);
}
}
up(x);
}
inline void access(int x){
int y=0;
while(x){
splay(x);
if(ch[x][1])rt[ch[x][1]]=1,pre[ch[x][1]]=x,ch[x][1]=0;
if(y)rt[y]=0;
ch[x][1]=y;up(x);
y=x;
x=pre[x];
}
}
inline void mroot(int u){
access(u);
splay(u);
reverse(u);
}
inline bool pd(int u,int v){
while(pre[u])u=pre[u];
while(pre[v])v=pre[v];
return u==v;
}
inline void Link(int u,int v,ll vul){
int t1=newnode(vul);d[t1].first=u;d[t1].second=v;
mroot(u);mroot(v);
pre[u]=t1;
pre[v]=t1;
}
void destory1(int u,int v){
mroot(u);access(v);splay(v);
rt[u]=rt[v]=1;pre[u]=pre[v]=0;ch[v][0]=0;
up(u);up(v);
}
int vis[MAXN],id;
inline void destory(int u,int v){
mroot(u);access(v);splay(v);
int t1=pos[v];id=t1;
destory1(t1,d[t1].first);destory1(t1,d[t1].second);
}
inline ll ksm(ll a,ll b,ll c){
ll ans=1;
while(b){
if(b&1)ans=ans*a%c;
a=a*a%c;b=b>>1;
}
return ans;
}
typedef struct Node{
int vul;ll sum1,sum2;
friend bool operator<(Node aa,Node bb){
if(aa.vul==bb.vul)return aa.sum1<bb.sum1;
return aa.vul<bb.vul;
}
}Node;
Node p[MAXN];
int main(){
//freopen("1.in","r",stdin);
int _;scanf("%d",&_);
while(_--){
scanf("%d%d",&n,&m);
cnt=0;key[0]=inf;ans1=0;ans2=0;ll ans3=ksm(n-1,mod-2,mod);
for(int i=1;i<=m;i++)scanf("%d%d%lld",&que[i].u,&que[i].v,&que[i].vul),vis[i]=-1;
for(int i=1;i<=n;i++)newnode(inf),vis[i]=-1;
sort(que+1,que+m+1);int t;
int cnt1=0;ll h1=1e18,h2=1e18;
for(int i=1;i<=m;i++){
if(pd(que[i].u,que[i].v)==0)Link(que[i].u,que[i].v,que[i].vul),p[++cnt1].vul=-1*inf,p[cnt1].sum1=1LL*que[i].vul*1LL*que[i].vul,p[cnt1].sum2=que[i].vul;
else{
destory(que[i].u,que[i].v);t=id;
vis[t-n]=i;p[++cnt1].vul=key[t]+que[i].vul;p[cnt1].sum1=1ll*que[i].vul*que[i].vul;
p[cnt1].sum2=que[i].vul;
Link(que[i].u,que[i].v,que[i].vul);
}
}
for(int i=1;i<=m;i++){
if(vis[i]==-1)p[++cnt1].vul=inf,p[cnt1].sum1=-1*1LL*que[i].vul*que[i].vul-1,p[cnt1].sum2=-1*que[i].vul;
else p[++cnt1].vul=que[vis[i]].vul+que[i].vul,p[cnt1].sum1=-1*1LL*que[i].vul*que[i].vul-1,p[cnt1].sum2=-1*que[i].vul;
}
sort(p+1,p+cnt1+1);
//for(int i=1;i<=cnt1;i++)cout<<p[i].vul<<" "<<p[i].sum1<<" "<<p[i].sum2<<endl;
int cnt2=0;
for(int i=1;i<=cnt1;i++){
if(p[i].sum1>=0)cnt2++;
else cnt2--,p[i].sum1++;
ans1+=p[i].sum1;ans2+=p[i].sum2;
// cout<<ans1<<" "<<ans2<<endl;
if(cnt2==n-1){
ll h = ans1-(ans2/(n - 1))*ans2-(ans2%(n-1))*ans2/(n - 1);
ll l=-((ans2%(n-1))*ans2%(n - 1));
if(l < 0) l += n - 1 , h--;
// cout<<h<<"====="<<l<<endl;
if(make_pair(h1,h2)>make_pair(h,l)){
h1=h;h2=l;
}
}
}
printf("%lld\n",((h1+h2*ans3)%mod*ans3%mod));
}
return 0;
}