题目出自 Owaski
题目大意
n<=1e5
解法1
点分。
对于当前的分治,假设走到了点
x
,那么
然后用一个线段树维护当前哪些点能走,哪些不能走。离开这棵子树的时候,把不在这棵子树的标记撤销掉。
由于要用到撤销,所以要用主席树。
解法2
先求不合法的路径的数量。
若
(a,b)
这条路径不合法,则是它内部包含了形如
(x,kx)
的路径。
对于所有形如
(x,kx)
的路径,假设
x
是 dfs 序小的那个点,
1、若
x
是
2、若
x
不是
第一种情况是两个矩形,第二种情况是一个矩形,扫描线求面积并。
代码
//解法2
#include<cstdio>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
typedef long long LL;
const int maxn=1e5+5, maxrec=24e5+5, MX=17;
struct TRST{
int nmin,num;
TRST(int NMIN=0,int NUM=0) {nmin=NMIN, num=NUM;}
};
int n;
int tot,go[2*maxn],next[2*maxn],f1[maxn];
void ins(int x,int y)
{
go[++tot]=y;
next[tot]=f1[x];
f1[x]=tot;
}
int tt[2],rx[2][maxrec],ry[2][maxrec],nt[2][maxrec],fr[2][maxn];
void inr(int ty,int i,int x,int y)
{
rx[ty][++tt[ty]]=x;
ry[ty][tt[ty]]=y;
nt[ty][tt[ty]]=fr[ty][i];
fr[ty][i]=tt[ty];
}
int st[maxn],en[maxn],sum,fa[maxn][MX+5],deep[maxn];
void dfs_dfn(int k,int last)
{
deep[k]=deep[last]+1;
fa[k][0]=last;
fo(j,1,MX) fa[k][j]=fa[fa[k][j-1]][j-1];
st[k]=++sum;
for(int p=f1[k]; p; p=next[p]) if (go[p]!=last) dfs_dfn(go[p],k);
en[k]=sum;
}
int find(int x,int y)
{
fd(j,MX,0) if (deep[fa[y][j]]>deep[x]) y=fa[y][j];
return y;
}
TRST tr[4*maxn];
int bz[4*maxn];
void tr_js(int k,int l,int r)
{
tr[k].num=r-l+1;
if (l==r) return;
int t=k<<1, t1=(l+r)>>1;
tr_js(t,l,t1), tr_js(t+1,t1+1,r);
}
TRST merge(TRST a,TRST b)
{
if (a.nmin<b.nmin) return a;
else if (a.nmin>b.nmin) return b;
else return TRST(a.nmin,a.num+b.num);
}
void update(int k,int t)
{
if (!bz[k]) return;
tr[t].nmin+=bz[k], tr[t+1].nmin+=bz[k];
bz[t]+=bz[k], bz[t+1]+=bz[k];
bz[k]=0;
}
void tr_xg(int k,int l,int r,int x,int y,int z)
{
if (l==x && r==y)
{
tr[k].nmin+=z;
bz[k]+=z;
return;
}
int t=k<<1, t1=(l+r)>>1;
update(k,t);
if (y<=t1) tr_xg(t,l,t1,x,y,z);
else if (x>t1) tr_xg(t+1,t1+1,r,x,y,z);
else tr_xg(t,l,t1,x,t1,z), tr_xg(t+1,t1+1,r,t1+1,y,z);
tr[k]=merge(tr[t],tr[t+1]);
}
LL ans;
void Scanline()
{
tr_js(1,1,n);
fo(i,1,n)
{
for(int p=fr[0][i]; p; p=nt[0][p]) tr_xg(1,1,n,rx[0][p],ry[0][p],1);
ans+=n-tr[1].num;
for(int p=fr[1][i]; p; p=nt[1][p]) tr_xg(1,1,n,rx[1][p],ry[1][p],-1);
}
}
int main()
{
scanf("%d",&n);
fo(i,1,n-1)
{
int x,y;
scanf("%d %d",&x,&y);
ins(x,y), ins(y,x);
}
dfs_dfn(1,0);
fo(i,1,n)
for(int j=2*i; j<=n; j+=i)
{
int x=i, y=j;
if (st[x]>st[y]) swap(x,y);
if (st[y]<=en[x])
{
int g=find(x,y);
inr(0,1,st[y],en[y]), inr(1,st[g]-1,st[y],en[y]);
if (en[g]<n) inr(0,st[y],en[g]+1,n), inr(1,en[y],en[g]+1,n);
} else
{
inr(0,st[x],st[y],en[y]), inr(1,en[x],st[y],en[y]);
}
}
Scanline();
printf("%lld\n",(LL)n*(n-1)/2-ans);
}