题目链接:点击打开链接
题目大意:给出n个数a1,a2,....an,和m个数b1,b2,...bn,问b数组中有多少个数可以由a数组中一个或两个的和组成
思路:
构造两个相同的多项式,指数分别为a1,a2...an,系数表示存不存在(1或0),然后用FFT相乘,得到结果多项式,系数大于等于1的表示能够得到该项。指数表示某两个ai+aj(1<=i,j<=n)的和。
所以将这些系数大于等于1的指数记录下来,这些都是a中任意两个的和,最后在加上一个的和,即为最后所有可以得到的情况,对b数组中的数逐一判断即可.
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef long double LB;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
typedef vector<int> VII;
const int INF = 0x3f3f3f3f;
const LL INFL = 0x3f3f3f3f3f3f3f3fLL;
const double eps = 1e-8;
const double PI = acos(-1.0);
const int MAXN = 262144*2 + 5; /// 数组大小应为2^k
//typedef complex<double> CP;
struct CP {
double x, y;
CP() {}
CP(double x, double y) : x(x), y(y) {}
inline double real() { return x; }
inline CP operator * (const CP& r) const { return CP(x * r.x - y * r.y, x * r.y + y * r.x); }
inline CP operator - (const CP& r) const { return CP(x - r.x, y - r.y); }
inline CP operator + (const CP& r) const { return CP(x + r.x, y + r.y); }
};
CP a[MAXN], b[MAXN];
int r[MAXN], res[MAXN];
void fft_init(int nm, int k) {
for(int i = 0; i < nm; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1)); /// Rader操作
}
void fft(CP ax[], int nm, int op) {
for(int i = 0; i < nm; ++i) if(i < r[i]) swap(ax[i], ax[r[i]]);
for(int h = 2, m = 1; h <= nm; h <<= 1, m <<= 1) { /// 枚举长度
CP wn = CP(cos(op * 2 * PI / h), sin(op * 2 * PI / h));
for(int i = 0; i < nm; i += h) { /// 枚举所有长度为h的区间
CP w(1, 0); /// 旋转因子
for(int j = i; j < i + m; ++j, w = w * wn) { /// 枚举角度
CP t = w * ax[j + m]; /// 蝴蝶操作
ax[j + m] = ax[j] - t;
ax[j] = ax[j] + t;
}
}
}
if(op == -1) for(int i = 0; i < nm; ++i) ax[i].x /= nm;
}
void trans(int ax[], int bx[], int n, int m) {
int nm = 1, k = 0;
while(nm < 2 * n || nm < 2 * m) nm <<= 1, ++k;
for(int i = 0; i < n; ++i) a[i] = CP(ax[i], 0);
for(int i = 0; i < m; ++i) b[i] = CP(bx[i], 0);
for(int i = n; i < nm; ++i) a[i] = CP(0, 0);
for(int i = m; i < nm; ++i) b[i] = CP(0, 0);
fft_init(nm, k);
fft(a, nm, 1); fft(b, nm, 1);
for(int i = 0; i < nm; ++i) a[i] = a[i] * b[i];
fft(a, nm, -1);
nm = n + m - 1;
/*for(int i = 0; i < nm; ++i)
res[i] = (int)(a[i].real() + 0.5), print(res[i]), putchar(" \n"[i == nm - 1]);*/
for(int i=0;i<nm;i++)
{
res[i]=(int)(a[i].real() + 0.5);
}
}
int ax[MAXN], bx[MAXN], n, m;
int num[MAXN];
int main() {
while(~scanf("%d",&n))
{
int u;
memset(a,0,sizeof a);
memset(b,0,sizeof b);
memset(res,0,sizeof res);
int anum=0;
for(int i=0;i<n;i++)
{
scanf("%d",&u);
num[i]=u;
ax[u]=1;
anum=max(anum,u);
}
trans(ax,ax,anum+1,anum+1);
scanf("%d",&m);
int cnt=0;
for(int i=0;i<n;i++)
{
res[num[i]]=1;
}
for(int i=0;i<m;i++)
{
scanf("%d",&u);//printf("%d %d\n",u,res[u]);
if(res[u]>=1)
{
cnt++;
}
}
printf("%d\n",cnt);
}
return 0;
}