给出一个大小为n的字符串集合A,需要确定一个大小为m的字符串集合B,要保证能把A集合分成恰好m份非空集合,要求前缀为bi的字符串都被分在第i个集合,同时一个字符串不能分在多个集合。求有多少符合要求的B,同时给出一种方案.
第一种思路是建出字母树,那么问题转化为在树上选m个独立的子树覆盖所有的叶子,这个直接做是o(n*m*m)的,所以要转到dfs序上,压缩字母树的边后dp就可以了,这是一队的算法.
第二种思路是我自己想的.由于第一种思路直接转移会T,当时只想到用FFT来快速转移,所以就另辟蹊径,先将所有的串排序并求出height,这样任意两个串的最长公共前缀就是之间的最小height,考虑一个集合如果用某个字符串bi作为前缀,那么它一定可以拓展到某两个字符串的最长公共前缀,如果继续拓展那么就会导致集合变小,那么我们来考虑以某个height作为关键值,他能拓展的就是一段连续的区间,这段区间以这个height作为前缀一定是最大可行值,考虑将这个前缀缩小且不改变所覆盖的区间,那就是这个height与区间左右的height的一个差值,这些长度的前缀对应的都是这段区间,那么现在的问题就是选择恰好m个区间以恰好不重的覆盖n个点,这个dp显然是o(n*m),这个思路可以当做是取出了字母树的一个类似后缀树取后缀数组的东西.
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <iostream>
const int mo=1000000007;
using namespace std;
struct line{
int l,r,w;
}a[500000];
int ans,f[5000][2010],b[5000][2010],n,m,tot,u[500000],ht[500000],g[5000][2010],b1[5000][2010];
int v[500000];
char ch[5000][500];
int ask(int e,int x)
{
int sum=b[e][x];
// for (;x;x-=(x & (-x)))
// (sum+=b[e][x])%=mo;
return sum;
}
void change(int e,int x,int sum)
{
// for (;x<=n;x+=(x & (-x)))
// (b[e][x]+=sum)%=mo;
(b[e][x]+=sum)%=mo;
}
bool cmp(int i,int j)
{
int l1=strlen(ch[i]+1),l2=strlen(ch[j]+1);
for (int k=1;k<=min(l1,l2);k++)
if (ch[i][k]!=ch[j][k]) return ch[i][k]<ch[j][k];
return l1<l2;
}
bool cmp2(line i,line j)
{
if (i.l!=j.l) return i.l<j.l;
return i.r<j.r;
}
void getout(int i)
{
if (a[i].l==a[i].r) {
printf("%s\n",ch[u[a[i].l]]+1);
return ;
}
int x=u[a[i].l];
int sum=100000000;
for (int j=a[i].l;j<=a[i].r-1;j++) sum=min(sum,ht[j]);
for (int j=1;j<=sum;j++)
printf("%c",ch[x][j]);
printf("\n");
return ;
}
bool pd(int i,int j)
{
int l1=strlen(ch[i]+1),l2=strlen(ch[j]+1);
if (l1>l2) return 1;
for (int k=1;k<=l1;k++)
if (ch[i][k]!=ch[j][k]) return 1;
return 0;
}
int main()
{
scanf("%d%d\n",&n,&m);
for (int i=1;i<=n;i++)
scanf("%s",ch[i]+1);
for (int i=1;i<=n;i++) u[i]=i;
sort(u+1,u+n+1,cmp);
tot=0;
for (int i=1;i<=n;i++) {
// int ne=u[i];
// if (1==i || (pd(u[i-1],u[i]))) ++tot,a[tot].l=i,a[tot].r=i;
if (i==n) continue;
int len,l1=strlen(ch[u[i]]+1),l2=strlen(ch[u[i+1]]+1);
for (len=1;len<=min(l1,l2) && (ch[u[i]][len]==ch[u[i+1]][len]);len++) ;
len--;
ht[i]=len;
}
// for (int i=1;i<=n-1;i++) cout<<ht[i]<<' ';cout<<endl;
// for (int i=1;i<=n;i++) cout<<ch[u[i]]+1<<endl;//cout<<endl;
// for (int i=1;i<=tot;i++) printf("%d %d\n",a[i].l,a[i].r);
for (int i=1;i<=n-1;i++) {
int op,ed;
if (!ht[i]) continue;
++tot;
for (op=i;(op) && (ht[op]>=ht[i]);op--) ;
a[tot].l=op+1;
for (ed=i;(ed<n) && (ht[ed]>=ht[i]);ed++) ;
a[tot].r=ed;
if (op) a[tot].w=ht[op];else a[tot].w=0;
a[tot].w=ht[i]-max(a[tot].w,ht[ed]);
}
for (int i=1;i<=n;i++) {
int len=strlen(ch[u[i]]+1);
if (len>ht[i-1] && len>ht[i])
++tot,a[tot].l=a[tot].r=i,a[tot].w=len-max(ht[i-1],ht[i]);
}
sort(a+1,a+tot+1,cmp2);
// printf("%d\n",tot);
// for (int i=1;i<=tot;i++) printf("%d %d %d\n",a[i].l,a[i].r,a[i].w);
memset(f,0,sizeof(f));
memset(v,0,sizeof(v));
for (int i=2;i<=tot;i++)
if (a[i].l==a[i-1].l && a[i].r==a[i-1].r) v[i]=1;
for (int i=1;i<=tot;i++)
if (a[i].l==1 && !v[i]) f[i][1]=a[i].w,g[i][1]=1;
for (int i=1;i<=tot;i++) {
if (v[i]) continue;
for (int j=1;j<=m;j++) {
(f[i][j]+=(((long long)a[i].w)*ask(j-1,a[i].l-1))%mo)%=mo;
if (b1[j-1][a[i].l-1]) g[i][j]=1;
change(j,a[i].r,f[i][j]);
if (g[i][j]) b1[j][a[i].r]=1;
}
}
ans=0;
int ans1=0;
for (int i=1;i<=tot;i++)
if (a[i].r==n&&g[i][m]){
(ans+=f[i][m])%=mo;
ans1=1;
}
printf("%d\n",ans);
if (ans1)
for (int i=n,j,k=m;i;) {
for (j=1;j<=tot;j++)
if ((a[j].r==i) && (g[j][k])) break;
// cout<<j<<endl;
getout(j);
i=a[j].l-1,k--;
}
return 0;
}