题目分析
给定两个长度为 n n n 的数组 a , b a, b a,b。
寻找一个最长的序列 a i 1 a_{i_1} ai1, b i 2 b_{i_2} bi2, a i 3 a_{i_3} ai3, … \dots …,满足:
- i i < i 2 < i 3 < … i_i < i_2 < i_3<\dots ii<i2<i3<…。
- 序列中所有相邻的元素最少要拥有 { 0 , 2 , 4 } \{0, 2, 4\} {0,2,4} 中的一个相同的数字。
解题思路
考虑起点,起点可以是 a 1 a_1 a1, a 2 a_2 a2, a 3 a_3 a3, … \dots …, a n a_n an,起点较多,实现较为复杂,我们可以考虑终点。
考虑终点,终点可以是 a 1 a_1 a1, a 2 a_2 a2, a 3 a_3 a3, … \dots …, a n a_n an, b 2 b_2 b2, a 3 a_3 a3, a 4 a_4 a4, … \dots …, b n b_n bn,更多更复杂,所以我们可以在 a , b a, b a,b 的最后加上一个终点元素 420 420 420,那么终点就一定是 a n + 1 a_{n+1} an+1 或者 b n + 1 b_{n+1} bn+1。
设 f a [ i ] fa[i] fa[i] 表示以 a i a_i ai 为起点的符合要求的最大数组长度, f b [ i ] fb[i] fb[i] 表示以 b i b_i bi 为起点的符合要求的最大数组长度。
对于 f a [ i ] fa[i] fa[i]:
- 若 a i a_i ai 中有 0 0 0, f a [ i ] = max ( f b [ j 1 ] , f b [ j 2 ] , f b [ j 3 ] , … ) + 1 fa[i] = \max(fb[j_1],fb[j_2],fb[j_3],\dots) + 1 fa[i]=max(fb[j1],fb[j2],fb[j3],…)+1,其中 b [ j 1 ] , b [ j 2 ] , b [ j 3 ] , … b[j_1],b[j_2],b[j_3],\dots b[j1],b[j2],b[j3],… 中含有 0 0 0,且 min ( j 1 , j 2 , j 3 , … ) > i \min(j_1,j_2,j_3,\dots) > i min(j1,j2,j3,…)>i。
2 , 4 2,4 2,4 同理,我们应当在所有的 f b [ j ] fb[j] fb[j] 中去最大值。
由上可知,计算 f a [ i ] fa[i] fa[i] 时,需要用到 f b [ j ] fb[j] fb[j],其中 j > i j > i j>i,计算 f b fb fb 时同理,故我们需要在计算 f a [ i ] fa[i] fa[i] 前计算 f b [ j ] fb[j] fb[j],所以顺序应当由后往前。
为了优化动态转移的计算,我们可以定义一个长度为 3 3 3 的数组 p a pa pa,其中 p a [ 0 ] pa[0] pa[0] 表示截至当前的情况下,含有 0 0 0 的 a k a_k ak,使得 f a [ k ] fa[k] fa[k] 达到最大值, p b pb pb 同理。
那么,若
a
i
a_i
ai 中含有
0
,
2
,
4
0,2,4
0,2,4,则有 fa[i] = max(fb[pb[0]], fb[pb[1], fb[pb[2]) + 1
。
答案为 f a fa fa 中的最大值。
import sys
sys.setrecursionlimit(1000000)
input = lambda:sys.stdin.readline().strip()
def get(x):
a, b, c = False, False, False
while x > 0:
t = x % 10
if t == 0:
a = True
elif t == 2:
b = True
elif t == 4:
c = True
x //= 10
return [a, b, c]
n = int(input())
a = list(map(int, input().split()))
b = list(map(int, input().split()))
fa = [0 for i in range(n + 1)]
fb = [0 for i in range(n + 1)]
pa = [n, n, n]
pb = [n, n, n]
for i in range(n - 1, 0 - 1, -1):
x = get(a[i])
for j in range(3):
if x[j]:
fa[i] = max(fa[i], fb[pb[j]] + 1)
y = get(b[i])
for j in range(3):
if y[j]:
fb[i] = max(fb[i], fa[pa[j]] + 1)
for j in range(3):
if x[j] and fa[i] > fa[pa[j]]:
pa[j] = i
if y[j] and fb[i] > fb[pb[j]]:
pb[j] = i
print(max(fa))