题目链接:Assign the task
题解
我们不难发现给树的一个节点及其子节点染色,可以转化为线段树区间修改模型。
我们通过dfs,对每一个点进行重新编号。然后会发现重新编号后每个节点及其所有子节点的标号是连续的,所以我们可以转化为区间修改模型。剩下的就是线段树单点修改区间查询问题。
本题我认为关键点在于dfs序对节点进行重新编号,这种从无序转化为连续的思想值得我们学习、举一反三。
代码
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<bitset>
#include<cassert>
#include<cctype>
#include<cmath>
#include<cstdlib>
#include<ctime>
#include<deque>
#include<iomanip>
#include<list>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<vector>
#include<unordered_set>
#include<unordered_map>
using namespace std;
//extern "C"{void *__dso_handle=0;}
typedef long long ll;
typedef long double ld;
#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define pii pair<int,int>
#define lowbit(x) x&-x
const double PI=acos(-1.0);
const double eps=1e-6;
const ll mod=1e9+7;
const int inf=0x3f3f3f3f;
const int maxn=5e4+10;
const int maxm=1e7+10;
#define ios ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
vector<int>g[maxn];
int tree[maxn<<4],addv[maxn<<4],ls[maxn<<4],rs[maxn<<4];
map<int,int> nt;
void pushdown(int p)
{
if(addv[p])
{
tree[p*2]=tree[p*2+1]=addv[p];
addv[p*2]=addv[p*2+1]=addv[p];
addv[p]=0;
}
}
void build(int l,int r,int p)
{
tree[p]=-1;
addv[p]=0;
if(l==r) { return ;}
int mid=(l+r)>>1;
build(l, mid, p*2);
build(mid+1,r,p*2+1);
}
void add(int l,int r,int p,int ql,int qr,int y)
{
if(ql<=l && r<=qr)
{
tree[p]=addv[p]=y;
return;
}
pushdown(p);
int mid=(l+r)>>1;
if(ql<=mid) add(l,mid,p*2,ql,qr,y);
if(qr>mid) add(mid+1, r, p*2+1,ql,qr, y);
}
int query(int l,int r,int p,int num)
{
if(addv[p]) pushdown(p);
if(l==r) return tree[p];
int mid=(l+r)>>1;
if(num<=mid) return query(l, mid, p*2, num);
else return query(mid+1, r, p*2+1, num);
}
int cnt=0;
void dfs(int u)
{
ls[u]=++cnt;
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
dfs(v);
}
rs[u]=cnt;
}
int main()
{
int t;
scanf("%d",&t);
for(int cas=1;cas<=t;cas++)
{
printf("Case #%d:\n",cas);
cnt=0; nt.clear();
for(int i=1;i<maxn;i++) g[i].clear();
memset(ls, 0, sizeof(ls));
memset(rs, 0, sizeof(rs));
int n; scanf("%d",&n);
for(int i=1;i<n;i++)
{
int u,v; scanf("%d%d",&u,&v);
g[v].push_back(u); nt[u]=1;
}
build(1, n, 1);
for(int i=1;i<=n;i++)
{
if(!nt[i]) { dfs(i); break; }
}
int m; scanf("%d",&m);
while(m--)
{
char c[5];
scanf("%s",c);
if(c[0]=='C')
{
int num; scanf("%d",&num);
printf("%d\n",query(1, n, 1, ls[num]));
}
else
{
int x,y; scanf("%d%d",&x,&y);
add(1,n,1,ls[x],rs[x],y);
}
}
}
}