题目大意是:
对n个数进行如下操作
1 x y z , 令a[i] = a[i] + z, x <= i <= y
2 x y z, 令 a[i] = a[i] * z, x <= i <= y
3 x y z, 令 a[i] = z, x <= i <= y
4 x y p, 输出 a[i]^z 的和, x<= i <=y
这一题如果要查询的只是一次方的和的话,这一题就是一道非常简单的线段树了。
所以要特别处理的就是怎么维护区间的“平方和” 与“立方和”
我们可以对一段区间乘法和加法操作过后的值进行维护
如a[i]变成了A*a[i]+B
(A*a[i]+b)^1 = A*a[i]+B
(A*a[i]+b)^2 = (A^2)*(a[I]^2) + 2*A*B*a[i] + (B^2)
(A*a[i]+b)^3 = (A^3)*(a[i]^3) + 3*(A^2)*B*(a[i]^2) + 3*A*(B^2)*(a[i]^1) + (B^3)
我们在之前有
a[i]的和:value[1]
a[i]^2的和:value[2]
a[i]^3的和:value[3]
那么对于区间[x, y]
(A*a[i]+b)^1的和就是:A*value[1] + b*(y-x+1)
(A*a[i]+b)^2的和就是: (A^2)*value[2] + 2*A*B*value[1] + (B^2)*(y-x+1)
(A*a[i]+b)^3的和就是:(A^3)*value[3] + 3*(A^2)*B*value[2] + 3*A*(B^2)*value[1] + (B^3)*(y-x+1)
这样我们就可以用O(1)的时间来维护一段区间的乘法和加法操作
对于将一段区间的赋值操作,只要加两个标记:
is_num:表示这个一段数都是同一个整数
the_num:表示这个一段是同一个整数的情况下,a[i]为什么。
并且这个标记会随着DFS下放,这样,就只有最上层的is_num是有效的。
我线段树的类是这么建立的
class Node
{
public:
long long a, b, the_num; //a*a[i]+b这个是需要下方的标记,表示该段区间被乘和被加过多少
bool is_num; //判断是否是被整段覆盖的数
long long value[4]; //记录这一段数的和、平方和、立方和
long long x, y; //计算该段区间是[x, y]
//输出该类的内容
void show()
{
cout << x << ' ' << y << ' ' << value[1] << ' ' << value[2] << ' ' << value[3] << endl;
cout << a << ' ' << b << endl;
}
//初始化
void init(long long xx = 0, long long yy = 0)
{
x = xx;
y = yy;
value[1] = 0;
value[2] = 0;
value[3] = 0;
a = 1;
b = 0;
the_num = 0;
}
//将整段的a[i] * new_a + new_b 并更新和、平方和、立方和的值
void updata(long long new_a, long long new_b)
{
value[3] = pow(new_a, 3) * value[3]
+ 3 * pow(new_a, 2) * pow(new_b, 1) * value[2]
+ 3 * pow(new_a, 1) * pow(new_b, 2) * value[1]
+ pow(new_b, 3) * (y - x + 1);
value[3] %= MM;
value[2] = pow(new_a, 2) * value[2]
+ 2 * new_a * new_b * value[1]
+ pow(new_b, 2) * (y - x + 1);
value[2] %= MM;
value[1] = new_a * value[1] + new_b * (y - x + 1);
value[1] %= MM;
a *= new_a;
a %= MM;
b *= new_a;
b += new_b;
b %= MM;
}
//对一整段的值赋值为num
void to_num(long long num)
{
is_num = true;
the_num = num;
value[1] = pow(num, 1) * (y - x + 1);
value[1] %= MM;
value[2] = pow(num, 2) * (y - x + 1);
value[2] %= MM;
value[3] = pow(num, 3) * (y - x + 1);
value[3] %= MM;
a = 1;
b = 0;
}
};
更新操作:
void updata(long long type, long long x, long long y, long long z, long long t)
{
long long sum = 0;
if (x > T[t].y || y < T[t].x)
return;
else if (x <= T[t].x && T[t].y <= y)
{
if (type == 1)
T[t].updata(1, z);
else if (type == 2)
T[t].updata(z, 0);
else if (type == 3)
{
T[t].is_num = true;
T[t].to_num(z);
}
}
else
{
if (T[t].is_num)
{
T[t].is_num = false;
T[t*2].is_num = true;
T[t*2].to_num(T[t].the_num);
T[t*2+1].is_num = true;
T[t*2+1].to_num(T[t].the_num);
}
T[t*2].updata(T[t].a, T[t].b);
T[t*2+1].updata(T[t].a, T[t].b);
updata(type, x, y, z, t*2);
updata(type, x, y, z, t*2+1);
T[t].value[1] = T[t*2].value[1] + T[t*2+1].value[1];
T[t].value[1] %= MM;
T[t].value[2] = T[t*2].value[2] + T[t*2+1].value[2];
T[t].value[2] %= MM;
T[t].value[3] = T[t*2].value[3] + T[t*2+1].value[3];
T[t].value[3] %= MM;
T[t].a = 1;
T[t].b = 0;
return;
}
}
查询操作和更新类似,同样需要在DFS过程中下放标记
最后贴上整个的代码:
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
using namespace std;
const long long MM = 10007;
const long long MAXN = 1E5+100;
long long pow(long long x, long long n)
{
long long ans = 1, i;
for (i = 1; i <= n; i++)
{
ans *= x;
ans %= MM;
}
return ans;
}
class Node
{
public:
long long a, b, the_num; //a*a[i]+b这个是需要下方的标记,表示该段区间被乘和被加过多少
bool is_num; //判断是否是被整段覆盖的数
long long value[4]; //记录这一段数的和、平方和、立方和
long long x, y; //计算该段区间是[x, y]
//输出该类的内容
void show()
{
cout << x << ' ' << y << ' ' << value[1] << ' ' << value[2] << ' ' << value[3] << endl;
cout << a << ' ' << b << endl;
}
//初始化
void init(long long xx = 0, long long yy = 0)
{
x = xx;
y = yy;
value[1] = 0;
value[2] = 0;
value[3] = 0;
a = 1;
b = 0;
the_num = 0;
}
//将整段的a[i] * new_a + new_b 并更新和、平方和、立方和的值
void updata(long long new_a, long long new_b)
{
value[3] = pow(new_a, 3) * value[3]
+ 3 * pow(new_a, 2) * pow(new_b, 1) * value[2]
+ 3 * pow(new_a, 1) * pow(new_b, 2) * value[1]
+ pow(new_b, 3) * (y - x + 1);
value[3] %= MM;
value[2] = pow(new_a, 2) * value[2]
+ 2 * new_a * new_b * value[1]
+ pow(new_b, 2) * (y - x + 1);
value[2] %= MM;
value[1] = new_a * value[1] + new_b * (y - x + 1);
value[1] %= MM;
a *= new_a;
a %= MM;
b *= new_a;
b += new_b;
b %= MM;
}
//对一整段的值赋值为num
void to_num(long long num)
{
is_num = true;
the_num = num;
value[1] = pow(num, 1) * (y - x + 1);
value[1] %= MM;
value[2] = pow(num, 2) * (y - x + 1);
value[2] %= MM;
value[3] = pow(num, 3) * (y - x + 1);
value[3] %= MM;
a = 1;
b = 0;
}
};
Node T[MAXN*4];
long long n, m;
void build(long long x, long long y, long long tt)
{
long long mid;
if (x == y)
{
T[tt].init(x, y);
T[tt].is_num = false;
}
else
{
mid = (x+y)/2;
build(x, mid, tt*2);
build(mid+1, y, tt*2+1);
T[tt].init(x, y);
T[tt].is_num = false;
}
}
long long query(long long type, long long x, long long y, long long z, long long t)
{
long long sum = 0;
if (x > T[t].y || y < T[t].x)
return 0;
else if (x <= T[t].x && T[t].y <= y)
return T[t].value[z];
else
{
if (T[t].is_num)
{
T[t].is_num = false;
T[t*2].is_num = true;
T[t*2].to_num(T[t].the_num);
T[t*2+1].is_num = true;
T[t*2+1].to_num(T[t].the_num);
}
T[t*2].updata(T[t].a, T[t].b);
T[t*2+1].updata(T[t].a, T[t].b);
sum = query(type, x, y, z, t*2);
sum += query(type, x, y, z, t*2+1);
sum %= MM;
T[t].value[1] = T[t*2].value[1] + T[t*2+1].value[1];
T[t].value[1] %= MM;
T[t].value[2] = T[t*2].value[2] + T[t*2+1].value[2];
T[t].value[2] %= MM;
T[t].value[3] = T[t*2].value[3] + T[t*2+1].value[3];
T[t].value[3] %= MM;
T[t].a = 1;
T[t].b = 0;
return sum;
}
}
void updata(long long type, long long x, long long y, long long z, long long t)
{
long long sum = 0;
if (x > T[t].y || y < T[t].x)
return;
else if (x <= T[t].x && T[t].y <= y)
{
if (type == 1)
T[t].updata(1, z);
else if (type == 2)
T[t].updata(z, 0);
else if (type == 3)
{
T[t].is_num = true;
T[t].to_num(z);
}
}
else
{
if (T[t].is_num)
{
T[t].is_num = false;
T[t*2].is_num = true;
T[t*2].to_num(T[t].the_num);
T[t*2+1].is_num = true;
T[t*2+1].to_num(T[t].the_num);
}
T[t*2].updata(T[t].a, T[t].b);
T[t*2+1].updata(T[t].a, T[t].b);
updata(type, x, y, z, t*2);
updata(type, x, y, z, t*2+1);
T[t].value[1] = T[t*2].value[1] + T[t*2+1].value[1];
T[t].value[1] %= MM;
T[t].value[2] = T[t*2].value[2] + T[t*2+1].value[2];
T[t].value[2] %= MM;
T[t].value[3] = T[t*2].value[3] + T[t*2+1].value[3];
T[t].value[3] %= MM;
T[t].a = 1;
T[t].b = 0;
return;
}
}
void show(long long t)
{
long long mid;
if (T[t].x == T[t].y)
T[t].show();
else
{
T[t].show();
show(t*2);
show(t*2+1);
}
}
int main()
{
long long i, type, x, y, z, ans;
while (scanf("%I64d%I64d",&n,&m))
{
if (n == 0 && m == 0)
break;
build(1, n, 1);
for (i = 1; i <= m; i ++)
{
scanf("%I64d%I64d%I64d%I64d",&type, &x, &y, &z);
if (type == 4)
{
ans = query(type, x, y, z, 1);
printf("%I64d\n",ans);
}
else
updata(type, x, y, z, 1);
}
}
return 0;
}
最后附上Hdu上面别人提供的一组数据。。。。我WA了一次,就是用这个数据改的。。。。。
13 4 1 1 4 6988 4 1 3 1 2 1 13 640 4 1 3 1 15 3 1 1 12 10 3 1 15 2 4 1 4 1 78 4 1 1 44 7815 2 1 20 542 3 1 47 1 4 1 36 1 0 0 ans: 950 7580 8 36