Description
给出两棵树,T1和T2
对于T1中的每一条边e1,你需要求出T2中有多少条边e2满足
1:T1-e1+e2是一棵树
2:T2-e2+e1是一棵树
n<=2e5
Solution
我们考虑一组限制的两种方法,并且这两种方法能够套在一起
首先,我们知道可以对于T2中的每一条边(u,v),在第一棵树上的u,v打上标记,在lca(u,v)处撤销
这样子我们可以在遍历的时候求出那些e2边可能满足条件
接下来对于每个e1(x,y)满足条件的e2需要存在在T2中x到y的路径上
也就是我们需要满足单点修改,路径询问,然后用线段树合并维护
这个并没有必要用树链剖分,可以直接括号序维护
复杂度O(n log n)
Code
#include <cmath>
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define rep(i,a,b) for(int i=lst[b][a];i;i=nxt[i])
using namespace std;
typedef vector<int> vec;
#define pb(a) push_back(a)
int read() {
char ch;
for(ch=getchar();ch<'0'||ch>'9';ch=getchar());
int x=ch-'0';
for(ch=getchar();ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x;
}
void write(int x) {
if (!x) {putchar('0');putchar(' ');return;}
char ch[20];int tot=0;
for(;x;x/=10) ch[++tot]=x%10+'0';
fd(i,tot,1) putchar(ch[i]);
putchar(' ');
}
const int N=4e5+5;
int t[N<<1],nxt[N<<1],lst[2][N],l;
void add(int x,int y,int a) {
t[++l]=y;nxt[l]=lst[a][x];lst[a][x]=l;
}
int n,dfn[2][N],fir[2][N],dep[2][N],tot,f[2][N][19],lg[N];
void dfs(int x,int y,int a) {
dfn[a][++tot]=x;fir[a][x]=tot;dep[a][x]=dep[a][y]+1;
rep(i,x,a) if (t[i]!=y) dfs(t[i],x,a),dfn[a][++tot]=x;
}
int lca(int x,int y,int a) {
x=fir[a][x];y=fir[a][y];
if (x>y) swap(x,y);
int z=lg[y-x+1];
x=f[a][x][z];y=f[a][y-(1<<z)+1][z];
return dep[a][x]<dep[a][y]?x:y;
}
int in[N],out[N];
void travel(int x,int y) {
in[x]=++tot;
rep(i,x,1) if (t[i]!=y) travel(t[i],x);
out[x]=++tot;
}
int u1[N],u2[N],v1[N],v2[N],an[N];
vec q[N];
int tr[N<<5],ls[N<<5],rs[N<<5],rt[N],cnt;
void modify(int &v,int l,int r,int x,int y) {
if (!v) v=++cnt;
if (l==r) {tr[v]+=y;return;}
int mid=l+r>>1;
if (x<=mid) modify(ls[v],l,mid,x,y);
else modify(rs[v],mid+1,r,x,y);
tr[v]=tr[ls[v]]+tr[rs[v]];
}
int merge(int x,int y,int l,int r) {
if (!x||!y) return x+y;
if (l==r) {tr[x]+=tr[y];return x;}
int mid=l+r>>1;
ls[x]=merge(ls[x],ls[y],l,mid);
rs[x]=merge(rs[x],rs[y],mid+1,r);
tr[x]=tr[ls[x]]+tr[rs[x]];
return x;
}
int query(int v,int l,int r,int x,int y) {
if (!v) return 0;
if (l==x&&r==y) return tr[v];
int mid=l+r>>1;
if (y<=mid) return query(ls[v],l,mid,x,y);
else if (x>mid) return query(rs[v],mid+1,r,x,y);
else return query(ls[v],l,mid,x,mid)+query(rs[v],mid+1,r,mid+1,y);
}
void solve(int x,int y) {
rep(i,x,0) if (t[i]!=y) solve(t[i],x);
rep(i,x,0) rt[x]=merge(rt[x],rt[t[i]],1,tot);
if (!q[x].empty())
fo(i,0,q[x].size()-1) {
int z=q[x][i];
if (z>0) modify(rt[x],1,tot,in[z],1),modify(rt[x],1,tot,out[z],-1);
else modify(rt[x],1,tot,in[-z],-1),modify(rt[x],1,tot,out[-z],1);
}
if (y) {
int z=lca(x,y,1);
an[x]=query(rt[x],1,tot,1,in[x])+query(rt[x],1,tot,1,in[y])-2*query(rt[x],1,tot,1,in[z]);
}
}
int main() {
for(int ty=read();ty;ty--) {
n=read();
memset(lst,0,sizeof(lst));l=0;
fo(i,1,n-1) {
u1[i]=read();v1[i]=read();
add(u1[i],v1[i],0);add(v1[i],u1[i],0);
}
fo(i,1,n-1) {
u2[i]=read();v2[i]=read();
add(u2[i],v2[i],1);add(v2[i],u2[i],1);
}
tot=0;dfs(1,0,0);
tot=0;dfs(1,0,1);
fo(a,0,1) fo(i,1,tot) f[a][i][0]=dfn[a][i];
fo(i,1,tot) lg[i]=log(i)/log(2);
fo(a,0,1)
fo(j,1,18)
fo(i,1,tot-(1<<j)+1) {
int x=f[a][i][j-1],y=f[a][i+(1<<j-1)][j-1];
f[a][i][j]=dep[a][x]<dep[a][y]?x:y;
}
tot=0;travel(1,0);
fo(i,1,cnt) tr[i]=ls[i]=rs[i]=0;cnt=0;
fo(i,1,n) q[i].clear(),rt[i]=0;
fo(i,1,n-1) {
int x=u2[i],y=v2[i],z=lca(x,y,0);
int val=dep[1][x]<dep[1][y]?y:x;
q[x].pb(val);q[y].pb(val);
q[z].pb(-val);q[z].pb(-val);
}
solve(1,0);
fo(i,1,n-1) {
int x=u1[i],y=v1[i];
write(dep[0][x]<dep[0][y]?an[y]:an[x]);
}
puts("");
}
return 0;
}