题目链接:点击查看
题目大意:给出公式,其中,则,现在给出n(<=1e12),求出答案对998244353取模后的答案
题目分析:若用暴力实现上述公式,只需要两层for循环,时间复杂度为n*n,也就是1e24。。所以我们需要想办法优化,其实上面的公式看着吓人,我们稍微化简一下:
这样看起来就亲切多了,再其实仔细观察一下就会发现,因为第二层循环的b是从a开始的,所以一定是大于零并且小于等于1,这样向上取整就变为1了,乘上1就可以直接约分了:
这样一来这个公式就比较好想了,因为n给的是1e12,这个数字很容易让人联想到sqrt(n),也就是1e6,所以我们不妨分一下块,对第一层的a简单分为[1,sqrt(n)]+[sqrt(n)+1,n],因为上面公式的b最大为1e12,所以只要a大于1e6,那么向下取整答案就是1了,也就是说对于[sqrt(n)+1,n]这一部分,我们可以稍微化简一下:
现在利用等差数列求和公式以及平方和求和公式,就可以O(1)计算出区间[sqrt(n)+1,n]之间的答案了
而对于区间[1,sqrt(n)]中的答案,我们可以直接O(n)跑了,对于其中第二层循环b,虽然看着是要从a跑到1e12,但其实稍微思考一下就能发现,因为是要求的值作为答案,而这个值向下取整后,在a确定的情况下,的值在一段连续的区间上肯定是相同的,根据这个性质我们可以将1e12的数据分块,分为一段一段的,每一段上的答案都为一个值,这样分块后我们就能很快的算出答案了,这个分块最大是当a以2为底时,也就才40最大了,时间复杂度可以大胆地放缩,因为最大也就才1e6*40,5e7左右,在评测机上的表现还是十分优秀的
话说回来,我们该怎么求这个区间呢,其实稍微将上面的式子转换一下:
等价于
所以我们可以从1开始枚举x,从而计算出b的区间在内的都等于x,既然都是一个常数了,那么我们直接根据公式就能直接计算贡献了,前面的公式也就是a*x*区间长度
最后需要注意一下关于取模时的一个巨坑,也就是遇到n就要模一下,不然1e12*1e9直接就把longlong爆掉了,还有就是在计算的过程中可能会出现负数,所以最后需要先模再加模最后再取模,基操基操
还有就是在计算等差数列求和公式和平方和求和公式时,会遇到除法,这个时候直接用费马小定理求一下逆元就可以解决了
代码:
#include<iostream>
#include<cstdlib>
#include<string>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<climits>
#include<cmath>
#include<cctype>
#include<stack>
#include<queue>
#include<list>
#include<vector>
#include<set>
#include<map>
#include<sstream>
#include<unordered_map>
using namespace std;
typedef long long LL;
const int inf=0x3f3f3f3f;
const int N=1e3+100;
const int mod=998244353;
LL inv2,inv6;//2的逆元和6的逆元
LL q_pow(LL a,LL b)
{
LL ans=1;
while(b)
{
if(b&1)
ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans;
}
LL get_ans1(LL n)//1~a的等差和
{
return (n%mod)*((n+1)%mod)%mod*inv2%mod;
}
LL get_ans2(LL n)//1~a*a的平方和
{
return (n%mod)*((n+1)%mod)%mod*(2*n%mod+1)%mod*inv6%mod;
}
int main()
{
// freopen("input.txt","r",stdin);
// ios::sync_with_stdio(false);
inv2=q_pow(2,mod-2);
inv6=q_pow(6,mod-2);
LL n;
cin>>n;
LL t=sqrt(n);
LL ans=(((n+1)%mod)*(get_ans1(n)-get_ans1(t))%mod-get_ans2(n)+get_ans2(t))%mod;
for(LL a=2;a<=t;a++)
{
LL x=1;//记录loga(b)的值,即b=a^x
LL b=a;
LL sum=0;
while(b<=n)
{
sum=(sum+(min(b*a-1,n)-b+1)%mod*x%mod)%mod;
b*=a;
x++;
}
ans=(ans+sum*a%mod)%mod;
}
cout<<(ans%mod+mod)%mod<<endl;
return 0;
}