题意:
给定序列,从序列中选择 k(1≤k≤1e18) 个数(可以重复选择),使得得到的排列满足 xi与xi+1 异或的二进制表示中 1 的个数是 3 的倍数。问长度为 k 的满足条件的 序列有多少种?
分析:
首先每个元素自己构成一个长度为
1
的满足条件的序列。
其次我们可以预处理出满足条件的
vi,vj
,就可以得到一个横纵为
n
的
01
矩阵。这还是很显然的。此时我们得到了以
vi
开头,
vj
结尾的长度为
2
的序列个数。
接下来我们发现,两个矩阵相乘,矩阵
c
为新得到的矩阵,此时矩阵
a=b
,
c[i][j]=a[i][1]∗b[1][j]+a[i][2]∗b[2][j]+...+a[i][n]∗b[n][j]
,我们得到的即为以
ai
开头,
aJ
结尾的长度为
3
的序列个数!
接下来用矩阵
c
更新矩阵
a
,再与最初的
01
矩阵,即
b
相乘,得到的又为开头元素为
ai
,结尾元素为
aj
的长度为
4
的序列个数!
依次乘
k−1
次即得到结果,这部分可以用矩阵快速幂进行优化。
最后把得到的矩阵中的每个元素的值加起来即为长度为
k
的满足条件的序列个数!
实质上就是
floyd
求长度为
k
的道路。
巧妙的利用矩阵乘法的性质解决问题!这很矩阵!
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<queue>
#include<cstring>
#include<stack>
#include<vector>
#include<algorithm>
#include<map>
#include<cmath>
using namespace std;
#define pr(x) cout << #x << ": " << x << " "
#define pl(x) cout << #x << ": " << x << endl;
#define sa(x) scanf("%d",&(x))
#define sal(x) scanf("%I64d",&(x))
typedef long long ll;
const int maxn = 105, mod = 1e9 + 7;
ll a[maxn];
int n;
const int N = 105;
struct Matrix
{
int row,cal;
long long m[N][N];
};
Matrix init(Matrix a, long long t)
{
for(int i = 0; i < a.row; i++)
for(int j = 0; j < a.cal; j++)
a.m[i][j] = t;
return a;
}
Matrix mul(Matrix a,Matrix b)
{
Matrix ans;
ans.row = a.row, ans.cal = b.cal;
ans = init(ans,0);
for(int i = 0; i < a.row; i++)
for(int j = 0; j < b.cal; j++)
for(int k = 0; k < a.cal; k++)
ans.m[i][j] = (ans.m[i][j] + a.m[i][k] * b.m[k][j])%mod;
return ans;
}
Matrix quick_pow(long long k, Matrix A)
{
Matrix I;
I.row = n, I.cal = n;
I = init(I, 0);
for(int i = 0; i < n; i++){
I.m[i][i] = 1;
}
while(k){
if(k & 1) I = mul(I, A);
A = mul(A, A);
k >>= 1;
}
return I;
}
int count(ll a)
{
int ans = 0;
while(a){
if(a & 1) ans++;
a >>= 1;
}
return ans;
}
int main(void)
{
sa(n);
ll k;sal(k);
for(int i = 0; i < n; i++){
sal(a[i]);
}
Matrix A;
A.row = n, A.cal = n;
A = init(A, 0);
for(int i = 0 ; i < n; i++){
for(int j = 0; j < n; j++){
if(count(a[i] ^ a[j]) % 3 == 0){
A.m[i][j] = 1;
}
}
}
ll ans = 0;
A = quick_pow(k - 1, A);
for(int i = 0; i < n; i++){
for(int j = 0; j < n; j++){
(ans += A.m[i][j]) %= mod;
}
}
printf("%I64d\n", ans);
return 0;
}