【题目链接】
【思路要点】
- 首先,数集\(\{A,B\}\)等价于数集\(\{A,A\ Xor\ B\}\),且数集\(\{A,0\}\)等价于\(\{A\}\)。
- 因此,我们可以先构建原数集的线性基,并删去多余的0。令\(M\)为线性基的元素个数,则\(M\)是\(O(LogMax\{a_i\})\)级别的。
- 注意到题目保证答案小于等于\(2^{63}\),那么\(Max\{a_i\}\)应当是\(O(2^{\frac{63}{k}})\)级别的。
- 也就是说,\(M\)是\(O(\frac{63}{k})\)的。
- 当\(k≥3\),\(M\)最大约为22,\(O(k*2^M)\)的复杂度是可以接受的,因此我们直接暴力计算该线性基的答案即可。由于计算的中间过程量可能会达到\(O(2^{63+M})\)的级别,需要手动实现一个小高精度类。并且几组数据测下来发现答案的小数部分只有可能是\(.0\)或者\(.5\),这一点笔者并不会证明。
- 当\(k=1\),我们可以分开考虑每一个二进制位。考虑一个数值\(x\),\(x\)在第\(i\)个二进制位上为1,那么无论\(x\)以外的数如何选取,\(x\)选与不选都对应了结果第\(i\)个二进制位上为0和1两种情况。因此,一旦存在一个数值\(x\)在第\(i\)个二进制位上为1,第\(i\)个二进制位对答案的贡献就为\(2^{i-1}\),否则为0。因此,答案实际上就是所有数二进制或的结果除以二。
- 当\(k=2\),我们仍然可以分开考虑答案的每一个二进制位。枚举数位\(i\)、\(j\),那么它们对答案的贡献应当为\(\frac{func(i,j)*2^{i+j}}{2^M}\),其中\(func(i,j)\)代表在所有\(2^M\)中选取方法中得到的结果在数位数位\(i\)、\(j\)上均为1的方案数,这个值可以通过一个\(O(2^k*M)\)的DP得到。时间复杂度\(O(2^k*M^{k+1})\),事实上,这个做法同样可以推广到\(k≥3\)的情况。
- 时间复杂度\(O(Min(2^k*M^{k+1},k*2^M))\),其中\(M=O(\frac{63}{k})\)。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 100005; const int MAXLOG = 63; template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } int n, k, cnt; unsigned long long val[MAXLOG]; struct ExtendedLongLong { unsigned long long P, x, y; void init() { P = 1ll << 62; } void print(int cnt) { for (int i = 1; i <= cnt - 1; i++) { if (x % 2 == 1) y += P; x /= 2; y /= 2; } y += x * P; if (y % 2 == 0) printf("%llu\n", y / 2); else printf("%llu.5\n", y / 2); } } ans; ExtendedLongLong operator + (ExtendedLongLong a, ExtendedLongLong b) { a.y += b.y; a.x += b.x + a.y / a.P; a.y %= a.P; return a; } ExtendedLongLong operator * (ExtendedLongLong a, unsigned long long b) { if (b == 0) return (ExtendedLongLong) {a.P, 0, 0}; if (a.P / b > a.y) { a.y *= b; return a; } ExtendedLongLong tmp = a * (b / 2); if (b % 2 == 0) return tmp + tmp; else return tmp + tmp + a; } void add(unsigned long long x) { for (int i = MAXLOG - 1; i >= 0; i--) { unsigned long long tmp = 1ll << i; if (tmp & x) { if (val[i]) x ^= val[i]; else {val[i] = x; cnt++; return; } } } } void calc(unsigned long long x) { ExtendedLongLong tmp; tmp.init(); tmp.x = 0; tmp.y = 1; for (int i = 1; i <= k; i++) tmp = tmp * x; ans = ans + tmp; } void work(int pos, unsigned long long now) { if (pos == -1) calc(now); else if (val[pos]) { work(pos - 1, now); work(pos - 1, now ^ val[pos]); } else work(pos - 1, now); } unsigned long long func(int x, int y) { unsigned long long dp[2][4]; memset(dp, 0, sizeof(dp)); dp[0][0] = 1; int pos = 0; for (int i = 1; i <= cnt; i++) { while (val[pos] == 0) pos++; int now = i & 1, last = now ^ 1; int tmp = 2 * ((val[pos] & (1ll << x)) != 0) + 1 * ((val[pos] & (1ll << y)) != 0); for (int j = 0; j <= 3; j++) dp[now][j] = dp[last][j] + dp[last][j ^ tmp]; pos++; } return dp[cnt & 1][3]; } int main() { read(n), read(k); if (k == 1) { unsigned long long tmp = 0; for (int i = 1; i <= n; i++) { unsigned long long x; read(x); tmp |= x; } if (tmp % 2 == 0) printf("%llu\n", tmp / 2); else printf("%llu.5\n", tmp / 2); return 0; } for (int i = 1; i <= n; i++) { unsigned long long x; read(x); add(x); } ans.init(); if (k == 2) { for (int i = 0; i < MAXLOG; i++) for (int x = 0, y = i; x <= i; x++, y--) { ExtendedLongLong tmp; tmp.init(); tmp.x = 0; tmp.y = 1ll << i; ans = ans + tmp * func(x, y); } ans.print(cnt); return 0; } work(MAXLOG - 1, 0); ans.print(cnt); return 0; }