题目链接
题意:
你现在有一个n*m矩阵,有q次操作,每次询问(x,y)位置的编号,(x,y)的起始编号是(y-1)*m+x。每次操作输出该位置的编号,并且拿出该编号,该行后面的向前补,然后对于最后一列向上补,最后把这个编号放在第n行第m列的位置。(n,m,q<=3e5)
题解:
这题想要用树形数据结构维护,但是发现空间不够。于是线段树就要动态开点,splay的话要把一段连续的编号缩成一个点,查询的时候把一个区间分成三个,中间一个是当前要删除的点,左右就是出去从中间裂开后的两个区间。
我写的是动态开点线段树。维护的方法是我们开n+1棵线段树,前n棵分别维护每一行前n-1个数的编号,第n+1棵维护第m列的编号。对于没有操作过的点,我们不动他,新插入的点用一个vector记录。每次操作将当前点(x,y)查询编号输出并删除,然后加入第n+1棵线段树的最后一个位置,然后把第n+1棵的x位置记录加入到第x行的最后一个位置,注意如果x本身就在最后一列就不用管,并且记录每个区间删除过几个点。
最后说一下动态开点线段树的大小,考虑q的范围,即最多有q个点加入,那么线段树的长度最长为max(n,m)+q。
还有什么不清楚的看下代码理解一下吧。
update on 2021.7.22:
又写了一下splay的做法,真的好麻烦啊。我全程自己yy的写法,写了250行,感觉NOIP当场不太可能调出来(虽说我已经OI退役多年了)。splay的做法关键就是一个点可以维护一个范围的编号,有人离开就将这个点从有人离开的编号处分裂成两个。其他的和线段树差不多,有
n
n
n棵splay,每行维护前
m
−
1
m-1
m−1个,有一棵splay维护第
m
m
m列,一共
n
+
1
n+1
n+1棵。其间调试确实遇到了好多乱七八糟的问题,下面粘的代码是把调试都删了,不删有300多行的。
最后别忘了注意一下爆int的问题
#include <bits/stdc++.h>
using namespace std;
int n,m,q,mx;
long long cnt,root[600100];
struct node
{
long long l,r,sum;//sum是拿走了几个
}tr[10000010];
vector<long long> v[600010];
long long query(long long rt,int l,int r,long long x)
{
if(l==r)
return l;
int mid=(l+r)>>1;
long long cur=mid-l+1-tr[tr[rt].l].sum;
if(x<=cur)
{
if(!tr[rt].l)
tr[rt].l=++cnt;
return query(tr[rt].l,l,mid,x);
}
else
{
if(!tr[rt].r)
tr[rt].r=++cnt;
return query(tr[rt].r,mid+1,r,x-cur);
}
}
void update(long long rt,int l,int r,long long x)
{
tr[rt].sum++;
if(l==r)
return;
int mid=(l+r)>>1;
if(x<=mid)
{
if(!tr[rt].l)
tr[rt].l=++cnt;
update(tr[rt].l,l,mid,x);
}
else
{
if(!tr[rt].r)
tr[rt].r=++cnt;
update(tr[rt].r,mid+1,r,x);
}
}
long long solve1(long long x,long long y)
{
if(!root[n+1])
root[n+1]=++cnt;
long long pos=query(root[n+1],1,mx,x);
update(root[n+1],1,mx,pos);
long long ans;
if(pos<=n)
ans=(long long)pos*m;
else
ans=(long long)v[n+1][pos-n-1];
if(y)
v[n+1].push_back(y);
else
v[n+1].push_back(ans);
return ans;
}
long long solve2(long long x,long long y)
{
if(!root[x])
root[x]=++cnt;
long long pos=query(root[x],1,mx,y);
update(root[x],1,mx,pos);
long long ans;
if(pos<m)
ans=(long long)m*(x-1)+pos;
else
ans=v[x][pos-m];
v[x].push_back(solve1(x,ans));
return ans;
}
int main()
{
scanf("%d%d%d",&n,&m,&q);
mx=max(n,m)+q;
for(int i=1;i<=q;++i)
{
long long x,y;
scanf("%lld%lld",&x,&y);
if(y==m)
printf("%lld\n",solve1(x,0));
else
printf("%lld\n",solve2(x,y));
}
return 0;
}
splay代码:
#include <bits/stdc++.h>
using namespace std;
int n,m,q,cnt;
int root[2400010],sz[2400010],v[2400010],c[2400010][2],f[2400010];
long long id[2400010];
inline void pushup(int x)
{
if(x)
sz[x]=sz[c[x][0]]+sz[c[x][1]]+v[x];
}
inline void build(int l,int r,int rt,int shu)
{
int mid=(l+r)>>1;
int x=++cnt;
f[x]=rt;
sz[x]=1;
v[x]=1;
id[x]=1ll*mid*m;
if(rt)
{
if(mid<shu)
c[rt][0]=x;
else
c[rt][1]=x;
}
if(l<mid)
build(l,mid-1,x,mid);
if(mid+1<=r)
build(mid+1,r,x,mid);
pushup(x);
}
inline long long find(int rt,long long k)
{
int x=rt;
while(x)
{
if(sz[c[x][0]]>=k)
x=c[x][0];
else if(sz[c[x][0]]+v[x]>=k)
{
return id[x]+k-sz[c[x][0]]-1;
}
else
{
k-=sz[c[x][0]]+v[x];
x=c[x][1];
}
}
}
inline void rotate(int x,int i)
{
int y=f[x],z=f[y],k=c[y][1]==x,w=c[x][!k];
if(y!=root[i])
c[z][c[z][1]==y]=x;
c[x][!k]=y;
c[y][k]=w;
if(w)
f[w]=y;
f[y]=x;
f[x]=z;
pushup(y);
}
inline void splay(int x,int rt,int i)
{
while(f[x]!=rt)
{
int y=f[x],z=f[y];
if(f[y]!=rt)
{
if(c[z][0]==y ^ c[y][0]==x)
rotate(x,i);
else
rotate(y,i);
}
rotate(x,i);
if(rt==0)
root[i]=x;
}
pushup(x);
}
inline void del(int i,int k)
{
int x=root[i];
int qaq=k;
int y=0;
while(x)
{
if(sz[c[x][0]]>=k)
x=c[x][0];
else if(sz[c[x][0]]+v[x]>=k)
break;
else
{
k-=sz[c[x][0]]+v[x];
x=c[x][1];
}
}
if(v[x]==1)
{
splay(x,0,i);
int z=0;
if(c[x][0])
{
z=c[x][0];
while(c[z][1])
z=c[z][1];
y=0;
if(c[x][1])
{
y=c[x][1];
while(c[y][0])
y=c[y][0];
}
if(y)
{
splay(z,0,i);
splay(y,z,i);
int ji=f[x];
c[f[x]][0]=0;
f[x]=0;
pushup(ji);
splay(ji,0,i);
}
else
{
splay(z,0,i);
int ji=f[x];
c[f[x]][1]=0;
f[x]=0;
pushup(ji);
splay(ji,0,i);
}
}
else
{
root[i]=c[x][1];
f[c[x][1]]=0;
}
}
else
{
long long ji=k-sz[c[x][0]];
if(ji==1)
{
v[x]--;
sz[x]--;
id[x]++;
splay(x,0,i);
}
else if(ji==v[x])
{
v[x]--;
sz[x]--;
splay(x,0,i);
}
else
{
long long qwq=v[x];
v[x]=ji-1;
splay(x,0,i);
if(c[x][1]==0)
{
y=++cnt;
c[x][1]=y;
f[y]=x;
v[y]=qwq-v[x]-1;
sz[y]=v[y];
id[y]=id[x]+ji;
splay(y,0,i);
}
else
{
int z=c[x][1];
while(c[z][0])
z=c[z][0];
splay(z,x,i);
y=++cnt;
c[z][0]=y;
f[y]=z;
v[y]=qwq-v[x]-1;
sz[y]=v[y];
id[y]=id[x]+ji;
splay(z,0,i);
}
}
}
}
inline void ins(int i,long long k)
{
int x=root[i];
if(x==0)
{
int y=++cnt;
root[i]=y;
sz[y]=1;
v[y]=1;
id[y]=k;
return;
}
while(c[x][1])
x=c[x][1];
splay(x,0,i);
int y=++cnt;
c[x][1]=y;
f[y]=x;
sz[y]=1;
v[y]=1;
id[y]=k;
pushup(x);
splay(x,0,i);
}
int main()
{
scanf("%d%d%d",&n,&m,&q);
for(int i=1;i<=n;++i)
{
root[i]=++cnt;
sz[cnt]=m-1;
v[cnt]=m-1;
id[cnt]=1ll*m*(i-1)+1;
f[cnt]=0;
c[cnt][0]=0;
c[cnt][1]=0;
}
root[n+1]=cnt+1;
build(1,n,0,0);
for(int i=1;i<=q;++i)
{
int x,y;
scanf("%d%d",&x,&y);
if(y==m)
{
long long ans=find(root[n+1],x);
printf("%lld\n",ans);
del(n+1,x);
ins(n+1,ans);
}
else
{
long long ans=find(root[x],y);
printf("%lld\n",ans);
del(x,y);
long long res=find(root[n+1],x);
del(n+1,x);
ins(x,res);
ins(n+1,ans);
}
}
return 0;
}