题目链接
爆搜即可。。。
#include<bits/stdc++.h>
#define N 1000006
typedef long long ll;
using namespace std;
const int inf=1e6;
inline int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
ll ca[N],cb[N];
int pri[N],v[N],tot;
inline void init(int n){
for(int i=2;i<=n;++i){
if(!v[i]){
v[i]=i;pri[++tot]=i;
}
for(int j=1;j<=tot;++j){
if(pri[j]*i>n||v[i]<pri[j])break;
v[i*pri[j]]=pri[j];
}
}
}
int p[N],k[N],pt;
ll ans=0;
inline ll power(ll x,int c){
ll now=1;
while(c){
if(c&1)now=now*x;
x=x*x;c>>=1;
}
return now;
}
void dfs(int now,ll x,ll y){
if(x>inf||y>inf)return ;
if(now==0){
ans=ans+1ll*ca[x]*cb[y];
return ;
}
ll ax=x*power(p[now],k[now]),ay=y;
for(int i=0;i<k[now];++i){
dfs(now-1,ax,ay);
ay*=p[now];
if(ay>inf)break;
}
ax=x;ay=y*power(p[now],k[now]);
for(int i=0;i<=k[now];++i){
dfs(now-1,ax,ay);
ax*=p[now];
if(ax>inf)break;
}
}
int main(){
int n=read(),m=read();
init(N-5);
for(int i=1;i<=n;++i){
int x=read();
ca[x]++;
}
for(int i=1;i<=m;++i){
int y=read();
cb[y]++;
}
for(int i=1;i<=N-5;++i){
pt=0;
int x=i,las=0;
while(x>1){
if(v[x]!=las){
k[pt]*=2;
p[++pt]=v[x];
las=v[x];
k[pt]=0;
}
k[pt]++;
x/=v[x];
}
k[pt]*=2;
dfs(pt,1,1);
}
printf("%lld\n",ans);
return 0;
}