题意
有这样一个程序:
function add(x,v)
while x <= n do
s[x] = s[x] xor v
x = x + lowbit(x) //注意,这里是 lowbit,这也是两份代码唯一的区别
end while
end function
function query(x)
ans = 0
while x > 0 do
ans = ans xor s[x]
x = x - lowbit(x)
end while
return ans
end function
其中
lowbit(x)
l
o
w
b
i
t
(
x
)
表示
x
x
在进制表示下的最低非0位的值。
现在给出
n,k
n
,
k
,有
q
q
次操作,每次操作是或
query(x)
q
u
e
r
y
(
x
)
中的一种,要求对于每个
query
q
u
e
r
y
输出正确答案。
n≤109,q,k≤2∗105
n
≤
10
9
,
q
,
k
≤
2
∗
10
5
分析
比较神仙的一道题。
add
a
d
d
操作可以看成是每次把最低非0位乘2并进位。
先考虑
k
k
为奇数的时候,我们从向
x+lowbit(x)
x
+
l
o
w
b
i
t
(
x
)
连一条边,因为最低非0位始终不改变,所以这
n
n
个点一定会形成若干条链,那么只要对每条链用线段树维护下就好了。
当不为奇数的时候,就未必会形成若干条链了。但如果只考虑最低非0位中含有的质因子2的个数不少于
k
k
的那些数,它们被划分成了若干条链。
对于不在链上的数,我们可以暴力跳,在跳不超过次后一定会跳到一条链上或跳出去。
现在问题在于如何找到一个点所在的链并求出该点到链头的距离。
注意到现在是最低非0位确定在哪条链,其余位确定在链的哪个位置,又因为最低位的变化有环,环对高位的影响是固定的,所以可以先模掉,然后再通过倍增往前跳确定链头即可。
处理大下标同样可以使用线段树。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
const int N=200005;
int n,k,q,cyc[N],num[N],pri[N],a[N],bz1[N][20],bz2[N][20],root,st,dis,sz;
bool vis[N];
struct tree{int l,r,s;}t[N*80];
int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int lowbit(int x)
{
while (x%k==0) x/=k;
return x%k;
}
int lowbitv(int x)
{
int lg=0;
while (x%k==0) x/=k,lg++;
x%=k;
while (lg) x*=k,lg--;
return x;
}
void prework()
{
for (int i=1;i<=k;i++)
for (int x=i;x%2==0;x>>=1,pri[i]++);
for (int i=1;i<k;i++)
{
if (vis[i]||pri[i]<pri[k]) continue;
int tot=0;a[tot++]=i;vis[i]=1;
for (int x=i*2%k;x!=i;x=x*2%k) a[tot++]=x,vis[x]=1;
int w=0;
for (int j=0;j<tot;j++) w+=(a[j]*2>=k);
for (int j=0;j<tot;j++)
{
int x=a[j],to=a[(j+tot-1)%tot];
cyc[x]=w;num[x]=tot;
bz1[x][0]=to;
bz2[x][0]=(to*2>=k);
}
for (int j=1;j<=18;j++)
for (int k=0;k<tot;k++)
{
int x=a[k];
bz1[x][j]=bz1[bz1[x][j-1]][j-1];
bz2[x][j]=bz2[x][j-1]+bz2[bz1[x][j-1]][j-1];
}
}
}
void ins(int &d,int l,int r,int x,int y,int z)
{
if (!d) d=++sz;
if (x<=l&&r<=y) {t[d].s^=z;return;}
int mid=(l+r)/2;
if (x<=mid) ins(t[d].l,l,mid,x,y,z);
if (y>mid) ins(t[d].r,mid+1,r,x,y,z);
}
int query(int d,int l,int r,int x)
{
if (l==r||!d) return t[d].s;
int mid=(l+r)/2;
if (x<=mid) return query(t[d].l,l,mid,x)^t[d].s;
else return query(t[d].r,mid+1,r,x)^t[d].s;
}
void get_st(int x)
{
int lg=0;
while (x%k==0) x/=k,lg++;
dis=1;st=lowbit(x);int u=(x-st)/k;
dis+=u/cyc[st]*num[st];
u%=cyc[st];
for (int i=18;i>=0;i--)
if (bz2[st][i]<=u)
{
u-=bz2[st][i];
dis+=(1<<i);
st=bz1[st][i];
}
while (lg) st*=k,lg--;
}
int main()
{
n=read();q=read();k=read();
prework();
while (q--)
{
int op=read();
if (op==1)
{
int x=read(),v=read();
while (x<=n&&pri[lowbit(x)]<pri[k])
{
ins(root,1,n,x,x,v);
x+=lowbitv(x);
}
if (x>n) continue;
get_st(x);int rt=query(root,1,n,st),tmp=rt;
ins(rt,1,n,dis,n,v);
if (!tmp) ins(root,1,n,st,st,rt);
}
else
{
int x=read(),ans=0;
while (x)
{
if (pri[lowbit(x)]<pri[k]) ans^=query(root,1,n,x);
else
{
get_st(x);
int rt=query(root,1,n,st);
if (rt) ans^=query(rt,1,n,dis);
}
x-=lowbitv(x);
}
printf("%d\n",ans);
}
}
return 0;
}