Solution
显然的,每次得到的集合都会是区间
根据条件,每次区间会扩大一倍
第i个区间来讲
左端点为
2i−1
右端点为
2i∗n−2i−1+1
因而可以依次枚举每个区间的左右端点 加入高精度即可
其实,有一种更优秀的方式
可以发现,每个区间的大小可以被表达为
(n−3)∗2i−1+3
那么 前i个区间元素个数和也可以被计算
这就满足二分了
Code
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define oo 2139062143
#define sqr(x) ((x)*(x))
#define lowbit(x) ((x)&(-x))
#define abs(x) (((x)>=0)?(x):(-(x)))
#define max(x,y) (((x)>(y))?(x):(y))
#define min(x,y) (((x)<(y))?(x):(y))
#define fo(i,x,y) for (int i = (x);i <= (y);++ i)
#define fd(i,x,y) for (int i = (x);i >= (y);-- i)
using namespace std;
typedef double db;
typedef long long ll;
const int M = 8080, JW = 10000000,N = 1010;
ll n,m,ans[N][M];
ll k[M];
void write(ll a[])
{
printf("%lld",a[a[0]]);
fd(i,a[0] - 1,1) printf("%07lld",a[i]);
}
int cmp(ll a[],ll b[])//-1:a=b 0:a<b 1:a>b
{
if (a[0] != b[0]) return (a[0] > b[0]);
fd(i,a[0],1)
if (a[i] != b[i]) return (a[i] > b[i]);
return -1;
}
void mult(ll a[],ll b[],ll c[])
{
ll t[M]= {0};
t[0] = a[0] + b[0];
fo(i,1,a[0]) fo(j,1,b[0])
t[i + j - 1] += a[i] * b[j];
fo(i,1,t[0]) t[i + 1] += t[i] / JW,t[i] %= JW;
while (!t[t[0]] && t[0] > 1) -- t[0];
memcpy(c,t,sizeof(t));
}
void qsm(ll a[],ll x)
{
ll t[M]= {0};
t[0] = 1,t[1] = 2;
a[0] = a[1] = 1;
while (x)
{
if (x & 1) mult(a,t,a);
mult(t,t,t);
x >>= 1;
}
}
void ins(char s[],ll a[])
{
int l = strlen(s);
a[0] = 1;
int tmp;
for (tmp = l;tmp > 6;tmp -= 7)
{
fo(j,0,6) a[a[0]] = a[a[0]] * 10 + (s[tmp - 7 + j] -'0');
++ a[0];
}
for (int j = 0;j < tmp;++ j)
a[a[0]] = a[a[0]] * 10 + s[j] - '0';
}
void summing(ll a[],ll b[],ll c[])
{
ll t[M] = {0};
t[0]=max(a[0],b[0]);
fo(i,1,t[0]) t[i] = a[i] + b[i];
fo(i,1,t[0]) t[i + 1] += t[i] / JW,t[i] %= JW;
while (t[t[0] + 1]) ++ t[0];
memcpy(c,t,sizeof(t));
}
void minu(ll a[],ll b[],ll c[])
{
ll t[M]= {0};
t[0] = a[0];
fo(i,1,t[0]) t[i] = a[i] - b[i];
fo(i,1,t[0] - 1)
if (t[i] < 0) -- t[i + 1],t[i] += JW;
while (!t[t[0]] && t[0] > 1) -- t[0];
memcpy(c,t,sizeof(t));
}
void divt(ll a[])
{
fd(i,a[0],2)
{
if (a[i] & 1) a[i - 1] += JW;
a[i] >>= 1;
}
a[1] >>= 1;
while (!a[a[0]] && a[0] > 1) -- a[0];
}
void sigc(ll t[],ll mid,ll tt[])
{
qsm(t,mid - 1);
tt[1] = n - 3;
mult(t,tt,t);
tt[1] = mid * 3;
summing(t,tt,t);
tt[1] = n;
minu(t,tt,t);
}
void work()
{
ll t[M]= {0},tt[M]= {0};
tt[0] = 1;
int l = 1,r = 5000;
while (l < r)
{
int mid = (l + r + 1) >> 1;
sigc(t,mid,tt);
int j = cmp(k,t);
if (j == -1)
{
qsm(t,mid-2);
tt[1] = n - 1;
mult(t,tt,t);
tt[1] = 1;
summing(t,tt,t);
write(t);
return;
}
if (j) l = mid;
else r = mid - 1;
}
sigc(t,r,tt);
minu(k,t,k);
qsm(t,r);
tt[1] = 2;
minu(t,tt,t);
summing(k,t,t);
write(t);
}
int main()
{
ans[1][1] = 1;
ans[2][1] = 1,ans[2][2] = 2,ans[2][3] = 3;
char ch[N];
scanf("%lld", &n);
if (n < 4)
{
scanf("%lld",&m);
if (n <= 2)
if (ans[n][m] == 0) printf("-1");
else printf("%lld", ans[n][m]);
if (n == 3)
{
qsm(k,(m + 2) / 3);
if (m % 3 == 1) -- k[1];
if (m % 3 == 0) ++ k[1];
write(k);
}
return 0;
}
memset(k,0,sizeof k);
scanf("%s", ch);
ins(ch,k);
work();
printf("\n");
}