FFT例题

模板

const double PI = acos(-1);

struct Complex {
	double x, y;
	Complex operator + (const Complex &t) const	{ return {x + t.x, y + t.y}; }
	Complex operator - (const Complex &t) const { return {x - t.x, y - t.y}; }
	Complex operator * (const Complex &t) const { return {x * t.x - y * t.y, x * t.y + y * t.x}; }
};

int n, m;
int rev[N], tot, bit;
Complex f[N], g[N];

void getrev() {
	while ((1 << bit) < n + m + 1)	bit ++;
	tot = 1 << bit;
	for (int i = 0; i < tot; i ++)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}

void FFT(Complex a[], int inv) {
	for (int i = 0; i < tot; i ++)
		if (i < rev[i])	swap(a[i], a[rev[i]]);

	for (int h = 2; h <= tot; h <<= 1) {
		Complex wn = {cos(2 * PI/ h), sin(inv * 2 * PI / h)};
		for (int j = 0; j < tot; j += h) {
			Complex w = {1, 0};
			for (int k = j; k < j + h / 2; k ++, w = w * wn) {
				Complex u = a[k], t = w * a[k + h / 2];
				a[k] = u + t, a[k + h / 2] = u - t;
			}
		}
	}
	if (inv == -1) {
		for (int i = 0; i < tot; i ++)	a[i].x /= tot;
		// int(a[i].x + 0.5) is  Polynomial coefficient
	}
}

例题

2018宁夏网络赛H题 Rock Paper Scissors Lizard Spock.

题意
给定类似石头剪刀布游戏的五种手势和十种克制关系。每种手势克制其他两种手势并被另外两种手势克制。
如图是克制关系,以及各字母代表的手势。
在这里插入图片描述
现在,给你两个字符串 S 1 , S 2 S_1,S_2 S1,S2分别表示 A A A B B B的手势顺序,且 ∣ S 1 ∣ ≥ ∣ S 2 ∣ |S_1| \ge |S_2| S1S2。你可以随意挪动 S 2 S_2 S2相对于 S 1 S_1 S1的位置,求 S 2 S_2 S2最多能赢多少次?
1 ≤ ∣ S ∣ ≤ 1 e 6 1\le|S|\le1e6 1S1e6

做法
固定 S 1 S_1 S1串,将 S 2 S_2 S2逆置。然后 S 2 S_2 S2串在某一位置上与 S 1 S_1 S1有一个对应位置关系。这个关系可以用多项式乘法得到对应。

例如
S 1 = a b c d e f S_1=abcdef S1=abcdef S 2 = g h i S_2=ghi S2=ghi
逆置 S 2 S_2 S2得到 S 2 = i h g S_2=ihg S2=ihg

然后将其看为多项式并标上幂次。
S 1 = a x 5 + b x 4 + c x 3 + d x 2 + e x + f S_1=ax^5+bx^4+cx^3+dx^2+ex+f S1=ax5+bx4+cx3+dx2+ex+f
S 2 = i x 2 + h x + g S_2=ix^2+hx+g S2=ix2+hx+g
如果我们将其进行多项式的乘法。
得到的 x 5 x^5 x5次的系数为 i i i c c c h h h b b b g g g a a a的值。
得到的 x 4 x^4 x4次的系数为 i i i d d d h h h c c c g g g b b b的值。
类似的,还有 x 3 . . . x^3... x3...

也就是,每个不同幂次前面的系数都对应了 S 2 S_2 S2 S 1 S_1 S1中不同位置匹配的所有结果。

如何让系数体现为赢的次数呢?
我们可以枚举 S 2 S_2 S2所有能赢的情况共5种:

  1. S > P   a n d   S > L S > P\ and\ S > L S>P and S>L
  2. P > R   a n d   P > K P > R\ and\ P > K P>R and P>K
  3. R > S   a n d   R > L R > S\ and\ R > L R>S and R>L
  4. L > P   a n d   L > K L > P\ and\ L > K L>P and L>K
  5. K > R   a n d   K > S K > R\ and\ K > S K>R and K>S

我们不可能一次将这5中情况全部算出来,但是可以一种一种的算,然后将他们累加起来。

只看第一种 S > P   a n d   S > L S > P\ and\ S > L S>P and S>L:
我们希望当 S S S匹配到 P P P S S S匹配到 L L L时为1,其余为0。
对于 S 1 S_1 S1,如果当前为的字符是 P   o r   L P\ or\ L P or L时,让该项的系数为 1 1 1,否则为 0 0 0
对于 S 2 S_2 S2,如果当前为的字符是 S S S时,让该项的系数为 1 1 1,否则为 0 0 0

然后我们用 F F T FFT FFT得到这两个多项式相乘后得到的系数。将合法位置系数累加到对应位置上即可。

求完5种情况之后,我们在所有的合法位置上取一个最大值即可得到答案。

#include <bits/stdc++.h>
#define endl '\n'
#define INF 0x3f3f3f3f
#define all(x) begin(x),end(x)
#define debug(x) cout<<#x<<": "<<x<<endl;
using namespace std;
using ll = long long;
const int N = 4e6 + 10, M = 1e6 + 10;

const double PI = acos(-1);

struct Complex {
	double x, y;
	Complex operator + (const Complex &t) const	{ return {x + t.x, y + t.y}; }
	Complex operator - (const Complex &t) const { return {x - t.x, y - t.y}; }
	Complex operator * (const Complex &t) const { return {x * t.x - y * t.y, x * t.y + y * t.x}; }
};

int n, m, ans[M];
int rev[N], tot, bit;
Complex f[N], g[N], h[N];
char s1[M], s2[M];

void getrev() {
	while ((1 << bit) < n + m + 1)	bit ++;
	tot = 1 << bit;
	for (int i = 0; i < tot; i ++)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}

void FFT(Complex a[], int inv) {
	for (int i = 0; i < tot; i ++)
		if (i < rev[i])	swap(a[i], a[rev[i]]);

	for (int h = 2; h <= tot; h <<= 1) {
		Complex wn = {cos(2 * PI/ h), sin(inv * 2 * PI / h)};
		for (int j = 0; j < tot; j += h) {
			Complex w = {1, 0};
			for (int k = j; k < j + h / 2; k ++, w = w * wn) {
				Complex u = a[k], t = w * a[k + h / 2];
				a[k] = u + t, a[k + h / 2] = u - t;
			}
		}
	}
	if (inv == -1) {
		for (int i = 0; i < tot; i ++)	a[i].x /= tot;
		// int(a[i].x + 0.5) is  Polynomial coefficient
	}
}


//Rock-R Paper-P Scissors-S Lizard-L Spock-K
int main() {
	cin.tie(0);
	ios::sync_with_stdio(false);
	
	cin >> s1 >> s2;
	n = strlen(s1);
	m = strlen(s2);
	reverse(s2, s2 + m);
	n --, m --;
	getrev();

	auto clear = [&]() {  // 数组清空
		for (int i = 0; i < tot; i ++)	f[i] = g[i] = {0, 0};
	};
	auto calc = [&]() {
		FFT(f, 1); FFT(g, 1);
		for (int i = 0; i < tot; i ++)	h[i] = h[i] + f[i] * g[i];
		// FFT(f, -1);
		// for (int i = m; i <= n; i ++)	ans[i] += int(f[i].x + 0.5);
	};

	// S > P and S > L
	clear();
	for (int i = 0; i <= n; i ++)	f[i].x = (s1[i] == 'P' or s1[i] == 'L') ? 1 : 0;
	for (int i = 0; i <= m; i ++)	g[i].x = (s2[i] == 'S') ? 1 : 0;
	calc();
	// P > R and P > K
	clear();
	for (int i = 0; i <= n; i ++)	f[i].x = (s1[i] == 'R' or s1[i] == 'K') ? 1 : 0;
	for (int i = 0; i <= m; i ++)	g[i].x = (s2[i] == 'P') ? 1 : 0;
	calc();
	// R > S and R > L
	clear();
	for (int i = 0; i <= n; i ++)	f[i].x = (s1[i] == 'S' or s1[i] == 'L') ? 1 : 0;
	for (int i = 0; i <= m; i ++)	g[i].x = (s2[i] == 'R') ? 1 : 0;
	calc();
	// L > P and L > K
	clear();
	for (int i = 0; i <= n; i ++)	f[i].x = (s1[i] == 'P' or s1[i] == 'K') ? 1 : 0;
	for (int i = 0; i <= m; i ++)	g[i].x = (s2[i] == 'L') ? 1 : 0;
	calc();
	// K > R and K > S
	clear();
	for (int i = 0; i <= n; i ++)	f[i].x = (s1[i] == 'R' or s1[i] == 'S') ? 1 : 0;
	for (int i = 0; i <= m; i ++)	g[i].x = (s2[i] == 'K') ? 1 : 0;
	calc();
	
	FFT(h, -1);
	for (int i = m; i <= n; i ++)	ans[i] += int(h[i].x + 0.5);
	cout << *max_element(ans + m, ans + n + 1) << '\n';
	
	return 0;
}


残缺的字符串

题意
给你两个字符串原串 B B B和匹配串 A A A,串里只包含小写字母和∗,∗可以表示任意字符,问匹配串在原串不同位置中出现多少次,起始位置不同即不同位置。

做法
A A A的长度为 m + 1 m+1 m+1,即 A = s 0 s 1 . . . s m A=s_0s_1...s_m A=s0s1...sm
B B B的长度为 n + 1 n+1 n+1,即 B = s 0 s 1 . . . s n B=s_0s_1...s_n B=s0s1...sn

定义一个字符串的匹配函数 C ( i , j ) = [ A ( i ) − B ( j ) ] 2 ∗ A ( i ) ∗ B ( j ) C(i,j)=[A(i)-B(j)]^2*A(i)*B(j) C(i,j)=[A(i)B(j)]2A(i)B(j)。如果是字符*我们零其值为0,否则令其值为 s i − s_i- si a a a + 1 +1 +1
定义完全匹配函数 P ( x ) = ∑ i = 0 m C ( i , x − m + i ) P(x)=\sum_{i=0}^mC(i,x-m+i) P(x)=i=0mC(i,xm+i)。表示在 B B B中以位置 x x x为结尾长度为 m + 1 m+1 m+1的这一段的匹配值。显然,如果 P ( x ) = 0 P(x)=0 P(x)=0。那么这个 x x x位置就是合法的位置。

我们将 A A A串进行逆置,将他们的对应关系变一下。
P ( x ) = ∑ i = 0 m C ( m − i , x − m + i ) P(x)=\sum_{i=0}^mC(m-i,x-m+i) P(x)=i=0mC(mi,xm+i)。可以发现 ( m − i ) + ( x − m + i ) (m-i)+(x-m+i) (mi)+(xm+i)等于定值 x x x
展开可以得到
P ( x ) = ∑ i = 0 m A 3 ( m − i ) B ( x − m + i ) + ∑ i = 0 m A ( m − i ) B 3 ( x − m + i ) − 2 ∑ i = 0 m A 2 ( m − i ) B 2 ( x − m + i ) P(x)=\sum_{i=0}^mA^3(m-i)B(x-m+i)+\sum_{i=0}^mA(m-i)B^3(x-m+i)-2\sum_{i=0}^mA^2(m-i)B^2(x-m+i) P(x)=i=0mA3(mi)B(xm+i)+i=0mA(mi)B3(xm+i)2i=0mA2(mi)B2(xm+i)

其中的每一项都是一个多项式乘法,用 F F T FFT FFT求一下即可,将结果累加到 P ( x ) P(x) P(x)
代码不开 O 2 O^2 O2会TLE。



#include <bits/stdc++.h>
#define endl '\n'
#define INF 0x3f3f3f3f
#define all(x) begin(x),end(x)
#define debug(x) cout<<#x<<": "<<x<<endl;
using namespace std;
using ll = long long;
const int N = 1.2e6 + 10, M = 3e5 + 10;
const double PI = acos(-1);

struct Complex {
	double x, y;
	Complex operator + (const Complex &t) const	{ return {x + t.x, y + t.y}; }
	Complex operator - (const Complex &t) const { return {x - t.x, y - t.y}; }
	Complex operator * (const Complex &t) const { return {x * t.x - y * t.y, x * t.y + y * t.x}; }
	Complex operator * (const double t)	const	{ return {x * t, y * t}; }
};

int n, m;
int rev[N], tot, bit;
Complex f[N], g[N], h[N];
char A[M], B[M];

void getrev() {
	while ((1 << bit) < n + m + 1)	bit ++;
	tot = 1 << bit;
	for (int i = 0; i < tot; i ++)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}

void FFT(Complex a[], int inv) {
	for (int i = 0; i < tot; i ++)
		if (i < rev[i])	swap(a[i], a[rev[i]]);

	for (int h = 2; h <= tot; h <<= 1) {
		Complex wn = {cos(2 * PI/ h), sin(inv * 2 * PI / h)};
		for (int j = 0; j < tot; j += h) {
			Complex w = {1, 0};
			for (int k = j; k < j + h / 2; k ++, w = w * wn) {
				Complex u = a[k], t = w * a[k + h / 2];
				a[k] = u + t, a[k + h / 2] = u - t;
			}
		}
	}
	if (inv == -1) {
		for (int i = 0; i < tot; i ++)	a[i].x /= tot;
		// int(a[i].x + 0.5) is  Polynomial coefficient
	}
}


int main() {
	cin.tie(0);
	ios::sync_with_stdio(false);
	
	cin >> m >> n;
	cin >> A >> B;
	reverse(A, A + m);
	n --, m --;
	getrev();

	auto clear = [&]() {
		for (int i = 0; i < tot; i ++)	f[i] = g[i] = {0, 0};
	};
	auto calc = [&](double sgn) {
		FFT(f, 1); FFT(g, 1);
		for (int i = 0; i < tot; i ++)	h[i] = h[i] + f[i] * g[i] * sgn;
		// FFT(f, -1);
		// for (int i = m; i <= n; i ++)	P[i] += sgn * ll(f[i].x + 0.5);
	};

	for (int i = 0; i <= n; i ++)	B[i] = (B[i] == '*') ? 0 : (B[i] - 'a' + 1);
	for (int i = 0; i <= m; i ++)	A[i] = (A[i] == '*') ? 0 : (A[i] - 'a' + 1);

	// +A^3(m-i)*B(x-m+i)
	clear();
	for (int i = 0; i <= n; i ++)	f[i].x = B[i];
	for (int i = 0; i <= m; i ++)	g[i].x = A[i] * A[i] * A[i];
	calc(1);
	// +A(m-i)*B^3(x-m+i)
	clear();
	for (int i = 0; i <= n; i ++)	f[i].x = B[i] * B[i] * B[i];
	for (int i = 0; i <= m; i ++)	g[i].x = A[i];
	calc(1);
	// -2A^2(m-i)B^2(x-m+i)
	clear();
	for (int i = 0; i <= n; i ++)	f[i].x = B[i] * B[i];
	for (int i = 0; i <= m; i ++)	g[i].x = A[i] * A[i];
	calc(-2);

	FFT(h, -1);  // 把前面的结果累积到h上面,最后只需要对h进行一次逆变换就好了

	vector<int> path;
	for (int i = m; i <= n; i ++)
		if (int(h[i].x + 0.5) == 0)	path.push_back(i - m + 1);
	cout << path.size() << '\n';
	for (int i = 0; i < path.size(); i++)	cout << path[i] << " \n"[i == path.size()];
	return 0;
}

2016香港网络赛 A+B Problem

题意
给定一个长度为 N N N的数组 a a a − 50000 ≤ a i ≤ 50000 -50000 \le a_i \le 50000 50000ai50000,求有多少对 ( i , j , k ) (i,j,k) (i,j,k)满足 a i + a j = a k a_i+a_j=a_k ai+aj=ak
( 1 ≤ N ≤ 2 e 5 ) (1\le N \le2e5) (1N2e5)

做法
我们令数组a中的值作为指数,该值在数组中出现的次数作为系数,可以得到一个多项式。
令该多项式与自己相乘。如果两个相同的多项式相乘,得到的指数即为所有可能加法的结果,而对应的系数即为加法结果出现次数。
本题还有一些其他细节需要处理。
首先我们得到的是 a i + a j a_i+a_j ai+aj的所有可能结果及其出现次数。但是题目要求 i ≠ j i \ne j i=j,所有我们求出来之后,还要减去自己与自己相加贡献的次数。
另外,题目中存在负数,我们需要将每一个值加上一个大数使其为非负数。
还有,原数组中的 0 0 0是比较特殊的,它也会对出现次数造成影响。
对于不等于 0 0 0的数 x x x来说, x = 0 + x = x + 0 x=0+x=x+0 x=0+x=x+0,因此我们要减去0的个数的二倍。
对于等于 0 0 0的数 0 0 0来说, 0 = 0 + 0 = 0 + 0 0=0+0=0+0 0=0+0=0+0,因为它自己本上要占用一个 0 0 0,所以只需要减去( 0 0 0的个数 − 1 -1 1)的二倍。

#include <bits/stdc++.h>
#define endl '\n'
#define INF 0x3f3f3f3f
#define all(x) begin(x),end(x)
#define debug(x) cout<<#x<<": "<<x<<endl;
using namespace std;
using ll = long long;
const int N = 3e5 + 10;
const double PI = acos(-1);

struct Complex {
	double x, y;
	Complex operator + (const Complex &t) const	{ return {x + t.x, y + t.y}; }
	Complex operator - (const Complex &t) const { return {x - t.x, y - t.y}; }
	Complex operator * (const Complex &t) const { return {x * t.x - y * t.y, x * t.y + y * t.x}; }
};

int n, m, a[N], delta = 50000;
int rev[N], tot, bit;
Complex f[N], g[N];
ll cnt[N];

void getrev() {
	bit = 0;
	while ((1 << bit) < m + m + 1)	bit ++;
	tot = 1 << bit;
	for (int i = 0; i < tot; i ++)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}

void FFT(Complex a[], int inv) {
	for (int i = 0; i < tot; i ++)
		if (i < rev[i])	swap(a[i], a[rev[i]]);

	for (int h = 2; h <= tot; h <<= 1) {
		Complex wn = {cos(2 * PI/ h), sin(inv * 2 * PI / h)};
		for (int j = 0; j < tot; j += h) {
			Complex w = {1, 0};
			for (int k = j; k < j + h / 2; k ++, w = w * wn) {
				Complex u = a[k], t = w * a[k + h / 2];
				a[k] = u + t, a[k + h / 2] = u - t;
			}
		}
	}
	if (inv == -1) {
		for (int i = 0; i < tot; i ++)	a[i].x /= tot;
		// int(a[i].x + 0.5) is  Polynomial coefficient
	}
}


int main() {
	cin.tie(0);
	ios::sync_with_stdio(false);
	
	cin >> n;
	for (int i = 0; i < n; i ++)	cin >> a[i];
	sort(a, a + n);
	m = a[n - 1] + delta;
	getrev();
	int zero = 0;
	for (int i = 0; i < n; i ++) {
		if (a[i] == 0)	zero ++;
		cnt[a[i] + delta] ++;
	}

	for (int i = 0; i < n; i ++)	f[a[i] + delta].x = cnt[a[i] + delta];
	FFT(f, 1);
	for (int i = 0; i < tot; i ++)	f[i] = f[i] * f[i];
	FFT(f, -1);
	
	for (int i = 0; i < n; i ++)	cnt[a[i] + delta] --;  // 清空cnt
	for (int i = 0; i < tot; i ++)	cnt[i] = ll(f[i].x + 0.5);
	for (int i = 0; i < n; i ++)	cnt[(a[i] + delta) * 2] --;  // 自己与自己相加

	ll ans = 0;
	for (int k = 0; k < n; k ++) {
		ans += cnt[a[k] + 2 * delta];
		if (a[k] == 0)	ans -= (zero - 1) * 2;
		else	ans -= zero * 2;
	}
	cout << ans << '\n';
	
	return 0;
}

P3723 [AH2017/HNOI2017] 礼物

在这里插入图片描述
做法:
注意:我写的 n n n不是数组a的长度, ( n + 1 ) (n+1) (n+1)才是数组的长度(个人习惯)也就是数组 a : a 0 , a 1 , . . . , a n a: a_0,a_1,...,a_n a:a0,a1,...,an ( n + 1 ) (n+1) (n+1)个。

我们先不改变它们的亮度,也不旋转,求一下差异值的表达式。
∑ i = 0 n ( a i − b i ) 2 = ∑ i = 0 n a i 2 + ∑ i = 0 n b i 2 − 2 ∑ i = 0 n a i b i \sum_{i=0}^n(a_i-b_i)^2=\sum_{i=0}^na_i^2+\sum_{i=0}^nb_i^2-2\sum_{i=0}^na_ib_i i=0n(aibi)2=i=0nai2+i=0nbi22i=0naibi

假设这里给 a a a增加了 c c c的亮度(给 a a a b b b增加亮度都是一样的,因为 c c c可正可负),将 a i → a i + c a_i \to a_i+c aiai+c,表达式变为:

∑ i = 0 n ( a i + c ) 2 + ∑ i = 0 n b i 2 − 2 ∑ i = 0 n ( a i + c ) b i \sum_{i=0}^n(a_i+c)^2+\sum_{i=0}^nb_i^2-2\sum_{i=0}^n(a_i+c)b_i i=0n(ai+c)2+i=0nbi22i=0n(ai+c)bi

展开化简一下得到:
∑ i = 0 n a i 2 + ∑ i = 0 n b i 2 + ( n + 1 ) c 2 + 2 c ∑ i = 0 n ( a i − b i ) − 2 ∑ i = 0 n a i b i \sum_{i=0}^na_i^2+\sum_{i=0}^nb_i^2+(n+1)c^2+2c\sum_{i=0}^n(a_i-b_i)-2\sum_{i=0}^na_ib_i i=0nai2+i=0nbi2+(n+1)c2+2ci=0n(aibi)2i=0naibi

其中和 c c c有关的两项构成了一个开口向上的抛物线,我们可以直接求得它的最小值。
然后我们发现,除了最后一项其余的都是相当于是定值。

单独拿出来最后一项
− 2 ∑ i = 0 n a i b i -2\sum_{i=0}^na_ib_i 2i=0naibi
我们将 b b b数组翻转
− 2 ∑ i = 0 n a i b n − i -2\sum_{i=0}^na_ib_{n-i} 2i=0naibni

就变成了可以用FFT解决的经典问题。

#include <bits/stdc++.h>
#define endl '\n'
#define INF 0x3f3f3f3f
#define all(x) begin(x),end(x)
#define debug(x) cout<<#x<<": "<<x<<endl;
using namespace std;
using ll = long long;
const int N = 3e5 + 10, M = 5e4 + 10;
const double PI = acos(-1);

struct Complex {
	double x, y;
	Complex operator + (const Complex &t) const	{ return {x + t.x, y + t.y}; }
	Complex operator - (const Complex &t) const { return {x - t.x, y - t.y}; }
	Complex operator * (const Complex &t) const { return {x * t.x - y * t.y, x * t.y + y * t.x}; }
};

int n, m, a[M], b[M];
int rev[N], tot, bit;
Complex f[N], g[N];

void getrev(int n, int m) {
	bit = 0;
	while ((1 << bit) < n + m + 1)	bit ++;
	tot = 1 << bit;
	for (int i = 0; i < tot; i ++)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}

void FFT(Complex a[], int inv) {
	for (int i = 0; i < tot; i ++)
		if (i < rev[i])	swap(a[i], a[rev[i]]);

	for (int h = 2; h <= tot; h <<= 1) {
		Complex wn = {cos(2 * PI/ h), sin(inv * 2 * PI / h)};
		for (int j = 0; j < tot; j += h) {
			Complex w = {1, 0};
			for (int k = j; k < j + h / 2; k ++, w = w * wn) {
				Complex u = a[k], t = w * a[k + h / 2];
				a[k] = u + t, a[k + h / 2] = u - t;
			}
		}
	}
	if (inv == -1) {
		for (int i = 0; i < tot; i ++)
			a[i].x /= tot, a[i].y /= tot;
	}
}


int main() {
	cin.tie(0);
	ios::sync_with_stdio(false);
	
	cin >> n >> m;
	ll ans = 0, t = 0;  // t存一下\sum{ai-bi}
	n --;
	for (int i = 0; i <= n; i ++) {
		cin >> a[i];
		ans += a[i] * a[i];
		t += a[i];
	}
	for (int i = 0; i <= n; i ++) {
		cin >> b[i];
		ans += b[i] * b[i];
		t -= b[i];
	}
	// ---FFT板子---
	getrev(2 * n, n);

	for (int i = 0; i <= n; i ++)	f[i].x = f[i + n + 1].x = a[i];
	for (int i = 0; i <= n; i ++)	f[i].y = b[n - i];
	FFT(f, 1);
	for (int i = 0; i < tot; i ++)	f[i] = f[i] * f[i];
	FFT(f, -1);
	// ---
	
	int mx = 0;  // \sum{ai*bi}的最大值
	for (int i = n; i <= 2 * n; i ++)	mx = max(mx, int(f[i].y / 2 + 0.5));
	ans -= 2 * mx;
	// 抛物线中点
	int mid = -t / (n + 1);
	auto func = [&](int x) {
		return (n + 1) * x * x + 2 * t * x;
	};
	// 实际的mid可能是小数,所以取一下两边的点的值
	ans += min(func(mid), min(func(mid + 1), func(mid - 1)));
	cout << ans << '\n';

	return 0;
}

ABC196 F. Substring 2

给定两个只包含01的字符串 S S S T T T,最少需要改变多少个字符可以使得 T T T成为 S S S的子串。
1 ≤ ∣ T ∣ ≤ ∣ S ∣ ≤ 1 e 6 1 \le |T| \le |S| \le 1e6 1TS1e6

做法:
规定 S = s 0 s 1 . . . s n S=s_0s_1...s_n S=s0s1...sn T = t 0 t 1 . . . t m T=t_0t_1...t_m T=t0t1...tm

定义 C ( i , j ) = ( S i − T j ) 2 C(i, j)=(S_i-T_j)^2 C(i,j)=(SiTj)2,由于只包含01,所以该值是0或1,为0表示相同,为1表示不同。
定义 P ( x ) = ∑ i = 0 m C ( x − m + i , i ) P(x)=\sum_{i=0}^mC(x-m+i, i) P(x)=i=0mC(xm+i,i),表示在 S S S中以位置 x x x结尾长度为 ∣ T ∣ |T| T的字串与 T T T相差几个字符。
展开有 P ( x ) = ∑ i = 0 m T i 2 + ∑ i = 0 m S x − m + i 2 − 2 ∑ i = 0 m S x − m + i T i P(x)=\sum_{i=0}^mT_i^2+\sum_{i=0}^mS_{x-m+i}^2-2\sum_{i=0}^mS_{x-m+i}T_i P(x)=i=0mTi2+i=0mSxm+i22i=0mSxm+iTi
前一个求和式是一个定值,后面两个求和式与我们选择的位置 x x x有关。
其中 ∑ i = 0 m S x − m + i 2 \sum_{i=0}^mS_{x-m+i}^2 i=0mSxm+i2很好求,我们只需要一个 S 2 S_2 S2的前缀和数组就可以解决。
∑ i = 0 m S x − m + i T i \sum_{i=0}^mS_{x-m+i}T_i i=0mSxm+iTi就需要用到我们的FFT了。
反转 T T T串,式子变为 ∑ i = 0 m S x − m + i T m − i \sum_{i=0}^mS_{x-m+i}T_{m-i} i=0mSxm+iTmi,变成了我们熟悉的形式之后,就可以愉快的FFT了。

#include <bits/stdc++.h>
#define endl '\n'
#define INF 0x3f3f3f3f
#define all(x) begin(x),end(x)
#define debug(x) cout<<#x<<": "<<x<<endl;
using namespace std;
using ll = long long;
const int N = 2.1e6 + 10, M = 1e6 + 10;
const double PI = acos(-1);

struct Complex {
	double x, y;
	Complex operator + (const Complex &t) const	{ return {x + t.x, y + t.y}; }
	Complex operator - (const Complex &t) const { return {x - t.x, y - t.y}; }
	Complex operator * (const Complex &t) const { return {x * t.x - y * t.y, x * t.y + y * t.x}; }
};

int n, m, pre[M];
int rev[N], tot, bit;
Complex f[N];
char s[M], t[M];

void getrev() {
	bit = 0;
	while ((1 << bit) < n + m + 1)	bit ++;
	tot = 1 << bit;
	for (int i = 0; i < tot; i ++)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}

void FFT(Complex a[], int inv) {
	for (int i = 0; i < tot; i ++)
		if (i < rev[i])	swap(a[i], a[rev[i]]);

	for (int h = 2; h <= tot; h <<= 1) {
		Complex wn = {cos(2 * PI/ h), sin(inv * 2 * PI / h)};
		for (int j = 0; j < tot; j += h) {
			Complex w = {1, 0};
			for (int k = j; k < j + h / 2; k ++, w = w * wn) {
				Complex u = a[k], t = w * a[k + h / 2];
				a[k] = u + t, a[k + h / 2] = u - t;
			}
		}
	}
	if (inv == -1) {
		for (int i = 0; i < tot; i ++)
			a[i].x /= tot, a[i].y /= tot;
	}
}


int main() {
	cin.tie(0);
	ios::sync_with_stdio(false);
	
	cin >> s >> t;
	n = strlen(s) - 1;
	m = strlen(t) - 1;
	reverse(t, t + m + 1);

	getrev();

	for (int i = 0; i <= n; i ++)	f[i].x = s[i] - '0';
	for (int i = 0; i <= m; i ++)	f[i].y = t[i] - '0';
	FFT(f, 1);
	for (int i = 0; i < tot; i ++)	f[i] = f[i] * f[i];
	FFT(f, -1);
	// 因为下标是从0开始的,求前缀和不方便,因此pre数组往前错一位
	for (int i = 0; i <= n; i ++) 
		pre[i + 1] = pre[i] + (s[i] - '0');

	int ans = INF, ts = 0;
	for (int i = 0; i <= m; i ++)	ts += (t[i] - '0');
	for (int i = m; i <= n; i ++) {  // 枚举位置,找到一个最小值
		int v = int(f[i].y / 2 + 0.5);
		ans = min(ans, ts + pre[i + 1] - pre[i - m] - 2 * v);
	}

	cout << ans << '\n';
	
	return 0;
}

CF528D Fuzzy Search

在这里插入图片描述
做法:
同样是一个匹配问题,本题是有一个界限值 k k k。比如我的 s i s_i si要与 t j t_j tj匹配,如何 s i ≠ t j s_i\ne t_j si=tj但是存在另一个位置 p p p,满足 s p = t j s_p=t_j sp=tj并且 a b s ( i − p ) ≤ k abs(i-p)\le k abs(ip)k,那么也算做 s i s_i si t j t_j tj匹配成功。

本题与上面的那个石头剪刀布的题目很像。因为只有 A G C T AGCT AGCT这几个字母。我们直接枚举每一个字母成功匹配了几个位置,最后加起来看一下匹配成功的总次数是否等于 ∣ T ∣ |T| T
这里用了一下 N T T NTT NTT来求多项式乘法( 390 m s 390ms 390ms)。但是优化后的 F F T FFT FFT跑的也很快( 436 m s 436ms 436ms)。

#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
#define all(x) begin(x),end(x)
#define debug(x) cout<<#x<<": "<<x<<endl;
using namespace std;
using ll = long long;
const double PI = acos(-1);
const int N = 1e6 + 10, M = 2e5 + 10;

const int MOD = 998244353, G = 3;  // 1004535809

int n, m, K, p[M];
int rev[N], tot, bit;
int f[N], g[N];
char s[M], t[M];
bool ok[M];

ll ksm(ll a, ll b, ll p) {
    ll res = 1;
    while (b) {
        if (b & 1)    res = res * a % p;
        a = a * a % p; b >>= 1;
    }
    return res;
}

void getrev() {
    bit = 0;
    while ((1 << bit) < n + m + 1)    bit ++;
    tot = 1 << bit;
    for (int i = 0; i < tot; i ++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}

void NTT(int a[], int inv) {
    for (int i = 0; i < tot; i ++)
        if (i < rev[i])    swap(a[i], a[rev[i]]);

    for (int h = 2; h <= tot; h <<= 1) {
        int gn = ksm(G, (MOD - 1) / h, MOD);
        if (inv == -1)    gn = ksm(gn, MOD - 2, MOD);
        for (int j = 0; j < tot; j += h) {
            ll g = 1;
            for (int k = j; k < j + h / 2; k ++, g = g * gn % MOD) {
                int u = a[k], v = g * a[k + h / 2] % MOD;
                a[k] = (u + v) % MOD, a[k + h / 2] = (u - v + MOD) % MOD;
            }
        }
    }
    if (inv == -1) {
        ll tot_inv = ksm(tot, MOD - 2, MOD);  // 乘法逆元
        for (int i = 0; i < tot; i ++)
            a[i] = (a[i] * tot_inv) % MOD;
    }
}



int main() {
    cin.tie(0);
    ios::sync_with_stdio(false);
    
    cin >> n >> m >> K;
    cin >> s >> t;
    reverse(t, t + m);
    n --, m --;
    getrev();

    auto work = [&](char c) {
        for (int i = 0; i < tot; i ++)  f[i] = g[i] = 0;

        memset(ok, false, sizeof ok);
        for (int i = 0, last = -INF; i <= n; i ++) {
            if (s[i] == c or (i - last) <= K)   ok[i] = true;
            if (s[i] == c)  last = i;
        }
        for (int i = n, last = INF; i >= 0; i --) {
            if ((last - i) <= K)    ok[i] = true;
            if (s[i] == c)  last = i;
        }
        for (int i = 0; i <= n; i ++)   f[i] = ok[i];
        for (int i = 0; i <= m; i ++)   g[i] = (t[i] == c);
        NTT(f, 1), NTT(g, 1);
        for (int i = 0; i < tot; i ++)  f[i] = (ll)f[i] * g[i] % MOD;
        NTT(f, -1);
        for (int i = m; i <= n; i ++)   p[i] += f[i];
    };

    work('A');
    work('T');
    work('C');
    work('G');
    int ans = 0;
    for (int i = m; i <= n; i ++)   ans += (p[i] == m + 1);
    cout << ans << '\n';
    
    return 0;
}


CF1709F Multiset of Strings

在这里插入图片描述
做法:
我们可以在一个 01 T i r e 01Tire 01Tire上考虑这个问题。并把题意抽象成为一个流量问题。
在一颗高度为 n + 1 n+1 n+1的满二叉树上。每条边有一个最大流量 c ∈ [ 0 , k ] c\in[0,k] c[0,k]。求从根节点到所有叶子节点的流量之和恰好是 f f f的方案数。
然后我们可以在这棵树上进行树形 d p dp dp
定义 d p [ x ] [ y ] dp[x][y] dp[x][y]表示在以 x x x为根的子树,从 x x x的父节点连向 x x x的边成功流到叶子节点的流量是 y y y的方案数。

l s ls ls表示 x x x的左儿子, r s rs rs表示 x x x的右儿子。
转移方程为
d p x , y = ∑ i + j > y i + j ≤ 2 k ( d p l s , i × d p r s , j ) + ( k − y + 1 ) × ∑ i + j = y ( d p l s , i × d p r s , j ) dp_{x,y}=\sum_{i+j>y}^{i+j\le2k}(dp_{ls,i}\times dp_{rs,j})+(k-y+1)\times\sum_{i+j=y}(dp_{ls,i}\times dp_{rs,j}) dpx,y=i+j>yi+j2k(dpls,i×dprs,j)+(ky+1)×i+j=y(dpls,i×dprs,j)

第一项表示从节点 l s , r s ls,rs ls,rs成功流到叶节点的流量 i + j > y i+j>y i+j>y,所以,要想从 x x x节点成功流到叶节点的流量是 y y y,那么它的父边的最大流量只能是 y y y,否则成功流到叶节点的流量就会大于 y y y
第二项表示从节点 l s , r s ls,rs ls,rs成功流到叶节点的流量恰好等于 y y y,所有, x x x的父边最大流量只要满足大于等于 y y y即可,有 k − y + 1 k-y+1 ky+1种选择。

值得注意的是,对于同一层上的节点 x , y x,y x,y,以 x x x为根的子树和以 y y y为根的子树,它们的结构是相同的,它们的 d p dp dp值也是相同的,即 ∀ i ∈ [ 0 , k ] \forall i\in[0,k] i[0,k] d p [ x ] [ i ] = d p [ y ] [ i ] dp[x][i]=dp[y][i] dp[x][i]=dp[y][i]。所以每一层我们只需要计算一次。

需要注意一点,在从第 2 2 2层向第 1 1 1层转移时,根节点的父边是我们虚拟的,它的流量可以看作无穷大,这个时候的转移方程是 d p x , y = ∑ i + j = y ( d p l s , i × d p r s , j ) dp_{x,y}=\sum_{i+j=y}(dp_{ls,i}\times dp_{rs,j}) dpx,y=i+j=y(dpls,i×dprs,j)

d p dp dp转移式是求多项式相乘,做 n n n N T T NTT NTT即可。

#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
#define all(x) begin(x),end(x)
#define debug(x) cout<<#x<<": "<<x<<endl;
using namespace std;
using ll = long long;
const int N = 1e6 + 10;
const int MOD = 998244353, G = 3;  // 1004535809

int n, k, f;
int rev[N], tot, bit;
int dp[N];

ll ksm(ll a, ll b, ll p) {
	ll res = 1;
	while (b) {
		if (b & 1)	res = res * a % p;
		a = a * a % p; b >>= 1;
	}
	return res;
}

void getrev() {
	bit = 0;
	while ((1 << bit) < k + k + 1)	bit ++;
	tot = 1 << bit;
	for (int i = 0; i < tot; i ++)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}

void NTT(int a[], int inv) {
	for (int i = 0; i < tot; i ++)
		if (i < rev[i])	swap(a[i], a[rev[i]]);

	for (int h = 2; h <= tot; h <<= 1) {
		int gn = ksm(G, (MOD - 1) / h, MOD);
		if (inv == -1)	gn = ksm(gn, MOD - 2, MOD);
		for (int j = 0; j < tot; j += h) {
			ll g = 1;
			for (int k = j; k < j + h / 2; k ++, g = g * gn % MOD) {
				int u = a[k], v = g * a[k + h / 2] % MOD;
				a[k] = (u + v) % MOD, a[k + h / 2] = (u - v + MOD) % MOD;
			}
		}
	}
	if (inv == -1) {
		ll tot_inv = ksm(tot, MOD - 2, MOD);  // 乘法逆元
		for (int i = 0; i < tot; i ++)
			a[i] = (a[i] * tot_inv) % MOD;
	}
}

int main() {
	cin.tie(0);
	ios::sync_with_stdio(false);
	
	cin >> n >> k >> f;
	getrev();
	// 一开始存储第n+1层的dp值,然后不断向上求
	for (int i = 0; i <= k; i ++)	dp[i] = 1;
	for (int l = n; l >= 1; l --) {
		NTT(dp, 1);
		for (int i = 0; i < tot; i ++)	dp[i] = (ll)dp[i] * dp[i] % MOD;
		NTT(dp, -1);
		//  大于2k无意义
		for (int i = 2 * k + 1; i < tot; i ++)	dp[i] = 0;
		if (l == 1)	break;
		ll sum = 0;  // 第一项的和
		for (int i = 2 * k; i >= 0; i --) {
			ll value = dp[i];
			if (i <= k)	dp[i] = (sum + max(0, k - i + 1) * value) % MOD;
			else	dp[i] = 0;
			sum = (sum + value) % MOD;
		}
	}
	cout << dp[f] << '\n';
	
	return 0;
}


2022年“图森未来杯” I. 01串

在这里插入图片描述
1 ≤ n , q ≤ 1 e 5 1\le n,q \le 1e5 1n,q1e5

解法:
f [ i ] f[i] f[i]表示前缀值为i的前缀个数。
(注意 f [ 0 ] f[0] f[0]初始为 1 1 1

然后对于恰好包含 k k k个1的所有子串,其个数为 a n s [ k ] = ∑ i = 0 n f [ i ] × f [ i + k ] ans[k]=\sum_{i=0}^{n}f[i]\times f[i+k] ans[k]=i=0nf[i]×f[i+k]
然后我们令 g g g数组表示 f f f数组的翻转,即 g [ i ] = f [ n − i ] g[i]=f[n-i] g[i]=f[ni]
于是 a n s [ k ] = ∑ i = 0 n f [ i ] × g [ n − i + k ] ans[k]=\sum_{i=0}^nf[i]\times g[n-i+k] ans[k]=i=0nf[i]×g[ni+k]
使用FFT解决即可。

#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
#define all(x) begin(x),end(x)
#define debug(x) cout<<#x<<": "<<x<<endl;
using namespace std;
using ll = long long;
const int N = 1e5 + 10;

const double PI = acos(-1);

struct Complex {
	double x, y;
	Complex operator + (const Complex &t) const	{ return {x + t.x, y + t.y}; }
	Complex operator - (const Complex &t) const { return {x - t.x, y - t.y}; }
	Complex operator * (const Complex &t) const { return {x * t.x - y * t.y, x * t.y + y * t.x}; }
};

int n, m, cnt[N];
ll ans[N];
char str[N];
int rev[N], tot, bit;
Complex f[N];

void getrev() {
	bit = 0;
	while ((1 << bit) < n + n + 1)	bit ++;
	tot = 1 << bit;
	for (int i = 0; i < tot; i ++)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}

void FFT(Complex a[], int inv) {
	for (int i = 0; i < tot; i ++)
		if (i < rev[i])	swap(a[i], a[rev[i]]);

	for (int h = 2; h <= tot; h <<= 1) {
		Complex wn = {cos(2 * PI/ h), sin(inv * 2 * PI / h)};
		for (int j = 0; j < tot; j += h) {
			Complex w = {1, 0};
			for (int k = j; k < j + h / 2; k ++, w = w * wn) {
				Complex u = a[k], t = w * a[k + h / 2];
				a[k] = u + t, a[k + h / 2] = u - t;
			}
		}
	}
	if (inv == -1) {
		for (int i = 0; i < tot; i ++)
			a[i].x /= tot, a[i].y /= tot;
	}
}

int main() {
	cin.tie(0);
	ios::sync_with_stdio(false);
	
	cin >> n;
	cin >> str + 1;
	cnt[0] = 1;
	for (int i = 1, s = 0; i <= n; i ++) {
		s += str[i] - '0';
		cnt[s] ++;
	}
	getrev();

	for (int i = 0; i <= n; i ++)	f[i].x = cnt[i], f[i].y = cnt[n - i];
	FFT(f, 1);
	for (int i = 0; i < tot; i ++)	f[i] = f[i] * f[i];
	FFT(f, -1);

	for (int i = 0; i <= n; i ++)	ans[i] = ll(f[n - i].y / 2 + 0.5);
	cin >> m;
	while (m --) {
		int x; cin >> x;
		cout << ans[x] << '\n';
	}

	return 0;
}


update: FFT优化模版

具体就是将A放到实部上,B放到虚部上。得到 A + B i A+Bi A+Bi
用FFT计算平方得到 ( A + B i ) ( A + B i ) = ( A 2 − B 2 ) + 2 A B i (A+Bi)(A+Bi)=(A^2-B^2)+2ABi (A+Bi)(A+Bi)=(A2B2)+2ABi
最后将虚部除以 2 2 2就是答案。可以少做一次FFT。

以多项式乘法为例题。
在这里插入图片描述

#include <bits/stdc++.h>
#define endl '\n'
#define INF 0x3f3f3f3f
#define all(x) begin(x),end(x)
#define debug(x) cout<<#x<<": "<<x<<endl;
using namespace std;
using ll = long long;
const int N = 2.1e6 + 10;
const double PI = acos(-1);

struct Complex {
	double x, y;
	Complex operator + (const Complex &t) const	{ return {x + t.x, y + t.y}; }
	Complex operator - (const Complex &t) const { return {x - t.x, y - t.y}; }
	Complex operator * (const Complex &t) const { return {x * t.x - y * t.y, x * t.y + y * t.x}; }
};

int n, m;
int rev[N], tot, bit;
Complex f[N], g[N];

void getrev() {
	while ((1 << bit) < n + m + 1)	bit ++;
	tot = 1 << bit;
	for (int i = 0; i < tot; i ++)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}

void FFT(Complex a[], int inv) {
	for (int i = 0; i < tot; i ++)
		if (i < rev[i])	swap(a[i], a[rev[i]]);

	for (int h = 2; h <= tot; h <<= 1) {
		Complex wn = {cos(2 * PI/ h), sin(inv * 2 * PI / h)};
		for (int j = 0; j < tot; j += h) {
			Complex w = {1, 0};
			for (int k = j; k < j + h / 2; k ++, w = w * wn) {
				Complex u = a[k], t = w * a[k + h / 2];
				a[k] = u + t, a[k + h / 2] = u - t;
			}
		}
	}
	if (inv == -1) {
		for (int i = 0; i < tot; i ++)
			a[i].x /= tot, a[i].y /= tot;
	}
}


int main() {
	cin.tie(0);
	ios::sync_with_stdio(false);
	
	cin >> n >> m;
	for (int i = 0; i <= n; i ++)	cin >> f[i].x;
	for (int i = 0; i <= m; i ++)	cin >> f[i].y;

	getrev();

	FFT(f, 1);
	for (int i = 0; i < tot; i ++)	f[i] = f[i] * f[i];
	FFT(f, -1);
	
	for (int i = 0; i <= n + m; i ++)
		cout << int(f[i].y / 2 + 0.5) << " \n"[i == n + m];
	
	return 0;
}

NTT模版

const int MOD = 998244353, G = 3;  // 1004535809

int n, m;
int rev[N], tot, bit;
int f[N], g[N];

ll ksm(ll a, ll b, ll p) {
	ll res = 1;
	while (b) {
		if (b & 1)	res = res * a % p;
		a = a * a % p; b >>= 1;
	}
	return res;
}

void getrev() {
	bit = 0;
	while ((1 << bit) < n + m + 1)	bit ++;
	tot = 1 << bit;
	for (int i = 0; i < tot; i ++)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}

void NTT(int a[], int inv) {
	for (int i = 0; i < tot; i ++)
		if (i < rev[i])	swap(a[i], a[rev[i]]);

	for (int h = 2; h <= tot; h <<= 1) {
		int gn = ksm(G, (MOD - 1) / h, MOD);
		if (inv == -1)	gn = ksm(gn, MOD - 2, MOD);
		for (int j = 0; j < tot; j += h) {
			ll g = 1;
			for (int k = j; k < j + h / 2; k ++, g = g * gn % MOD) {
				int u = a[k], v = g * a[k + h / 2] % MOD;
				a[k] = (u + v) % MOD, a[k + h / 2] = (u - v + MOD) % MOD;
			}
		}
	}
	if (inv == -1) {
		ll tot_inv = ksm(tot, MOD - 2, MOD);  // 乘法逆元
		for (int i = 0; i < tot; i ++)
			a[i] = (a[i] * tot_inv) % MOD;
	}
}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值