离散化x然后用树状数组解决,排序y然后分治解决,z在分治的时候排序解决。
具体:先对y排序,solve(l,r)分成solve(l,mid),solve(mid+1,r), 然后因为是按照y排序,所以l,mid区间内的y值都小于mid+1,r。现在再对z排序,按照顺序以x做关键字插入到树状数组中,那么就可以一起解决l,mid对mid+1,r的影响。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=1e5+9,mod=1<<30;
int trsum[maxn],trmax[maxn];
int n;
struct P
{
int x,y,z,id;
}point[maxn],now[maxn];
struct A
{
int max,sum;
}ans[maxn],tr[maxn];
bool cmpx(const P a,const P b)
{
return a.x<b.x;
}
bool cmpy(const P a,const P b)
{
return a.y<b.y;
}
bool cmpz(const P a,const P b)
{
return a.z<b.z;
}
int lowbit(int x)
{
return (x&-x);
}
void insert(int x,A tmp)
{
for(int i=x;i<=n;i+=lowbit(i))
{
if(tr[i].max==tmp.max)
{
tr[i].sum+=tmp.sum;
tr[i].sum%=mod;
}
else if(tr[i].max<tmp.max)
{
tr[i].sum=tmp.sum;
tr[i].max=tmp.max;
}
}
}
A getsum(int x)
{
A ret;
ret.max=-1;
for(int i=x;i>=1;i-=lowbit(i))
{
if(tr[i].max>ret.max)
{
ret.max=tr[i].max;
ret.sum=tr[i].sum;
}
else if(tr[i].max==ret.max)
{
ret.sum+=tr[i].sum;
ret.sum%=mod;
}
}
return ret;
}
void clear(int x)
{
for(int i=x;i<=n;i+=lowbit(i))
{
tr[i].max=0;
tr[i].sum=0;
}
}
void solve(int l,int r)
{
if(l==r) return ;
int mid=l+r>>1;
solve(l,mid);
for(int i=mid+1;i<=r;i++)
now[i]=point[i];
sort(point+l,point+mid+1,cmpz);
sort(point+mid+1,point+r+1,cmpz);
for(int i=mid+1,top=l;i<=r;i++)
{
while(top<=mid&&point[top].z<=point[i].z)
{
insert(point[top].x,ans[point[top].id]);
top++;
}
A ret=getsum(point[i].x);
ret.max++;
if(ret.max==ans[point[i].id].max)
{
ans[point[i].id].sum+=ret.sum;
ans[point[i].id].sum%=mod;
}
else if(ret.max>ans[point[i].id].max)
{
ans[point[i].id]=ret;
}
}
for(int i=l;i<=mid;i++) clear(point[i].x);
for(int i=mid+1;i<=r;i++)
point[i]=now[i];
solve(mid+1,r);
}
int main()
{
// freopen("in.txt","r",stdin);
int T;
scanf("%d",&T);
while(T--)
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%d %d %d",&point[i].x,&point[i].y,&point[i].z);
point[i].id=i;
}
sort(point+1,point+1+n,cmpx);
for(int i=1,xx=point[1].x-1,num=0;i<=n;i++)
{
if(point[i].x!=xx) num++,xx=point[i].x;
point[i].x=num;
}
sort(point+1,point+1+n,cmpy);
for(int i=1;i<=n;i++)
{
ans[i].max=1;
ans[i].sum=1;
}
solve(1,n);
A ret;
ret.max=-1;
for(int i=1;i<=n;i++)
{
if(ret.max==ans[i].max)
{
ret.sum+=ans[i].sum;
ret.sum%=mod;
}
else if(ret.max<ans[i].max)
{
ret=ans[i];
}
}
printf("%d %d\n",ret.max,ret.sum);
}
return 0;
}