分析:树形dp
dp[i][v]代表以i为根的子树中权值为v的节点到i的路径和
c[i][v]代表以i为根的子树中权值为v的节点个数
对于一个节点 有很多孩子分支,从最左边的分支开始,算跨过当前根节点的
每次分支算一次
最后再算当前根节点i与它子树互质的
每次算的方法就是先算出不互质的,然后总的减
每次分支算一次
最后再算当前根节点i与它子树互质的
每次算的方法就是先算出不互质的,然后总的减
复杂度O(n*m*sqrt(m)*k*2^k) m=max(a[i]),i属于1到n
#include <cstdio> #include <cstring> #include <algorithm> #include <iostream> #include <cstdlib> #include <vector> #include <cmath> using namespace std; typedef long long LL; const int N=1e4+5; struct Edge { int v,next; } edge[N<<1]; int head[N],tot; void add(int u,int v) { edge[tot].v=v; edge[tot].next=head[u]; head[u]=tot++; } vector<int>fac[505]; LL dp[N][505],c[N][505],a[N],n,ret,cnt[505],mx,cnt2[505]; void cal(int rt) { memset(cnt,0,sizeof(cnt)); memset(cnt2,0,sizeof(cnt2)); for(int i=1; i<=mx; ++i) { for(int j=1; j*j<=i; ++j) { if(i%j)continue; cnt[j]+=dp[rt][i]; cnt2[j]+=c[rt][i]; if(i/j!=j){ cnt2[i/j]+=c[rt][i]; cnt[i/j]+=dp[rt][i]; } } } } void treedp(int u,int f) { for(int i=head[u]; ~i; i=edge[i].next) { int v=edge[i].v; if(v==f)continue; treedp(v,u); LL sum=0,sum2=0; for(int j=1; j<=mx; ++j)sum+=dp[u][j],sum2+=c[u][j]; cal(u); for(int j=1; j<=mx; ++j) { if(!c[v][j])continue; int l=(1<<(fac[j].size())); LL tot=0,tot2=0; for(int k=1; k<l; ++k) { int tmp=1,t1=0; for(int p=0; p<fac[j].size(); ++p) if(k&(1<<p))++t1,tmp*=fac[j][p]; if(t1&1)tot+=cnt[tmp],tot2+=cnt2[tmp]; else tot-=cnt[tmp],tot2-=cnt2[tmp]; } ret+=(sum-tot)*c[v][j]+(sum2-tot2)*(dp[v][j]+c[v][j]); } for(int j=1; j<=mx; ++j) { dp[u][j]+=dp[v][j]+c[v][j]; c[u][j]+=c[v][j]; } } cal(u); LL sum=0,tot=0; for(int j=1; j<=mx; ++j)sum+=dp[u][j]; int l=(1<<(fac[a[u]].size())); for(int i=1; i<l; ++i) { int tmp=1,t1=0; for(int j=0; j<fac[a[u]].size(); ++j) if(i&(1<<j))++t1,tmp*=fac[a[u]][j]; if(t1&1)tot+=cnt[tmp]; else tot-=cnt[tmp]; } ret+=(sum-tot); c[u][a[u]]++; } int main() { for(int i=1; i<=500; ++i)fac[i].clear(); for(int i=2; i<=500; ++i) { int t=i; for(int j=2; j<=t; ++j) { if(t%j)continue; fac[i].push_back(j); while(t%j==0)t/=j; } } while(~scanf("%lld",&n)) { mx=0; for(int i=1; i<=n; ++i)scanf("%lld",&a[i]),mx=max(mx,a[i]); memset(head,-1,sizeof(head)); tot=0; memset(dp,0,sizeof(dp)); memset(c,0,sizeof(c)); for(int i=1; i<n; ++i) { int u,v; scanf("%d%d",&u,&v); add(u,v),add(v,u); } ret=0; treedp(1,0); printf("%lld\n",ret); } return 0; }