题目大意:
给了n个不同的正整数序列a。定义了一个无限集合S,满足 ∀ x ∈ S \forall x \in S ∀x∈S,至少符号下列条件之一:
- ∃ i , s t : x = a i \exist i,st:x=a_i ∃i,st:x=ai
- x = 2 ∗ y + 1 & & y ∈ S x=2*y+1\, \&\& \,y\in S x=2∗y+1&&y∈S
- x = 4 ∗ y & & y ∈ S x=4*y\,\&\&\,y\in S x=4∗y&&y∈S
对于给定的序列a,以及正整数p,有多少S集合中的数严格小于 2 p 2^p 2p
范围: 1 ≤ n , p ≤ 2 ⋅ 1 0 5 , 1 ≤ a i ≤ 1 0 9 1\le n,p\le 2\cdot 10^5,1\le a_i\le 10^9 1≤n,p≤2⋅105,1≤ai≤109
解题思路
-
打表找规律:
-
前提:假定序列中只有一个数,设其为x;且p已给定
-
经过打表可以发现,只有x为2的某次幂,所求答案才会发生变化。
且当 2 k ≤ x < 2 k + 1 2^k\le x < 2^{k+1} 2k≤x<2k+1时, a n s ( x ) = a n s ( 2 k ) ans(x)=ans(2^k) ans(x)=ans(2k)
-
然后又发现 a n s ( 2 k ) = a n s ( 2 k + 1 ) + a n s ( 2 k + 2 ) + 1 ans(2^k)=ans(2^{k+1})+ans(2^{k+2})+1 ans(2k)=ans(2k+1)+ans(2k+2)+1,而 a n s ( 2 p − 1 ) = 1 ans(2^{p-1})=1 ans(2p−1)=1
-
所以对于x,利用上述两条规律,就能得到ans
-
-
如果依照上述规律直接算,必然会算多。
例如:a={2,5},那么5其实是包含在2的答案中的,所以5是无用的数据,应该删除
-
找无用的数据:
- 先从小到大排序
- 如果正序遍历的话,那么必然空间以及时间都承受不了;但如果逆序遍历,则只要log复杂度就行。
AC代码:
#include <bits/stdc++.h>
#define ft first
#define sd second
#define IOS ios::sync_with_stdio(false), cin.tie(0), cout.tie(0)
#define seteps(N) fixed << setprecision(N)
#define endl "\n"
const int maxn = 2e5 + 10;
using namespace std;
typedef long long ll;
typedef double db;
typedef pair<int, int> pii;
const ll mod = 1e9 + 7;
ll n, p, mx, ans = 0, a[maxn], g[maxn];
bool vis[maxn];
map <ll, bool> dvis;
void dfs(ll x) { //p很小的情况,暴力遍历
if (dvis.count(x) || x >= (1ll << p)) return;
dvis[x] = true;
dfs(2 * x + 1);
dfs(4 * x);
}
void ndfs(ll x, bool &ok) { //逆序遍历,ok为true说明x不会被某一个序列中的元素表示,否则会
if (dvis.count(x)) ok = false;
if (!ok) return;
if ((x & 1) && x != 1) ndfs((x - 1) / 2, ok);
if (x / 4 * 4 == x && ok) ndfs(x / 4, ok);
}
int resolve(ll x) { //求k,2^k <= x < 2^{k+1}
if (x == 1) return 0;
else if (x == 2) return 1;
else {
int res = 0;
while (x != 1) x /= 2, res++;
return res;
}
}
int main() {
cin >> n >> p;
for (int i = 1; i <= n; i++) cin >> a[i];
sort (a + 1, a + n + 1);
// cout << a[1] << endl;
if (p <= 3) {
for (int i = 1; i <= n; i++) dfs(a[i]);
cout << dvis.size() << endl;
return 0;
}
g[p - 1] = 1;
g[p - 2] = 2;
for (int i = p - 3; i >= 0; i--) g[i] = (g[i + 1] + g[i + 2] + 1) % mod;
for (int i = 1; i <= n; i++) {
vis[i] = true;
ndfs(a[i], vis[i]);
dvis[a[i]] = true;
}
for (int i = 1; i <= n; i++) {
if (!vis[i]) continue;
ans = (ans + g[resolve(a[i])]) % mod;
}
cout << ans << endl;
return 0;
}