题意:初始给定一个带权的树。有两种操作,一种是(a,b,w),表示将(a,b)之间的权值改成w(这里保证ab之间有一条边)。第二种操作是查询树的所有节点对的距离之和。数据范围:2<=节点数<=100,000, 1<=查询<=50,000
思路:容易想到,在最后的距离和中,每条边的计算次数是去掉这条边形成的两个连通分量点数的乘积。所以我们可以事先求出每条边需要计算的次数,并维护这个总的距离之和。每次更新边权的时候,总距离和中只需要改变和这条边相关的量即可。
具体的,先求出每条边应该计算的次数,用类似树形dp的思路。然后用map保存边的信息,每次查询实时更新即可。
一开始表示边的数组开小了,但是平台不报RE而报TLE,所以我以为是STL浪费时间就又改了一版本,实际上STL也没问题,见下面两个版本:
版本1,用STL存储:
#include <cstdio>
#include <cstring>
#include <string>
#include <cstdlib>
#include <vector>
#include <queue>
#include <algorithm>
#include <cmath>
#include <map>
#include <iostream>
#define N 100005
#define INF 0x3fffffff
using namespace std;
struct edge{
int y,next,w;
}e[N<<1];
int first[N],top;
map<pair<int, int>, pair<int, int>> lookup;
void add(int x,int y,int w){
e[top].y = y;
e[top].w = w;
e[top].next = first[x];
first[x] = top++;
}
int dfs(int x,int fa){
int sum = 1;
for(int i = first[x];i!=-1;i=e[i].next){
int y = e[i].y;
if(y != fa){
int tmp = dfs(y, x);
lookup[make_pair(min(x,y), max(x,y))] = make_pair(tmp, e[i].w);
sum += tmp;
}
}
return sum;
}
int main(){
int n,m,a,b,w;
scanf("%d %d",&n,&m);
memset(first, -1, (n+1)*sizeof(int));
for(int i = 1;i<n;i++){
scanf("%d %d %d",&a,&b,&w);
add(a,b,w);
add(b,a,w);
}
dfs(1,-1);
long long res = 0;
map<pair<int, int>,pair<int, int> >::iterator it;
for(it = lookup.begin();it!=lookup.end();++it)
res += (long long)it->second.first * (n-it->second.first) * it->second.second;
while(m--){
string cmd;
cin >> cmd;
if(cmd == "QUERY"){
printf("%lld\n",res);
}else{
scanf("%d %d %d",&a,&b,&w);
it = lookup.find(make_pair(min(a,b), max(a,b)));
res += (long long)(w-it->second.second)*it->second.first*(n-it->second.first);
it->second.second = w;
}
}
return 0;
}
版本2:用数组存储:
#include <cstdio>
#include <cstring>
#include <string>
#include <cstdlib>
#include <vector>
#include <queue>
#include <algorithm>
#include <cmath>
#include <map>
#include <unordered_map>
#include <iostream>
#define N 100005
#define INF 0x3fffffff
using namespace std;
struct edge{
int y,next,w;
}e[N<<1];
int first[N],top=0;
int lookup[N][2];
unordered_map<long long, int> hh;
int id = 1;
void add(int x,int y,int w){
e[top].y = y;
e[top].w = w;
e[top].next = first[x];
first[x] = top++;
}
int test(int x,int y){
if(x>y)
swap(x,y);
long long tmp = (long long)x*N+y;
if(hh.find(tmp) == hh.end())
hh[tmp] = id++;
return hh[tmp];
}
int dfs(int x,int fa){
int sum = 1;
for(int i = first[x];i!=-1;i=e[i].next){
int y = e[i].y;
if(y != fa){
int tmp = dfs(y, x);
int eid = test(x,y);
lookup[eid][0] = tmp;
lookup[eid][1] = e[i].w;
sum += tmp;
}
}
return sum;
}
int main(){
int n,m,a,b,w;
scanf("%d %d",&n,&m);
memset(first, -1, (n+1)*sizeof(int));
for(int i = 1;i<n;i++){
scanf("%d %d %d",&a,&b,&w);
add(a,b,w);
add(b,a,w);
}
dfs(1,-1);
long long res = 0;
for(int i = 1;i<id;i++)
res += (long long)lookup[i][0]*(n-lookup[i][0])*lookup[i][1];
while(m--){
string cmd;
cin >> cmd;
if(cmd == "QUERY"){
printf("%lld\n",res);
}else{
scanf("%d %d %d",&a,&b,&w);
int tmpid = test(a,b);
res += (long long)(w-lookup[tmpid][1])*lookup[tmpid][0]*(n-lookup[tmpid][0]);
lookup[tmpid][1] = w;
}
}
return 0;
}