Description
给出正整数n和k,计算j(n, k)=k mod 1 + k mod 2 + k mod 3 + … + k mod n的值,其中k mod i表示k除以i的余数。例如j(5, 3)=3 mod 1 + 3 mod 2 + 3 mod 3 + 3 mod 4 + 3 mod 5=0+1+0+3+3=7
Input
输入仅一行,包含两个整数n, k。
Output
输出仅一行,即j(n, k)。
Sample Input
Sample Output
HINT
50%的数据满足:1<=n, k<=1000 100%的数据满足:1<=n ,k<=10^9
分析:
一开始试图反演,但是得到的式子无法简单计算
后来才发现这不是一道标准的反演
而是打标找规律
有一个很显然的结论:
如果n>m,则有一部分的贡献确定的:(n-m)*m
所以我们的问题就是如何计算n<=m这一部分的贡献
简单的打表:
可以初步发现序列实际上可以看作是若干个等差数列求和
于是我就暴力计算每一个等差数列的首项和项数,求Σ
(相当于把整个序列分成了若干块)
但是在对拍的时候出现了bug:
也就是说,每个等差数列的最后一项是不确定的,
我们并不明确这个等差数列到底会出现多少项
所以只利用这一个性质是不行的
在收受到题解的启发后,发现自己的想法实际上是对的(等差数列),但是实现的方法要改进一下:
考虑i<=m的时候
m mod i = m - trunc(m / i) * i
会发现有一些连续的i的trunc(m/i)的值会相等
那么这一个区间的余数就应该是一个等差数列
我们可以通过等差序列把这些trunc(m/i)相同的值求出来
于是我们就可以找这些trunc(m/i)相同的组,
用二分查找这些组的右端点来完成i<=m情况的计算
假设我们现在已经知道了一个区间的长度len,以及这一段的两端点(l,r),
那么这一部分的贡献就是:
(首项+末项)* len / 2
Q.
复杂度呢?
A.
我们会发现,trunc(m / i)的值的个数不会超过sqrt(m)个
因为以sqrt(m)为分界点,trunc(m/i) (i<=sqrt(m))的取值肯定不会超过sqrt(m)个
而trunc(k/i) (i>sqrt(m))的取值肯定只有1一个
那么我们就可知,这些trunc(k/i)相同的组的个数不会超过sqrt(k)个
总复杂度O(sqrt(m)*logm)
//这里写代码片
#include<cstdio>
#include<cstring>
#include<iostream>
#define ll long long
using namespace std;
ll n,m,d;
int last,a;
int solve(int x) //trunc(m/last)=tranc(m/x)
{
int l=x,r=n;
while (r-l>1)
{
int mid=(l+r)>>1;
if (m/mid!=m/x) r=mid; //查找最右端,相同的tranc(m/x)一定是连续的一段
else l=mid;
}
if (m/r==m/x) return r;
else return l;
}
int main()
{
ll ans=0;
scanf("%lld%lld",&n,&m);
for (int i=1;i<=min(m,n);i=last+1)
{
last=solve(i);
a=m/i; //trunc(m/i)
int cnt=last-i+1; //项数
ans+=(ll)(2*m-a*(i+last))*cnt/2;
}
if (n>m) ans+=(ll)(n-m)*m;
printf("%lld",ans);
return 0;
}
刚刚舒老师来慰问,并表示这道题ta的做法是分块
实际上这就牵扯到莫比乌斯反演的经典操作
基本所有反演的求和操作都是一个分块,并且都有这句话:
这句话可以感性的理解为:
由i求得last,得到一段序列(i,last)
其中 trunc(n/i)=trunc(n/(i+1))=…=trunc(n/last)
正好是本题的精髓
//舒老师的代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
long long ans,k,n,check;
long long f(long long l,long long r)
{
l--;
return (r+1)*r/2-(l+1)*l/2;
}
int main()
{
scanf("%lld%lld",&n,&k);
for(long long i=1,last;i<=min(n,k);i=last+1)
{
last=min(n,k/(k/i));
ans+=(k/i)*f(i,last);
}
printf("%lld\n",k*n-ans);
}