题目来源于牛客竞赛:https://ac.nowcoder.com/acm/contest/discuss
题目描述:
输入描述:
输出描述:
示例1:
题解:
代码:
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
#define fi first
#define se second
#define mp make_pair
#define pb push_back
typedef long long ll;
typedef pair<int,int> ii;
typedef vector<int> vi;
typedef long double ld;
typedef tree<ii, null_type, less<ii>, rb_tree_tag, tree_order_statistics_node_update> pbds;
const int MOD = 1e9+7;
ll add(ll a, ll b)
{
a+=b;
a%=MOD;
if(a<0) a+=MOD;
return a;
}
ll mult(ll a, ll b)
{
a%=MOD; b%=MOD;
return (a*b)%MOD;
}
ll solve_naive(ll n, ll m)
{
ll ans = 0;
for(ll i=1;i<=n;i++)
{
ans=add(ans,(i*m)&m);
}
return ans;
}
ll P2[101];
ll solve_powers(ll d, ll m, ll ad=0) //solve until (1<<(d-1))*m
{
if(d==0) return (ad&m)%MOD;
ll dd = 0;
while(m>=(1LL<<dd)) dd++;
ll mod = (1LL<<dd)-1;
int mid = min(dd/2, d-1);
//brute P2[0] to P2[mid]
int num = mid + 1;
ll ans = 0;
vector<ll> B((1<<num),0);
B[0] = ad;
for(int i=1;i<=num;i++)
{
for(int j=(1<<(i-1));j<(1<<i);j++)
{
B[j] = B[j^(1<<(i-1))] + P2[i-1];
B[j]&=mod;
//cerr<<j<<' '<<B[j]<<'\n';
}
}
if(num==d)
{
for(int i=0;i<(1<<num);i++) ans=add(ans,(B[i]&m));
return ans;
}
for(int i=0;i<(1<<num);i++)
{
ans = add(ans, ((B[i]&((1LL<<(mid+1))-1))&m));
}
//solve for bits mid+1 to dd-1
//right shift everything by mid+1, multiply by 2^{mid+1} later
for(int i=0;i<(1<<num);i++) B[i]>>=mid+1;
//brute P2[mid+1] to P2[d-1]
int num2 = d - (mid+1);
ans = mult(ans, (1LL<<(num2)));
vector<ll> B2((1<<num2),0);
B2[0] = 0;
for(int i=1;i<=num2;i++)
{
for(int j=(1<<(i-1));j<(1<<i);j++)
{
B2[j] = B2[j^(1<<(i-1))] + P2[i-1+mid+1];
B2[j]&=mod;
}
}
for(int i=0;i<(1<<num2);i++) B2[i]>>=mid+1;
/*
for(int v:B) cerr<<v<<' ';
cerr<<'\n';
for(int v:B2) cerr<<v<<' ';
cerr<<'\n';
*/
for(int i=0;i<dd-(mid+1);i++)
{
if(!(m&(1LL<<(i+mid+1)))) continue;
ll cnt = 0;
vector<ll> dp[2]; //sort by bit 2^i
dp[0].resize((1LL<<i),0); dp[1].resize((1LL<<i),0);
for(int j=0;j<(1<<num2);j++)
{
int type = (((1LL<<i)&B2[j])?1:0);
dp[type][(B2[j]&((1LL<<i)-1))]++;
}
for(int j=0;j<2;j++)
{
for(int k=1;k<(1LL<<i);k++)
{
dp[j][k]+=dp[j][k-1];
}
}
for(int j=0;j<(1<<num);j++)
{
ll cur = B[j];
ll bit = ((cur&(1LL<<i))?1:0);
ll rem = cur&((1LL<<i)-1);
for(int k=0;k<2;k++)
{
ll tot = dp[k][(1LL<<i)-1];
ll maxbound = (1LL<<i) - 1 - rem;
if(bit^k)
{
cnt+=dp[k][maxbound];
}
else
{
cnt+=tot-dp[k][maxbound];
}
}
}
ans = add(ans, mult((1LL<<(i+mid+1)), cnt));
}
return ans;
}
ll solve_fast(ll n, ll m)
{
ll d = 0;
while(m>=(1LL<<d)) d++;
if(n>=(1LL<<d))
{
ll q = n/(1LL<<d); ll r = n%(1LL<<d);
ll full = solve_fast((1LL<<d) - 1, m);
ll ans = mult(full, q);
ans = add(ans, solve_fast(r, m));
return ans;
}
ll ad = 0;
memset(P2,0,sizeof(P2));
for(int i=0;i<d;i++)
{
if(i==0) P2[i]=m;
else P2[i]=(P2[i-1]<<1);
P2[i]&= (1LL<<d)-1;
}
ll ans = 0;
n++;
for(int i=d;i>=0;i--)
{
if(n&(1LL<<i))
{
ans = add(ans, solve_powers(i, m, ad));
ad += P2[i];
ad &= (1LL<<d)-1;
}
}
return ans;
}
void run_once()
{
ll n,m;
cin>>n>>m;
ll a1 = solve_naive(n,m);
ll a2 = solve_fast(n,m);
cerr<<"NAIVE : "<<a1<<'\n';
cerr<<"FAST : "<<a2<<'\n';
}
void run_row(int n, int m)
{
for(ll i=1;i<=n;i++)
{
cerr<<solve_fast(i,m);
if(i+1<=n) cerr<<",";
}
cerr<<'\n';
}
void run_all(int C)
{
//ll n,m;
//cin>>n>>m;
for(ll n=1;n<=C;n++)
{
for(ll m=1;m<=C;m++)
{
ll a1 = solve_naive(n,m);
ll a2 = solve_fast(n,m);
cerr<<"NAIVE : "<<a1<<'\n';
cerr<<"FAST : "<<a2<<'\n';
if(a1!=a2)
{
freopen("sum_of_km_and_m.out","w",stdout);
cout<<n<<' '<<m<<'\n';
return ;
}
cerr<<"SOLVED "<<n<<' '<<m<<'\n';
}
}
}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
int rnd(int x)
{
return uniform_int_distribution<int>(0, x)(rng);
}
void stress_test()
{
for(int cc=1;;cc++)
{
ll n=rnd((1<<21));
ll m=rnd((1<<21));
ll a1 = solve_naive(n,m);
ll a2 = solve_fast(n,m);
cerr<<"NAIVE : "<<a1<<'\n';
cerr<<"FAST : "<<a2<<'\n';
if(a1!=a2)
{
freopen("sum_of_km_and_m.out","w",stdout);
cout<<n<<' '<<m<<'\n';
return ;
}
cerr<<"Case #"<<cc<<" complete\n";
}
}
void run_solve()
{
ll n,m; cin>>n>>m;
cout<<solve_fast(n,m)<<'\n';
}
int main()
{
ios_base::sync_with_stdio(0); cin.tie(0);
/*
ll n; cin>>n;
for(ll m=2;m<=n;m++)
{
ll d = 0;
while(m>=(1LL<<d)) d++;
cout<<solve_naive((1LL<<(d-1))-1,m)<<'\n';
}
*/
run_solve();
//run_once();
//run_all(256);
//ll n,m; cin>>n>>m;
//run_row(n,m);
//stress_test();
}
更多问题,更详细题解可关注牛客竞赛区,一个刷题、比赛、分享的社区。
传送门:https://ac.nowcoder.com/acm/contest/discuss