题目大意
给出一棵有n个节点树,树的编号为i节点有一个权值a[i],且边也有权值c[num],每一个点对的贡献为p[i,j]=dis[i,j]*(a[i] xor a[j]),同时有t次修改操作,每次求改一个节点的权值,每次修改后要输出
∑n−1i=1∑nj=i+1p[i,j]
数据范围a[i]<16384= 214 ,c[num]<=30000,n<=30000,t<=30000
异或——拆位
把a[i]拆成14个位,每个位为0或1,这样处理xor操作就方便多了。
点分治
构造点分治树,
记录的信息有每一位为0和1的数量,和到子树根结点的距离总和。
还要记录每个点在每一层属于哪个树。
需要记录的根和根儿子的信息。
每次修改,我们一层一层的修改同时更新答案。
有序点对没有关系,求出无序的ans,输出ans/2就好了。
代码
#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
using namespace std;
const int maxn=30005;
int sum=0,cf[maxn],n,t,root,next[maxn*2],k[maxn],g[maxn*2],a[maxn],c[maxn*2],s[maxn],f[maxn],num,be[maxn][16],sbe[maxn][16];
long long ans,tree[2][maxn][15][2],d[maxn][15],l,r,cnt[2][maxn][15][2];
bool bz[maxn];
void add(int x,int y,int z)
{
next[++num]=k[x];
k[x]=num;
g[num]=y;
c[num]=z;
}
void dfs(int x,int y){
int i=k[x];
s[x]=1;f[x]=0;
while (i>0){
if (g[i]!=y&&!bz[g[i]]){
dfs(g[i],x);
s[x]+=s[g[i]];
f[x]=max(f[x],s[g[i]]);
}
i=next[i];
}
f[x]=max(f[x],f[0]-1-s[x]);
if (f[x]<f[root]) root=x;
}
void put(int b,int k,int x,long long y,int bk){
for (int i=0;i<14;i++){
tree[b][k][i][x%2]+=bk*y;
cnt[b][k][i][x%2]+=bk;
y*=2;x/=2;
}
}
void make(int p,int t,int x,int y){
put(0,t,a[x],d[x][p],1);be[x][p]=t;cf[++num]=x;
sbe[x][p]=sbe[y][p];if (y==t) sbe[x][p]=++sum;
s[x]=1;
if (y>0)
{ put(1,sbe[x][p],a[x],d[x][p],1);
}
int i=k[x];
while (i>0){
if (g[i]!=y&&!bz[g[i]])
{
d[g[i]][p]=d[x][p]+c[i];
make(p,t,g[i],x);s[x]+=s[g[i]];
}
i=next[i];
}
}
void count(int x,int y,long long g){
int z=a[x];long long w=d[x][y];
for(int i=0;i<14;i++){
int p=1-z%2;
ans+=g*(tree[0][be[x][y]][i][p]-tree[1][sbe[x][y]][i][p]);
ans+=g*w*(cnt[0][be[x][y]][i][p]-cnt[1][sbe[x][y]][i][p]);
z/=2;w*=2;
}
}
void fen(int x,int t){
d[x][t]=0;
bz[x]=1;num=0;
make(t,x,x,0);
for (int i=1;i<=num;i++) count(cf[i],t,1);
int i=k[x];
while (i>0){
if (!bz[g[i]]){
f[root=0]=s[g[i]]+1;
dfs(g[i],x);
fen(root,t+1);
}
i=next[i];
}
}
main(){
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
scanf("%d",&n);num=0;
memset(tree,0,sizeof(tree));memset(cnt,0,sizeof(cnt));
for (int i=1;i<=n;i++) scanf("%d",&a[i]);
for (int i=1;i<=n-1;i++){
int x,y,z;scanf("%d%d%d",&x,&y,&z);
add(x,y,z);add(y,x,z);
}f[root=0]=n+1;
dfs(1,0);
ans=0;
fen(root,0);
ans/=2;
scanf("%d",&t);
for (int i=1;i<=t;i++){
int x,y;scanf("%d%d",&x,&y);
for (int j=0;j<=15;j++){
count(x,j,-1);
put(0,be[x][j],a[x],d[x][j],-1);
put(1,sbe[x][j],a[x],d[x][j],-1);
if (be[x][j]==x) break;
}a[x]=y;
for (int j=0;j<=15;j++){
count(x,j,1);
put(0,be[x][j],a[x],d[x][j],1);
put(1,sbe[x][j],a[x],d[x][j],1);
if (be[x][j]==x) break;
}
printf("%lld\n",ans);
}return 0;
}