题意:
给定一个长度为 n n n 的数组 a a a,以及模数 m o d mod mod,再有 q q q 次操作,每次操作为 ① 1 , l , r , x 1, l, r, x 1,l,r,x,表示将 a l , a l + 1 , … a r a_l, a_{l+1}, \dots a_{r} al,al+1,…ar 乘上 x x x;② 2 , p , x 2, p, x 2,p,x,表示将 a p a_p ap 修改为 a p x ( x ∣ a p ) \cfrac{a_p}{x} ~ (x \mid a_p) xap (x∣ap);③ 3 , l , r 3, l, r 3,l,r,询问 ∑ i = l r a i \sum\limits_{i = l}^{r} a_i i=l∑rai 模 m o d mod mod 的结果。 ( n , q , x , a i ≤ 1 0 5 , 2 ≤ m o d ≤ 1 e 9 + 9 ) (n, q, x, a_i \leq 10^5, 2 \leq mod \leq 1e9 + 9) (n,q,x,ai≤105,2≤mod≤1e9+9)
链接:
https://codeforces.com/contest/1109/problem/E
解题思路:
不考虑到除法操作的话,很容易用线段树维护出来。若 x , m o d x, mod x,mod 互质,则除法操作可以变为乘上 i n v ( x ) inv(x) inv(x)。当 x , m o d x, mod x,mod 不互质时,将 x x x 改写为 x 0 x 1 x_0x_1 x0x1,即与 m o d mod mod 互质的部分 x 0 x_0 x0 以及剩下的部分 x 1 x_1 x1,不互质部分则通过分解质因子来支持除法操作,预处理出 m o d mod mod 的质因子,线段树分别维护与 m o d mod mod 互质、不互质的两部分答案,则相乘可得最后答案。
参考代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define sz(a) ((int)a.size())
#define pb push_back
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define gmid (l + r >> 1)
const int maxn = 1e5 + 5;
const int inf = 0x3f3f3f3f;
// const int mod = 1e9 + 7;
int a[maxn], pi[maxn];
int n, mod, q, tot;
void init(int x){
tot = 0;
for(int i = 2; i * i <= x; ++i){
if(x % i) continue;
pi[++tot] = i;
while(x % i == 0) x /= i;
}
if(x > 1) pi[++tot] = x;
}
void exgcd(ll a, ll b, ll &x, ll &y){
if(!b) x = 1, y = 0;
else exgcd(b, a % b, y, x), y -= a / b * x;
}
ll inv(ll a, ll p){
ll x, y;
exgcd(a, p, x, y);
x = (x % p + p) % p;
return x;
}
ll qpow(ll a, ll b, ll p){
ll ret = 1;
while(b){
if(b & 1) ret = ret * a % p;
a = a * a % p;
b >>= 1;
}
return ret;
}
struct Node{
int pro, cnt[10];
Node() { clear(); }
void clear() { pro = 1; memset(cnt, 0, sizeof cnt); }
Node operator * (const Node &o) const{
Node ret;
for(int i = 0; i < 10; ++i) ret.cnt[i] = cnt[i] + o.cnt[i];
ret.pro = pro * 1ll * o.pro % mod;
return ret;
}
Node operator * (int x){
Node ret = *this;
for(int i = 1; i <= tot; ++i){
if(x % pi[i]) continue;
while(x % pi[i] == 0) x /= pi[i], ++ret.cnt[i];
}
ret.pro = ret.pro * 1ll * x % mod;
return ret;
}
Node operator / (int x){
Node ret = *this;
for(int i = 1; i <= tot; ++i){
if(x % pi[i]) continue;
while(x % pi[i] == 0) x /= pi[i], --ret.cnt[i];
}
ret.pro = ret.pro * 1ll * inv(x, mod) % mod;
return ret;
}
int getVal(){
int ret = pro;
for(int i = 1; i <= tot; ++i){
ret = ret * 1ll * qpow(pi[i], cnt[i], mod) % mod;
}
return ret;
}
};
struct SegTree{
Node tag[maxn << 2]; int sum[maxn << 2];
void pushUp(int rt){
sum[rt] = (sum[lson] + sum[rson]) % mod;
}
void build(int l, int r, int rt){
if(l == r){
sum[rt] = a[l] % mod;
tag[rt] = tag[rt] * a[l];
return;
}
int mid = gmid;
build(l, mid, lson);
build(mid + 1, r, rson);
pushUp(rt);
}
void pushDown(int rt){
if(tag[rt].pro == 1 && *max_element(tag[rt].cnt + 1, tag[rt].cnt + 1 + tot) == 0) return;
tag[lson] = tag[lson] * tag[rt];
tag[rson] = tag[rson] * tag[rt];
int val = tag[rt].getVal();
sum[lson] = sum[lson] * 1ll * val % mod;
sum[rson] = sum[rson] * 1ll * val % mod;
tag[rt].clear();
}
void update(int l, int r, int rt, int L, int R, int val){
if(l >= L && r <= R){
sum[rt] = sum[rt] * 1ll * val % mod;
tag[rt] = tag[rt] * val;
return;
}
int mid = gmid;
pushDown(rt);
if(L <= mid) update(l, mid, lson, L, R, val);
if(R > mid) update(mid + 1, r, rson, L, R, val);
pushUp(rt);
}
void update(int l, int r, int rt, int pos, int val){
if(l == r){
tag[rt] = tag[rt] / val;
sum[rt] = tag[rt].getVal();
return;
}
int mid = gmid;
pushDown(rt);
if(pos <= mid) update(l, mid, lson, pos, val);
else update(mid + 1, r, rson, pos, val);
pushUp(rt);
}
int query(int l, int r, int rt, int L, int R){
if(l >= L && r <= R) return sum[rt];
int mid = gmid, ret = 0;
pushDown(rt);
if(L <= mid) ret = (ret + query(l, mid, lson, L, R)) % mod;
if(R > mid) ret = (ret + query(mid + 1, r, rson, L, R)) % mod;
return ret;
}
} tr;
int main(){
ios::sync_with_stdio(0); cin.tie(0);
cin >> n >> mod;
for(int i = 1; i <= n; ++i) cin >> a[i];
init(mod);
tr.build(1, n, 1);
cin >> q;
while(q--){
int opt, x, y, z; cin >> opt >> x >> y;
if(opt == 1){
cin >> z;
tr.update(1, n, 1, x, y, z);
}
else if(opt == 2){
tr.update(1, n, 1, x, y);
}
else{
int ret = tr.query(1, n, 1, x, y);
cout << ret << "\n";
}
}
return 0;
}