题目来源:码蹄集
题目描述:
Java代码实现:
import java.util.Scanner;
public class Main {
static final int N = 4000005;
static final int M = 100005;
static final int mod = 998244353;
static long[] f = new long[N];
static long[] g = new long[N];
public static void solve() {
long sum = 0, sum2 = 0;
f[1] = 1;
Scanner scanner = new Scanner(System.in);
long n = scanner.nextLong();
long p = scanner.nextLong();
for (int i = 1; i <= n; ++i) {
f[i] = (f[i] + ((sum + sum2) % p + g[i]) % p) % p;
for (int j = i; j <= n; j += i)
g[j] = (g[j] + f[i]) % p;
for (int j = i + 1; j <= n; j += i + 1)
g[j] = (g[j] - f[i] + p) % p;
sum = (sum + f[i]) % p;
sum2 = (sum2 + g[i]) % p;
}
System.out.println(f[(int) n]);
}
public static void main(String[] args) {
solve();
}
}
C++代码实现:
参考链接:https://yxsmarter.blog.csdn.net/article/details/128211350?spm=1001.2014.3001.5502
#include <bits/stdc++.h>
#define IO ios::sync_with_stdio(NULL)
#define sc(z) scanf("%lld", &(z))
#define _ff(i, a, b) for (int i = a; i <= b; ++i)
#define _rr(i, a, b) for (int i = b; i >= a; --i)
#define _f(i, a, b) for (int i = a; i < b; ++i)
#define _r(i, a, b) for (int i = b - 1; i >= a; --i)
#define mkp make_pair
#define endl "\n"
#define pii pair<int, int>
#define fi first
#define se second
#define lowbit(x) x & (-x)
#define pb push_back
using namespace std;
typedef double db;
typedef long long ll;
typedef long double ld;
const int N = 4e6 + 5;
const ll M = 1e5 + 5;
const ll mod = 998244353;
const int inf = 1e9;
const double eps = 1e-9;
const double PI = acos(-1.0);
ll f[N], g[N];
void solve()
{
ll sum = 0, sum2 = 0;
f[1] = 1;
ll n, p;
cin >> n >> p;
_ff(i, 1, n)
{
f[i] = (f[i] + ((sum + sum2) % p + g[i]) % p) % p;
for (int j = i; j <= n; j += i)
g[j] = (g[j] + f[i]) % p;
for (int j = i + 1; j <= n; j += i + 1)
g[j] = (g[j] - f[i] + p) % p;
sum = (sum + f[i]) % p;
sum2 = (sum2 + g[i]) % p;
}
cout << f[n] << endl;
}
int main()
{
solve();
return 0;
}
Python代码实现(能通过4个测试用例,但会超时):
def solve():
N = 4000005
f = [0] * N
g = [0] * N
sum_val = 0
sum2_val = 0
f[1] = 1
n, p = map(int, input().split())
for i in range(1, n + 1):
f[i] = (f[i] + ((sum_val + sum2_val) % p + g[i]) % p) % p
j = i
while j <= n:
g[j] = (g[j] + f[i]) % p
j += i
j = i + 1
while j <= n:
g[j] = (g[j] - f[i] + p) % p
j += i + 1
sum_val = (sum_val + f[i]) % p
sum2_val = (sum2_val + g[i]) % p
print(f[n])
solve()
代码提交测试结果:
附B站老师思路讲解,可供参考:https://www.bilibili.com/video/BV1ih4y1x7su/?t=484.3