链接:https://ac.nowcoder.com/acm/contest/9115/B
题意:给定一棵n个节点的树,并且根节点的编号为p,第i个节点有属性值vali, 定义F(i): 在以i为根的子树中,属性值是vali的合约数的节点个数。y 是 x 的合约数是指 y 是合数且 y 是 x 的约数。小埃想知道对1000000007取模后的结果。
思路1:反向思考,对于一个编号u,他自己对其子树中是他合约数的节点的贡献为u。用素数筛的思想,预处理每个数的约数以及是否为合数。假设dfs到一个节点u,先将val[u]的约数都加上其贡献u,然后判断val[u]是否为合数,如果是就将当前val[u]得到的贡献都加在答案上。之后dfs子树,搜完之后要将val[u]的对其约数贡献u去掉,因为u只对其子树的合约数有贡献,避免u对他的兄弟也产生贡献。注意初始化!
代码1:
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 1e4+10;
const int mod = 1e9+7;
bool book[N],vis[N<<1];
vector<int> fac[N],g[N<<1];
int t,n,root,val[N<<1],pri[N<<1],cnt=0;
ll ans,f[N];
void init(){
book[1]=0;
for(int i=2;i<N;i++){
if(!book[i]) pri[++cnt]=i;
for(int j=1;j<=cnt&&i*pri[j]<N;j++){
book[i*pri[j]]=1;
if(i%pri[j]==0) break;
}
}
for(int i=1;i<N;i++){
for(int j=i;j<N;j+=i){
fac[j].push_back(i);
}
}
}
void dfs(int u,int fa){
for(int i=0;i<fac[val[u]].size();i++){
int v = fac[val[u]][i];
f[v]=(f[v]+u)%mod;
}
if(book[val[u]]) ans=(ans+f[val[u]])%mod;
for(int i=0;i<g[u].size();i++){
int v=g[u][i];
if(v==fa) continue;
dfs(v,u);
}
for(int i=0;i<fac[val[u]].size();i++){
int v = fac[val[u]][i];
f[v]=(f[v]-u)%mod;
}
}
int main(void){
scanf("%d",&t);
init();
while(t--){
scanf("%d%d",&n,&root);
ans=0;
for(int i=1;i<=n;i++) g[i].clear();
for(int i=1;i<N;i++) f[i]=0;
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
for(int i=1;i<=n;i++)
scanf("%d",&val[i]);
dfs(root,-1);
printf("%lld\n",ans);
}
return 0;
}
思路2:dsu on tree,还是要预处理是否为合数以及每个数的因子。维护每个数的个数,统计答案时,看一下每个数的合约数有多少就行了。
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 1e4+10;
const ll mod = 1e9+7;
bool book[N];
int pri[N],cnt;
vector<int> fac[N];
void prime(){
book[1]=0;
for(int i=2;i<N;i++){
if(!book[i]) pri[++cnt]=i;
for(int j=1;j<=cnt&&i*pri[j]<N;j++){
book[i*pri[j]]=1;
if(i%pri[j]==0) break;
}
}
for(int i=1;i<N;i++)
for(int j=i;j<N;j+=i)
fac[j].push_back(i);
}
int n,p,val[N<<1];
vector<int> g[N<<1];
ll ans;
int num[N];
bool big[N<<1];
void init(){
ans=0;
memset(num,0,sizeof num);
for(int i=1;i<=n;i++){
g[i].clear();
big[i]=0;
}
}
int sz[N<<1];
void getsz(int u,int fa){
sz[u]=1;
int siz=g[u].size();
for(int i=0;i<siz;i++){
int v=g[u][i];
if(v==fa) continue;
getsz(v,u);
sz[u]+=sz[v];
}
}
void dfs1(int u,int fa,int x){
num[val[u]]+=x;
int siz=g[u].size();
for(int i=0;i<siz;i++){
int v=g[u][i];
if(v==fa||big[v]) continue;
dfs1(v,u,x);
}
}
void dfs(int u,int fa,bool keep){
int mx=-1,bigc=-1;
int siz=g[u].size();
for(int i=0;i<siz;i++){
int v=g[u][i];
if(v==fa) continue;
if(sz[v]>mx) mx=sz[v],bigc=v;
}
for(int i=0;i<siz;i++){
int v=g[u][i];
if(v==fa||v==bigc) continue;
dfs(v,u,0);
}
if(bigc!=-1){
dfs(bigc,u,1);
big[bigc]=1;
}
dfs1(u,fa,1);
siz=fac[val[u]].size();
for(int i=0;i<siz;i++)
if(book[fac[val[u]][i]]){
ans=(ans+u*num[fac[val[u]][i]])%mod;
}
if(bigc!=-1) big[bigc]=0;
if(keep==0)
dfs1(u,fa,-1);
}
int main(void){
int t;
prime();
scanf("%d",&t);
while(t--){
scanf("%d%d",&n,&p);
init();
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
for(int i=1;i<=n;i++) scanf("%d",&val[i]);
getsz(p,0);
dfs(p,0,0);
printf("%lld\n",ans);
}
return 0;
}