题解:
第一个数据n=1让我调了一年。。。
首先
a
n
s
=
∑
x
μ
(
x
)
f
(
x
)
ans=\sum_x\mu(x)f(x)
ans=∑xμ(x)f(x),
f
(
x
)
f(x)
f(x)表示有多少条路径的值全为
x
x
x的倍数。
显然路径上的边权值肯定是
x
x
x的倍数,我们对于每个
x
x
x把所有他倍数的边加入图中。
因为只用管
μ
(
x
)
̸
=
0
\mu(x) \not =0
μ(x)̸=0的值,每个边最多位于
2
7
2^7
27个图中。我们先把那些不会被修改的边加进去,修改的
O
(
Q
2
2
c
)
O(Q^2 2^c)
O(Q22c)加入即可。注意要支持加边、删边,并查集需要按秩合并。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair <int,int> pii;
const int RLEN=1<<18|1;
inline char nc() {
static char ibuf[RLEN],*ib,*ob;
(ib==ob) && (ob=(ib=ibuf)+fread(ibuf,1,RLEN,stdin));
return (ib==ob) ? -1 : *ib++;
}
inline int rd() {
char ch=nc(); int i=0,f=1;
while(!isdigit(ch)) {if(ch=='-')f=-1; ch=nc();}
while(isdigit(ch)) {i=(i<<1)+(i<<3)+ch-'0'; ch=nc();}
return i*f;
}
const int N=1e5+50, L=1e6+50, Q=1e2+10;
struct Sieve {
int pr[N],npr[L],pt;
int mu[L],mnpr[L];
Sieve() {
mu[1]=1;
for(int i=2;i<=1e6;i++) {
if(!npr[i]) {pr[++pt]=i; mu[i]=-1; mnpr[i]=i;}
for(int j=1;j<=pt;j++) {
long long k=i*pr[j];
if(k>1e6) break;
npr[k]=1; mnpr[k]=pr[j];
if(i%pr[j]) mu[k]=-mu[i];
else {mu[k]=0; break;}
}
}
}
vector <int> cont,num;
inline void dfs(int p,int v) {
if(p==cont.size()) {num.push_back(v); return;}
dfs(p+1,v); dfs(p+1,v*cont[p]);
}
inline vector <int> get_num(int val) {
cont.clear(); num.clear();
while(val^1) {
int t=mnpr[val];
cont.push_back(t);
while(!(val%t)) val/=t;
} dfs(0,1);
return num;
}
} C;
int n,q,a[N],b[N],w[N];
int id[Q],w2[Q],last[N];
int anc[N],sze[N],rk[N];
LL ori_ans[L],ans[Q];
set <int> vis;
vector <int> ori_edge[L];
vector <pii> new_edge[L];
struct data {
int x,y,rc;
} stk[N<<1];
int top; LL sum;
inline int ga(int x) {return (anc[x]==x) ? x : ga(anc[x]);}
inline void merge(int x,int y) {
x=ga(x), y=ga(y); data t;
if(rk[x]<rk[y]) swap(x,y);
t.x=x; t.y=y;
t.rc=(rk[x]==rk[y]) ? 1 : 0;
rk[x]+=t.rc;
sum+=(LL)sze[x]*sze[y];
sze[x]+=sze[y]; anc[y]=x;
stk[++top]=t;
}
inline void erase(data &t) {
anc[t.y]=t.y;
sze[t.x]-=sze[t.y];
sum-=(LL)sze[t.x]*sze[t.y];
rk[t.x]-=t.rc;
}
inline void solve(int x) {
for(auto v:ori_edge[x]) merge(a[v],b[v]);
ori_ans[x]=sum;
for(int i=0,j;i<new_edge[x].size();i=j+1) {
j=i;
while(j+1<new_edge[x].size() && new_edge[x][i].first==new_edge[x][j+1].first) ++j;
for(int k=i;k<=j;k++) merge(a[new_edge[x][k].second],b[new_edge[x][k].second]);
ans[new_edge[x][i].first]+=C.mu[x]*sum;
for(int k=i;k<=j;k++) erase(stk[top--]);
} while(top) erase(stk[top--]);
}
int main() {
n=rd();
if(n==1) {puts("0"); return 0;}
for(int i=1;i<n;i++) a[i]=rd(), b[i]=rd(), w[i]=rd();
for(int i=1;i<=n;i++) anc[i]=i, sze[i]=1;
q=rd()+1;
for(int i=1;i<=q;i++) {
id[i]=(i==1) ? 1 : rd(), w2[i]=(i==1) ? w[1] : rd();
vis.insert(id[i]);
}
for(int i=1;i<=q;i++) {
last[id[i]]=i;
for(auto j:vis) {
int val=(last[j] && last[j]<=i) ? w2[last[j]] : w[j];
vector <int> vec=C.get_num(val);
for(auto k:vec) new_edge[k].push_back(pii(i,j));
}
}
for(int i=1;i<n;i++) if(!vis.count(i)) {
vector <int> vec=C.get_num(w[i]);
for(auto k:vec) ori_edge[k].push_back(i);
}
for(int i=1;i<=1e6;i++)
if(C.mu[i] && (ori_edge[i].size() || new_edge[i].size())) solve(i);
for(int i=1;i<=1e6;i++) sum+=C.mu[i]*ori_ans[i];
memset(last,0,sizeof(last));
for(int i=1;i<=q;i++) {
set <int> vs; last[id[i]]=i;
for(auto j:vis) {
int val=(last[j] && last[j]<=i) ? w2[last[j]] : w[j];
vector <int> vec=C.get_num(val);
for(auto k:vec) vs.insert(k);
}
ans[i]+=sum;
for(auto j:vs) ans[i]-=C.mu[j]*ori_ans[j];
cout<<ans[i]<<'\n';
}
}