题意:给定一个长度为n的数组(数组元素∈{1,2,...,n}),对k∈{1,2,…,n},求最小的ans[k],使得数组可以分为连续的ans[k]段,每段内不相同的元素个数都不超过k。
分析:为了使ans[k]最小,我们可以贪心地来分段,即对起点l,取最大的r,使得a[l],...,a[r]中不相同的元素个数不超过k。可以来估计一下总段数(即ans[1]+...+ans[n])的上界。显然有ans[k]≤n/k,所以ans[1]+...+ans[n]≤n(1/1+1/2+...+1/n)≈nlgn。因此只要能找到一种对固定的起点l,快速(如O(lgn))求出终点r的算法,我们就能较快地(如O(n(lgn)^2))解决这个问题。
基于主席树的做法是比较容易想到的,所以这里不再赘述。接下来详细讲解另一种和主席树时空复杂度相同但代码量小且常数小的做法。
假设对于k,数组分成的ans[k]段为[l_k1,r_k1]、[l_k2,r_k2]、...、[l_kans[k],r_kans[k]]。那么我们的程序执行过程中就需要求对起点l_kj的最大终点r_kj使得[l_kj,r_kj]是不相同元素个数不超过k的以l_kj为起点的最大区间。考虑换一种顺序来求解这一系列问题(一系列问题指一系列求终点的问题)。先解决以1为起点的所有问题(此时要考虑的k取遍1到n),再解决以2为起点的所有问题(此时要考虑的k不一定取遍1到n),依次类推。改变问题的处理顺序后,我们就可以简单地用一个树状数组来求解这些问题。具体做法是:在解决了以1~i-1为起点的问题后,用一个数组array[]来标记a[i]、...、a[n]中每个数第一次出现的位置(如a[i]第一次出现在位置i,则令array[i]=1),则array[]的前r项和即为a[l],...,a[r]中不相同的元素个数(注意array[1],...,array[i-1]均为0)。用树状数组维护array[],那么就可以通过倍增的方式O(lgn)地求出起点为i,不相同元素不超过k的最大终点r(而不是二分再用树状数组求前缀和)。容易想到以O(lgn)的时间代价将以i为起点的树状数组转化为以i+1为起点的树状数组的做法。显然这个做法的时空复杂度也是O(n(lgn)^2)。
代码(主席树)
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
struct tnode
{
int lc,rc,s;
}rt[maxn],nd[maxn*40],temp;
int n,a[maxn],sz,Last[maxn];
void build(tnode &o,int l,int r)
{
o.s=0;
if (l==r) return ;
int mid=(l+r)/2;
o.lc=++sz;o.rc=++sz;
build(nd[o.lc],l,mid);build(nd[o.rc],mid+1,r);
}
void updata(tnode &o1,tnode o2,int l,int r,int p,int val)
{
o1.s=o2.s+val;
if (l==r) return ;
int mid=(l+r)/2;
if (p<=mid)
{
o1.rc=o2.rc;
o1.lc=++sz;
updata(nd[o1.lc],nd[o2.lc],l,mid,p,val);
}
else
{
o1.lc=o2.lc;
o1.rc=++sz;
updata(nd[o1.rc],nd[o2.rc],mid+1,r,p,val);
}
}
int query(tnode o,int l,int r,int k)
{
if (o.s<=k) return r;
if (l==r) return l-1;
int mid=(l+r)/2;
if (nd[o.lc].s<=k) return query(nd[o.rc],mid+1,r,k-nd[o.lc].s);
else return query(nd[o.lc],l,mid,k);
}
int main()
{
cin>>n;
for (int i=1;i<=n;i++) scanf("%d",&a[i]);
build(rt[n+1],1,n);
for (int i=n;i>=1;i--)
{
if (Last[a[i]])
{
updata(temp,rt[i+1],1,n,Last[a[i]],-1);
updata(rt[i],temp,1,n,i,1);
}
else updata(rt[i],rt[i+1],1,n,i,1);
Last[a[i]]=i;
}
for (int k=1;k<=n;k++)
{
int ans=0;
for (int st=1;st<=n;)
{
//cout<<st<<" ";system("pause");
ans++;
st=query(rt[st],1,n,k)+1;
}
printf("%d ",ans);
}
return 0;
}
代码(另解)
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
int n,a[maxn],c[maxn],ans[maxn];
set<int> pos[maxn];
vector<int> S[maxn];
int lowbit(int x)
{
return x&(-x);
}
void add(int x,int val)
{
while (x<=n)
{
c[x]+=val;
x+=lowbit(x);
}
}
int query(int k)
{
int ret=0;
for (int l=20;l>=0;l--)
if (ret+(1<<l)<=n&&c[ret+(1<<l)]<=k)
{
ret+=(1<<l);k-=c[ret];
}
return ret;
}
void work(int x)
{
if (!pos[x].empty())
{
add(*pos[x].begin(),1);
pos[x].erase(pos[x].begin());
}
}
int main()
{
cin>>n;
for (int i=1;i<=n;i++) scanf("%d",&a[i]),pos[a[i]].insert(i);
for (int i=1;i<=n;i++)
{
S[1].push_back(i);
work(i);
}
for (int i=1;i<=n;i++)
{
for (int j=0;j<S[i].size();j++)
{
int x=S[i][j];
int y=query(x)+1;
S[y].push_back(x);
ans[x]++;
}
add(i,-1);
work(a[i]);
}
for (int i=1;i<=n;i++) printf("%d ",ans[i]);
return 0;
}