Query on the subtree
Time Limit: 16000/8000 MS (Java/Others) Memory Limit: 131072/131072 K (Java/Others)Total Submission(s): 1297 Accepted Submission(s): 405
Problem Description
bobo has a tree, whose vertices are conveniently labeled by 1,2,…,n. At the very begining, the i-th vertex is assigned with weight w
i.
There are q operations. Each operations are of the following 2 types:
Change the weight of vertex v into x (denoted as "! v x"),
Ask the total weight of vertices whose distance are no more than d away from vertex v (denoted as "? v d").
Note that the distance between vertex u and v is the number of edges on the shortest path between them.
There are q operations. Each operations are of the following 2 types:
Change the weight of vertex v into x (denoted as "! v x"),
Ask the total weight of vertices whose distance are no more than d away from vertex v (denoted as "? v d").
Note that the distance between vertex u and v is the number of edges on the shortest path between them.
Input
The input consists of several tests. For each tests:
The first line contains n,q (1≤n,q≤10 5). The second line contains n integers w 1,w 2,…,w n (0≤w i≤10 4). Each of the following (n - 1) lines contain 2 integers a i,b idenoting an edge between vertices a i and b i (1≤a i,b i≤n). Each of the following q lines contain the operations (1≤v≤n,0≤x≤10 4,0≤d≤n).
The first line contains n,q (1≤n,q≤10 5). The second line contains n integers w 1,w 2,…,w n (0≤w i≤10 4). Each of the following (n - 1) lines contain 2 integers a i,b idenoting an edge between vertices a i and b i (1≤a i,b i≤n). Each of the following q lines contain the operations (1≤v≤n,0≤x≤10 4,0≤d≤n).
Output
For each tests:
For each queries, a single number denotes the total weight.
For each queries, a single number denotes the total weight.
Sample Input
4 3 1 1 1 1 1 2 2 3 3 4 ? 2 1 ! 1 0 ? 2 1 3 3 1 2 3 1 2 1 3 ? 1 0 ? 1 1 ? 1 2
Sample Output
3 2 1 6 6
Author
Xiaoxu Guo (ftiasch)
Source
题意:给出一颗n个点的树,每个点有一个权值,有两种操作,一种是将某个点的权值修改为v,另一种是查询距离点u不超过d的点的权值和
解题思路:动态点分治+树状数组
#include <iostream>
#include <cstdio>
#include <cstring>
#include <string>
#include <algorithm>
#include <cmath>
#include <map>
#include <set>
#include <stack>
#include <queue>
#include <vector>
#include <bitset>
#include <functional>
using namespace std;
#define LL long long
const double pi = acos(-1.0);
const int INF = 0x3f3f3f3f;
const int maxn = 2e5 + 10;
int n, q, x, y;
int s[maxn], nt[maxn], e[maxn], cnt;
int sum[maxn], mx[maxn], vis[maxn], w[maxn], pre[maxn];
char ch[5];
vector<int>d[maxn], D[maxn];
vector<pair<int, int> >dis[maxn];
int lowbit(int k) { return k&-k; }
int dfs(int k, int fa, int p)
{
int ans = mx[k] = (sum[k] = 1) - 1;
for (int i = s[k]; ~i; i = nt[i])
{
if (e[i] == fa || vis[e[i]]) continue;
int temp = dfs(e[i], k, p);
sum[k] += sum[e[i]];
mx[k] = max(mx[k], sum[e[i]]);
if (mx[temp] < mx[ans]) ans = temp;
}
mx[k] = max(mx[k], p - sum[k]);
return mx[k] < mx[ans] ? k : ans;
}
int getlen(int k, int fa, int len, int rt)
{
if (rt) dis[rt].push_back(make_pair(k, len));
int ans = len;
for (int i = s[k]; ~i; i = nt[i])
{
if (e[i] == fa || vis[e[i]]) continue;
ans = max(ans, getlen(e[i], k, len + 1, rt));
}
return ans;
}
void put(int k, int fa, int len, vector<int>&p)
{
if (len)
{
int Size = p.size();
for (int i = len; i < Size; i += lowbit(i)) p[i] += w[k];
}
for (int i = s[k]; ~i; i = nt[i])
{
if (e[i] == fa || vis[e[i]]) continue;
put(e[i], k, len + 1, p);
}
}
int build(int k, int p, int fa)
{
int y = dfs(k, k, p);
pre[y] = fa, vis[y] = 1;
int len = getlen(y, y, 0, y);
sort(dis[y].begin(), dis[y].end());
for (int i = 0; i <= len; i++) d[y].push_back(0);
put(y, y, 0, d[y]);
for (int i = s[y]; ~i; i = nt[i])
{
if (vis[e[i]]) continue;
int temp;
if (sum[e[i]] < sum[y]) temp = build(e[i], sum[e[i]], y);
else temp = build(e[i], p - sum[y], y);
len = getlen(e[i], y, 1, 0);
for (int i = 0; i <= len; i++) D[temp].push_back(0);
put(e[i], y, 1, D[temp]);
}
vis[y] = 0;
return y;
}
int solve(int rt, int x, int y)
{
int ans = 0, Size;
if (rt == x)
{
Size = d[rt].size() - 1;
for (int i = min(y, Size); i; i -= lowbit(i))
ans += d[rt][i];
ans += w[x];
}
if (pre[x] != -1)
{
int k = dis[pre[x]][lower_bound(dis[pre[x]].begin(), dis[pre[x]].end(), make_pair(rt, 0)) - dis[pre[x]].begin()].second;
if (k <= y)
{
ans += w[pre[x]];
Size = d[pre[x]].size() - 1;
for (int i = min(y - k, Size); i; i -= lowbit(i)) ans += d[pre[x]][i];
Size = D[x].size() - 1;
for (int i = min(y - k, Size); i; i -= lowbit(i)) ans -= D[x][i];
}
ans += solve(rt, pre[x], y);
}
return ans;
}
void update(int rt, int x, int y)
{
if (pre[x] == -1) return;
int k = dis[pre[x]][lower_bound(dis[pre[x]].begin(), dis[pre[x]].end(), make_pair(rt, 0)) - dis[pre[x]].begin()].second;
int Size = d[pre[x]].size();
for (int i = k; i < Size; i += lowbit(i)) d[pre[x]][i] += y;
Size = D[x].size();
for (int i = k; i < Size; i += lowbit(i)) D[x][i] += y;
update(rt, pre[x], y);
}
int main()
{
while (~scanf("%d %d", &n, &q))
{
memset(s, -1, sizeof s);
mx[cnt = 0] = INF;
for (int i = 1; i <= n; i++) dis[i].clear(), d[i].clear(), D[i].clear();
for (int i = 1; i <= n; i++) scanf("%d", &w[i]);
for (int i = 1; i < n; i++)
{
scanf("%d%d", &x, &y);
nt[cnt] = s[x], s[x] = cnt, e[cnt++] = y;
nt[cnt] = s[y], s[y] = cnt, e[cnt++] = x;
}
build(1, n, -1);
while (q--)
{
scanf("%s%d%d", ch, &x, &y);
if (ch[0] == '?') printf("%d\n", solve(x, x, y));
else
{
update(x, x, y - w[x]);
w[x] = y;
}
}
}
return 0;
}