树剖做法:先预处理出来轻重链,然后当修改某一个点的时候,只需要修改同一条链中与当前点相关的边(红色边), 而那些黑色边不需要维护,只需要查询的时候暴力搞一下就好了。
这也就是维护当前点和重儿子点的做法。
树链剖分
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <unordered_map>
#include <vector>
#include <map>
#include <list>
#include <queue>
#include <cstring>
#include <cstdlib>
#include <ctime>
#include <cmath>
#include <stack>
#pragma GCC optimize(3 , "Ofast" , "inline")
using namespace std ;
typedef long long ll ;
const double esp = 1e-6 , pi = acos(-1) ;
typedef pair<int , int> PII ;
const int N = 2e4 + 10 , INF = 0x3f3f3f3f , mod = 1e9 + 7;
int in()
{
int x = 0 , f = 1 ;
char ch = getchar() ;
while(!isdigit(ch)) {if(ch == '-') f = -1 ; ch = getchar() ;}
while(isdigit(ch)) x = x * 10 + ch - 48 , ch = getchar() ;
return x * f ;
}
int dfn[N] , son[N] , fa[N] , deep[N] , tp[N] , sz[N] , ans[N] ;
vector<int> g[N] ;
int cnt , a[N] ;
void dfs1(int u , int f)
{
fa[u] = f , sz[u] = 1 , deep[u] = deep[f] + 1;
for(auto v : g[u])
{
if(v == f) continue ;
dfs1(v , u) ;
sz[u] += sz[v] ;
if(sz[son[u]] < sz[v]) son[u] = v ;
}
}
void dfs2(int u , int top)
{
tp[u] = top ;
dfn[u] = ++ cnt ;
if(u != top) ans[dfn[u]] = __gcd(a[u] , a[fa[u]]) ;
if(son[u]) dfs2(son[u] , top) ;
for(auto v : g[u])
if(v != fa[u] && v != son[u])
dfs2(v , v) ;
}
int solve(int u , int v , int k)
{
int res = 0 ;
while(tp[u] != tp[v]) // 将u点跳,一直到u和v两点在同一条链上
{
if(deep[tp[u]] < deep[tp[v]])
swap(u , v) ;
for(int i = dfn[tp[u]] + 1 ;i <= dfn[u] ;i ++ ) // 根据dfs序进行跳
res += ans[i] <= k ;
res += __gcd(a[tp[u]] , a[fa[tp[u]]]) <= k ; // 链与链之间贡献
u = fa[tp[u]] ;
}
if(deep[u] < deep[v]) swap(u , v) ;
for(int i = dfn[v] + 1 ;i <= dfn[u] ;i ++ ) // 现在是同一条链,直接从深度小的跳到深度大的点
res += ans[i] <= k ;
return res ;
}
int main()
{
int n = in() , q = in() ;
for(int i = 1; i <= n ;i ++ ) a[i] = in() ;
for(int i = 1 ;i < n ;i ++ )
{
int u = in() , v = in() ;
g[u].push_back(v) , g[v].push_back(u) ;
}
dfs1(1 , 0) , dfs2(1 , 1) ;
while(q --)
{
int op = in() ;
if(op == 1)
{
int u = in() , k = in() ;
a[u] = k ;
if(tp[u] != u) ans[dfn[u]] = __gcd(a[u] , a[fa[u]]) ; // 如果当前点不是这个链的顶部
if(son[u]) ans[dfn[son[u]]] = __gcd(a[u] , a[son[u]]) ;
// 如果当前点存在重儿子,也相当于不是这个链的链尾
}
else
{
int u = in() , v = in() , k = in() ;
cout << solve(u , v , k) << endl ;
}
}
return 0 ;
}
还有一个暴力做法,超过一定度的点不予处理,度数小的点直接暴力修改 , 最后查询的时候,对超过一定度的点暴力查询;
暴力做法
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <unordered_map>
#include <vector>
#include <map>
#include <list>
#include <queue>
#include <cstring>
#include <cstdlib>
#include <ctime>
#include <cmath>
#include <stack>
#pragma GCC optimize(3 , "Ofast" , "inline")
using namespace std ;
typedef long long ll ;
const double esp = 1e-6 , pi = acos(-1) ;
typedef pair<int , int> PII ;
const int N = 1e5 + 10 , INF = 0x3f3f3f3f , mod = 1e9 + 7;
int in()
{
int x = 0 , f = 1 ;
char ch = getchar() ;
while(!isdigit(ch)) {if(ch == '-') f = -1 ; ch = getchar() ;}
while(isdigit(ch)) x = x * 10 + ch - 48 , ch = getchar() ;
return x * f ;
}
int e[N] , ne[N] , h[N] , a[N] , idx , n , q , fa[N][25] , deep[N] , fe[N] ;
void add(int a , int b)
{
e[idx] = b , ne[idx] = h[a] , h[a] = idx ++ ;
}
void dfs(int u , int f )
{
deep[u] = deep[f] + 1 ;
fa[u][0] = f ;
for(int i = 1 ; i <= 20 ; i ++ ) fa[u][i] = fa[fa[u][i - 1]][i - 1] ;
for(int i = h[u] ; ~i ;i = ne[i])
{
int v = e[i] ;
if(v == f) continue ;
fe[v] = __gcd(a[u] , a[v]) ;
dfs(v , u) ;
}
}
int lca(int a , int b)
{
if(deep[a] < deep[b]) swap(a , b) ;
for(int i = 20 ;i >= 0 ;i -- )
if(deep[fa[a][i]] >= deep[b])
a = fa[a][i] ;
if(a == b) return a ;
for(int i = 20 ;i >= 0 ;i -- )
if(fa[a][i] != fa[b][i])
a = fa[a][i] , b = fa[b][i] ;
return fa[a][0] ;
}
int out[N] , m = 150;
int solve(int u , int v , int k)
{
int ans = 0 ;
int pos = lca(u , v) ;
while(u != pos)
{
if(out[u] > m || out[fa[u][0]] > m)
ans += (__gcd(a[u] , a[fa[u][0]]) <= k) ;
else if(fe[u] <= k) ans ++ ;
u = fa[u][0] ;
}
while(v != pos)
{
if(out[v] > m || out[fa[v][0]] > m)
ans += (__gcd(a[v] , a[fa[v][0]]) <= k) ;
else if(fe[v] <= k) ans ++ ;
v = fa[v][0] ;
}
return ans ;
}
int main()
{
memset(h , -1 , sizeof h) ;
n = in() , q = in() ;
for(int i = 1; i <= n ;i ++ ) a[i] = in() ;
for(int i = 1 ; i < n ;i ++ )
{
int u = in() , v = in() ;
add(u , v) , add(v , u) ;
out[u] ++ , out[v] ++ ;
}
dfs(1 , 0) ;
int u , x ;
while(q --)
{
int op = in() ;
if(op == 1)
{
u = in() , x = in() ;
a[u] = x ;
if(out[u] > m) continue ;
for(int i = h[u] ; ~i ; i = ne[i])
{
int v = e[i] ;
if(out[v] > m) continue ;
if(fa[u][0] == v)
fe[u] = __gcd(a[u] , a[v]) ;
else
fe[v] = __gcd(a[u] , a[v]) ;
}
}
else
{
int u = in() , v = in() , k = in() ;
int res = solve(u , v , k) ;
cout << res << endl ;
}
}
return 0 ;
}