[BZOJ3658]Jabberwocky
试题描述
平面上有n个点,每个点有k种颜色中的一个。
你可以选择一条水平的线段获得在其上方或其下方的所有点,如图所示:
请求出你最多能够得到多少点,使得获得的点并不包含所有的颜色。
输入
包含多组测试数据,第一行输入一个数T表示测试数据组数。
接下来T组测试数据,对于每组测试数据,第一行输入两个数n,k,分别表示点的个数和颜色数。
接下来n行每行描述一个点,前两个数z,y(lxl,lyl≤2^32-1)描述点的位置,最后一个数z(1≤z≤K)描述点的颜色。
接下来T组测试数据,对于每组测试数据,第一行输入两个数n,k,分别表示点的个数和颜色数。
接下来n行每行描述一个点,前两个数z,y(lxl,lyl≤2^32-1)描述点的位置,最后一个数z(1≤z≤K)描述点的颜色。
输出
对于每组数据输出一行,每行一个数ans,表示答案。
输入示例
1 10 3 1 2 3 2 1 1 2 4 2 3 5 3 4 4 2 5 1 2 6 3 1 6 7 1 7 2 3 9 4 2
输出示例
5
数据规模及约定
N<=100000,K<=100000,T<=3
题解
题目要求不包含所有颜色,即至少有一个颜色不被包含,我们可以枚举这个不被包含的颜色。
先考虑某条线段下的情况,假设当前枚举的不包含的颜色为 x,接下来想象一条扫描线往上走,遇到颜色 x 的点就得被这个点劈成两半(同时在这个时刻统计并更新答案),接下来这两半分别进行同样的操作。
对于某条线段上的情况,做法是对称的。
我们可以用分治的方法模拟这个过程,因为上面的描述显然符合子问题和当前问题一模一样的条件,需要做的就是用数据结构优化这些操作:
1.) 找到最先碰到的颜色为 x 的点,可以用线段树,按照 x 坐标存点,维护每个区间内点的 y 坐标的最大、最小值。
2.) 求一个矩形内部点的个数,可以用主席树,每个 y 坐标作为一个版本,按 y 从小到大维护 x 坐标的个数。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <stack>
#include <vector>
#include <queue>
#include <cstring>
#include <string>
#include <map>
#include <set>
using namespace std;
const int BufferSize = 1 << 16;
char buffer[BufferSize], *Head, *Tail;
inline char Getchar() {
if(Head == Tail) {
int l = fread(buffer, 1, BufferSize, stdin);
Tail = (Head = buffer) + l;
}
return *Head++;
}
int read() {
int x = 0, f = 1; char c = Getchar();
while(!isdigit(c)){ if(c == '-') f = -1; c = Getchar(); }
while(isdigit(c)){ x = x * 10 + c - '0'; c = Getchar(); }
return x * f;
}
#define maxn 100010
#define maxnode 6666666
#define oo 2147483647
int n, num[maxn<<1], cntn, head[maxn], next[maxn];
struct Point {
int x, y;
Point() {}
Point(int _, int __): x(_), y(__) {}
} ps[maxn], p2[maxn];
bool cmpx(Point a, Point b) { return a.x < b.x; }
bool cmpy(Point a, Point b) { return a.y < b.y; }
int ToT, rt[maxn<<1], sumv[maxnode], lc[maxnode], rc[maxnode];
void update(int& y, int x, int l, int r, int p) {
sumv[y = ++ToT] = sumv[x] + 1;
if(l == r) return ;
int mid = l + r >> 1; lc[y] = lc[x]; rc[y] = rc[x];
if(p <= mid) update(lc[y], lc[x], l, mid, p);
else update(rc[y], rc[x], mid + 1, r, p);
return ;
}
void build() {
ToT = 0;
memset(sumv, 0, sizeof(sumv));
memset(lc, 0, sizeof(lc));
memset(rc, 0, sizeof(rc));
sort(p2 + 1, p2 + n + 1, cmpy);
for(int y = 1, i = 1; y <= cntn; y++) {
rt[y] = rt[y-1];
while(i <= n && p2[i].y == y) update(rt[y], rt[y], 1, cntn, p2[i].x), i++;
}
return ;
}
int query(int o, int l, int r, int ql, int qr) {
if(!o) return 0;
if(ql <= l && r <= qr) return sumv[o];
int mid = l + r >> 1, ans = 0;
if(ql <= mid) ans += query(lc[o], l, mid, ql, qr);
if(qr > mid) ans += query(rc[o], mid + 1, r, ql, qr);
return ans;
}
Point mxp[maxn<<3], mnp[maxn<<3];
void init_seg(int L, int R, int o) {
if(L == R) mxp[o] = Point(L, -233), mnp[o] = Point(R, oo);
else {
int M = L + R >> 1, lc = o << 1, rc = lc | 1;
init_seg(L, M, lc); init_seg(M+1, R, rc);
mxp[o] = Point(L, -233); mnp[o] = Point(R, oo);
}
return ;
}
void modify(int L, int R, int o, Point p) {
if(L == R) {
if(p.y > mxp[o].y) mxp[o] = p;
if(p.y < mnp[o].y) mnp[o] = p;
}
else {
int M = L + R >> 1, lc = o << 1, rc = lc | 1;
if(p.x <= M) modify(L, M, lc, p);
else modify(M+1, R, rc, p);
if(mxp[lc].y > mxp[rc].y) mxp[o] = mxp[lc]; else mxp[o] = mxp[rc];
if(mnp[lc].y < mnp[rc].y) mnp[o] = mnp[lc]; else mnp[o] = mnp[rc];
}
return ;
}
void clear(int L, int R, int o, Point p) {
if(L == R) mxp[o] = Point(L, -233), mnp[o] = Point(R, oo);
else {
int M = L + R >> 1, lc = o << 1, rc = lc | 1;
if(p.x <= M) clear(L, M, lc, p);
else clear(M+1, R, rc, p);
mxp[o] = Point(L, -233); mnp[o] = Point(R, oo);
}
return ;
}
Point querymx(int L, int R, int o, int ql, int qr) {
if(ql <= L && R <= qr) return mxp[o];
int M = L + R >> 1, lc = o << 1, rc = lc | 1;
Point res(L, -233), tmp;
if(ql <= M) {
tmp = querymx(L, M, lc, ql, qr);
if(res.y < tmp.y) res = tmp;
}
if(qr > M) {
tmp = querymx(M+1, R, rc, ql, qr);
if(res.y < tmp.y) res = tmp;
}
return res;
}
Point querymn(int L, int R, int o, int ql, int qr) {
if(ql <= L && R <= qr) return mnp[o];
int M = L + R >> 1, lc = o << 1, rc = lc | 1;
Point res(R, oo), tmp;
if(ql <= M) {
tmp = querymn(L, M, lc, ql, qr);
if(res.y > tmp.y) res = tmp;
}
if(qr > M) {
tmp = querymn(M+1, R, rc, ql, qr);
if(res.y > tmp.y) res = tmp;
}
return res;
}
void build_seg(int col) {
for(int i = head[col]; i; i = next[i]) modify(1, cntn, 1, ps[i]);
return ;
}
void remove_seg(int col) {
for(int i = head[col]; i; i = next[i]) clear(1, cntn, 1, ps[i]);
return ;
}
int ans;
void Solve_down(int l, int r) {
if(l > r) return ;
Point tmp = querymn(1, cntn, 1, l, r);
// matrix: x[l, r], y[1, tmp.y - 1]
if(tmp.y < oo) ans = max(ans, query(rt[tmp.y-1], 1, cntn, l, r));
else {
ans = max(ans, query(rt[cntn], 1, cntn, l, r));
return ;
}
Solve_down(l, tmp.x - 1); Solve_down(tmp.x + 1, r);
return ;
}
void solve_down(int col) {
build_seg(col);
Solve_down(1, cntn);
remove_seg(col);
return ;
}
void Solve_up(int l, int r) {
if(l > r) return ;
Point tmp = querymx(1, cntn, 1, l, r);
// matrix: x[l, r], y[tmp.y + 1, cntn]
// printf("query_up: %d %d\n", tmp.x, tmp.y);
if(tmp.y >= 0) ans = max(ans, query(rt[cntn], 1, cntn, l, r) - query(rt[tmp.y], 1, cntn, l, r));
else {
ans = max(ans, query(rt[cntn], 1, cntn, l, r));
return ;
}
// printf("ans: %d\n", ans);
Solve_up(l, tmp.x - 1); Solve_up(tmp.x + 1, r);
return ;
}
void solve_up(int col) {
build_seg(col);
Solve_up(1, cntn);
remove_seg(col);
return ;
}
int main() {
int T = read();
while(T--) {
memset(head, 0, sizeof(head));
n = read(); int col = read(); cntn = 0;
for(int i = 1; i <= n; i++) {
int x = read(), y = read(), c = read();
num[++cntn] = x; num[++cntn] = y;
ps[i] = Point(x, y); next[i] = head[c]; head[c] = i;
}
sort(num + 1, num + cntn + 1);
cntn = unique(num + 1, num + cntn + 1) - num - 1;
for(int i = 1; i <= n; i++)
ps[i].x = lower_bound(num + 1, num + cntn + 1, ps[i].x) - num,
ps[i].y = lower_bound(num + 1, num + cntn + 1, ps[i].y) - num,
p2[i] = ps[i];
/*for(int i = 1; i <= n; i++) printf("%d %d\n", ps[i].x, ps[i].y);
for(int i = 1; i <= col; i++) {
printf("%d:", i);
for(int j = head[i]; j; j = next[j]) printf(" (%d, %d)", ps[j].x, ps[j].y);
putchar('\n');
}*/
build();
ans = 0;
init_seg(1, cntn, 1);
// puts("here");
for(int i = 1; i <= col; i++) solve_down(i);
for(int i = 1; i <= col; i++) solve_up(i);
printf("%d\n", ans);
}
return 0;
}