1. 问题描述:
给定一个长度为 n 的整数数组 a1,a2,…,an;请你构造长度为 n 的整数数组 b1,b2,…,bn,要求数组 b 满足:
- b1 = 0;
- 对于任意一对索引 i 和 j(1 ≤ i,j ≤ n),如果 ai = aj 则 bi = bj(注意,如果 ai ≠ aj,则 bi 和 bj 相等与否随意);
- 对于任意索引 i(i ∈ [1,n − 1]),要么满足 bi = bi+1,要么满足 bi+1 = bi + 1。
请计算,一共可以构造出多少个不同的满足条件的数组 b。由于答案可能很大,你只需要输出对 998244353 取模后的结果。例如,如果 a = [1,2,1,2,3],则一共有 2 个满足条件的数组 b,分别是 b = [0,0,0,0,0] 和 b = [0,0,0,0,1];
输入格式
第一行包含一个整数 n。第二行包含 n 个整数 a1,a2,…,an。
输出格式
一个整数,表示对 998244353 取模后的结果。
数据范围
前 3 个测试点满足 2 ≤ n ≤ 5;
所有测试点满足 2 ≤ n ≤ 2 × 10 ^ 5,1 ≤ ai ≤ 10 ^ 9;
输入样例1:
5
1 2 1 2 3
输出样例1:
2
输入样例2:
2
100 1
输出样例2:
2
输入样例3:
4
1 3 3 7
输出样例3:
4
来源:https://www.acwing.com/problem/content/description/4415/
2. 思路分析:
分析题目可以知道数组 a 是对于数组 b 的相对限制,对于数组 b 来说需要满足 bi+1 - bi = 0 或者 bi+1 - bi = 1,所以对于数组 b 来说非严格单调递增,由于数组 a 对于数组 b 的是相对限制,所以理论上对于数组 b 来说有无限种可能,所以为了求解数组 b 在满足上述限制的情况下的方案数目所以添加了一个限制 b[1] = 0,有了这个限制之后那么对于数组 b 来说方案数目就是确定的,由下图可知对于任意两个不相等的 i,j,如果 bi == bj 且数组 b 非严格单调递增所以 bi == bi+1 == ... bj,所以数组 b 就被划分为若干段,并且每一段都是相等的,若有 m 段那么总共的方案数目有 2 ^ (m - 1)种,所以问题的关键是如何求解数组 b 划分的段数。
如何求解划分的若干段区间呢?对于这道题目来说其实有三种方案,由下图知,第一种比较简单的方法是区间合并,将有交集的区间合并为一个区间,第二种方法是离散化 + 差分(离散化其实是数组到下标的映射),这种方法比较难理解;第三种方法是并查集,一开始数组 b 的祖先节点都是自己,我们在合并区间的时候每一个点都指向右端点,对于区间 [l,r],先找到 l 和 r 的祖先节点,将 l 的祖先节点指向 r 的祖先节点;
对于这道题目来说会卡哈希表 unordered_map,因为哈希表在均摊的情况下时间复杂度为 O(1),最坏情况下为 O(n),因为 unordered_map 有一个哈希函数,出题人可以根据这个哈希函数将数据置为哈希函数对应质数的倍数这样就会经常发生冲突重建哈希表的情况,所以就会变得很慢,其中一个解决方法是将哈希表初始化为一个比较大的长度,这样出题人就不好猜这个质数;下面采用的是比较好写的区间合并方法来解决这个问题。
3. 代码如下:
python:
class Solution:
def process(self):
n = int(input())
a = list(map(int, input().split()))
# L记录a[i]最左边的位置, R记录a[i]最右边的位置
L, R = dict(), dict()
for i in range(n):
R[a[i]] = i
if a[i] not in L:
L[a[i]] = i
q = list()
for i in range(n):
q.append((L[a[i]], R[a[i]]))
# 根据左端点排序, 左端点相同按照右端点排序
q.sort(key=lambda x: (x[0], x[1]))
st = ed = -1
count = 0
# 区间合并
for i in range(n):
if q[i][0] <= ed:
ed = max(q[i][1], ed)
else:
count += 1
st, ed = q[i][0], q[i][1]
res = 1
mod = 998244353
# 计算方案数目
for i in range(count - 1):
res = (res * 2) % mod
print(res)
if __name__ == '__main__':
Solution().process()
go:
package main
import (
"bufio"
"fmt"
"io"
"os"
"sort"
)
func max(a, b int) int {
if a > b {
return a
}
return b
}
func run(r io.Reader, w io.Writer) {
// 使用 bufio.NewReader()等等函数优化读取数据和写入数据的效率
in := bufio.NewReader(r)
out := bufio.NewWriter(w)
// 将缓存数据写入到标准输出中
defer out.Flush()
var n int
fmt.Fscan(in, &n)
a := make([]int, n+10)
L, R := make(map[int]int), make(map[int]int)
for i := 0; i < n; i++ {
fmt.Fscan(in, &a[i])
R[a[i]] = i
if _, flag := L[a[i]]; !flag {
L[a[i]] = i
}
}
var q [][]int
for i := 0; i < n; i++ {
t := []int{L[a[i]], R[a[i]]}
q = append(q, t)
}
// 对二维切片进行排序
sort.Slice(q, func(i, j int) bool {
return q[i][0] <= q[j][0]
})
ed := -1
count := 0
for i := 0; i < n; i++ {
if q[i][0] <= ed {
ed = max(ed, q[i][1])
} else {
count += 1
ed = q[i][0]
}
}
res := 1
mod := 998244353
for i := 0; i < count-1; i++ {
res = (res * 2) % mod
}
fmt.Fprintln(out, res)
}
func main() {
run(os.Stdin, os.Stdout)
}