题目链接:https://vjudge.net/problem/Kattis-prefixfreecode
转自:https://blog.csdn.net/Q755100802/article/details/99856499
原博主的代码有些问题,做了些修改。
题意:输入n,k,输入n个字符串,输入目标字符串,让你取k个字符串进行排列组合,求目标字符串在其中排第几。
思路:首先应该排序后用字典树给每个字符串编号(为了排列组合方便),然后把目标字符串的序列号求出。
#include <bits/stdc++.h>
using namespace std;
const int maxn=1e6+10;
const int mod=1e9+7;
typedef long long ll;
int idx(char c)
{
return c-'a';
}
struct Trie//字典树
{
int ch[maxn][26];
int val[maxn];
int sz;
void clear()
{
sz=1;
memset(ch[0],0,sizeof(ch[0]));
memset(val,0,sizeof(val));
}
void Insert(string s,int v)
{
int u=0,n=s.size();
for(int i=0; i<n; i++)
{
int c=idx(s[i]);
if(!ch[u][c])
{
memset(ch[sz],0,sizeof(ch[sz]));
val[sz]=0;
ch[u][c]=sz++;
}
u=ch[u][c];
}
val[u]=v;
}
int Find(const char *s,int& st,int len)//一点点的分割目标字符串
{
int u = 0;
for(int i = st; i < len; i++)
{
if(s[i] == '\0')
break;
int c = idx(s[i]);
if(!ch[u][c])
break;
u = ch[u][c];
if(val[u])
{
st=i+1;
return val[u];
}
}
return 0;
}
};
Trie trie;
char str[maxn];
int n,k;
vector<int>v;
int c[maxn];
int num[maxn];
//树状数组维护逆序对
int lowbit(int x)
{
return x&-x;
}
void add(int x,int d)
{
while(x<=n)
{
c[x]+=d;
x+=lowbit(x);
}
}
int sum(int x)
{
int res=0;
while(x>0)
{
res+=c[x];
x-=lowbit(x);
}
return res;
}
long long inv[maxn+5];
long long f[maxn+5];
long long qsm(long long a,long long b)
{
long long ans=1,base=a;
while(b)
{
if(b&1)
{
ans=ans*base%mod;
}
base=base*base%mod;
b>>=1;
}
return ans%mod;
}
void init()//线性推逆元
{
f[0]=f[1]=1;
for(int i=2; i<=maxn; i++)
{
f[i]=f[i-1]*i%mod;
}
inv[maxn]=qsm(f[maxn],mod-2);
for(long long i=maxn; i>=1; i--)
{
inv[i-1]=inv[i]*i%mod;
}
}
long long A(int a,int b)//求排列数
{
if(a==0||b==0)
return 1;
return f[a]*inv[a-b]%mod;
}
string temp[maxn];
int main()
{
init();
ios::sync_with_stdio(false);
while(cin>>n>>k)
{
trie.clear();
v.clear();
for(int i=1; i<=n; i++)
{
cin>>temp[i];
}
sort(temp+1,temp+n+1);//将字符串排序
for(int i=1; i<=n; i++)
{
trie.Insert(temp[i],i);//对字符串编号
}
cin>>str;
int len=strlen(str);
int st=0;
while(st<len)
{
v.push_back(trie.Find(str,st,len));//将目标字符串转为编号序列
}
memset(c,0,sizeof(c));
for(int i=1; i<=v.size(); i++)//计算前面有多少比它小的数
{
add(v[i-1],1);
num[i]=sum(v[i-1])-1;
}
ll ans=0;
for(int i=1; i<=v.size(); i++)//统计答案
{
ans+=1ll*(v[i-1]-num[i]-1)*A(n-i,k-i)%mod;
ans%=mod;
}
cout<<(ans+1)%mod<<endl;
}
return 0;
}