给定一棵 n 个点的树,要求支持一下两种操作:
1 u v a b:在删去原有边 (u,v),加入新边 (a,b),保证操作后整张图还是一棵树。
2 u v:求从 u 到 v 经过每个点不超过 2 次的不同路径条数(初始在 u 及最后到达 v 均计算入经过次数)。两个路径被认为是不同的,当且仅当其经过点的序列不同。
真点好题
如果是矩阵的话要注意矩乘没有交换律,所以要维护左到右和右到左。
现在才意识到push_up的时候直接把两个儿子节点push_down会比较方便。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#define p 998244353
#define lint long long
#define gc getchar()
#define N 100010
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
inline int inn()
{
int x,ch;while((ch=gc)<'0'||ch>'9');
x=ch^'0';while((ch=gc)>='0'&&ch<='9')
x=(x<<1)+(x<<3)+(ch^'0');return x;
}
struct node{
int k1,k2,k3,k4;
node(int _k1=1,int _k2=0,int _k3=0,int _k4=1) { k1=_k1,k2=_k2,k3=_k3,k4=_k4; }
inline node operator*(const node &b)const
{
const node &a=*this;node c;
c.k1=(1ll*a.k1*b.k1+1ll*a.k2*b.k3)%p;
c.k2=(1ll*a.k1*b.k2+1ll*a.k2*b.k4)%p;
c.k3=(1ll*a.k3*b.k1+1ll*a.k4*b.k3)%p;
c.k4=(1ll*a.k3*b.k2+1ll*a.k4*b.k4)%p;
return c;
}
inline int show()
{
debug(k1)sp,debug(k2)sp,debug(k3)sp,debug(k4)ln;return 0;
}
inline node operator*=(const node &b) { return (*this)=(*this)*b; }
}val[N],ltr[N],rtl[N];
int ch[N][2],fa[N],tsz[N],pf[N],rev[N];
inline int show(int n=6)
{
for(int i=1;i<=n;i++)
debug(i)sp,debug(fa[i])sp,debug(ch[i][0])sp,debug(ch[i][1])sp,debug(pf[i])sp,debug(rev[i])sp,debug(tsz[i])ln,val[i].show(),ltr[i].show(),rtl[i].show(),cerr ln;
return 0;
}
inline int gw(int x) { return ch[fa[x]][1]==x; }
inline node& LTR(int x) { return rev[x]?rtl[x]:ltr[x]; }
inline node& RTL(int x) { return rev[x]?ltr[x]:rtl[x]; }
inline int push_up(int x)
{
ltr[x]=LTR(ch[x][0])*val[x]*LTR(ch[x][1]);
rtl[x]=RTL(ch[x][1])*val[x]*RTL(ch[x][0]);
return tsz[x]=tsz[ch[x][0]]+tsz[ch[x][1]]+val[x].k1;
}
inline int setc(int x,int y,int z) { if(!x) return fa[y]=0;ch[x][z]=y;if(y) fa[y]=x;return push_up(x); }
inline int rotate(int x)
{
int y=fa[x],z=fa[y],a=gw(x),b=gw(y),c=ch[x][a^1];
return swap(pf[x],pf[y]),setc(y,c,a),setc(x,y,a^1),setc(z,x,b);
}
inline int push_down(int x)
{
if(!rev[x]) return 0;
if(ch[x][0]) rev[ch[x][0]]^=1;if(ch[x][1]) rev[ch[x][1]]^=1;
return swap(ltr[x],rtl[x]),swap(ch[x][0],ch[x][1]),rev[x]=0;
}
inline int all_down(int x) { return (fa[x]?all_down(fa[x]):0),(rev[x]?push_down(x):0),0; }
inline int splay(int x,int tar=0)
{
for(all_down(x);fa[x]^tar;rotate(x))
if(fa[fa[x]]) rotate((gw(x)^gw(fa[x]))?x:fa[x]);
return 0;
}
inline int expose(int x)
{
splay(x);int y=ch[x][1];if(!y) return 0;
return pf[y]=x,fa[y]=0,ch[x][1]=0,val[x].k1+=tsz[y],push_up(x);
}
inline int splice(int x)
{
splay(x);int y=pf[x];if(!y) return 0;
return expose(y),splay(y),val[y].k1-=tsz[x],setc(y,x,1),pf[x]=0,1;
}
inline int access(int x) { expose(x);while(splice(x));return 0; }
inline int evert(int x) { return access(x),splay(x),rev[x]^=1; }
inline int link(int x,int y) { return evert(x),evert(y),splay(y),pf[x]=y,tsz[y]+=tsz[x],val[y].k1+=tsz[x]; }
inline int cut(int x,int y) { return evert(x),access(y),splay(x),ch[x][1]=fa[y]=pf[y]=0,push_up(x); }//wrong
inline int query(int x,int y) { return evert(x),access(y),splay(x),ltr[x].k1; }
int main()
{
int n=inn(),m=inn(),u,v,a,b;
for(int i=1;i<=n;i++) ltr[i]=rtl[i]=val[i]=node(tsz[i]=1,1,1,0);
for(int i=1;i<n;i++) u=inn(),v=inn(),link(u,v);
while(m--)
{
if(inn()==1) u=inn(),v=inn(),a=inn(),b=inn(),cut(u,v),link(a,b);
else u=inn(),v=inn(),printf("%d\n",query(u,v));
}
return 0;
}