http://acm.hdu.edu.cn/showproblem.php?pid=6683
题意:问1-n这些数字中有多少子序列是等比数列。
做法:这道题有点悬。。。。。
我们接下来来一下,数学推导(瞎JB乱搞):
我们设等比数列的公比,等比数列的长度为,首项末项
这个必定是一个整数,所以可以得到,这个是显而易见的。但不过这个好像不好弄。
因为如果你统计p的数目,这只是n范围内的p,并没有说明q啊。
但不过也简单,你把这个式子变一下形,就可以得到了,然后a的数目就是
然后对于每一个a我们还要找到与他互质的数b,而且还是小于他的。
所以对于一个a就是这样的。
然后枚举a所以:
这里的这个k是不会多大的,最多也就六十几。
到了这里怎么做呢。
首先对于k=1和k=2可以直接通过公式计算得到。
然后我们通过枚举k,在分块来做,我刚刚开始就是这样想的,而且还做了,结果,由于开的次方开得开大了,精度出了问题。
后来有修正了一下精度,本地,测了一发数据,好像很多不对。。。。。。。(我果然还是太菜了)
然后看了题解。
我们对于k=3的时候,我们可以分块做,只有一个平方而已(这个分块具体还是看代码吧,我就不说了,反证我也是跟着感觉走的)。
当k>3时,由于a不会大于,当k大于3时,我们可以,暴力算,不会很大的,也就,可以用计算器算一算。
这里的具体写法如下:
for (ll a = 2;; a++) {
ll now = a * a;
if (now > n / a) break;
now *= a;
while (now <= n) {
ans = (ans + n / now % mod * phi[a]) % mod;
if (now > n / a) break;
now *= a;
}
}
我们暴力遍历a,然后对于每一个a在遍历他的k次方,然后累加答案,这样的判断一定要这样写,不然会爆精度。。
这里的复杂的我认为没有多大的,我就不写了。。。。。。
然后分块求k=3的时候,需要用到杜教筛,这里杜教筛我觉得有点玄。。反正起码要开40000000吧。否则会T,肯刚刚开始的那几个比较大吧。
然后就是内存,不要全部数组开long long会炸的。。。
其他都很简单,都是数学题的套路而已
#include "bits/stdc++.h"
using namespace std;
inline int read() {
int x = 0;
bool f = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') f = 0;
for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
if (f) return x;
return 0 - x;
}
typedef long long ll;
const int maxn = 40000000 + 10;
const int mod = 998244353;
int pri[maxn], vis[maxn], cnt = 0, phi[maxn], sum[maxn];
void init() {
vis[1] = phi[1] = 1;
cnt = 0;
for (int i = 2; i < maxn; i++) {
if (!vis[i]) {
pri[++cnt] = i;
phi[i] = i - 1;
}
for (int j = 1; j <= cnt && i * pri[j] < maxn; j++) {
vis[i * pri[j]] = 1;
if (i % pri[j] == 0) {
phi[i * pri[j]] = phi[i] * pri[j];
break;
}
phi[i * pri[j]] = phi[i] * (pri[j] - 1);
}
}
for (int i = 1; i < maxn; i++) {
sum[i] = (sum[i - 1] + phi[i]) % mod;
}
}
unordered_map<ll, int> mp;
ll get_s1(ll x) {
x %= mod;
ll tmp = x * (x + 1) / 2;
return tmp % mod;
}
ll get_s(ll x) {
if (x < maxn) return sum[x];
if (mp.count(x)) return mp[x];
ll ans = 0;
for (ll l = 2, r; l <= x; l = r + 1) {
r = x / (x / l);
ans = (ans + (r - l + 1) * get_s(x / l) % mod) % mod;
}
ans = (get_s1(x) - ans + mod) % mod;
return ans;
}
ll get_sqrt(ll a) {
ll x = (ll) sqrt(a);
return x;
}
ll solve(ll n) {
ll ans = (n % mod) * (n % mod + 1) / 2;
ans %= mod;
for (ll a = 2;; a++) {
ll now = a * a;
if (now > n / a) break;
now *= a;
while (now <= n) {
ans = (ans + n / now % mod * phi[a]) % mod;
if (now > n / a) break;
now *= a;
}
}
for (ll l = 2, r; l * l <= n; l = r + 1) {
r = get_sqrt(n / (n / l / l));
ans = (ans + n / l / l % mod * (get_s(r) - get_s(l - 1) + mod) % mod) % mod;
}
return ans;
}
int main() {
int T;
init();
scanf("%d", &T);
while (T--) {
ll n;
scanf("%lld", &n);
printf("%lld\n", solve(n));
}
return 0;
}