考场上切了不考虑没有逆元的情况(出题人真良心).
考场代码:
#include <cstdio>
#include <algorithm>
#define lson (now<<1)
#define rson (now<<1|1)
#define ll long long
#define setIO(s) freopen(s".in","r",stdin) , freopen(s".out","w",stdout)
using namespace std;
char *p1, *p2, buf[100000];
namespace IO
{
#define nc() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 100000, stdin), p1 == p2) ? EOF : *p1 ++ )
int rd() {
int x = 0, f = 1;
char c = nc();
while (c < 48) {
if (c == '-')
f = -1;
c = nc();
}
while (c > 47) {
x = (((x << 2) + x) << 1) + (c ^ 48), c = nc();
}
return x * f;
}
};
const int mod=998244353,N=120005;
int arr[N],n,Q;
inline ll qpow(ll base,ll k) {
ll tmp=1;
for(;k;base=base*base%mod,k>>=1)if(k&1)tmp=tmp*base%mod;
return tmp;
}
inline ll inv(ll k) {
return qpow(k,mod-2);
}
struct Node {
int len;
ll sum,sqr,sumlen,sqrlen,lazy;
}t[N<<2];
inline void pushup(int l,int r,int now) {
int mid=(l+r)>>1;
t[now].sum=t[lson].sum;
t[now].sqr=t[lson].sqr;
t[now].sumlen=t[lson].sumlen;
t[now].sqrlen=t[lson].sqrlen;
if(r>mid) {
t[now].sum=(t[now].sum+t[rson].sum)%mod;
t[now].sqr=(t[now].sqr+t[rson].sqr)%mod;
t[now].sumlen=(t[now].sumlen+t[rson].sumlen)%mod;
t[now].sqrlen=(t[now].sqrlen+t[rson].sqrlen)%mod;
}
t[now].sqr=(t[now].sqr+(ll)t[now].sum*t[now].sum)%mod;
t[now].sumlen=(t[now].sumlen+t[now].len*t[now].sum%mod)%mod;
t[now].sqrlen=(t[now].sqrlen+(ll)t[now].len*t[now].len%mod)%mod;
}
inline void mark(int l,int r,int now,ll v)
{
t[now].lazy+=v, t[now].lazy%=mod;
t[now].sqr=(t[now].sqr+((v*v)%mod)*t[now].sqrlen%mod+2ll*v*t[now].sumlen%mod)%mod;
t[now].sumlen=(t[now].sumlen+(v*t[now].sqrlen)%mod)%mod;
t[now].sum=(t[now].sum+(t[now].len*v)%mod)%mod;
}
inline void pushdown(int l,int r,int now)
{
int mid=(l+r)>>1;
if(t[now].lazy)
{
mark(l,mid,lson,t[now].lazy);
if(r>mid) mark(mid+1,r,rson,t[now].lazy);
t[now].lazy=0;
}
}
void build(int l,int r,int now) {
t[now].len=r-l+1;
if(l==r) {
t[now].sum=arr[l];
t[now].sqr=(ll)arr[l]*arr[l]%mod;
t[now].sumlen=t[now].len*t[now].sum%mod;
t[now].sqrlen=(ll)t[now].len*t[now].len%mod;
return;
}
int mid=(l+r)>>1;
if(l<=mid) build(l,mid,lson);
if(r>mid) build(mid+1,r,rson);
pushup(l,r,now);
}
void update(int l,int r,int now,int L,int R,ll v) {
if(l>=L&&r<=R) {
mark(l,r,now,v);
return;
}
pushdown(l,r,now);
int mid=(l+r)>>1;
if(L<=mid) update(l,mid,lson,L,R,v);
if(R>mid) update(mid+1,r,rson,L,R,v);
pushup(l,r,now);
}
int main() {
using namespace IO;
int i,j,cas;
// setIO("b");
n=rd(),Q=rd();
for(i=1;i<=n;++i) arr[i]=rd();
build(1,n,1);
for(cas=1;cas<=Q;++cas) {
int opt,l,r,v;
opt=rd();
if(opt==1) {
l=rd(),r=rd(),v=rd(), update(1,n,1,l,r,v);
}
if(opt==2) {
ll a=t[1].sqr,b=t[1].sum;
printf("%lld\n",a*inv(b)%mod);
}
}
return 0;
}