题意
给n个数 a1 ~ an(1≤ai≤1e9) ,求 ∑[1≤l≤r≤n][al&al+1&...&ar][al|al+1|...|ar] (结果对le9+7取模)
思路
∑[1≤l≤r≤n](al&al+1&...&ar)(al|al+1|...|ar)
=∑[1≤l≤r≤n]∑[0≤x<30]2x[axl&...&axr]∑[0≤y≤30]2y[ayl|...|ayr]
=∑[0≤x≤30]∑[0≤y≤30]2x+y∑[1≤l≤r≤n][axl&...&axr][ayl|...|ayr]
=∑[0≤x≤30]∑[0≤y≤30]2x+y∑[1≤l≤r≤n][axl&...&axr](1−[(!al)y&...&(!ar)y])
=∑[0≤x≤30]∑[0≤y≤30]2x+y∑[1≤l≤r≤n][axl&...&axr][ayl|...|ayr]
=∑[0≤x≤30]∑[0≤y≤30]2x+y∑[1≤l≤r≤n]([axl&...&axr]−[axl&(!al)y&...&axr&(!ar)y])
枚举 x,y ,令 bi=axi , ci=axi&(!ai)y ,所以
∑[1≤l≤r≤n](al&al+1&...&ar)(al|al+1|...|ar)
=∑[0≤x≤30]∑[0≤y≤30]2x+y∑[1≤l≤r≤n]([bl&...&br]−[cl&...&cr])
=∑[0≤x≤30]∑[0≤y≤30]2x+y(∑[1≤l≤r≤n][bl&...&br]−∑[1≤l≤r≤n][cl&...&cr])
所以只需要求一个01数组 ai 内,满足 al ~ ar 之间全为1的 (l,r)对数 ,假设 ai 里面的连续的1的个数依次为 c1,c2... 那么答案就是 c1(c1+1)2+c2(c2+1)2+... ,直接按段统计复杂度是 O(n∗logXlogX )的,优化办法是,用20位整数把a压缩,这样size可以降低20倍,复杂度接近 O(n∗logX)
代码
#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#include "local.h"
#endif // LOCAL
typedef long long ll;
const ll mod = 1e9 + 7;
const ll N = 1e5 + 7;
ll n, a[N], b[N], c[N], pow2[77], Log2[1 << 22], t[1 << 21], pre[1 << 21], suf[1 << 21];
inline ll lowbit(ll x) {
return x & -x;
}
inline ll calc(ll x) {
return x * (x + 1) / 2;
}
ll calc(vector<ll> &a) {
ll n = a.size(), buf = 0, ans = 0;
for (ll i = 0; i < n; i++) {
buf += pre[a[i]];
if (pre[a[i]] != 20) {
ans += calc(buf);
ans += t[a[i]] - calc(pre[a[i]]) - calc(suf[a[i]]);
buf = suf[a[i]];
}
}
ans += calc(buf);
return ans;
}
ll calc(vector<ll> &a, vector<ll> &b) {
ll n = a.size();
vector<ll> c(n);
for (ll i = 0; i < n; i++) {
c[i] = a[i] & ~b[i];
}
return calc(c);
}
vector<ll> bl[30];
int main() {
pow2[0] = 1;
for (ll i = 1; i < 77; i++) pow2[i] = pow2[i - 1] * 2 % mod;
for (ll i = 0; i < 22; i++) Log2[1 << i] = i;
for (ll i = 1; i < (1 << 21); i++) {
ll lastb1 = Log2[lowbit(i)], lastb2 = Log2[lowbit(i + lowbit(i))], lastl = lastb2 - lastb1, lastv = pow2[lastb2] - pow2[lastb1];
t[i] = t[i - lastv] + calc(lastl);
if (i & 1) suf[i] = lastl;
if ((i >> 19) & 1) pre[i] = pre[(i - (1 << 19)) << 1] + 1;
}
cin >> n;
for (ll i = 0; i < n; i++) {
scanf("%lld", a + i);
}
ll ans = 0;
for (ll i = 0; i < 30; i++) {
bl[i].resize(n / 20 + (n % 20 > 0));
}
for (ll i = 0; i < 30; i++) {
for (ll id = 0; id < n; id++) {
if (id % 20 == 0) bl[i][id / 20] = 0;
ll &p = bl[i][id / 20];
p = p * 2 + ((a[id] >> i) & 1);
}
bl[i][bl[i].size() - 1] <<= (20 - n % 20) % 20;
}
for (ll i = 0; i < 30; i++) {
ll buf = calc(bl[i]);
for (ll j = 0; j < 30; j++) {
ans = (ans + pow2[i + j] * (buf - calc(bl[i], bl[j]))) % mod;
}
}
cout << ans << endl;
return 0;
}