题意:
给定一棵有 n n n 个结点的树,每个结点上有权值 a i a_i ai,问有多少条路径满足路径上的点权乘积为立方数,点权可以被表示为 k k k 个素数。 ( n ≤ 5 × 1 0 4 , k ≤ 30 ) (n \leq 5×10^4, ~k \leq 30) (n≤5×104, k≤30)
链接:
https://vjudge.net/problem/HDU-4670
解题思路:
点分,一条路径点权乘积为立方数当且仅当所有质因子的幂次为 3 3 3 的倍数,故用三进制状压,统计答案时对应查询即可。
参考代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define sz(a) ((int)a.size())
#define pb push_back
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define gmid (l + r >> 1)
const int maxn = 5e4 + 5;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
map<ll, int> mp;
vector<int> G[maxn];
ll pw[maxn], pi[maxn], a[maxn], dis[maxn];
int siz[maxn], vis[maxn];
int n, k, tot, tn, rt, rmn;
ll ans;
void getRt(int u, int f){
int mx = 0; siz[u] = 1;
for(auto v : G[u]){
if(v == f || vis[v]) continue;
getRt(v, u);
siz[u] += siz[v];
mx = max(mx, siz[v]);
}
mx = max(mx, tn - siz[u]);
if(mx < rmn) rmn = mx, rt = u;
}
ll add(ll x, ll y){
ll ret = 0;
for(int i = 0; i < k; ++i){
ll tmp = (x % 3 + y % 3) % 3;
ret += tmp * pw[i];
x /= 3, y /= 3;
}
return ret;
}
ll sub(ll x, ll y){
ll ret = 0;
for(int i = 0; i < k; ++i){
ll tmp = (x % 3 - y % 3 + 3) % 3;
ret += tmp * pw[i];
x /= 3, y /= 3;
}
return ret;
}
void dfs(int u, int f, ll val){
dis[++tot] = val;
for(auto v : G[u]){
if(v == f || vis[v]) continue;
dfs(v, u, add(val, a[v]));
}
}
void cal(int u){
if(!a[u]) ++ans;
++mp[0ll];
ll msk = sub(0ll, a[u]);
for(auto v : G[u]){
if(vis[v]) continue;
tot = 0;
dfs(v, u, a[v]);
for(int i = 1; i <= tot; ++i){
ll tmp = sub(msk, dis[i]);
if(mp.find(tmp) != mp.end()) ans += mp[tmp];
}
for(int i = 1; i <= tot; ++i){
++mp[dis[i]];
}
}
mp.clear();
}
void dfz(int u){
vis[u] = 1; cal(u);
for(auto v : G[u]){
if(vis[v]) continue;
tn = siz[v], rmn = inf, getRt(v, u);
dfz(rt);
}
vis[u] = 0;
}
int main(){
ios::sync_with_stdio(0); cin.tie(0);
pw[0] = 1;
for(int i = 1; i <= 30; ++i) pw[i] = pw[i - 1] * 3;
while(cin >> n){
ans = 0;
for(int i = 1; i <= n; ++i) G[i].clear();
cin >> k;
for(int i = 0; i < k; ++i) cin >> pi[i];
for(int i = 1; i <= n; ++i) cin >> a[i];
for(int i = 1; i < n; ++i){
int u, v; cin >> u >> v;
G[u].pb(v), G[v].pb(u);
}
for(int i = 1; i <= n; ++i){
ll tmp = a[i], sum = 0;
for(int j = 0; j < k; ++j){
int cnt = 0;
while(tmp % pi[j] == 0){
++cnt;
tmp /= pi[j];
}
cnt %= 3;
sum += cnt * pw[j];
}
a[i] = sum;
}
tn = n, rmn = inf, getRt(1, 0);
dfz(rt);
cout << ans << "\n";
}
return 0;
}