[Rust刷题模板] 组合数取模,线性求逆元
一、 算法&数据结构
1. 描述
组合数是一个经常用到的公式,但是由于除法不满足同余性质,因此需要通过逆元来转换。
这里就涉及到线性求逆元的问题。
因此封装了一个类。
2. 复杂度分析
- 预处理, O(MAX_N+log(MAX_N))
- 查询,O(1)
3. 常见应用
- 各种组合数学题。
4. 常用优化
- 手写快速幂,while代替递归。
- 逆元使用费马小定理求,扩展欧几里得据说快一点,但是本文只用一次,不涉及。
- 本文只涉及线性求逆元,优先计算inv_f[-1]然后倒推,因此复杂度是n+lg,而不是n*lg。
二、 模板代码
1. 幂取模、组合数取模。
例题: F - Strivore
/***
https://atcoder.jp/contests/abc171/tasks/abc171_f
输入 k(≤1e6) 和一个长度不超过 1e6 的字符串 s,由小写字母组成。
你需要在 s 中插入恰好 k 个小写字母。
输出你能得到的字符串的个数,模 1e9+7。
输入
5
oof
输出 575111451
输入
37564
whydidyoudesertme
输出 318008117
https://atcoder.jp/contests/abc171/submissions/36296507
设 s 的长度为 n。
提示 1:如何避免重复统计?做一个规定,插入在 s[i] 左侧的字符,不能和 s[i] 相同,这不会影响答案的正确性。
提示 2:枚举最后一个字符的右侧插入了多少个字符,设为 i,这些字符没有限制,有 26^i 种方案。
提示 3:剩下 (n-1) + (k-i) 个字符,我们需要考虑其中 n-1 个字符的位置,这就是 C(n-1+k-i, n-1)。
提示 4:其余插入字符的方案数就是 25^(k-i)。
因此答案为 ∑26^i * C(n-1+k-i, n-1) * 25^(k-i), i=[0,k]
不知道组合数怎么算的,需要学一下逆元。
*/
use proconio::{input, marker::Bytes};
// use std::collections::VecDeque;
//
use itertools::Itertools;
// use petgraph::{Graph, data::Build};
use superslice::Ext;
const MOD:usize = 1000000000+7;
fn quick_pow_mod(mut a:i64,mut b:i64,p:i64)->i64{
let mut ans = 1i64;
while b>0 {
if b&1>0{
ans = (ans*a)%p;
}
a = a*a%p;
b >>= 1;
}
return ans;
}
struct ModComb {
fact: Vec<usize>,
inv_f: Vec<usize>,
}
impl ModComb {
fn new(n:usize) -> ModComb {
let mut fact = vec![1;n+1];
let mut inv_f = vec![1;n+1];
for i in 2..=n {
fact[i] = i * fact[i-1] % MOD;
}
inv_f[n] = quick_pow_mod(fact[n] as i64, (MOD - 2) as i64, MOD as i64) as usize;
for i in (1..=n).rev(){
inv_f[i-1] = i * inv_f[i] %MOD;
}
ModComb {fact,inv_f}
}
fn comb(&self, m:usize,r:usize) -> usize{
if m <r || r < 0{
return 0;
}
// 公式C(m,r) = m!/(r!*(m-r)!)
return self.fact[m]*self.inv_f[r]%MOD*self.inv_f[m-r]%MOD;
}
}
fn main() {
input!{
k:usize,
s:Bytes,
}
let n = s.len();
let mut p25 = vec![1usize;k+1];
let mut p26 = vec![1usize;k+1];
for i in 1..=k{
p25[i] = p25[i-1]*25%MOD;
p26[i] = p26[i-1]*26%MOD
}
let mc = ModComb::new(n+k);
let mut ans = 0usize;
for i in 0..=k {
ans = (ans + p26[i]%MOD * mc.comb(n-1+k-i, n-1)%MOD * p25[k-i]%MOD)%MOD
}
println!("{}",ans)
}