题目链接:B—合约数
题意:一棵树,有n个节点,从1编号到n。根节点的编号为p。给出每个节点的val[i]值,定义f(i)为以编号i为根节点的子树中(包括根节点),所有val[j]是合数并且是val[i]的约数的节点个数。求所有f(i)的和。答案对1e9+7取模。
题解:学长说可以用set的启发式合并或者splay的启发式合并。emmm,我太菜了不会这些怎么写。我的做法是:预处理范围内的数字是否是合数,然后对1-10000的数字用vector数组存下各自的因数。然后用dfs序,将各子树变成一个区间问题,用莫队算法解决。n是20000,每个数字不超过10000,这样复杂度就是 20000*(200+约数的个数,没几十个)。
ac代码:
#include<iostream>
#include<stdio.h>
#include<string>
#include<string.h>
#include<algorithm>
#include<vector>
#include<stack>
#include<queue>
#include<set>
#include<map>
#include<math.h>
using namespace std;
int val[20000+10];
int block;
int caz=0;
long long mod =1000000007;
int vis[20000+10];
int cishu[20000+10];
int num[20000+10];
int notprime[20000+10];
int prime[20000+10];
vector<int> vt[10000+10];
int cnt=0;
struct Query
{
int l,r,pos;
}query[20000+10];
bool cmp(Query a,Query b)
{
if(a.l/block!=b.l/block){
return a.l<b.l;
}
return a.r<b.r;
}
struct Node
{
int v,nne;
}node[40000+10];
int head[20000+10];
void build(int u,int v)
{
node[cnt].v=v;
node[cnt].nne=head[u];
head[u]=cnt++;
}
void dfs(int rt)
{
vis[rt]=caz;
num[++cnt]=val[rt];
query[rt].l=cnt;
query[rt].pos=rt;
for(int i=head[rt];i;i=node[i].nne){
if(vis[node[i].v]<caz){
dfs(node[i].v);
}
}
query[rt].r=cnt;
}
int main()
{
int t,n,p;
//prime
for(int i=2;i<=20000;i++){
if(notprime[i]==0){
prime[cnt++]=i;
}
for(int j=0;j<cnt&&prime[j]*i<=20000;j++){
notprime[i*prime[j]]=1;
if(i%prime[j]==0) break;
}
}
//
for(int i=1;i<=10000;i++){
if(notprime[i]){
double sq=sqrt(i);
for(int j=1;j<=sq;j++){
if(i%j==0){
if(notprime[j])vt[i].push_back(j);
if(notprime[i/j]&&i/j!=j)vt[i].push_back(i/j);
}
}
}
}
//
scanf("%d",&t);
while(t--){
long long ans=0,temp=0;
scanf("%d%d",&n,&p);
block=(int)sqrt(n);
memset(cishu,0,sizeof(cishu));
memset(head,0,sizeof(head));
caz++;
cnt=1;
int a,b;
for(int i=1;i<n;i++){
scanf("%d%d",&a,&b);
build(a,b);
build(b,a);
}
for(int i=1;i<=n;i++){
scanf("%d",&val[i]);
}
cnt=0;
dfs(p);
sort(query+1,query+1+n,cmp);
int L=query[1].l,R=query[1].r;
for(int i=L;i<=R;i++){
cishu[num[i]]++;
}
temp=0;
vector<int>::iterator it=vt[val[query[1].pos]].begin();
for(;it!=vt[val[query[1].pos]].end();++it){
temp+=cishu[*it];
}
ans=(ans+temp*query[1].pos%mod)%mod;
printf("%lld\n",ans);
for(int i=2;i<=n;i++){
if(query[i].l>L){
while(query[i].l>L){
cishu[num[L]]--;
L++;
}
}
else if(query[i].l<L){
while(query[i].l<L){
L--;
cishu[num[L]]++;
}
}
if(query[i].r>R){
while(query[i].r>R){
R++;
cishu[num[R]]++;
}
}
else if(query[i].r<R){
while(query[i].r<R){
cishu[num[R]]--;
R--;
}
}
temp=0;
for(it=vt[val[query[i].pos]].begin();it!=vt[val[query[i].pos]].end();it++){
temp+=cishu[*it];
}
ans=(ans+temp*query[i].pos%mod)%mod;
printf("%lld %lld\n",temp,ans);
}
printf("%lld\n",ans);
}
}