目录
一. 问题描述
Problem Description
给你n个字符串,给你一个k,意思是你能任选k个字符串组成一个长字符串。再给你一个长字符串问你这个字符串在所有任选k个字符串组合中字典序排第几。(所有字符串长度之和不大于1e6 , 要求结果对1e9 + 7取模)
二. 题解代码
1. 思路分析
字典树+树状数组+求逆元。由康拓展开的思路联想到,如果我们把所有字符串由小到大排序后标为1~n,然后再把要求的字符串映射为对应数列,比如:
5 3
a
b
c
d
e
cad
在这个样例中,cad映射为3 1 4,接下来找所有比314字典序小的数列;由组合数学可知,第 i 位上比当前字典序小的数列共有:getNum(ans[i]) * C(n-i,m-i) * (m-i) !,其中getNum(ans[i])表示比第i位数字小的没有用过的数字个数。进一步推得:
于是第i位上的比他小的数字有 getNum(ans[i]) * A(n-i,n-m);这样最后的累和ans就是所有比要求字符串小的字符串数目,他的位数当然是ans+1,思路分析完毕。
2. 代码实现
(1)分配下标:这题肯定会卡超时,所以我们在字符串上的处理要使用字典树,因为所有字符串都不是另一个的前缀,我们把所有输入的字符串插入字典树里,排序后分别给他们标上下标。
(2)数列映射:然后把长字符串(也就是我们要求的)转化为数列,怎么转化呢?还是用字典树不断搜索,当遇到该处judge = true;说明该处构成单词了,把该处下标保存,然后cnt = 0(字典树下标归零),重新找下一个单词。
(3)逆序数:计算getNum(ans[i])时,表示比当前数字小的还未使用的数字数目,为了加快速度,这里使用树状数组处理。
(4)求逆元:计算(n - i)! / (n - m)! % mod 因为取模运算分配率不满足除法,所以这里必须来求逆元(不然会出现小数或者溢出),所以这里采用费马小定理来求逆元:(n - i)! / (n - m)! = (n - i)! * inv((n - m)!)% mod
2.1 string + sort排序(132ms)
#include <iostream>
#include<cstdio>
#include<string>
#include<algorithm>
#include<cstring>
#include<map>
#include<vector>
#define lowbit(x) x&(-x)
#define mod 1000000007
using namespace std;
typedef long long LL;
const int maxn = 1000002;
int n,k,tot,number,l;
LL f[maxn];
LL com[maxn],ans[maxn];
string str[maxn];
struct Node
{
int next[26];
bool judge;
int sign;
};
Node node[maxn];
int CreatTree()
{
memset(node[tot].next,0,sizeof(node[tot].next));
node[tot].sign = 0;
node[tot].judge = false;
return tot++;
}
LL power(LL a,LL b)//快速幂
{
a%=mod;
LL aans = 1;
while(b){
if(b&1)aans = (aans*a)%mod;
b>>=1;
a = (a*a)%mod;
}
return aans;
}
void insertTree(string s)
{
int len = s.size();
int cnt = 0;
for(int i = 0;i<len;i++){
int k = s[i] - 'a';
if(node[cnt].next[k]==0){
node[cnt].next[k] = CreatTree();
}
cnt = node[cnt].next[k];
}
node[cnt].sign = ++number;//分配数字
node[cnt].judge = true;
}
void findTree(string s)
{
int len = s.size();
int cnt = 0;
for(int i = 0;i<len;i++){
int k = s[i] - 'a';
cnt = node[cnt].next[k];
if(node[cnt].judge==true){//找到一个字符串,保存数字
ans[++l] = node[cnt].sign;
cnt = 0;//归零找下一个字符串
}
}
}
void Update(int x,int c)
{
for(int i = x;i<=n;i+=lowbit(i)){
com[i]+=c;
}
}
LL A(LL numA,LL numB)
{
if(numB<0)return 0;//费马小定理
return ((f[numA]%mod)*(power(f[numB],mod-2)%mod))%mod;
}
LL get_Num(int x)
{
LL p = 0;
for(int i = x;i>0;i-=lowbit(i)){
p+=com[i];
}
return p;
}
int main()
{
tot = number = l = 0;
memset(com,0,sizeof(com));
scanf("%d%d",&n,&k);
CreatTree();//一开始忘了建立树根无限RE QAQ
for(int i = 0;i<n;i++){
cin>>str[i];
}
sort(str,str+n);//先排序再插入
for(int i = 0;i<n;i++)insertTree(str[i]);
string name;
cin>>name;
findTree(name);
f[0] = 1;
for(int i = 1;i<maxn;i++){f[i] =((f[i-1]%mod)*(i%mod))%mod;}//计算乘阶
for(int i = 1;i<=n;i++)Update(i,1);//刚开始所有数字都能用
LL sum = 0;
for(int i = 1;i<=l;i++){
sum = (sum + (A((LL)n-i,(LL)n-k)*(LL)(get_Num(ans[i])-1))%mod)%mod;
Update(ans[i],-1);//更新
}
printf("%lld\n",(sum+1)%mod);
return 0;
}
2.2 字符数组 + dfs排序(87ms)
#include <iostream>
#include<cstdio>
#include<string>
#include<algorithm>
#include<cstring>
#include<map>
#include<vector>
#define lowbit(x) x&(-x)
#define mod 1000000007
using namespace std;
typedef long long LL;
const int maxn = 1000007;
char str[maxn];
LL f[maxn];
int ans[maxn],sum[maxn],n,k,tot,lenth,number;
struct Node
{
int next[26];
bool judge;
int index;
};
Node node[maxn];
int CreatTree(){
node[tot].judge = false;
memset(node[tot].next,0,sizeof(node[tot].next));
node[tot].index = 0;
return tot++;
}
LL power(LL x,LL y){
x%=mod;
LL cnt = 1;
while(y){
if(y&1)cnt = (cnt*x)%mod;
y>>=1;
x = (x*x)%mod;
}
return cnt;
}
void insertTree(char *s){
int len = strlen(s);
int cnt = 0;
for(int i = 0;i<len;i++){
int k = s[i] - 'a';
if(node[cnt].next[k]==0){
node[cnt].next[k] = CreatTree();
}
cnt = node[cnt].next[k];
}
node[cnt].judge = true;
}
void findTree(char *s){
int len = strlen(s);
int cnt = 0;
for(int i = 0;i<len;i++){
int k = s[i] - 'a';
cnt = node[cnt].next[k];
if(node[cnt].judge){
ans[++lenth] = node[cnt].index;
cnt = 0;
}
}
}
void DfsOrder(int a){//排序分配数字
if(node[a].judge){node[a].index = ++number;return;}
for(int i = 0;i<26;i++){
if(node[a].next[i])DfsOrder(node[a].next[i]);
}
}
void Update(int x,int c){
for(int i = x;i<=n;i+=lowbit(i)){
sum[i]+=c;
}
}
int getNum(int x){
int cnt = 0;
for(int i = x;i>0;i-=lowbit(i)){
cnt+=sum[i];
}
return cnt;
}
LL A(int x,int y){
if(y<0)return 0;
return (f[x]*(power(f[y],mod-2)%mod))%mod;
}
int main(){
tot = lenth = number = 0;
memset(sum,0,sizeof(sum));
f[0] = 1;
for(int i = 1;i<maxn;i++){f[i] = (f[i-1]*i)%mod;}
scanf("%d%d",&n,&k);
CreatTree();
for(int i = 1;i<=n;i++)Update(i,1);
for(int i = 0;i<n;i++){
scanf("%s",str);
insertTree(str);//先插入
}
DfsOrder(0);//递归排序,因为我们循环肯定先从小的开始,字典序由有小到大的
scanf("%s",str);
findTree(str);
LL answer = 0;
for(int i = 1;i<=lenth;i++){
answer = (answer + (A(n-i,n-k)*(getNum(ans[i])-1))%mod)%mod;
Update(ans[i],-1);
}
printf("%lld\n",(answer+1)%mod);
return 0;
}