题目大意:给定一棵树,m次询问,每次给出k个关键点,询问这k个点之间的两两距离和、最小距离和最大距离
n<=100W,m<=50000,Σk<=2*n
思路:利用LCA单调性,每次询问的时候重新建树,在这棵树上做DP,使得总体时间复杂度降到O(nlogn)。
我的做法是维护四个数组,sum,size,_min,_max,分别表示以当前节点为根节点的子树中的所有关键点到根节点的距离的总和,共有多少个关键点,距离根节点最近的关键点的距离,距离跟节点最远的关键点的距离。此外,在做DP的同时,除了最值还要记录一下次值,用__min和__max表示。记录一个全局变脸来表示最终答案。DP方程(y表示x的一个子树的根节点):
size[x] = ∑size[y] + super[x];
sum[x] = ∑(sum[y] + length * size[y]);
_min[x] = min{_min[y] + length}
_max[x] = max{_max[y] + length}
注意还要更新一下次值
更新答案的表达式:
ans += ∑((sum[y] + length * size[y]) * (size[x] - size[y]));
ans_min = min(ans_min,_min[x] + __min);
ans_max = max(ans_max,_max[x] + __max);
没了。。记得开long long
其实完全不用记录次值,只需要两两子树合并就可以了,然后都取最值。
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#include <cmath>
#include <vector>
using namespace std;
typedef long long sint;
#define pii pair<int,int>
#define mp make_pair
#define deg 25
#define maxn 1001000
#define inf 0x3f3f3f3f
#define INF (1ll<<40)
int getint()
{
int res;char c;
while(c=getchar(),c<'0'||c>'9');
res=c-'0';
while(c=getchar(),c>='0'&&c<='9')
res=res*10+c-'0';
return res;
}
struct node
{
int u,v,len,next;
}tree[maxn];
int first[maxn],next[maxn<<1],to[maxn<<1],fa[maxn][deg];
sint dp[maxn];
int indexs,en,en2,pre[maxn],n,dfn[maxn],dep[maxn],num,T;
sint anssum=0,ansmin,ansmax=0;
pii g[maxn];
int vis[maxn],meet[maxn],slen[maxn][deg];
int siz[maxn],sumsiz;
long long lmin[maxn],lmin2[maxn],lmax[maxn],lmax2[maxn];
void build(int a,int b)
{
en++;
to[en]=b;
next[en]=first[a];
first[a]=en;
}
void add(int u,int v,int len)
{
en2++;
tree[en2].v=v;
tree[en2].len=len;
if(vis[u]!=T)
{
vis[u]=T;
tree[en2].next=0;
}
else
{
tree[en2].next=pre[u];
}
pre[u]=en2;
}
void dfs(int now)
{
int v;
indexs++;
dfn[now]=indexs;
for(int i=first[now];i;i=next[i])
{
v=to[i];
if(v==fa[now][0]) continue;
fa[v][0]=now;
slen[v][0]=1;
for(int j=1;j<deg;j++)
{
fa[v][j]=fa[fa[v][j-1]][j-1];
slen[v][j]=slen[v][j-1]+slen[fa[v][j-1]][j-1];
}
dep[v]=dep[now]+1;
dfs(v);
}
}
int lca(int u,int v)
{
if(dep[u]<dep[v]) swap(u,v);
for(int i=24;i>=0;i--)
{
if(dep[fa[u][i]]>=dep[v])
u=fa[u][i];
}
if(u==v)return v;
for(int i=24;i>=0;i--)
if(fa[u][i]!=fa[v][i])
{
u=fa[u][i];
v=fa[v][i];
}
return fa[u][0];
}
int sta[maxn],top;
inline int getlen(int x,int y)
{
int ans=0;
if(dep[x]<dep[y])swap(x,y);
for(int i=24;i>=0;i--)
if(dep[fa[x][i]]>=dep[y])
{
ans+=slen[x][i];
x=fa[x][i];
}
return ans;
}
void dfs2(int x)
{
int v,temp;
if(meet[x]==T) lmax[x]=lmax2[x]=0;
else lmax[x]=lmax2[x]=-inf;
lmin[x]=lmin2[x]=inf;
siz[x]=(meet[x]==T);
if(vis[x]!=T) return;
for(int i=pre[x];i;i=tree[i].next)
{
v=tree[i].v;
dfs2(v);
siz[x]+=siz[v];
temp=tree[i].len+(meet[v]==T?0:lmin[v]);
anssum+=(long long)tree[i].len*siz[v]*(sumsiz-siz[v]);
if(temp<lmin[x])
{
lmin2[x]=lmin[x];
lmin[x]=temp;
}
else if(temp<lmin2[x])
{
lmin2[x]=temp;
}
temp=tree[i].len+lmax[v];
if(temp>lmax[x])
{
lmax2[x]=lmax[x];
lmax[x]=temp;
}
else if(temp>lmax2[x])
{
lmax2[x]=temp;
}
if(meet[x]==T) ansmin=min(ansmin,lmin[x]);
else ansmin=min(ansmin,lmin[x]+lmin2[x]);
ansmax=max(ansmax,lmax[x]+lmax2[x]);
}
}
void virtree()
{
sta[top=1]=1;
for(int i=1;i<=num;i++)
{
int j=g[i].second;
int anc=lca(sta[top],j);
for(;dep[anc]<dep[sta[top]];)
{
if(dep[sta[top-1]]<=dep[anc])
{
int last=sta[top--];
if(sta[top]!=anc)sta[++top]=anc;
add(anc,last,getlen(anc,last));
break;
}
add(sta[top-1],sta[top],getlen(sta[top-1],sta[top])),top--;
}
if(sta[top]!=j)sta[++top]=j;
}
for(;top;)add(sta[top-1],sta[top],getlen(sta[top-1],sta[top])),top--;
}
void solve()
{
int a,b,c;
num=getint();
for(int i=1;i<=num;i++)
{
a=getint();
g[i]=mp(dfn[a],a);
meet[a]=T,sumsiz++;
}
sort(g+1,g+1+num);
virtree();
anssum=0,ansmin=inf,ansmax=0;
dfs2(1);
printf("%lld %lld %lld\n",anssum,ansmin,ansmax);
}
int main()
{
int a,b,c;
n=getint();
for(int i=1;i<n;i++)
{
a=getint();
b=getint();
build(a,b);
build(b,a);
}
dep[1]=1;
dfs(1);
int cas=getint();
while(cas--)
{
T++;en2=0,sumsiz=0;
solve();
}
return 0;
}
/*
10
1 2
3 2
4 1
5 2
6 4
7 5
8 6
9 7
10 9
5
2
5 4
2
10 4
2
5 2
2
6 1
2
6 1
*/