Description
一棵二叉搜索树,插入n次,第i次插入的节点权值为(a+bn)%m,问第n次插入的点的深度
T<=5e4,n<=1e16,a,b,m<=1e8
Solution
定义val(n)表示第n个数的权值,suf(v)表示所有的(a+bn)%m中,大于v的最小的数,pre(v)表示小于v的最大的数
当n>m/gcd(m,b)时,后面的点构成循环,只需要计算第一层的val(n)或suf(val(n))
考虑这样一个过程,以权值为下标,插入的时间为单调栈,从0往val(n)和从m-1往val(n)做两个单调递增的栈,那么最后单调栈的大小就是答案
打表可以发现左右的单调栈都可以拆成log段等差数列
也就是说我们只需要支持求区间[L,R]的权值的最小值即可
设g(M,D,L,R)表示,最小的x,满足L<=Dx%M<=R
当L=0,答案为0
当(L-1)/gcd(M,D)>=R/gcd(M,D)时,无解
当[L,R]中有D的倍数时,返回最小的那个
当D>M-D时,答案为g(M,M-D,M-R,M-L)
否则,L mod D一定仍旧<=R mod D,考虑把问题范围mod D
假设存在解,设整数k,会存在一个x,满足L+kM<=Dx<=R+kM
=>L<=Dx-kM<=R => L-Dx<=-kM<=R-Dx => L mod D<=D-kM mod D<=R mod D
k=g(D,D-M,Lmod D,R mod D)
注意到每次M的范围至少/2,所以算g是一个log的
总复杂度O(Tlog^2m)
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
typedef long long ll;
ll read() {
char ch;
for(ch=getchar();ch<'0'||ch>'9';ch=getchar());
ll x=ch-'0';
for(ch=getchar();ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x;
}
void write(ll x) {
if (!x) {puts("0");return;}
char ch[20];int tot=0;
for(;x;x/=10) ch[++tot]=x%10+'0';
fd(i,tot,1) putchar(ch[i]);
puts("");
}
const int N=2e6+5;
ll n;
int gcd(int x,int y) {return y?gcd(y,x%y):x;}
int a,b,m;
int val(ll n) {return (a+b*n)%m;}
ll g(ll m,ll d,ll L,ll R) {
if (!L) return 0;
int r=gcd(m,d);
if ((L-1)/r>=R/r) return -1;
int x=(L-1)/d+1;
if (d*x<=R) return x;
if (d>m-d) return g(m,m-d,m-R,m-L);
int k=g(d,d-m%d,L%d,R%d);
if (k==-1) return -1;
L+=k*m,R+=k*m;
return (L-1)/d+1;
}
ll f(int L,int R) {
L=(L-a+m)%m;R=(R-a+m)%m;
if (L<=R) return g(m,b,L,R);
return 0;
}
int calc(int v) {
if (v>m-1) return 0;
int ret=0;
int x=val(f(0,v));ret+=x<v;
while (x<v) {
int y=f(x+1,v);
int d=val(y)-x;
int L=(v-x)/d;ret+=L;x+=d*L;
if (x==v) ret--;
}
x=val(f(v,m-1));ret+=x>v;
while (x>v) {
int y=f(v,x-1);
int d=x-val(y);
int L=(x-v)/d;ret+=L;x-=d*L;
if (x==v) ret--;
}
return ret;
}
int main() {
freopen("fuwafuwa.in","r",stdin);
freopen("fuwafuwa.out","w",stdout);
for(int ty=read();ty;ty--) {
a=read();b=read();m=read();n=read();
a%=m;b%=m;(a+=b)%=m;n--;
if (!b) {write(n);continue;}
int g=gcd(m,b);int len=m/g;
ll ans=n/len;n%=len;
int ret=calc(val(n));
if (ans) ret=max(ret,calc(val(n)+g));
write(ans+ret);
}
return 0;
}