题面
有一个n个点A+B条边的无向连通图,有一变量x,每条边的权值都是一个关于x的简单多项式,其中有A条边的权值是k+x,另外B条边的权值是k-x,如果只保留权值形如k+x的边,那么这个图仍是一个连通图,如果只保留权值形如k-x的边,这个图也依然是一个连通图。
给出q组询问,每组询问给出x的值,问此时这个无向连通图的最小生成树权值是多少。
对于100%的数据,1<=n,q<=100000 , n-1<=A,B<=200000, 0<=k<=10^9 , -10^9<=v<=10^9
sol
samjam还是睡觉的时候想的题果然神乎其技
不难发现,答案应该是一个关于x的分段函数,那么最多有多少段呢? 答案是B+1段。下面就来分析。
考虑到x无穷小的情况是全取正边,x无穷大的情况是全取负边。应该能猜到是负边一条条替换正边。
怎么替换呢?
首先要有一个显然的结论: 无论x取何值,正负边都只会选各自最小生成树里的。所以下文会将所有不在最小生成树里的边自动舍去。
当x取负无穷的时候,按照权值将边排序,是这样的:
当x慢慢增大,二者就会交在一起。 我们考虑负边中的某条边e(x,y)
记正边最小生成树上路径(x,y)中,k最大的边为t.
不难发现,当x增大到让e排在t的前面时,t就会被替换掉。而且无论x再怎么增大,t都不会再被替换回来(按顺序取边,t一定不会取到)。 那么这个替换点就是
x=ke−kt2
也就是一个函数上的转折点。
又因为负边肯定不会替换k比他小的负边(一定不优),所以我们只需要从小到大加入负边,算出每一条负边替换掉他环中剩下的最大正边的时间点 ,并且将这条正边删掉。
(因为负边最终一定会替换正边,而被删掉的正边已经有k更小的负边来替换了。),
然后将这些时间点排序就可以了。每一段函数的斜率就是正边数-负边数。
最小生成树用lct维护,边权建成一个点。
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#define halfup(x) ((x>>1)+(x&1))
#define min(a,b) ((a)<(b)?(a):(b))
using namespace std;
typedef long long ll;
const ll N = 1e5+10,M = 2e5+10,MX = 4e5+10,INF = 2e9;
ll n,A,B,q,tot,Q;
struct LCT{
ll ptot,fa[MX],c[MX][2],va[MX],mi[MX],rev[MX];
#define get(x) ((c[fa[x]][1]) == (x))
#define isroot(x) (fa[x]==0 || (c[fa[x]][0]!=(x)) && (c[fa[x]][1]!=(x)))
inline void update(ll x) {
mi[x]=x;
if (c[x][0] && va[mi[c[x][0]]]>va[mi[x]]) mi[x]=mi[c[x][0]];
if (c[x][1] && va[mi[c[x][1]]]>va[mi[x]]) mi[x]=mi[c[x][1]];
}
inline void putrev(ll x) {
if (x) rev[x]^=1; swap(c[x][0],c[x][1]);
}
inline void down(ll x) {
if (rev[x]) putrev(c[x][0]),putrev(c[x][1]),rev[x]=0;
}
ll Q[MX];
void downp(ll x) {
while (x) Q[++Q[0]]=x,x=fa[x];
while (Q[0]) down(Q[Q[0]--]);
}
void rotate(ll x) {
ll y=fa[x],z=get(x);
if (c[x][1-z]) fa[c[x][1-z]]=y;
c[y][z]=c[x][1-z];
if (!isroot(y)) c[fa[y]][get(y)]=x;
fa[x]=fa[y];
fa[y]=x;
c[x][1-z]=y;
update(y);
}
void splay(ll x) {
downp(x);
while (!isroot(x)) {
if (!isroot(fa[x]))
if (get(x) == get(fa[x])) rotate(fa[x]);
else rotate(x);
rotate(x);
}
update(x);
}
void access(ll x) {
for (ll i=0; x; i=x,x=fa[x]) {
splay(x);
c[x][1]=i;
update(x);
}
}
void mkroot(ll x) {
access(x);
splay(x);
putrev(x);
}
ll gt(ll x) {
if (c[x][0]) return gt(c[x][0]);
else return x;
}
ll getroot(ll x) {
access(x);
splay(x);
return gt(x);
}
} T;
struct edge{
ll x,y;
ll k;
edge(ll x0=0,ll y0=0,ll k0=0) {
x=x0,y=y0,k=k0;
}
} ez[M],ef[M],zmst[M],fmst[M];
bool cmp(const edge& a,const edge& b) {return a.k<b.k;}
ll totz,totf;
ll f[N];
ll gf(ll x) {return f[x]?f[x]=gf(f[x]):x;}
char fff;
int read(ll &x) {
int v=1;
while ((fff=getchar())<'0' || fff>'9') if (fff=='-') v=-1;
x=fff-'0'; while ((fff=getchar())>='0' && fff<='9') x=x*10+fff-'0';
x*=v;
}
void makeMST(edge *e,ll &tot,edge *mst) {
memset(f,0,sizeof f);
ll mtot=0;
for (ll i=1; i<=A; i++) {
ll u,v,w;
read(u),read(v),read(w);
e[++tot]=edge(u,v,w);
}
sort(e+1,e+1+tot,cmp);
for (ll j=1,fx,fy,js=0; js<n-1; j++) {
if ((fx=gf(e[j].x)) != (fy=gf(e[j].y))) {
f[fx]=fy, js++;
mst[++mtot]=e[j];
}
}
}
ll ptoe[MX],px[MX],py[MX];
ll nowsum;
void link(ll x,ll y,ll v,ll no) {
T.mkroot(x);
ll k; ptoe[k=++T.ptot]=no;
T.va[k]=v;
T.fa[x]=k,T.fa[k]=y;
px[k]=x,py[k]=y;
}
void cut(ll e) {
T.mkroot(px[e]);
T.access(py[e]);
T.splay(px[e]);
T.splay(T.c[T.c[px[e]][1]][0]);
T.fa[px[e]]=T.fa[py[e]]=0;
T.va[T.fa[px[e]]] = 23333+INF;
}
ll getMin(ll x,ll y) {
T.mkroot(x);
T.access(y);
T.splay(y);
return T.mi[y];
}
struct node{
double st;
ll v;
node(double st=0,ll v=0) {
this->st=st,this->v=v;
}
friend bool operator <(const node &a,const node &b) {return a.st<b.st;}
} fun[M];
ll ft;
int main() {
freopen("graph.in","r",stdin);
freopen("graph.out","w",stdout);
cin>>n>>A>>B>>q;
for (ll i=1; i<=n; i++) T.va[i]=-INF;
T.ptot=n;
makeMST(ez,totz,zmst);
makeMST(ef,totf,fmst);
for (ll i=1; i<n; i++) {
link(zmst[i].x,zmst[i].y,zmst[i].k,i),nowsum+=zmst[i].k;
}
fun[++ft] = node(-INF,nowsum);
for (ll i=1; i<n; i++) {
ll e = getMin(fmst[i].x,fmst[i].y), g = ptoe[e];
cut(e);
link(fmst[i].x,fmst[i].y,-INF,0);
fun[++ft] = node((fmst[i].k-zmst[g].k)/2.0,-zmst[g].k+fmst[i].k);
}
sort(fun+1,fun+1+ft);
for (ll i=2; i<=ft; i++) fun[i].v+=fun[i-1].v;
for (ll i=1; i<=q; i++) {
ll Q; read(Q);
ll l=1,r=ft,ans;
while (l<=r) {
if (fun[l+r>>1].st<=Q) {
ans=l=l+r>>1;
l++;
} else r=l+r>>1,r--;
}
printf("%lld\n",fun[ans].v+Q*(n-1-(ans-1)-(ans-1)));
}
}