NTT经典例题
CCPC-Winter-Camp-day6-A——NTT经典例题
对于上面格式,如果想求出每个i的值可以使用卷积求出,因为阶乘j和阶乘i-j相乘的值为(i+(i-j))=i
补充一个二次剩余定理
P5491 【模板】二次剩余 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
//#include<bits/stdc++.h>
#include<iostream>
#include<algorithm>
#include<numeric>
#include<cstring>//rfind("string"),s.find(string,begin)!=s.npos,find_first _of(),find_last_of()
#include<string>//to_string(value),s.substr(int begin, int length);
#include<cstdio>
#include<cmath>
#include<vector>//res.erase(unique(res.begin(), res.end()), res.end()),resize(n)//size of vector,vector<int>().swap(at[mx])
#include<queue>//priority_queue(big) /priority_queue<int, vector<int>, greater<int>> q(small)
#include<stack>
#include<map>
#include<set>
#include<unordered_map>
#include<unordered_set>
#include<bitset>
#include<random>
#include<chrono>
//#include<ext/pb_ds/assoc_container.hpp>//gp_hash_table
//#include<ext/pb_ds/hash_policy.hpp>
//using namespace __gnu_pbds;
std::mt19937_64 rnd(std::chrono::steady_clock::now().time_since_epoch().count());
using namespace std;
#define int long long//__int128 2^127-1(GCC)
#define PII pair<int,int>
struct num {
int x;// 实部
int y;// 虚部(即虚数单位√w的系数)
};
int t, w, n, p;
num mul(num a, num b, int p) {// 复数乘法
num res;
res.x = ((a.x * b.x % p + a.y * b.y % p * w % p) % p + p) % p;// x = a.x*b.x + a.y*b.y*w
res.y = ((a.x * b.y % p + a.y * b.x % p) % p + p) % p;// y = a.x*b.y + a.y*b.x
return res;
}
int qpow_r(int a, int b, int p) {// 实数快速幂
int res = 1;
while (b) {
if (b & 1) res = res * a % p;
a = a * a % p;
b >>= 1;
}
return res;
}
int qpow_i(num a, int b, int p) {// 复数快速幂
num res = { 1,0 };
while (b) {
if (b & 1) res = mul(res, a, p);
a = mul(a, a, p);
b >>= 1;
}
return res.x % p;// 只用返回实数部分,因为虚数部分没了
}
int cipolla(int n, int p) {
n %= p;
if (qpow_r(n, (p - 1) / 2, p) == -1 + p) return -1;// 据欧拉准则判定是否有解
int a;
while (1) {// 找出一个符合条件的a
a = rand() % p;
w = (((a * a) % p - n) % p + p) % p;// w = a^2 - n,虚数单位的平方
if (qpow_r(w, (p - 1) / 2, p) == -1 + p) break;
}
num x = { a,1 };
return qpow_i(x, (p + 1) / 2, p);
}
signed main() {
srand(time(0));
cin >> t;
while (t--) {
cin >> n >> p;
if (!n) {
printf("0\n");
continue;
}
int ans1 = cipolla(n, p), ans2 = -ans1 + p;// 另一个解就是其相反数,ans1正数解
if (ans1 == -1) printf("Hola!\n");//无解
else {
if (ans1 > ans2) swap(ans1, ans2);
if (ans1 == ans2) printf("%lld\n", ans1);
else printf("%lld %lld\n", ans1, ans2);
}
}
return 0;
}
NTT背包合并
PowerPoint 演示文稿 (nowcoder.com)
有点像数位dp,其中用到背包合并可以使用多项式解决,如果n个背包合并可以使用线段树和启发式合并类似的思想
//#include<bits/stdc++.h>
#include<iostream>
#include<algorithm>
#include<numeric>
#include<cstring>//rfind("string"),s.find(string,begin)!=s.npos,find_first _of(),find_last_of()
#include<string>//to_string(value),s.substr(int begin, int length);
#include<cstdio>
#include<cmath>
#include<vector>//res.erase(unique(res.begin(), res.end()), res.end()),resize(n)//size of vector,vector<int>().swap(at[mx])
#include<queue>//priority_queue(big) /priority_queue<int, vector<int>, greater<int>> q(small)
#include<stack>
#include<map>
#include<set>
#include<unordered_map>
#include<unordered_set>
#include<bitset>
#include<random>
#include<chrono>
//#include<ext/pb_ds/assoc_container.hpp>//gp_hash_table
//#include<ext/pb_ds/hash_policy.hpp>
//using namespace __gnu_pbds;
std::mt19937_64 rnd(std::chrono::steady_clock::now().time_since_epoch().count());
using namespace std;
#define int long long//__int128 2^127-1(GCC)
#define PII pair<int,int>
const int N = 3e6 + 5, mod = 998244353;
namespace ntt {
const int g = 3;
int a[N], b[N];
int r[N], tot, bit;
int invg;
int qpow(int a, int b) {
int res = 1;
while (b) {
if (b & 1) res = 1ll * res * a % mod;
a = 1ll * a * a % mod;
b >>= 1;
}
return res;
}
void add(int& a, int b) {
a += b;
if (a >= mod) a -= mod;
}
void NTT(int a[], int inv) {
for (int i = 0; i < tot; i++)
if (i < r[i])
swap(a[i], a[r[i]]);
for (int mid = 1; mid < tot; mid <<= 1) {
int g1 = qpow(inv == 1 ? g : invg, (mod - 1) / (mid << 1));
for (int i = 0; i < tot; i += mid << 1) {
for (int j = 0, gk = 1; j < mid; j++, gk = 1ll * gk * g1 % mod) {
int x = a[i + j], y = 1ll * gk * a[i + j + mid] % mod;
a[i + j] = (x + y) % mod, a[i + j + mid] = (x - y + mod) % mod;
}
}
}
if (inv == -1) {
int invtot = qpow(tot, mod - 2);
for (int i = 0; i < tot; i++) {
a[i] = 1ll * a[i] * invtot % mod;
}
}
}
struct Poly {
vector<int> coef;
int deg;
int& operator[](int x) {
return coef[x];
}
Poly(int deg = -1) : deg(deg) {
coef = vector<int>(deg + 1, 0);
}
void norm(int deg) {
this->deg = deg;
coef.resize(deg + 1);
}
};
void init(int len) {
bit = tot = 0;
while ((1ll << bit) <= len) bit++;
tot = 1ll << bit;
for (int i = 0; i < tot; i++) a[i] = b[i] = 0;
for (int i = 1; i < tot; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
Poly operator*(const Poly& f, const Poly& g) {
Poly res(f.deg + g.deg);
if (f.deg <= 8 || g.deg <= 8) {
for (int i = 0; i <= f.deg; i++)
for (int j = 0; j <= g.deg; j++)
add(res[i + j], 1ll * f.coef[i] * g.coef[j] % mod);
return res;
}
init(res.deg);
copy(f.coef.begin(), f.coef.end(), a);
copy(g.coef.begin(), g.coef.end(), b);
NTT(a, 1), NTT(b, 1);
for (int i = 0; i < tot; i++) a[i] = 1ll * a[i] * b[i] % mod;
NTT(a, -1);
copy(a, a + res.deg + 1, res.coef.begin());
return res;
}
int __ = []
{
invg = qpow(g, mod - 2);
return 0;
}();
}
using namespace ntt;
signed main()
{
ios_base::sync_with_stdio(0); cin.tie(0), cout.tie(0);
int n, k;
int m;
cin >> n >> m >> k;
vector<int>a(n + 1);
vector<Poly>v;
int sum = 0;
for (int i = 1; i <= n; i++) {
cin >> a[i];
sum += a[i];
Poly f(a[i]);
f[0] = f[a[i]] = 1;
v.emplace_back(f);
}
auto solve = [&](auto self, int l, int r)->Poly {
if (l == r) return v[l];
int mid = l + r >> 1;
return self(self, l, mid) * self(self, mid + 1, r);
};
Poly f = solve(solve, 0, v.size() - 1);
//assert(f.deg == sum);
vector<vector<int>>ban(60, vector<int>());
//array<vector<int>, 60>ban;
while (k--)
{
int b, c;
cin >> b >> c;
ban[c].push_back(b);
}
vector<Poly>dp(2);
dp[0].norm(0);
dp[0][0] = 1;
for (int i = 0; i < 60; i++) {
Poly g = f;
sort(ban[i].begin(), ban[i].end());
ban[i].erase(unique(ban[i].begin(), ban[i].end()), ban[i].end());
for (auto x : ban[i]) {
for (int j = a[x]; j <= sum; j++) {
g[j] -= g[j - a[x]];
if (g[j] < 0)g[j] += mod;
}
}
vector<Poly>f(2), ndp(2);
f[0] = dp[0] * g;
f[1] = dp[1] * g;
for (auto t : { 0,1 }) ndp[t].norm(f[t].deg / 2);
for (auto t : { 0,1 }) {
for (int j = 0; j <= f[t].deg; j++) {
if (j % 2 == (m >> i & 1)) {
add(ndp[t][j / 2], f[t][j]);
}
else if (j % 2 > (m >> i & 1)) {
add(ndp[1][j / 2], f[t][j]);
}
else {
add(ndp[0][j / 2], f[t][j]);
}
}
}
dp = ndp;
}
cout << dp[0][0] << "\n";
}