题面
wzy实在是太强了,所以他百无聊赖之下决定出一道简单的问题。
有一棵n个节点的树,以1号点为根,每个点可以有个权值,一开始每个点的权值都是0。你有m次操作,操作分为修改(modify) 和 查询 (query)。
修改操作有3种,可以把一棵子树打回原形,升级和进化。形如 Modify1 x c 代表把以x为根的子树的权值变成x。 Modify2 x c 代表把以x为根的子树的权值统统加上c 。 Modify3 x c代表把以x为根的子树的权值统统乘上c 。
查询操作询问一颗子树内的权值之和。形如 Query x 代表询问以x为根的子树的权值和。
由于这个问题实在是太简单了,wzy说他有9种方法可以切掉这个题。聪明的你有几种方法切题呢?
注意:为了方便,所有数都对1e9+7取模。
输入格式:
第1行两个正整数n,m 代表节点个数和操作个数。
第2~n 行,给出两个数x, y,代表x 点 和 y 点之间有一条边。 保证 1<= x, y <= n
第n + 1 ~ n + m 行 每行给出Modify x c 或者 Query x 如题面所示含义。
输出格式:
对于每个询问输出答案。
样例输入1:
5 5
1 2
1 3
1 4
1 5
Modify3 1 10
Query 1
Modify1 2 1
Modify2 1 1
Query 1
样例输出1:
0
6
样例输入2:
5 6
1 2
1 3
3 4
3 5
Modify1 1 2
Query 3
Modify2 3 2
Query 3
Modify3 1 2
Query 1
样例输出2:
6
12
32
输入范围:
对于30%的数据:1<=n,m<=1e2 。
对于100%的数据:1<=n,m<=1e5,0 <= c <= 1e9。
30 p t s 30pts 30pts
暴力子树修改什么的。
100 p t s 100pts 100pts
把 d f s dfs dfs序做出来。相当于线段树的区间加,区间乘,区间赋值。 线段树的每个区间打三个标记。注意 s e t set set操作会清空一个区间的乘或者加标记,区间加或者乘会修改 s e t set set操作的标记。
c o d e s : codes: codes:
#include <cstdio>//dfs序+线段树,子树的操作对应到了dfs序的一段连续区间。
#define ls (k << 1)//线段树有3个标记,下传的时候需要注意先后顺序。
#define rs ((k << 1) | 1)//一般规定set > mul > add
typedef long long LL;//具体实现看代码吧,线段树这东西人人写的可能不太一样。。。O((n + m)logn)
const int N = 1e5 + 5, P = 1e9 + 7;
char s[10];
int head[N], nxt[N<<1], v[N<<1], tot;
int siz[N], id[N], pos[N], tim;
int n, m;
struct poi {
int l, r;
LL sum, settag, addtag, multag;
}tr[N<<2];
inline int read() {
int x = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9'){if (ch == '-')f = -1; ch = getchar();}
while (ch >= '0' && ch <= '9'){x = x * 10 + ch - '0'; ch = getchar();}
return x * f;
}
inline void add(int x, int y) {
v[++tot] = y;
nxt[tot] = head[x];
head[x] = tot;
}
inline void dfs(int x, int fa) {
pos[++tim] = x;
id[x] = tim;
siz[x] = 1;
for (int i = head[x]; i; i = nxt[i]) {
int y = v[i];
if (y != fa) {
dfs(y, x);
siz[x] += siz[y];
}
}
}
inline void build(int k, int s, int t) {
tr[k].l = s; tr[k].r = t;
tr[k].settag = -1; tr[k].multag = 1;
if (s == t) return ;
int mid = (s + t) / 2;
build(ls, s, mid);
build(rs, mid + 1, t);
}
inline void pushup(int k){tr[k].sum = (tr[ls].sum + tr[rs].sum) % P;}
inline void pushdown(int k) {
int x = tr[k].r - tr[k].l + 1;
if (tr[k].settag != -1) { //
tr[ls].sum = (x - x / 2) * tr[k].settag % P;
tr[ls].settag = tr[k].settag;
tr[ls].multag = 1;
tr[ls].addtag = 0;
tr[rs].sum = (x / 2) * tr[k].settag % P;
tr[rs].settag = tr[k].settag;
tr[rs].multag = 1;
tr[rs].addtag = 0;
tr[k].settag = -1;
}
else {
tr[ls].sum = (tr[ls].sum * tr[k].multag % P + tr[k].addtag * (x - x / 2)) % P;
if (tr[ls].settag != -1) tr[ls].settag = (tr[ls].settag * tr[k].multag % P + tr[k].addtag) % P;
else {
tr[ls].addtag = (tr[ls].addtag * tr[k].multag % P + tr[k].addtag) % P;
tr[ls].multag = tr[ls].multag * tr[k].multag % P;
}
tr[rs].sum = (tr[rs].sum * tr[k].multag % P + tr[k].addtag * (x / 2)) % P;
if (tr[rs].settag != -1) tr[rs].settag = (tr[rs].settag * tr[k].multag % P + tr[k].addtag) % P;
else {
tr[rs].addtag = (tr[rs].addtag * tr[k].multag % P + tr[k].addtag) % P;
tr[rs].multag = tr[rs].multag * tr[k].multag % P;
}
tr[k].multag = 1; tr[k].addtag = 0;
}
}
inline void intervaladd(int k, int s, int t, int x) {
int l = tr[k].l, r = tr[k].r;
if (l == s && r == t) {
tr[k].sum += 1LL * (r - l + 1) * x;
tr[k].sum %= P;
if (tr[k].settag == -1) tr[k].addtag = (tr[k].addtag + x) % P;
else tr[k].settag += x, tr[k].settag %= P;
return ;
}
if (tr[k].settag != -1 || tr[k].addtag || tr[k].multag != 1) pushdown(k);
int mid = (l + r) / 2;
if (t <= mid)intervaladd(ls, s, t, x);
else if (s > mid)intervaladd(rs, s, t, x);
else {
intervaladd(ls, s, mid, x);
intervaladd(rs, mid + 1, t, x);
}
pushup(k);
}
inline void intervalset(int k, int s, int t, int x) {
int l = tr[k].l, r = tr[k].r;
if (l == s && r == t) {
tr[k].sum = 1LL * (r - l + 1) * x % P;
tr[k].settag = x;
tr[k].multag = 1;
tr[k].addtag = 0;
return ;
}
if (tr[k].settag != -1 || tr[k].multag != 1 || tr[k].addtag)pushdown(k);
int mid = (l + r) / 2;
if (t <= mid)intervalset(ls, s, t, x);
else if (s > mid)intervalset(rs, s, t, x);
else {
intervalset(ls, s, mid, x);
intervalset(rs, mid + 1, t, x);
}
pushup(k);
}
inline void intervalmul(int k, int s, int t, int x) {
int l = tr[k].l, r = tr[k].r;
if (s == l && r == t) {
tr[k].sum = tr[k].sum * x % P;
if (tr[k].settag == -1) {
tr[k].multag = tr[k].multag * x % P;
tr[k].addtag = tr[k].addtag * x % P;
}
else tr[k].settag = tr[k].settag * x % P;
return ;
}
if (tr[k].settag != -1 || tr[k].multag != 1 || tr[k].addtag)pushdown(k);
int mid = (l + r) / 2;
if (t <= mid)intervalmul(ls, s, t, x);
else if (s > mid)intervalmul(rs, s, t, x);
else {
intervalmul(ls, s, mid, x);
intervalmul(rs, mid + 1, t, x);
}
pushup(k);
}
inline LL query(int k, int s, int t) {
int l = tr[k].l, r = tr[k].r;
if (s == l && t == r) return tr[k].sum;
if (tr[k].settag != -1 || tr[k].multag != 1 || tr[k].addtag)pushdown(k);
int mid = (l + r) / 2;
if (t <= mid)return query(ls, s, t);
else if (s > mid)return query(rs, s, t);
else return (query(ls, s, mid) + query(rs, mid + 1, t)) % P ;
}
int main() {
//freopen("data.in", "r", stdin);
//freopen("data.out", "w", stdout);
n = read(); m = read();
for (int i = 1; i < n; ++i) {
int x = read(), y = read();
add(x, y); add(y, x);
}
dfs(1, 1);
build(1, 1, n);
for (int i = 1; i <= m; ++i) {
scanf("%s", s);
if (s[0] == 'M') {
int x = read(), c = read();
if (s[6] == '2') intervaladd(1, id[x], id[x] + siz[x] - 1, c);
else if (s[6] == '1') intervalset(1, id[x], id[x] + siz[x] - 1, c);
else intervalmul(1, id[x], id[x] + siz[x] - 1, c);
}
else {
int x = read();
printf("%lld\n", query(1, id[x], id[x] + siz[x] - 1));
}
}
return 0;
}