4548: 小奇的糖果
Time Limit: 10 Sec Memory Limit: 256 MBSubmit: 111 Solved: 52
[ Submit][ Status][ Discuss]
Description
有 N 个彩色糖果在平面上。小奇想在平面上取一条水平的线段,并拾起它上方或下方的所有糖果。求出最多能够拾
起多少糖果,使得获得的糖果并不包含所有的颜色。
Input
包含多组测试数据,第一行输入一个正整数 T 表示测试数据组数。
接下来 T 组测试数据,对于每组测试数据,第一行输入两个正整数 N、K,分别表示点数和颜色数。
接下来 N 行,每行描述一个点,前两个数 x, y (|x|, |y| ≤ 2^30 - 1) 描述点的位置,最后一个数 z (1 ≤ z ≤
k) 描述点的颜色。
对于 100% 的数据,N ≤ 100000,K ≤ 100000,T ≤ 3
Output
对于每组数据在一行内输出一个非负整数 ans,表示答案
Sample Input
1
10 3
1 2 3
2 1 1
2 4 2
3 5 3
4 4 2
5 1 2
6 3 1
6 7 1
7 2 3
9 4 2
10 3
1 2 3
2 1 1
2 4 2
3 5 3
4 4 2
5 1 2
6 3 1
6 7 1
7 2 3
9 4 2
Sample Output
5
HINT
Source
题解:线段树+树状数组
分成三种情况统计答案。该点与在该点下方的相同颜色的前驱后继围成的区域,该点与该点上方的相同颜色的前驱后继围成的区域,已经相邻最近的两个相同颜色之间的区域。
前驱后继可以用线段树维护,然后清数组的时候用lazy标记。
统计答案的时候用树状数组,如果统计下方的,就按照纵坐标从小到大向树状数组中该点对应的横坐标的位置加点,然后统计该点下方,在前驱后继之间的点。
统计上方反过来即可。
统计相邻两个之间的点,刚开始直接把所以点加入树状数组,然后查询两点之间的区域即可。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#define N 500003
using namespace std;
int n,m,cnt,tree[N],T,ans;
int tr[N],b[N],c[N],delta[N],tr1[N],delta1[N];
int next[N],head[N];
struct point
{
int x,y,num,col;
int pre,next,ans;
}a[N];
int cmp(int x,int y)
{
return c[x]<c[y];
}
int cmp1(point a,point b)
{
return a.col<b.col||a.col==b.col&&a.y<b.y;
}
void clear(int now)
{
delta[now]=1;
tr[now]=-1;
}
void pushdown(int now)
{
if (delta[now])
{
delta[now]=0;
clear(now<<1); clear(now<<1|1);
}
}
void update(int now)
{
tr[now]=max(tr[now<<1],tr[now<<1|1]);
}
void change(int now,int l,int r,int x)
{
if (x<1||x>cnt) return ;
if (l==r)
{
tr[now]=l;
return;
}
pushdown(now);
int mid=(l+r)/2;
if (x<=mid) change(now<<1,l,mid,x);
else change(now<<1|1,mid+1,r,x);
update(now);
}
int find(int now,int l,int r,int ll,int rr)
{
if (rr<ll) return -1;
if (ll<=l&&r<=rr) return tr[now];
pushdown(now);
int mid=(l+r)/2;
int ans=-1;
if (ll<=mid) ans=max(ans,find(now<<1,l,mid,ll,rr));
if (rr>mid) ans=max(ans,find(now<<1|1,mid+1,r,ll,rr));
return ans;
}
void clear1(int now)
{
delta1[now]=1;
tr1[now]=cnt+1;
}
void pushdown1(int now)
{
if (delta1[now])
{
delta1[now]=0;
clear1(now<<1); clear1(now<<1|1);
}
}
void update1(int now)
{
tr1[now]=min(tr1[now<<1],tr1[now<<1|1]);
}
void change1(int now,int l,int r,int x)
{
if (x<0||x>cnt) return;
if (l==r)
{
tr1[now]=l;
return;
}
pushdown1(now);
int mid=(l+r)/2;
if (x<=mid) change1(now<<1,l,mid,x);
else change1(now<<1|1,mid+1,r,x);
update1(now);
}
int find1(int now,int l,int r,int ll,int rr)
{
if (rr<ll) return -1;
if (ll<=l&&r<=rr) return tr1[now];
pushdown1(now);
int mid=(l+r)/2;
int ans=cnt+1;
if (ll<=mid) ans=min(ans,find1(now<<1,l,mid,ll,rr));
if (rr>mid) ans=min(ans,find1(now<<1|1,mid+1,r,ll,rr));
return ans;
}
int cmp2(point x,point y)
{
return x.y<y.y;
}
int lowbit(int x)
{
return x&(-x);
}
int sum(int x)
{
if (x==0) return 0;
int ans=0;
for (int i=x;i>=1;i-=lowbit(i))
ans+=tree[i];
return ans;
}
void add(int x)
{
for (int i=x;i<=cnt;i+=lowbit(i))
tree[i]++;
}
int cmp5(point x,point y)
{
return x.x<y.x;
}
void solve()
{
sort(a+1,a+n+1,cmp5);
for (int i=1;i<=n;i++)
add(a[i].x);
memset(head,0,sizeof(head));
for (int i=1;i<=n;i++)
{
int l=head[a[i].col];
int r=a[i].x;
if (r-1>=l) ans=max(ans,sum(r-1)-sum(l));
head[a[i].col]=a[i].x;
}
for (int i=1;i<=m;i++)
if(head[i]<=n)
ans=max(ans,sum(cnt)-sum(head[i]));
//cout<<ans<<endl;
}
void solve1()
{
sort(a+1,a+n+1,cmp1);
int i=1;
clear(1),clear1(1);
while(i<=n)
{
if (a[i].col!=a[i-1].col)
clear(1),clear1(1);
int j=i;
while (a[j].y==a[i].y&&a[j].col==a[i].col&&j<=n) {
int t=find(1,1,cnt,1,a[j].x-1);
int t1=find1(1,1,cnt,a[j].x,cnt);
if (t!=-1) a[j].pre=t;
else a[j].pre=0;
if (t1!=-1) a[j].next=t1;
else a[j].next=cnt;
j++;
}
for (int k=i;k<j;k++)
change(1,1,cnt,a[k].x),change1(1,1,cnt,a[k].x),i++;
}
memset(tree,0,sizeof(tree));
sort(a+1,a+n+1,cmp2);
i=1;
while (i<=n)
{
int j=i;
while (a[j].y==a[i].y&&j<=n)
{
a[j].ans=max(a[j].ans,sum(a[j].next-1)-sum(a[j].pre));
j++;
}
for (int k=i;k<j;k++)
add(a[k].x),i++;
}
}
int main()
{
freopen("candy.in","r",stdin);
freopen("candy.out","w",stdout);
scanf("%d",&T);
for (int t=1;t<=T;t++)
{
scanf("%d%d",&n,&m); ans=0;
for (int i=1;i<=n;i++)
{
scanf("%d%d%d",&a[i].x,&a[i].y,&a[i].col);
a[i].num=i; a[i].ans=0;
b[i]=i; c[i]=a[i].x;
}
sort(b+1,b+n+1,cmp);
cnt=0;
for (int i=1;i<=n;i++)
if (c[b[i]]!=c[b[i-1]]||i==1)
a[b[i]].x=++cnt;
else a[b[i]].x=cnt;
memset(tree,0,sizeof(tree));
solve();
solve1();
for (int i=1;i<=n;i++)
a[i].y=-a[i].y;
solve1();
for (int i=1;i<=n;i++)
ans=max(ans,a[i].ans);
printf("%d\n",ans);
}
}