解析
毒题
细节有亿点点多
我一开始的思路是没有问题的
尝试统计有多少种方案能砍出大小在一个区间的子树、
当时的想法是线段树合并
但是这个玩意在需要保留原树的情况下空间复杂度炸没了…
因为我垃圾的实现一个dfs里面就玩了七遍merge函数…
空间常数飞起
然后分数就和暴力一毛一样
qwq
考虑一些更灵巧的做法
动态维护一个关于值域的树状数组
要求一个子树内的答案用遍历到子树根前后的结果相减
子树外就是最终的结果减去子树内结果
然后我这个垃圾的实现似乎还需要再开一个树状数组动态维护返祖链的答案…
虽然实现还是垃圾但是树状数组就可以承受这么恶心的常数了
细节有亿点点多
但调来调去终于算调过去了
qwq
在dfs上动态维护一个树状数组是一个值得学习的好思想
不要动不动就开权值线段树暴力搞!!!
代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=3e5+100;
const double eps=1e-6;
const int mod=1333331;
inline ll read() {
ll x=0,f=1;
char c=getchar();
while(!isdigit(c)) {
if(c=='-') f=-1;
c=getchar();
}
while(isdigit(c)) {
x=(x<<1)+(x<<3)+c-'0';
c=getchar();
}
return x*f;
}
int n;
struct node {
int to,nxt;
} p[N<<1];
int fi[N],cnt;
inline void addline(int x,int y) {
p[++cnt]=(node) {
y,fi[x]
};
fi[x]=cnt;
}
int mx[N],sec[N],siz[N],mxson[N],secson[N];
void dfs(int x,int f) {
siz[x]=1;
for(int i=fi[x]; ~i; i=p[i].nxt) {
int to=p[i].to;
if(to==f) continue;
dfs(to,x);
siz[x]+=siz[to];
int o=siz[to],oo=to;
if(o>mx[x]) swap(o,mx[x]),swap(oo,mxson[x]);
if(o>sec[x]) swap(o,sec[x]),swap(oo,secson[x]);
}
//printf("x=%d siz=%d mx=%d->%d sec=%d->%d\n",x,siz[x],mxson[x],mx[x],secson[x],sec[x]);
return;
}
ll ans=0;
struct tree{
int f[N];
inline void add(int p,int v=1){
for(;p<=n;p+=p&-p) f[p]+=v;
return;
}
inline int ask(int p){
//printf("%d\n",p);
int res=0;
for(;p;p-=p&-p) res+=f[p];
return res;
}
}t1,t2;
struct query{
int l,r,val;
}Add[N];
int tot;
void solve(int u,int f) {
t1.add(siz[u]);t2.add(siz[u]);
int st,ed,l,r,x,y,res=0;
st=1,ed=n;x=n-siz[u];y=mx[u];
while(st<ed) {
int o=(st+ed)>>1;
if(x-o<=(n-o)/2) ed=o;
else st=o+1;
//if(u==4) printf(" st=%d ed=%d o=%d\n",st,ed,o);
}
l=st;//printf("l=%d\n",l);
r=min(x,n-2*y);
//l=n-l;r=n-r;swap(l,r);
int a=l,b=r;
if(a<=b) res+=t1.ask(b)-t1.ask(a-1)-(t2.ask(b)-t2.ask(a-1)),Add[++tot]=(query){a,b,u};
if(l<=r) res+=t2.ask(n-l)-t2.ask(n-r-1);
//printf("---u=%d x=%d y=%d (%d %d) res=%d\n",u,x,y,a,b,res);
for(int i=fi[u];~i;i=p[i].nxt){
int to=p[i].to;
if(to==f) continue;
x=siz[to],y=max(to==mxson[u]?sec[u]:mx[u],n-siz[u]);
st=1,ed=n;
while(st<ed) {
int o=(st+ed)>>1;
if(x-o<=(n-o)/2) ed=o;
else st=o+1;
}
l=st;
r=min(x,n-2*y);
if(l<=r) res-=t1.ask(r)-t1.ask(l-1);
solve(to,u);
if(l<=r) res+=t1.ask(r)-t1.ask(l-1);
//printf(" u=%d to=%d (%d %d) res=%d\n",u,to,l,r,res);
}
if(a<=b) res-=t1.ask(b)-t1.ask(a-1);
//printf("u=%d res=%d\n\n",u,res);
ans+=1ll*res*u;
t2.add(siz[u],-1);
return;
}
int main() {
/*#ifndef ONLINE_JUDGE
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
#endif
*/
//printf("%d\n",sizeof(rub)/1024/1024);
int T=read();
while(T--) {
memset(t1.f,0,sizeof(t1.f));
memset(t2.f,0,sizeof(t2.f));
memset(fi,-1,sizeof(fi));
cnt=-1;
memset(mx,0,sizeof(mx));
memset(sec,0,sizeof(sec));
tot=ans=0;
n=read();
for(int i=1; i<n; i++) {
int x=read(),y=read();
addline(x,y);
addline(y,x);
}
dfs(1,0);
solve(1,0);
//printf("\n");
for(int i=1;i<=tot;i++){
ans+=1ll*Add[i].val*(t1.ask(Add[i].r)-t1.ask(Add[i].l-1));
//printf("add:i=%d num=%d\n",Add[i].val,t1.ask(Add[i].r)-t1.ask(Add[i].l-1));
}
printf("%lld\n",ans);
}
return 0;
}
/*
*/