注:
代码中的 //… 这五个符号开头的表示该注释需要用相应的代码来取代
作者:袁乐天
最后一次更新:2020-08-25
不允许转载
树部分
KD-Tree (k-dimension tree)
给定一个
k
k
k维点
x
x
x,在
n
n
n个点中找到距离
x
x
x最近的点
kd-tree的本质仍是暴力
时间复杂度:
O
(
l
o
g
n
)
∼
O
(
n
)
O(log\space n) \sim O(n)
O(log n)∼O(n)
https://blog.csdn.net/silangquan/article/details/41483689
#include <iostream>
#include <algorithm>
#include <queue>
using namespace std;
namespace kd_tree
{
#define SIZE 50010
#define MAX_K 5
#define sqr(x) ((x)*(x))
int index, k;
struct Node
{
int x[MAX_K];
bool operator<(const Node &b) const
{
return x[index] < b.x[index];
}
} p[SIZE];
typedef pair<double, Node> dn;
priority_queue<dn> q;
struct KD_tree
{
int sz[SIZE << 2];
Node nd[SIZE << 2];
void build(int i, int l, int r, int d)
{
if(l > r) return;
int mid = (l + r) >> 1;
index = d % k;
sz[i] = r - l;
sz[i << 1] = sz[i << 1 | 1] = -1;
nth_element(p + l, p + mid, p + r + 1);
nd[i] = p[mid];
build(i << 1, l, mid - 1, d + 1);
build(i << 1 | 1, mid + 1, r, d + 1);
}
void query(int i, int m, int d, Node a)
{
if(sz[i] == -1) return;
dn tmp = dn(0, nd[i]);
for(int j = 0; j < k; j++)
tmp.first += sqr(tmp.second.x[j] - a.x[j]);
int lc = i << 1, rc = i << 1 | 1, dim = d % k, flag = 0;
if(a.x[dim] >= nd[i].x[dim]) swap(lc, rc);
if(~sz[lc]) query(lc, m, d + 1, a);
if(q.size() < m) q.push(tmp), flag = 1;
else
{
if(tmp.first < q.top().first) q.pop(), q.push(tmp);
if(sqr(a.x[dim] - nd[i].x[dim]) < q.top().first) flag = 1;
}
if(~sz[rc] && flag) query(rc, m, d + 1, a);
}
};
}
kd_tree::KD_tree kdt;
int main()
{
//...读入点p,数量为n,维度数为k,第一个点的下标从0开始,第一个维度的下标从0开始
kdt.build(1, 0, n - 1, 0);
//...读入查询,查询与目标点tar最接近的m个点,这里q.top()是第m个最近的点
kdt.query(1, m, 0, tar);
output(kd_tree::q);
}
可持久化线段树(主席树、函数式线段树)
记录线段树被修改的所有历史版本
时间复杂度:建树
O
(
n
l
o
g
n
)
O(n\space log\space n)
O(n log n),添加
O
(
l
o
g
n
)
O(log\space n)
O(log n)
https://www.luogu.com.cn/problem/P3834
#include <iostream>
#include <algorithm>
using namespace std;
namespace segment_tree
{
#define MAX_N 200010
#define INF 1e9
struct Node
{
int lc, rc;
int sum;
} tree[MAX_N * 20];
int tot, a[MAX_N], root[MAX_N];
int build(int l, int r)
{
int p = ++tot;
tree[p].sum = 0;
if(l == r) return p;
int mid = (l + r) >> 1;
tree[p].lc = build(l, mid);
tree[p].rc = build(mid + 1, r);
return p;
}
int insert(int now, int l, int r, int x, int delta)
{
int p = ++tot;
tree[p] = tree[now];
if(l == r)
{
tree[p].sum += delta;
return p;
}
int mid = (l + r) >> 1;
if(x <= mid) tree[p].lc = insert(tree[now].lc, l, mid, x, delta);
else tree[p].rc = insert(tree[now].rc, mid + 1, r, x, delta);
tree[p].sum = tree[tree[p].lc].sum + tree[tree[p].rc].sum;
return p;
}
//询问静态区间第k小
int ask(int p, int q, int l, int r, int k)
{
if(l == r) return l;
int mid = (l + r) >> 1;
int lcnt = tree[tree[p].lc].sum - tree[tree[q].lc].sum;
if(k <= lcnt) return ask(tree[p].lc, tree[q].lc, l, mid, k);
else return ask(tree[p].rc, tree[q].rc, mid + 1, r, k - lcnt);
}
}
using namespace segment_tree;
int main()
{
//...首先将n个数字读入数组input,下标从1开始
int t = 0;
for(int i = 1; i <= n; i++)
a[++t] = input[i];
sort(a + 1, a + t + 1); // 离散化
t = unique(a + 1, a + t + 1) - (a + 1);
root[0] = build(1, t); // 关于离散化后的值域建树
for(int i = 1; i <= n; i++)//n个数字
{
int x = lower_bound(a + 1, a + t + 1, input[i]) - a; // 离散化后的值
root[i] = insert(root[i - 1], 1, t, x, 1); // 值为x的数增加1个
}
//询问静态区间第k小
for(int i = 1; i <= m; i++)//m个询问
{
int l, r, k;
scanf("%d%d%d", &l, &r, &k);
int ans = ask(root[r], root[l - 1], 1, t, k);
printf("%d\n", a[ans]);
}
}
可持久化数组
时间复杂度:建树
O
(
n
l
o
g
n
)
O(n\space log\space n)
O(n log n),查询
O
(
l
o
g
n
)
O(log\space n)
O(log n),更改
O
(
l
o
g
n
)
O(log\space n)
O(log n)
https://www.luogu.com.cn/problem/P3919
#include <iostream>
using namespace std;
namespace persistent_array
{
#define MAX_N 1000010
#define INF 1e9
struct Node
{
int lc, rc;
int val;
} tree[MAX_N * 20];
int tot, a[MAX_N], root[MAX_N];
int build(int l, int r)
{
int p = ++tot;
if(l == r)
{
tree[p].val = a[l];
return p;
}
int mid = (l + r) / 2;
tree[p].lc = build(l, mid);
tree[p].rc = build(mid + 1, r);
return p;
}
int update(int nd, int l, int r, int k, int val)
{
int p = ++tot;
tree[p] = tree[nd];
if(l == r)
{
tree[p].val = val;
return p;
}
int mid = (l + r) / 2;
if(k <= mid) tree[p].lc = update(tree[p].lc, l, mid, k, val);
else tree[p].rc = update(tree[p].rc, mid + 1, r, k, val);
return p;
}
//询问第v个版本的第k个数字
int ask(int v, int l, int r, int k)
{
if(l == r) return tree[v].val;
int mid = (l + r) / 2;
if(k <= mid) return ask(tree[v].lc, l, mid, k);
else return ask(tree[v].rc, mid + 1, r, k);
}
}
using namespace persistent_array;
int main()
{
//...首先将n个数字读入数组a,下标从1开始
root[0] = build(1, n); //建树
//修改第v个版本的数组的第k个数字为val,将这个版本作为版本i
root[i] = update(root[v], 1, n, k, val);
//查询第v个版本的数组的第k个数字
printf("%d\n", ask(root[v], 1, n, k));
}
Splay Tree(伸展树)
平衡树的一种,用于减少增删改查的时间复杂度
代码不会少于200行,请在使用之前慎重考虑能否用别的数据结构替代
时间复杂度:增删改查平均
O
(
l
o
g
n
)
O(log\space n)
O(log n)
https://users.cs.fiu.edu/~weiss/dsaa_c++4/code/SplayTree.h
#include <iostream>
using namespace std;
namespace splay_tree//自顶向下Splay
{
// ******************PUBLIC OPERATIONS*********************
// void insert( x ) --> Insert x
// void remove( x ) --> Remove x
// bool contains( x ) --> Return true if x is present
// Comparable findMin( ) --> Return smallest item
// Comparable findMax( ) --> Return largest item
// bool isEmpty( ) --> Return true if empty; else false
// void makeEmpty( ) --> Remove all items
// void printTree( ) --> Print tree in sorted order
template<typename Comparable>
class SplayTree
{
public:
SplayTree()
{
nullNode = new BinaryNode;
nullNode->left = nullNode->right = nullNode;
root = nullNode;
}
SplayTree(const SplayTree &rhs)
{
nullNode = new BinaryNode;
nullNode->left = nullNode->right = nullNode;
root = clone(rhs.root);
}
SplayTree(SplayTree &&rhs) : root{rhs.root}, nullNode{rhs.nullNode}
{
rhs.root = nullptr;
rhs.nullNode = nullptr;
}
~SplayTree()
{
makeEmpty();
delete nullNode;
}
SplayTree &operator=(const SplayTree &rhs)
{
SplayTree copy = rhs;
std::swap(*this, copy);
return *this;
}
SplayTree &operator=(SplayTree &&rhs)
{
std::swap(root, rhs.root);
std::swap(nullNode, rhs.nullNode);
return *this;
}
const Comparable &findMin()
{
if(isEmpty())
throw "UnderflowException";
BinaryNode *ptr = root;
while(ptr->left != nullNode)
ptr = ptr->left;
splay(ptr->element, root);
return ptr->element;
}
const Comparable &findMax()
{
if(isEmpty())
throw "UnderflowException";
BinaryNode *ptr = root;
while(ptr->right != nullNode)
ptr = ptr->right;
splay(ptr->element, root);
return ptr->element;
}
bool contains(const Comparable &x)
{
if(isEmpty()) return false;
splay(x, root);
return root->element == x;
}
bool isEmpty() const
{
return root == nullNode;
}
void printTree() const
{
if(isEmpty()) cout << "Empty tree" << endl;
else printTree(root);
}
void makeEmpty()
{
while(!isEmpty())
{
findMax();
remove(root->element);
}
}
void insert(const Comparable &x)
{
static BinaryNode *newNode = nullptr;
if(newNode == nullptr)
newNode = new BinaryNode;
newNode->element = x;
if(root == nullNode)
{
newNode->left = newNode->right = nullNode;
root = newNode;
}
else
{
splay(x, root);
if(x < root->element)
{
newNode->left = root->left;
newNode->right = root;
root->left = nullNode;
root = newNode;
}
else if(root->element < x)
{
newNode->right = root->right;
newNode->left = root;
root->right = nullNode;
root = newNode;
}
else
return;
}
newNode = nullptr;
}
void remove(const Comparable &x)
{
if(!contains(x))
return;
BinaryNode *newTree;
if(root->left == nullNode)
newTree = root->right;
else
{
newTree = root->left;
splay(x, newTree);
newTree->right = root->right;
}
delete root;
root = newTree;
}
private:
struct BinaryNode
{
Comparable element;
BinaryNode *left;
BinaryNode *right;
BinaryNode() : left{nullptr}, right{nullptr}
{}
BinaryNode(const Comparable &theElement, BinaryNode *lt, BinaryNode *rt)
: element{theElement}, left{lt}, right{rt}
{}
};
BinaryNode *root;
BinaryNode *nullNode;
void reclaimMemory(BinaryNode *t)
{
if(t != t->left)
{
reclaimMemory(t->left);
reclaimMemory(t->right);
delete t;
}
}
void printTree(BinaryNode *t) const
{
if(t != t->left)
{
printTree(t->left);
cout << t->element << endl;
printTree(t->right);
}
}
BinaryNode *clone(BinaryNode *t) const
{
if(t == t->left)
return nullNode;
else
return new BinaryNode{t->element, clone(t->left), clone(t->right)};
}
void rotateWithLeftChild(BinaryNode *&k2)
{
BinaryNode *k1 = k2->left;
k2->left = k1->right;
k1->right = k2;
k2 = k1;
}
void rotateWithRightChild(BinaryNode *&k1)
{
BinaryNode *k2 = k1->right;
k1->right = k2->left;
k2->left = k1;
k1 = k2;
}
void splay(const Comparable &x, BinaryNode *&t)
{
BinaryNode *leftTreeMax, *rightTreeMin;
static BinaryNode header;
header.left = header.right = nullNode;
leftTreeMax = rightTreeMin = &header;
nullNode->element = x;
for(;;)
if(x < t->element)
{
if(x < t->left->element)
rotateWithLeftChild(t);
if(t->left == nullNode)
break;
rightTreeMin->left = t;
rightTreeMin = t;
t = t->left;
}
else if(t->element < x)
{
if(t->right->element < x)
rotateWithRightChild(t);
if(t->right == nullNode)
break;
leftTreeMax->right = t;
leftTreeMax = t;
t = t->right;
}
else
break;
leftTreeMax->right = t->left;
rightTreeMin->left = t->right;
t->left = header.right;
t->right = header.left;
}
};
}
数论部分
Pollard’s Rho质因数分解
时间复杂度: O ( n 1 4 ) O(n^{\frac{1}{4}}) O(n41) 左右
#include<iostream>
#include<algorithm>
#include<queue>
using namespace std;
namespace pollards_rho
{
typedef long long ll;
queue<ll> res;
ll min(ll a, ll b)
{
if(a < b) return a;
else return b;
}
ll multi(ll a, ll b, ll p)//龟速乘,防止爆long long
{
ll ans = 0;
while(b)
{
if(b & 1LL) ans = (ans + a) % p;
a = (a + a) % p;
b >>= 1;
}
return ans;
}
ll fastPow(ll a, ll b, ll p)
{
ll ans = 1;
while(b)
{
if(b & 1LL) ans = multi(ans, a, p);
a = multi(a, a, p);
b >>= 1;
}
return ans;
}
bool millerRabin(ll n)
{
ll x[105];
if(n == 2) return true;
int s = 20, i, t = 0;
ll u = n - 1;
while(!(u & 1))
{
t++;
u >>= 1;
}
while(s--)
{
ll a = rand() % (n - 2) + 2;
x[0] = fastPow(a, u, n);
for(i = 1; i <= t; i++)
{
x[i] = multi(x[i - 1], x[i - 1], n);
if(x[i] == 1 && x[i - 1] != 1 && x[i - 1] != n - 1) return false;
}
if(x[t] != 1) return false;
}
return true;
}
ll gcd(ll a, ll b)
{
if(b == 0) return a;
else return gcd(b, a % b);
}
ll pollardsRho(ll n, int c)
{
ll i = 1, k = 2, x = rand() % (n - 1) + 1, y = x;
while(1)
{
i++;
x = (multi(x, x, n) + c) % n;
ll p = gcd((y - x + n) % n, n);
if(p != 1 && p != n) return p;
if(y == x) return n;
if(i == k)
{
y = x;
k <<= 1;
}
}
}
void find(ll n, int c)
{
if(n == 1) return;
if(millerRabin(n))
{
res.push(n);
return;
}
ll p = n, k = c;
while(p >= n)
p = pollardsRho(p, c--);
find(p, k);
find(n / p, k);
}
}
using namespace pollards_rho;
int main()
{
ll n;
scanf("%lld", &n);//读入大数字
find(n, 107);//开始质因数分解
while(!res.empty())//n的所有质因数存储在res中
{
printf("%lld ", res.front());
res.pop();
}
return 0;
}
离散代数部分
线性基
对于整数集合S,找到整数集合B,使得集合B的在异或的运算下是线性无关的且集合B是S的基
时间复杂度:
O
(
n
)
O(n)
O(n)
https://oi.men.ci/linear-basis-notes/
#include <iostream>
namespace linear_basis
{
#define MAX_L 64
struct LinearBasis
{
long long a[MAX_L + 1];
LinearBasis()
{
std::fill(a, a + MAX_L + 1, 0);
}
LinearBasis(long long *x, int n)
{
build(x, n);
}
void insert(long long t)
{
for(int j = MAX_L; j >= 0; j--)
{
if(!t) return;
if(!(t & (1ll << j))) continue;
if(a[j]) t ^= a[j];
else
{
for(int k = 0; k < j; k++) if(t & (1ll << k)) t ^= a[k];
for(int k = j + 1; k <= MAX_L; k++) if(a[k] & (1ll << j)) a[k] ^= t;
a[j] = t;
return;
}
}
}
void build(long long *x, int n)
{
std::fill(a, a + MAX_L + 1, 0);
for(int i = 1; i <= n; i++)
{
insert(x[i]);
}
}
long long queryMax()
{
long long res = 0;
for(int i = 0; i <= MAX_L; i++) res ^= a[i];
return res;
}
void mergeFrom(const LinearBasis &other)
{
for(int i = 0; i <= MAX_L; i++) insert(other.a[i]);
}
static LinearBasis merge(const LinearBasis &a, const LinearBasis &b)
{
LinearBasis res = a;
for(int i = 0; i <= MAX_L; i++) res.insert(b.a[i]);
return res;
}
};
}
快速傅里叶变换(FFT)
快速进行多项式相乘
时间复杂度:
O
(
n
l
o
g
n
)
O(n\space log n)
O(n logn)
https://www.luogu.com.cn/problem/P3803
#include <iostream>
#include <complex>
#include <cmath>
using namespace std;
namespace FFT
{
#define MAX_N 2097153
#define PI (3.14159265358979323846)
#define C complex<double>
C a[MAX_N], b[MAX_N];
int rev[MAX_N];
void init(int k)
{
int len = 1 << k;
for(int i = 0; i < len; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
}
void fft(C *a, int n, int flag)
{
for(int i = 0; i < n; i++)
if(i < rev[i])swap(a[i], a[rev[i]]);
for(int h = 1; h < n; h *= 2)
{
C wn = exp(C(0, flag * PI / h));
for(int j = 0; j < n; j += h * 2)
{
C w(1, 0);
for(int k = j; k < j + h; k++)
{
C x = a[k];
C y = w * a[k + h];
a[k] = x + y;
a[k + h] = x - y;
w *= wn;
}
}
}
if(flag == -1)
for(int i = 0; i < n; i++)
a[i] /= n;
}
}
using namespace FFT;
int main()
{
//...读入数字,保存在复数(a[i], b[i])的实部,数组长度为n + 1,下标从0开始
//k表示转化成二进制的位数
int k = 1, s = 2;
while((1 << k) < 2 * n - 1)
k++, s <<= 1;
init(k);
//FFT 把a的系数表示转化为点值表示
fft(a, s, 1);
//FFT 把b的系数表示转化为点值表示
fft(b, s, 1);
//FFT 两个多项式的点值表示相乘
for(int i = 0; i < s; i++)
a[i] *= b[i];
//IFFT 把这个点值表示转化为系数表示
fft(a, s, -1);
//输出答案
for(int i = 0; i <= n + n; i++)
{
int ans = int(a[i].real() + 0.5);
printf("%d%c", ans, i == s - 1 ? '\n' : ' ');
}
return 0;
}
计算几何部分
半平面交
时间复杂度:
O
(
n
l
o
g
n
)
O(n\space log\space n)
O(n log n)
https://www.luogu.com.cn/problem/P4196
#include <iostream>
#include <cmath>
#include <queue>
#include <algorithm>
using namespace std;
namespace half_plane_intersection
{
#define N 1000
#define EPS 1E-10
#define UNSET (-999)
int n;
inline int cmp(double x) { return (x > EPS) - (x < -EPS); }
struct vec
{
double x, y;
vec(double a = 0, double b = 0) : x(a), y(b) {}
inline double len_sq() const { return x * x + y * y; }
inline double len() const { return sqrt(len_sq()); }
inline vec norm() const { return cmp(len_sq()) ? *this / len() : vec(); }
inline vec perp() const { return vec(-y, x).norm(); }
inline vec operator+(const vec &v) const { return vec(x + v.x, y + v.y); }
inline vec operator-(const vec &v) const { return vec(x - v.x, y - v.y); }
inline vec operator*(double v) const { return vec(x * v, y * v); }
inline vec operator/(double v) const { return vec(x / v, y / v); }
inline double operator*(const vec &v) const { return x * v.y - y * v.x; }
inline double operator%(const vec &v) const { return x * v.x + y * v.y; }
} a[N], c[N];
struct line
{
vec p, v;
line(vec a = vec(), vec b = vec()) : p(a), v(b) {}
inline bool on_left(const vec &a) const { return cmp((a - p) * v) <= 0; }
vec intersect(const line &l) const
{
double t = (l.p - p) * l.v / (v * l.v);
return p + v * t;
}
double deg = UNSET;
bool operator<(line &t)
{
if(cmp(t.deg - UNSET) == 0)
t.deg = atan2(t.v.y, t.v.x);
if(cmp(deg - UNSET) == 0)
deg = atan2(v.y, v.x);
return cmp(deg - t.deg) == 0 ?
cmp(v * (t.v + t.p - p)) > 0 : cmp(deg - t.deg) < 0;
}
} b[N];
inline bool check(const line &a, const line &b, const line &c)
{
return c.on_left(a.intersect(b));
}
double halfPlaneIntersection()//计算面积
{
sort(b, b + n);
int cnt = 0;
for(int i = 1; i < n; i++)
{
if(cmp(b[i].deg - b[i - 1].deg) != 0)cnt++;
b[cnt] = b[i];
}
n = cnt + 1;
deque<int> q;
for(int i = 0; i < n; i++)
{
while(q.size() > 1 && !check(b[q[q.size() - 2]], b[q[q.size() - 1]], b[i]))
q.pop_back();
while(q.size() > 1 && !check(b[q[1]], b[q[0]], b[i]))
q.pop_front();
q.push_back(i);
}
while(q.size() > 2 && !check(b[q[q.size() - 2]], b[q[q.size() - 1]], b[q[0]]))
q.pop_back();
while(q.size() > 2 && !check(b[q[1]], b[q[0]], b[q[q.size() - 1]]))
q.pop_front();
if(q.size() <= 2)return 0;
for(int i = 0; i < q.size(); i++)
c[i] = b[q[i]].intersect(b[q[(i + 1) % q.size()]]);
double ans = 0;
for(int i = 1; i < q.size() - 1; i++)
ans += (c[i] - c[0]) * (c[i + 1] - c[i]);
return ans / 2;
}
}
using namespace half_plane_intersection;
int main()
{
//读入n
scanf("%d", &n);
//读入n条线,不用确保逆时针方向,内部会进行极角排序
for(int i = 0; i < n; i++)
{
double x1, y1, x2, y2;
scanf("%lf%lf%lf%lf", &x1, &y1, &x2, &y2);
b[i] = {{x1, y1}, vec{x2, y2} - vec{x1, y1}};
}
printf("%lf", halfPlaneIntersection());
}
辛普森积分(Simpson积分)
求
∫
a
b
f
(
x
)
d
x
\int_a^b {f\left( x \right)dx}
∫abf(x)dx的近似值
设所要求的精度为EPS,时间复杂度为
O
(
log
∣
b
−
a
E
P
S
∣
)
O\left( {\log \left| {{{b - a} \over {EPS}}} \right|} \right)
O(log∣∣EPSb−a∣∣)
#include <iostream>
#include <cmath>
using namespace std;
namespace simpson_integration
{
#define EPS 1E-10
double (*_f)(double);//目标函数
double simpson(double l, double r)
{
double mid = (l + r) / 2;
return ((*_f)(l) + 4 * (*_f)(mid) + (*_f)(r)) * (r - l) / 6;
}
double calc(double l, double r, double a)
{
double mid = (l + r) / 2;
double u = simpson(l, mid), v = simpson(mid, r);
if(fabs(u + v - a) <= 15 * EPS)return u + v + (u + v - a) / 15;
return calc(l, mid, u) + calc(mid, r, v);
}
double calc(double l, double r, double (*f)(double))
{
_f = f;
return calc(l, r, simpson(l, r));
}
}
I/O部分
快读
inline int read()
{
int s = 0, w = 1;
char ch = getchar();
while(ch < '0' || ch > '9')
{
if(ch == '-')w = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9')s = s * 10 + ch - '0', ch = getchar();
return s * w;
}