题目
分析
我们可以把这看成一个图,每个点向编号为它平方的点连边。打一下表发现每个环的lcm不超过60,且每个点到环的距离不超过11。
那么只要线段树每个节点维护一下该区间周期就好了。
时间复杂度O(60nlogn)
代码
#include <bits/stdc++.h>
const int N = 100005;
const int P = 10005;
using namespace std;
int n,m,tim,now[P],tmp[65],p,T,a[N];
bool vis[P];
struct tree
{
int s,sum;
int w[65];
int tag;
}t[N * 4];
int read()
{
int x = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9') {if ( ch == '-') f = -1; ch = getchar();}
while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
return x * f;
}
int gcd(int x,int y)
{
if (!y) return x;
else return gcd(y, x % y);
}
void dfs(int x)
{
now[x] = tim;
if (now[x * x % p] != tim)
dfs(x * x % p);
else
{
int y = x * x % p, s = 0;
while (y != x)
vis[y] = 1, s++, y = y * y % p;
vis[x] = 1; s++;
T = T * s / gcd(T,s);
}
}
void change(int d)
{
if (!vis[t[d].sum])
return;
t[d].s = 1;
int y = t[d].sum;
for (int i = 0; i < T; i++)
t[d].w[i] = y, y = y * y % p;
}
void updata(int d,int l,int r)
{
t[d].s = t[d * 2].s + t[d * 2 + 1].s;
t[d].sum = t[d * 2].sum + t[d * 2 + 1].sum;
if (t[d].s == r - l + 1)
{
for (int i = 0; i < T; i++)
t[d].w[i] = t[d * 2].w[i] + t[d * 2 + 1].w[i];
}
}
void move(int d,int x)
{
t[d].tag += x;
for (int i = 0; i < T; i++)
tmp[i] = t[d].w[(i + x) % T];
for (int i = 0; i < T; i++)
t[d].w[i] = tmp[i];
t[d].sum = t[d].w[0];
}
void pushdown(int d,int l,int r)
{
if (l == r || !t[d].tag)
return;
move(d * 2,t[d].tag);
move(d * 2 + 1,t[d].tag);
t[d].tag = 0;
}
void build(int d,int l,int r)
{
if (l == r)
{
t[d].sum = a[l];
change(d);
return;
}
int mid = (l + r) / 2;
build(d * 2,l,mid);
build(d * 2 + 1, mid + 1,r);
updata(d,l,r);
}
void modify(int d,int l,int r,int x,int y)
{
pushdown(d,l,r);
int mid = (l + r) / 2;
if (l == x && r == y)
{
if (t[d].s == r - l + 1)
move(d,1);
else
if (l == r)
t[d].sum = t[d].sum * t[d].sum % p,change(d);
else
modify(d * 2,l,mid,x,mid), modify(d * 2 + 1,mid + 1,r,mid + 1,y), updata(d,l,r);
return;
}
if (x <= mid)
modify(d * 2,l,mid,x,min(y,mid));
if (y > mid)
modify(d * 2 + 1,mid + 1,r,max(x,mid + 1),y);
updata(d,l,r);
}
int query(int d,int l,int r,int x,int y)
{
pushdown(d,l,r);
if (l == x && r == y)
return t[d].sum;
int mid = (l + r) / 2, ans = 0;;
if (x <= mid)
ans += query(d * 2,l,mid,x,min(y,mid));
if (y > mid)
ans += query(d * 2 + 1,mid + 1,r,max(x,mid + 1),y);
return ans;
}
int main()
{
n = read();
m = read();
p = read();
T = 1;
for (int i = 0; i < p; i++)
tim++,dfs(i);
for (int i = 1; i <= n; i++)
a[i] = read();
build(1,1,n);
while (m--)
{
int op = read(), l = read(), r = read();
if (!op)
modify(1,1,n,l,r);
else printf("%d\n",query(1,1,n,l,r));
}
}