这道题当时在比赛的时候瞎搞搞出来了,后面复盘的时候发现其实思维更高的一个层次可以省下很多代码量。
首先声明,全子集异或和这个结论我事先是不知道的。
然后我就想着要把这个序列给构造出来。
常规思路:按位思考
对于每一个l,r,k,我开一个数字组cnt[i][j]表示第i个数二进制位j位置上是0还是1.
然后对于每一个l,r的或为k,如果k的第j位有1则利用差分的思想cnt[l][j]+=1
,cnt[r+1][j]-=1;如果
k的第j位没有1,则说明无论如何这里都不能有1,因此为了防止差分之后前缀和的+的效果把这里不能有1的限制给覆盖掉,因此有如下代码:cnt[l][j]-=INF
,cnt[r+1][j]+=INF
此处令INF=200010
这样的话,我们最后统计完做一遍cnt[i][j]+=cnt[i-1][j]
,然后对于每一个cnt[i][j]
如果大于0则说明这个位置为1,否则为0
从而变相的,那个序列我们也通过cnt[i][j]
构造出来了
然后我们只需要利用组合数推一下公式即可:
对于第j个位置上,有p个数为1,n-p个数为0,有:
2
j
∗
2
n
−
p
∗
(
C
p
1
+
C
p
3
+
.
.
.
)
2^j*2^{n-p}*(C_p^1+C_p^3+...)
2j∗2n−p∗(Cp1+Cp3+...)=
2
j
∗
2
n
−
p
∗
2
p
−
1
2^j*2^{n-p}*2^{p-1}
2j∗2n−p∗2p−1即可
Code:
const int N = 200010, mod = 1e9+7,INF = 200010;
LL T,n,m,k,l,r,a[N], pos[35],cnt[N][35];
LL fpower(LL a,LL b){
LL ans = 1;
while(b){
if(b&1) ans = a*ans%mod;
a = a*a%mod;
b >>= 1;
}
return ans;
}
void calc(int k,int l,int r){
for(int j=0;j<=31;++j){
if((k>>j)&1) cnt[l][j]+=1, cnt[r+1][j]-=1;
else cnt[l][j] -= INF, cnt[r+1][j] += INF;
}
}
//=================================
int main(){
T = read();
while(T--){
LL ans = 0;
n = read(), m=read();
rep(i,1,m){
l = read(), r = read(), k=read();
calc(k,l,r);
}
rep(i,0,32) rep(j,1,n) cnt[j][i] += cnt[j-1][i];
rep(i,0,31) rep(j,1,n) {
if(cnt[j][i]<=0) cnt[j][i] = 0;
else cnt[j][i] = 1;
}
memset(pos,0,sizeof pos);
rep(i,1,n) rep(j,0,31) pos[j] += cnt[i][j];
rep(i,0,31){
if(pos[i] == 0) continue;
ans = (ans + fpower(2,i)*fpower(2,pos[i]-1)%mod*fpower(2,n-pos[i])%mod)%mod;
} print(ans);
// rep(i,1,n){
// for(int j=0;j<=31;++j) a[i] += (LL)cnt[i][j] << j;//debug(a[i]);
// }
// rep(i,1,n) printf("%d ",a[i]);puts("");
// rep(i,0,33) printf("%d ",pos[i]); puts("");
rep(i,1,n+3) {rep(j,0,31) cnt[i][j] = 0;a[i] = 0;}
}
return 0;
}
但是后来接着发现了一个神奇的事情:
2
j
∗
2
n
−
p
∗
(
C
p
1
+
C
p
3
+
.
.
.
)
2^j*2^{n-p}*(C_p^1+C_p^3+...)
2j∗2n−p∗(Cp1+Cp3+...)=
2
j
∗
2
n
−
p
∗
2
p
−
1
2^j*2^{n-p}*2^{p-1}
2j∗2n−p∗2p−1这里可以继续化简:
=
2
j
∗
2
n
−
1
=2^j*2^{n-1}
=2j∗2n−1
也就是说,和p并没有关系,也就是说只有某一二进制位存在,它的贡献就是
2
j
∗
2
n
−
1
2^j*2^{n-1}
2j∗2n−1这里n-1就是可以看作常数了。。。靠,这样代码就能巨短无比了:
const int mod = 1e9+7;
int _,n,m,ans=0,k;
void solve(){
ans=0;
cin >> n >> m;
rep(i,1,m) cin>>k>>k>>k,ans|=k;
print((LL)fpower(2,n-1,mod)*ans%mod);
}
int main(){
for(scanf("%d",&_);_;_--){
solve();
}
return 0;
}
后来挖出来了Leetcode上一道模板题。。
https://leetcode-cn.com/problems/sum-of-all-subset-xor-totals/