Serval and Music Game
题面翻译
题目描述
给定整数
n
n
n 和长度为
n
n
n 的递增序列
s
s
s。
定义
f
(
x
)
f(x)
f(x) 为满足下列要求的整数
i
(
1
≤
i
≤
n
)
i(1\leq i\leq n)
i(1≤i≤n) 的数量:
- 存在非负整数 p i , q i p_i,q_i pi,qi 使得 s i = p i ⌊ s n x ⌋ + q i ⌈ s n x ⌉ s_i=p_i\bigg\lfloor\dfrac{s_n}{x}\bigg\rfloor+q_i\bigg\lceil\dfrac{s_n}{x}\bigg\rceil si=pi⌊xsn⌋+qi⌈xsn⌉。
你需要求出
∑
x
=
1
s
n
x
×
f
(
x
)
\sum_{x=1}^{s_n}x\times f(x)
∑x=1snx×f(x) 对
998244353
998244353
998244353 取模后的值。
每个测试点包含
t
t
t 组数据。
输入格式
第一行输入一个整数
t
(
1
≤
t
≤
1
0
4
)
t(1\leq t\leq10^4)
t(1≤t≤104) 表示数据组数,接下来对于每组数据:
第一行输入一个整数
n
(
1
≤
n
≤
1
0
6
)
n(1\leq n\leq10^6)
n(1≤n≤106)。
接下来输入一行
n
n
n 个整数表示序列
s
(
1
≤
s
1
<
s
2
<
⋯
<
s
n
≤
1
0
7
)
s(1\leq s_1<s_2<\cdots<s_n\leq10^7)
s(1≤s1<s2<⋯<sn≤107)。
单个测试点内所有组数据对应的
n
n
n 之和不超过
1
0
6
10^6
106,对应的
s
n
s_n
sn 之和不超过
1
0
7
10^7
107。
输出格式
对于每组数据:
输出一行一个整数表示
∑
x
=
1
s
n
x
×
f
(
x
)
\sum_{x=1}^{s_n}x\times f(x)
∑x=1snx×f(x) 对
998244353
998244353
998244353 取模后的值。
题目描述
Serval loves playing music games. He meets a problem when playing music games, and he leaves it for you to solve.
You are given n n n positive integers s 1 < s 2 < … < s n s_1 < s_2 < \ldots < s_n s1<s2<…<sn . f ( x ) f(x) f(x) is defined as the number of i i i ( 1 ≤ i ≤ n 1\leq i\leq n 1≤i≤n ) that exist non-negative integers p i , q i p_i, q_i pi,qi such that:
s i = p i ⌊ s n x ⌋ + q i ⌈ s n x ⌉ s_i=p_i\left\lfloor{s_n\over x}\right\rfloor + q_i\left\lceil{s_n\over x}\right\rceil si=pi⌊xsn⌋+qi⌈xsn⌉
Find out ∑ x = 1 s n x ⋅ f ( x ) \sum_{x=1}^{s_n} x\cdot f(x) ∑x=1snx⋅f(x) modulo $998,244,353 $ .
As a reminder, ⌊ x ⌋ \lfloor x\rfloor ⌊x⌋ denotes the maximal integer that is no greater than x x x , and ⌈ x ⌉ \lceil x\rceil ⌈x⌉ denotes the minimal integer that is no less than x x x.
输入格式
Each test contains multiple test cases. The first line contains the number of test cases t t t ( 1 ≤ t ≤ 1 0 4 1\leq t\leq 10^4 1≤t≤104 ). The description of the test cases follows.
The first line of each test cases contains a single integer n n n ( 1 ≤ n ≤ 1 0 6 1\leq n\leq 10^6 1≤n≤106 ).
The second line of each test case contains n n n positive integers s 1 , s 2 , … , s n s_1,s_2,\ldots,s_n s1,s2,…,sn ( 1 ≤ s 1 < s 2 < … < s n ≤ 1 0 7 1\leq s_1 < s_2 < \ldots < s_n \leq 10^7 1≤s1<s2<…<sn≤107 ).
It is guaranteed that the sum of n n n over all test cases does not exceed 1 0 6 10^6 106 , and the sum of s n s_n sn does not exceed 1 0 7 10^7 107 .
输出格式
For each test case, print a single integer in a single line — the sum of x ⋅ f ( x ) x\cdot f(x) x⋅f(x) over all possible x x x modulo 998 244 353 998\,244\,353 998244353 .
样例 #1
样例输入 #1
4
3
1 2 4
4
1 2 7 9
4
344208 591000 4779956 5403429
5
1633 1661 1741 2134 2221
样例输出 #1
26
158
758737625
12334970
提示
Solution
考虑每个 x x x 对答案的贡献。
分类讨论:
若 x ∣ s n x|s_n x∣sn ,则可以枚举 s n / x s_n/x sn/x 的倍数计算满足要求的 s i s_i si 的个数。
若 x ∤ s n x\nmid s_n x∤sn ,首先有 ⌈ s n x ⌉ = ⌊ s n x ⌋ + 1 \left\lceil\dfrac{s_n}{x}\right\rceil=\left\lfloor\dfrac{s_n}{x}\right\rfloor+1 ⌈xsn⌉=⌊xsn⌋+1
设 ⌊ s n x ⌋ = k \left\lfloor\dfrac{s_n}{x}\right\rfloor = k ⌊xsn⌋=k ,则 s i = ( p i + q i ) k + q i s_i=(p_i+q_i)k+q_i si=(pi+qi)k+qi ,由于 p i , q i ∈ N p_i,q_i\in \mathbb{N} pi,qi∈N ,
则若 s i s_i si 满足要求,必然有 ⌊ s i k ⌋ ≥ q i \left\lfloor\dfrac{s_i}{k}\right\rfloor\ge q_i ⌊ksi⌋≥qi ,不难证明 q i = s i m o d k q_i=s_i\bmod k qi=simodk ,
此时枚举 j = ⌊ s i k ⌋ j=\left\lfloor\dfrac{s_i}{k}\right\rfloor j=⌊ksi⌋ ,那么 s i s_i si 满足要求的充要条件为 s i ∈ [ j k , j k + j ] s_i\in [jk,jk+j] si∈[jk,jk+j] 。
可以发现当 0 ≤ q i < k 0\le q_i<k 0≤qi<k ,故当 j ≥ k j \ge k j≥k 时,此后的 s i s_i si 都会满足要求。
因此我们只需枚举 j < k j<k j<k ,剩下的统一处理(见代码)。
对于一个区间内满足要求的 s i s_i si ,可以用一个桶存储出现次数并将其进行前缀和,记为 c n t cnt cnt 数组,
那么在 [ l , r ] [l,r] [l,r] 中的 s i s_i si 的个数就是 c n t r − c n t l − 1 cnt_r - cnt_{l-1} cntr−cntl−1 。
这样我们可以 O ( s n ) O(\sqrt {s_n}) O(sn) 求出 f ( x ) f(x) f(x) 的值。
考虑到 k = ⌊ s n x ⌋ k=\left\lfloor\dfrac{s_n}{x}\right\rfloor k=⌊xsn⌋ 可以数论分块,于是可以在 O ( s n ) O(s_n) O(sn) 求出答案。
Code
#include<iostream>
#include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std;
typedef long long LL;
const int N = 1e7 + 3;
const LL MOD = 998244353;
inline void read(int &x)
{
int sgn = 1; x = 0;
char ch = getchar();
while(ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
if(ch == '-') sgn = -1, ch = getchar();
while(ch >= '0' && ch <= '9') x = (x << 1) + (x << 3), x += (ch ^ '0'), ch = getchar();
x *= sgn;
}
int T, cnt[N], n, s[N], m;
int main()
{
read(T);
while(T -- )
{
read(n); m = 0;
for(int i = 1; i <= n; i ++ ) read(s[i]), m = max(m, s[i]);
for(int i = 0; i <= m; i ++ ) cnt[i] = 0;
for(int i = 1; i <= n; i ++ ) cnt[s[i]] ++ ;
for(int i = 1; i <= m; i ++ ) cnt[i] += cnt[i - 1];
LL res = 0, tp = 0;
for(int i = 1; i <= m; i ++ )
{
LL sum = 0; int k = m / i;
if(m % i)
if(m % (i - 1) && m / (i - 1) == k) sum = tp;
else
{
for(int j = 1; j < k && j * k <= m; j ++ ) sum += cnt[min(j * k + j, m)] - cnt[j * k - 1];
if(1ll * k * k <= 1ll * m) sum += cnt[m] - cnt[k * k - 1];
}
else for(int j = k; j <= m; j += k) sum += cnt[j] - cnt[j - 1];
res += i * sum % MOD; res %= MOD;
tp = sum;
}
printf("%lld\n", res);
}
return 0;
}