Transformation
很麻烦的一道区间处理问题,考虑用线段树来解决。线段树的节点要维护三个和:普通和、平方和、立方和,要有三个懒标记:加法懒标记、乘法懒标记、统一修改懒标记。题目的难点在于如何维护这三个和以及如何处理三个懒标记之间的相互影响。综合起来考虑,此题的关键就是push_down函数,要注意到这三个懒标记之间会相互影响,应先处理优先级高的懒标记,优先级高的懒标记会影响优先级低的懒标记,此题中的优先级:统一修改懒标记 > 乘法懒标记 > 加法懒标记。至于在下放懒标记时如何维护三个和,下放统一修改懒标记和乘法懒标记还好说,下放加法懒标记时建议手推一下公式,实在不行搜一下公式,照着公式写就不容易出错啦。别的就没什么了,upd函数也得跟着改一改,如果嫌代码太长的话可以写些宏定义,把多次用到的变量换个简单的名,另外如果不想多次求余的话可以把三个和跟三个懒标记开成long long,只在每步的最后求余即可。
#include <iostream>
#include <cstring>
using namespace std;
#define N 100005
#define MOD 10007
#define lc rt << 1
#define rc rt << 1 | 1
struct Node
{
int l;
int r;
long long sum[3];
long long lazy[3];
}node[N << 2];
void push_up(int rt)
{
node[rt].sum[0] = (node[lc].sum[0] + node[rc].sum[0]) % MOD;
node[rt].sum[1] = (node[lc].sum[1] + node[rc].sum[1]) % MOD;
node[rt].sum[2] = (node[lc].sum[2] + node[rc].sum[2]) % MOD;
}
void push_down(int rt)
{
if(node[rt].lazy[2] != 0)
{
long long x = node[rt].lazy[2], len1 = node[lc].r - node[lc].l + 1, len2 = node[rc].r - node[rc].l + 1;
node[lc].sum[2] = (x * x * x * len1) % MOD;
node[lc].sum[1] = (x * x * len1) % MOD;
node[lc].sum[0] = (x * len1) % MOD;
node[rc].sum[2] = (x * x * x * len2) % MOD;
node[rc].sum[1] = (x * x * len2) % MOD;
node[rc].sum[0] = (x * len2) % MOD;
node[lc].lazy[0] = 0, node[lc].lazy[1] = 1, node[lc].lazy[2] = x;
node[rc].lazy[0] = 0, node[rc].lazy[1] = 1, node[rc].lazy[2] = x;
node[rt].lazy[2] = 0;
}
if(node[rt].lazy[1] != 1)
{
long long x = node[rt].lazy[1];
node[lc].sum[2] = (x * x * x * node[lc].sum[2]) % MOD;
node[lc].sum[1] = (x * x * node[lc].sum[1]) % MOD;
node[lc].sum[0] = (x * node[lc].sum[0]) % MOD;
node[rc].sum[2] = (x * x * x * node[rc].sum[2]) % MOD;
node[rc].sum[1] = (x * x * node[rc].sum[1]) % MOD;
node[rc].sum[0] = (x * node[rc].sum[0]) % MOD;
node[lc].lazy[0] = (x * node[lc].lazy[0]) % MOD, node[lc].lazy[1] = (x * node[lc].lazy[1]) % MOD;
node[rc].lazy[0] = (x * node[rc].lazy[0]) % MOD, node[rc].lazy[1] = (x * node[rc].lazy[1]) % MOD;
node[rt].lazy[1] = 1;
}
if(node[rt].lazy[0] != 0)
{
long long x = node[rt].lazy[0], len1 = node[lc].r - node[lc].l + 1, len2 = node[rc].r - node[rc].l + 1;
node[lc].sum[2] = (node[lc].sum[2] + 3 * (node[lc].sum[1] * x + node[lc].sum[0] * x * x) + x * x * x * len1) % MOD;
node[lc].sum[1] = (node[lc].sum[1] + 2 * node[lc].sum[0] * x + x * x * len1) % MOD;
node[lc].sum[0] = (node[lc].sum[0] + x * len1) % MOD;
node[rc].sum[2] = (node[rc].sum[2] + 3 * (node[rc].sum[1] * x + node[rc].sum[0] * x * x) + x * x * x * len2) % MOD;
node[rc].sum[1] = (node[rc].sum[1] + 2 * node[rc].sum[0] * x + x * x * len2) % MOD;
node[rc].sum[0] = (node[rc].sum[0] + x * len2) % MOD;
node[lc].lazy[0] = (node[lc].lazy[0] + x) % MOD;
node[rc].lazy[0] = (node[rc].lazy[0] + x) % MOD;
node[rt].lazy[0] = 0;
}
}
void build(int rt, int l, int r)
{
if(l == r)
node[rt] = {l, r, {0, 0, 0}, {0, 1, 0}};
else
{
node[rt] = {l, r, {0, 0, 0}, {0, 1, 0}};
int m = l + r >> 1;
build(lc, l, m);
build(rc, m + 1, r);
}
}
void upd(int rt, int opt, int l, int r, long long val)
{
if(node[rt].l >= l && node[rt].r <= r)
{
if(opt == 1)
{
int len = node[rt].r - node[rt].l + 1;
node[rt].sum[2] = (node[rt].sum[2] + 3 * (node[rt].sum[1] * val + node[rt].sum[0] * val * val) + val * val * val * len) % MOD;
node[rt].sum[1] = (node[rt].sum[1] + 2 * node[rt].sum[0] * val + val * val * len) % MOD;
node[rt].sum[0] = (node[rt].sum[0] + val * len) % MOD;
node[rt].lazy[0] = (node[rt].lazy[0] + val) % MOD;
}
if(opt == 2)
{
node[rt].sum[2] = (val * val * val * node[rt].sum[2]) % MOD;
node[rt].sum[1] = (val * val * node[rt].sum[1]) % MOD;
node[rt].sum[0] = (val * node[rt].sum[0]) % MOD;
node[rt].lazy[0] = (val * node[rt].lazy[0]) % MOD, node[rt].lazy[1] = (val * node[rt].lazy[1]) % MOD;
}
if(opt == 3)
{
int len = node[rt].r - node[rt].l + 1;
node[rt].sum[2] = (val * val * val * len) % MOD;
node[rt].sum[1] = (val * val * len) % MOD;
node[rt].sum[0] = (val * len) % MOD;
node[rt].lazy[0] = 0, node[rt].lazy[1] = 1, node[rt].lazy[2] = val;
}
}
else
{
push_down(rt);
int m = node[rt].l + node[rt].r >> 1;
if(l <= m)
upd(lc, opt, l, r, val);
if(r >= m + 1)
upd(rc, opt, l, r, val);
push_up(rt);
}
}
long long query(int rt, int l, int r, int p)
{
if(node[rt].l >= l && node[rt].r <= r)
return node[rt].sum[p - 1];
else
{
push_down(rt);
long long res = 0;
int m = node[rt].l + node[rt].r >> 1;
if(l <= m)
res = (res + query(lc, l, r, p)) % MOD;
if(r >= m + 1)
res = (res + query(rc, l, r, p)) % MOD;
return res;
}
}
int main()
{
int n, m;
while(cin >> n >> m)
{
if(!n && !m)
break;
memset(node, 0, sizeof(node));
build(1, 1, n);
while(m--)
{
int opt, x, y, z;
cin >> opt >> x >> y >> z;
if(opt != 4)
upd(1, opt, x, y, z);
else
cout << query(1, x, y, z) << endl;
}
}
return 0;
}