题面
思路
首先关于欧拉函数的性质
具体实现
用 b i t s e t bitset bitset维护100内的所有质数;
如果当前要乘的 w w w的全部质因子都被包含了;
那么我们就可以直接乘,也就是线段树裸的区间乘;
如果没有被包含,那么我们就交给子树来处理;
如果处理到叶子了还没被包含,那么直接暴力单点修改;
Code
#include <iostream>
#include <bitset>
#include <vector>
using namespace std;
typedef long long ll;
const ll MOD = 998244353;
const int N = 1e5+10;
ll phi(int n){
ll ret = n;
for(ll i=2;i*i<=n;++i){
if(n%i == 0){
ret = ret /i*(i-1);
ret %= MOD;
while(n%i==0) n/=i;
}
}
if(n>1) ret=ret/n*(n-1);
ret%=MOD;
return ret;
}
int n,m;
vector<int> primes;
int a[N];
struct Node{
int l,r;
ll sum;
ll mul;
bitset<105> bs;
}tr[N<<2];
#define lc (p<<1)
#define rc (p<<1|1)
void push_up(int p){
tr[p].sum = (tr[lc].sum + tr[rc].sum)%MOD;
tr[p].bs = tr[lc].bs & tr[rc].bs;
}
//只有满足直接乘的情况下才能调用这个函数
void node_mul(int p,ll k){
tr[p].sum *= k;
tr[p].sum %= MOD;
tr[p].mul *= k;
tr[p].mul %= MOD;
}
void push_down(int p){
ll mul = tr[p].mul;
if(mul!=1){
node_mul(lc,mul);
node_mul(rc,mul);
tr[p].mul = 1;
}
}
bitset<105> get_bitset(int w){
bitset<105> bs;
bs.reset();
for(auto x : primes){
if(x > w) break;
if(w % x == 0){
bs[x] = 1;
}
}
return bs;
}
void build(int p,int l,int r){
tr[p].l = l,tr[p].r = r;
tr[p].sum = 0,tr[p].mul = 1;
if(l==r){
tr[p].bs = get_bitset(a[l]);
tr[p].sum = phi(a[l]);
return;
}
int mid = (l+r) >> 1;
build(lc,l,mid);
build(rc,mid+1,r);
push_up(p);
}
void range_mul(int p,int l,int r,int k,bitset<105> bs){
if(tr[p].l>=l&&tr[p].r<=r){
//如果全部质因子都被包含了
if((tr[p].bs|bs) == tr[p].bs){
node_mul(p,k);
return;
}
}
if(tr[p].l == tr[p].r){
//计算差异部分
bitset<105> dif = (tr[p].bs|bs)^tr[p].bs;
int ww = k;
for(int i=dif._Find_first();i<dif.size();i=dif._Find_next(i)){
ww = ww/i*(i-1);
}
tr[p].sum *= ww;
tr[p].sum %= MOD;
tr[p].bs |= bs;
return;
}
push_down(p);
int mid = (tr[p].l+tr[p].r) >>1;
if(r<=mid) range_mul(lc,l,r,k,bs);
else if(l>mid) range_mul(rc,l,r,k,bs);
else{
range_mul(lc,l,mid,k,bs);
range_mul(rc,mid+1,r,k,bs);
}
push_up(p);
}
ll query(int p,int l,int r){
if(tr[p].l>=l&&tr[p].r<=r){
return tr[p].sum%MOD;
}
push_down(p);
int mid = (tr[p].l+tr[p].r) >>1;
if(r<=mid) return query(lc,l,r)%MOD;
else if(l>mid) return query(rc,l,r)%MOD;
else{
ll ret = query(lc,l,mid)%MOD +
query(rc,mid+1,r)%MOD;
ret %=MOD;
return ret;
}
}
void init(){
for(int i=2;i<=100;++i){
bool flag = true;
for(int j=2;j*j<=i;++j){
if(i%j == 0){
flag = false;
break;
}
}
if(flag) primes.push_back(i);
}
}
int main()
{
init();
cin >> n >> m;
for(int i=1;i<=n;++i){
cin >> a[i];
}
build(1,1,n);
int op,l,r;
ll w;
while(m--){
cin >> op;
cin >> l >> r;
if(op == 0){
cin >> w;
range_mul(1,l,r,w,get_bitset(w));
}else{
cout << query(1,l,r)%MOD << '\n';
}
}
return 0;
}