题解 | KM and M-2019牛客暑期多校训练营第九场I题

题目来源于牛客竞赛:https://ac.nowcoder.com/acm/contest/discuss
题目描述:
在这里插入图片描述
输入描述:
在这里插入图片描述
输出描述:
在这里插入图片描述
示例1:
在这里插入图片描述
题解:
在这里插入图片描述
代码:

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
  
using namespace std;
using namespace __gnu_pbds;
  
#define fi first
#define se second
#define mp make_pair
#define pb push_back
  
typedef long long ll;
typedef pair<int,int> ii;
typedef vector<int> vi;
typedef long double ld;
typedef tree<ii, null_type, less<ii>, rb_tree_tag, tree_order_statistics_node_update> pbds;
 
const int MOD = 1e9+7;
 
ll add(ll a, ll b)
{
    a+=b;
    a%=MOD;
    if(a<0) a+=MOD;
    return a;
}
 
ll mult(ll a, ll b)
{
    a%=MOD; b%=MOD;
    return (a*b)%MOD;
}
 
ll solve_naive(ll n, ll m)
{
    ll ans = 0;
    for(ll i=1;i<=n;i++)
    {
        ans=add(ans,(i*m)&m);
    }
    return ans;
}
 
ll P2[101];
 
ll solve_powers(ll d, ll m, ll ad=0) //solve until (1<<(d-1))*m
{
    if(d==0) return (ad&m)%MOD;
    ll dd = 0;
    while(m>=(1LL<<dd)) dd++;
    ll mod = (1LL<<dd)-1;
    int mid = min(dd/2, d-1);
    //brute P2[0] to P2[mid]
    int num = mid + 1;
    ll ans = 0;
    vector<ll> B((1<<num),0);
    B[0] = ad;
    for(int i=1;i<=num;i++)
    {
        for(int j=(1<<(i-1));j<(1<<i);j++)
        {
            B[j] = B[j^(1<<(i-1))] + P2[i-1];
            B[j]&=mod;
            //cerr<<j<<' '<<B[j]<<'\n';
        }
    }
    if(num==d)
    {
        for(int i=0;i<(1<<num);i++) ans=add(ans,(B[i]&m));
        return ans;
    }
    for(int i=0;i<(1<<num);i++)
    {
        ans = add(ans, ((B[i]&((1LL<<(mid+1))-1))&m));
    }
    //solve for bits mid+1 to dd-1
    //right shift everything by mid+1, multiply by 2^{mid+1} later
    for(int i=0;i<(1<<num);i++) B[i]>>=mid+1;
    //brute P2[mid+1] to P2[d-1]
    int num2 = d - (mid+1);
    ans = mult(ans, (1LL<<(num2)));
    vector<ll> B2((1<<num2),0);
    B2[0] = 0;
    for(int i=1;i<=num2;i++)
    {
        for(int j=(1<<(i-1));j<(1<<i);j++)
        {
            B2[j] = B2[j^(1<<(i-1))] + P2[i-1+mid+1];
            B2[j]&=mod;
        }
    }
    for(int i=0;i<(1<<num2);i++) B2[i]>>=mid+1;
    /*
    for(int v:B) cerr<<v<<' ';
    cerr<<'\n';
    for(int v:B2) cerr<<v<<' ';
    cerr<<'\n';
    */
    for(int i=0;i<dd-(mid+1);i++)
    {
        if(!(m&(1LL<<(i+mid+1)))) continue;
        ll cnt = 0;
        vector<ll> dp[2]; //sort by bit 2^i
        dp[0].resize((1LL<<i),0); dp[1].resize((1LL<<i),0);
        for(int j=0;j<(1<<num2);j++)
        {
            int type = (((1LL<<i)&B2[j])?1:0);
            dp[type][(B2[j]&((1LL<<i)-1))]++;
        }
        for(int j=0;j<2;j++)
        {
            for(int k=1;k<(1LL<<i);k++)
            {
                dp[j][k]+=dp[j][k-1];
            }
        }
        for(int j=0;j<(1<<num);j++)
        {
            ll cur = B[j];
            ll bit = ((cur&(1LL<<i))?1:0);
            ll rem = cur&((1LL<<i)-1);
            for(int k=0;k<2;k++)
            {
                ll tot = dp[k][(1LL<<i)-1];
                ll maxbound = (1LL<<i) - 1 - rem;
                if(bit^k)
                {
                    cnt+=dp[k][maxbound];
                }
                else
                {
                    cnt+=tot-dp[k][maxbound];
                }
            }
        }
        ans = add(ans, mult((1LL<<(i+mid+1)), cnt));
    }
    return ans;
}
 
ll solve_fast(ll n, ll m)
{
    ll d = 0;
    while(m>=(1LL<<d)) d++;
    if(n>=(1LL<<d))
    {
        ll q = n/(1LL<<d); ll r = n%(1LL<<d);
        ll full = solve_fast((1LL<<d) - 1, m);
        ll ans = mult(full, q);
        ans = add(ans, solve_fast(r, m));
        return ans;
    }
    ll ad = 0;
    memset(P2,0,sizeof(P2));
    for(int i=0;i<d;i++)
    {
        if(i==0) P2[i]=m;
        else P2[i]=(P2[i-1]<<1);
        P2[i]&= (1LL<<d)-1;
    }
    ll ans = 0;
    n++;
    for(int i=d;i>=0;i--)
    {
        if(n&(1LL<<i))
        {
            ans = add(ans, solve_powers(i, m, ad));
            ad += P2[i];
            ad &= (1LL<<d)-1;
        }
    }
    return ans;
}
 
void run_once()
{
    ll n,m;
    cin>>n>>m;
    ll a1 = solve_naive(n,m);
    ll a2 = solve_fast(n,m);
    cerr<<"NAIVE : "<<a1<<'\n';
    cerr<<"FAST : "<<a2<<'\n';
}
 
void run_row(int n, int m)
{
    for(ll i=1;i<=n;i++)
    {
        cerr<<solve_fast(i,m);
        if(i+1<=n) cerr<<",";
    }
    cerr<<'\n';
}
 
void run_all(int C)
{
    //ll n,m;
    //cin>>n>>m;
    for(ll n=1;n<=C;n++)
    {
        for(ll m=1;m<=C;m++)
        {
            ll a1 = solve_naive(n,m);
            ll a2 = solve_fast(n,m);
            cerr<<"NAIVE : "<<a1<<'\n';
            cerr<<"FAST : "<<a2<<'\n';
            if(a1!=a2)
            {
                freopen("sum_of_km_and_m.out","w",stdout);
                cout<<n<<' '<<m<<'\n';
                return ;
            }
            cerr<<"SOLVED "<<n<<' '<<m<<'\n';
        }
    }
}
 
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
int rnd(int x)
{
    return uniform_int_distribution<int>(0, x)(rng);
}
void stress_test()
{
    for(int cc=1;;cc++)
    {
        ll n=rnd((1<<21));
        ll m=rnd((1<<21));
        ll a1 = solve_naive(n,m);
        ll a2 = solve_fast(n,m);
        cerr<<"NAIVE : "<<a1<<'\n';
        cerr<<"FAST : "<<a2<<'\n';
        if(a1!=a2)
        {
            freopen("sum_of_km_and_m.out","w",stdout);
            cout<<n<<' '<<m<<'\n';
            return ;
        }
        cerr<<"Case #"<<cc<<" complete\n";
    }
}
 
void run_solve()
{
    ll n,m; cin>>n>>m;
    cout<<solve_fast(n,m)<<'\n';
}
 
int main()
{
    ios_base::sync_with_stdio(0); cin.tie(0);
    /*
    ll n; cin>>n;
    for(ll m=2;m<=n;m++)
    {
        ll d = 0;
        while(m>=(1LL<<d)) d++;
        cout<<solve_naive((1LL<<(d-1))-1,m)<<'\n';
    }
    */
    run_solve();
    //run_once();
    //run_all(256);
    //ll n,m; cin>>n>>m;
    //run_row(n,m);
    //stress_test();
}

更多问题,更详细题解可关注牛客竞赛区,一个刷题、比赛、分享的社区。
传送门:https://ac.nowcoder.com/acm/contest/discuss

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值