题目
思路来源
jiangly、kilo、starsilk、Splashing代码
题解
赛中想的是,枚举l和r不同的位是哪一位,然后左右分治往上merge答案
赛后发现这个区间即使不等长也是可以merge的,所以就是线段树区间合并类似的操作
每个2的幂次的长度的区间,对应一个区间的答案,可以直接merge
代码1(参考starsilk)
记忆化搜索写的数位dp,然后是一个区间合并
//#include<bits/stdc++.h>
#include<iostream>
#include<cstdio>
#include<vector>
#include<map>
#include<queue>
#include<set>
using namespace std;
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define per(i,a,b) for(int i=(a);i>=(b);--i)
typedef long long ll;
typedef double db;
typedef pair<int,int> P;
#define fi first
#define se second
#define pb push_back
#define dbg(x) cerr<<(#x)<<":"<<x<<" ";
#define dbg2(x) cerr<<(#x)<<":"<<x<<endl;
#define SZ(a) (int)(a.size())
#define sci(a) scanf("%d",&(a))
#define pt(a) printf("%d",a);
#define pte(a) printf("%d\n",a)
#define ptlle(a) printf("%lld\n",a)
#define debug(...) fprintf(stderr, __VA_ARGS__)
using namespace std;
const int N=65,mod=1e9+7;
int t,k;
ll n;
bool vis[N][N];
struct Info{
ll len,l,r,ans;
}dp[N][N];
Info operator+(const Info &a,const Info &b){
if(a.len==0)return b;
if(b.len==0)return a;
Info c;
c.len=a.len+b.len;
c.l=(a.len==a.l)?a.len+b.l:a.l;
c.r=(b.len==b.r)?a.r+b.len:b.r;
c.ans=1ll*(a.r%mod)*(b.l%mod)%mod;
c.ans=(c.ans+a.ans)%mod;
c.ans=(c.ans+b.ans)%mod;
return c;
}
Info dfs(int x,int y,bool lim){
if(y<0)return {1ll<<x,0,0,0};
if(x<0)return {1,1,1,1};
if(!lim && vis[x][y])return dp[x][y];
int up=lim?(n>>x&1):1;
Info info;
if(!lim){
vis[x][y]=1;
return dp[x][y]=dfs(x-1,y,0)+dfs(x-1,y-1,0);
}
else{
if(up==1)return dfs(x-1,y,0)+dfs(x-1,y-1,1);
return dfs(x-1,y,1);
}
}
int sol(){
scanf("%lld%d",&n,&k);
n--;
return dfs(62,k,1).ans;
}
int main(){
sci(t);
while(t--){
pte(sol());
}
return 0;
}
代码2(jiangly代码)
这个和代码1的区别就在于是递推的,可以直接对二进制位这些区间做合并
#include <bits/stdc++.h>
using i64 = long long;
template<class T>
constexpr T power(T a, i64 b) {
T res {1};
for (; b; b /= 2, a *= a) {
if (b % 2) {
res *= a;
}
}
return res;
}
constexpr i64 mul(i64 a, i64 b, i64 p) {
i64 res = a * b - i64(1.L * a * b / p) * p;
res %= p;
if (res < 0) {
res += p;
}
return res;
}
template<i64 P>
struct MInt {
i64 x;
constexpr MInt() : x {0} {}
constexpr MInt(i64 x) : x {norm(x % getMod())} {}
static i64 Mod;
constexpr static i64 getMod() {
if (P > 0) {
return P;
} else {
return Mod;
}
}
constexpr static void setMod(i64 Mod_) {
Mod = Mod_;
}
constexpr i64 norm(i64 x) const {
if (x < 0) {
x += getMod();
}
if (x >= getMod()) {
x -= getMod();
}
return x;
}
constexpr i64 val() const {
return x;
}
constexpr MInt operator-() const {
MInt res;
res.x = norm(getMod() - x);
return res;
}
constexpr MInt inv() const {
return power(*this, getMod() - 2);
}
constexpr MInt &operator*=(MInt rhs) & {
if (getMod() < (1ULL << 31)) {
x = x * rhs.x % int(getMod());
} else {
x = mul(x, rhs.x, getMod());
}
return *this;
}
constexpr MInt &operator+=(MInt rhs) & {
x = norm(x + rhs.x);
return *this;
}
constexpr MInt &operator-=(MInt rhs) & {
x = norm(x - rhs.x);
return *this;
}
constexpr MInt &operator/=(MInt rhs) & {
return *this *= rhs.inv();
}
friend constexpr MInt operator*(MInt lhs, MInt rhs) {
MInt res = lhs;
res *= rhs;
return res;
}
friend constexpr MInt operator+(MInt lhs, MInt rhs) {
MInt res = lhs;
res += rhs;
return res;
}
friend constexpr MInt operator-(MInt lhs, MInt rhs) {
MInt res = lhs;
res -= rhs;
return res;
}
friend constexpr MInt operator/(MInt lhs, MInt rhs) {
MInt res = lhs;
res /= rhs;
return res;
}
friend constexpr std::istream &operator>>(std::istream &is, MInt &a) {
i64 v;
is >> v;
a = MInt(v);
return is;
}
friend constexpr std::ostream &operator<<(std::ostream &os, const MInt &a) {
return os << a.val();
}
friend constexpr bool operator==(MInt lhs, MInt rhs) {
return lhs.val() == rhs.val();
}
friend constexpr bool operator!=(MInt lhs, MInt rhs) {
return lhs.val() != rhs.val();
}
friend constexpr bool operator<(MInt lhs, MInt rhs) {
return lhs.val() < rhs.val();
}
};
template<>
i64 MInt<0>::Mod = 998244353;
constexpr int P = 1000000007;
using Z = MInt<P>;
struct Info {
i64 len = 0;
i64 l = 0;
i64 r = 0;
Z ans = 0;
};
Info operator+(const Info &a, const Info &b) {
if (a.len == 0) {
return b;
}
if (b.len == 0) {
return a;
}
Info c;
c.len = a.len + b.len;
c.l = a.len == a.l ? a.len + b.l : a.l;
c.r = b.len == b.r ? a.r + b.len : b.r;
c.ans = a.ans + b.ans + Z(a.r) * b.l;
return c;
}
Info f[61][61];
Info get(int n, int k) {
if (k >= 0) {
return f[n][k];
}
return { 1LL << n, 0, 0, 0 };
}
void solve() {
i64 n;
int k;
std::cin >> n >> k;
Info s;
int c = 0;
for (int i = 59; i >= 0; i--) {
if (n >> i & 1) {
s = s + get(i, k - c);
c++;
}
}
std::cout << s.ans << "\n";
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
for (int i = 0; i <= 60; i++) {
f[0][i] = {1, 1, 1, 1};
}
for (int i = 1; i <= 60; i++) {
for (int j = 0; j <= 60; j++) {
f[i][j] = get(i - 1, j) + get(i - 1, j - 1);
}
}
int t;
std::cin >> t;
while (t--) {
solve();
}
return 0;
}
代码3(分治)
dp[i][j]表示i个二进制位最多可以填j个1的(l,r)方案数
递推一下,然后以n是否被中间的0111111隔断来分类讨论
#include <bits/stdc++.h>
using namespace std;
#define mz 1000000007
long long ans, dp[66][66];
long long add(long long x) {
x %= mz;
return (x + 1) * x / 2 % mz;
}
void f(int now, long long n, int k) {
long long sb = (1LL << k) - 1, wtf = 1LL << (now - 1);
if (k == 0) {
return ;
} else if (n >= sb) {
if (n < wtf)
f(now - 1, n, k);
else {
if (n == wtf * 2 - 1)
ans = (ans + dp[now][k]) % mz;
else {
ans = (ans + dp[now - 1][k]) % mz;
f(now - 1, n - wtf, k - 1);
}
}
} else
ans = (ans + add(n + 1)) % mz;
return ;
}
int main() {
for (int i = 1; i <= 60; i++) {
for (int j = 1; j < i; j++) {
dp[i][j] = (dp[i - 1][j] + dp[i - 1][j - 1]) % mz;
}
dp[i][i] = add((1LL << i) - 1);
}
int t, k;
long long n, p;
scanf("%d", &t);
while (t--) {
scanf("%lld%d", &n, &k);
ans = 0;
k++;
n--;
long long minn = (1LL << k) - 1;
if (n < minn)
ans = add(n + 1);
else
f(60, n, k);
cout << ans << endl;
}
return 0;
}
/*
*/
代码4(自己的乱搞)
数位dp+分治的一发乱搞,枚举数对(l,r)是在哪一位有的diff
#include<bits/stdc++.h>
#include<iostream>
#include<cstdio>
#include<vector>
#include<map>
#include<queue>
#include<set>
using namespace std;
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define per(i,a,b) for(int i=(a);i>=(b);--i)
typedef long long ll;
typedef double db;
typedef pair<int,int> P;
#define fi first
#define se second
#define pb push_back
#define dbg(x) cerr<<(#x)<<":"<<x<<" ";
#define dbg2(x) cerr<<(#x)<<":"<<x<<endl;
#define SZ(a) (int)(a.size())
#define sci(a) scanf("%d",&(a))
#define pt(a) printf("%d",a);
#define pte(a) printf("%d\n",a)
#define ptlle(a) printf("%lld\n",a)
#define debug(...) fprintf(stderr, __VA_ARGS__)
using namespace std;
const int N=65,mod=1e9+7;
int t,k,dp[N][N][2];
ll n,f[N][N],g[N][N];
//f[i][j]表示二进制i位 总共填<=j个1 从(0)111往下填最多能连填多少个
//g[i][j]表示二进制i位 总共填<=j个1 从(1)000往上填最多能连填多少个
//0 1 01 10 011 100 000 001 010 011 100 101 110 111
ll G(int x,int y){
if(y<0)return 0;
return g[x][y];
}
int dfs(int x,int y,bool lim){//在x 还能填y个1的方案
if(y<0)return 0;
if(x==-1)return y>=0;
if(~dp[x][y][lim])return dp[x][y][lim];
int &ans=dp[x][y][lim];ans=0;
int up=lim?(n>>x&1):1;
for(int i=0;i<=up;++i){
ans=(ans+dfs(x-1,y-i,lim && (i==up)))%mod;
}
if(!lim){//这一位一个填0 一个填1 做merge 后面还能填x位
//printf("x:%d y:%d f:%lld g:%lld add:%lld\n",x,y,f[x][y],G(x,y-1),1ll*f[x][y]%mod*G(x,y-1)%mod);
ans=(ans+1ll*f[x][y]%mod*(G(x,y-1)%mod)%mod)%mod;
}
else{
if(up==1){
ll w=n&((1ll<<x)-1);
ll z=min(w+1,G(x,y-1));
//printf("x:%d y:%d w+1:%lld G(x,y-1):%d\n",x,y,(w+1)%mod,G(x,y-1));
ans=(ans+1ll*z%mod*(f[x][y]%mod)%mod)%mod;
}
}
//if(ans<0)
//printf("x:%d y:%d lim:%1d ans:%d\n",x,y,lim,ans);
//printf("dp[%d][%d][%1d][%1d]:%lld\n",x,trail,one,lim,dp[x][trail][one][lim]);
return ans;
}
int sol(){
memset(dp,-1,sizeof dp);
scanf("%lld%d",&n,&k);
n--;
return dfs(62,k,1);
}
void init(){
rep(i,0,N-1)f[0][i]=g[0][i]=1;//不填数了
// f[1][0]=0;g[1][0]=1;
// f[1][1]=2;g[1][1]=2;
rep(i,1,N-1){
f[i][0]=0;g[i][0]=1;
rep(j,1,i){
if(j<i)f[i][j]=0;
else f[i][j]=2ll*f[i-1][j-1];
if(j<i){
if(j==i-1)g[i][j]=g[i-1][j]+g[i-1][j-1];
else g[i][j]=g[i-1][j];
}
else g[i][j]=2ll*g[i-1][j-1];
//printf("i:%d j:%d f:%lld g:%lld\n",i,j,f[i][j],g[i][j]);
}
rep(j,i+1,N-1){
f[i][j]=f[i][j-1];
g[i][j]=g[i][j-1];
//printf("i:%d j:%d f:%lld g:%lld\n",i,j,f[i][j],g[i][j]);
}
}
}
int main(){
init();
sci(t);
while(t--){
pte(sol());
}
return 0;
}