题目(困难)
给你一个整数 n,请你帮忙统计一下我们可以按下述规则形成多少个长度为 n 的字符串:
字符串中的每个字符都应当是小写元音字母(‘a’, ‘e’, ‘i’, ‘o’, ‘u’)
每个元音 ‘a’ 后面都只能跟着 ‘e’
每个元音 ‘e’ 后面只能跟着 ‘a’ 或者是 ‘i’
每个元音 ‘i’ 后面 不能 再跟着另一个 ‘i’
每个元音 ‘o’ 后面只能跟着 ‘i’ 或者是 ‘u’
每个元音 ‘u’ 后面只能跟着 ‘a’
由于答案可能会很大,所以请你返回 模 10^9 + 7 之后的结果。
示例 1:
输入:n = 1
输出:5
解释:所有可能的字符串分别是:“a”, “e”, “i” , “o” 和 “u”。
示例 2:
输入:n = 2
输出:10
解释:所有可能的字符串分别是:“ae”, “ea”, “ei”, “ia”, “ie”, “io”, “iu”, “oi”, “ou” 和 “ua”。
示例 3:
输入:n = 5
输出:68
提示:
1 <= n <= 2 * 10^4
解题思路
手写分析下每个字母可由哪些字母得来,就可以写出状态转移了,
a ———— eiu
e ———— ai
i ———— eo
o ———— i
u ———— io
用两个数组交替更新次数,最后总次数就是答案
代码
class Solution {
public:
int countVowelPermutation(int n) {
vector<long long> hash(5, 1);
int a = 0, e = 1, i = 2, o = 3, u = 4;
long long ans = 0;
long long MOD = 1e9 + 7;
n--;
while(n--) {
vector<long long> tmp(5);
tmp[a] = (hash[e] + hash[i] + hash[u]) % MOD;
tmp[e] = (hash[a] + hash[i]) % MOD;
tmp[i] = (hash[e] + hash[o]) % MOD;
tmp[o] = hash[i] % MOD;
tmp[u] = (hash[i] + hash[o]) % MOD;
hash = tmp;
}
for(auto x :hash) ans += x % MOD;
return ans % MOD;
}
};
另外可以用矩阵快速幂来加快计算速度
问题转化为计算矩阵的幂,套用矩阵快速幂的模板即可
using LL = long long;
using Mat = vector<vector<LL>>;
class Solution {
public:
Mat multiply(const Mat & matrixA, const Mat & matrixB, LL mod) {
int m = matrixA.size();
int n = matrixB[0].size();
Mat res(m, vector<LL>(n, 0));
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < matrixA[i].size(); ++k) {
res[i][j] = (res[i][j] + matrixA[i][k] * matrixB[k][j]) % mod;
}
}
}
return res;
}
Mat fastPow(const Mat & matrix, LL n, LL mod) {
int m = matrix.size();
Mat res(m, vector<LL>(m, 0));
Mat curr = matrix;
for (int i = 0; i < m; ++i) {
res[i][i] = 1;
}
for (int i = n; i != 0; i >>= 1) {
if (i & 1) {
res = multiply(curr, res, mod);
}
curr = multiply(curr, curr, mod);
}
return res;
}
int countVowelPermutation(int n) {
LL mod = 1e9 + 7;
Mat factor =
{
{0, 1, 0, 0, 0},
{1, 0, 1, 0, 0},
{1, 1, 0, 1, 1},
{0, 0, 1, 0, 1},
{1, 0, 0, 0, 0}
};
Mat res = fastPow(factor, n - 1, mod);
long long ans = 0;
for (int i = 0; i < 5; ++i) {
ans = (ans + accumulate(res[i].begin(), res[i].end(), 0LL)) % mod;
}
return ans;
}
};