1012
dp
题意
给出长度为
n
n
n 的全排列
p
,
q
p,q
p,q,还有一个由
p
,
q
p,q
p,q 组成的长度为
2
×
n
2\times n
2×n 的序列
S
S
S 。
现在有一个空序列
R
R
R ,每次可以从
p
p
p 或
q
q
q 的开头取出一个数字并加到
R
R
R 的末尾,问有多少种取法使得
R
=
S
R = S
R=S 。
思路
记
d
p
[
x
]
[
y
]
dp[x][y]
dp[x][y] 表示
p
,
q
p,q
p,q 中分别取前
x
,
y
x,y
x,y 位时的方案数。则
d
p
[
x
]
[
y
]
=
d
p
[
x
+
1
]
[
y
]
+
d
p
[
x
]
[
y
+
1
]
dp[x][y] = dp[x+1][y]+dp[x][y+1]
dp[x][y]=dp[x+1][y]+dp[x][y+1] ,前者需满足
p
[
x
+
1
]
=
s
[
x
+
y
+
1
]
p[x+1]=s[x+y+1]
p[x+1]=s[x+y+1] ,后者需满足
q
[
y
+
1
]
=
s
[
x
+
y
+
1
]
q[y+1] = s[x+y+1]
q[y+1]=s[x+y+1] 。
但因为空间不够,无法开二维数组(当然也可以用 std::map,但常数比较大),所以改成
d
p
[
2
×
m
a
x
n
]
[
2
]
dp[2\times maxn][2]
dp[2×maxn][2] ,其中
d
p
[
i
]
[
0
/
1
]
dp[i][0/1]
dp[i][0/1] 表示两数组一共匹配了
i
i
i 位且目前最后一位匹配的是
p
/
q
p/q
p/q 的方案数。转移也类似写就可以,代码如下:
int dfs(int x, int y, int t) {
if(x == n && y == n) return 1;
if(dp[x + y + 1][t] != -1) return dp[x + y + 1][t];
int ans = 0;
if(p[x + 1] == s[x+y+1] && x < n) {
ans += dfs(x + 1, y, 0);
ans %= P;
}
if(q[y + 1] == s[x+y+1] && y < n) {
ans += dfs(x, y + 1, 1);
ans %= P;
}
return dp[x+y+1][t] = ans;
}
//调用
int ans = 0;
if (p[1] == s[1])ans += dfs(1, 0, 0), ans %= P;
if (q[1] == s[1])ans += dfs(0, 1, 1), ans %= P;
这个代码看起来是 O ( n 2 ) O(n^2) O(n2) ,实际上是 O ( n ) O(n) O(n) ,因为 s s s 中每个数都只出现了两次(如果不是直接0),且在 p , q p,q p,q 中各出现一次,也就是 s s s 中同一个位置最多和两个位置(p,q中对应的位置)匹配,因为 s s s 长度为 2 × n 2\times n 2×n ,所以时间复杂度为 O ( 4 n ) O(4n) O(4n)。
代码
int n;
int dp[maxn << 1][2], p[maxn], q[maxn];
int s[maxn << 1];
int dfs(int x, int y, int t) {
if(x == n && y == n) return 1;
if(dp[x + y + 1][t] != -1) return dp[x + y + 1][t];
int ans = 0;
if(p[x + 1] == s[x+y+1] && x < n) {
ans += dfs(x + 1, y, 0);
ans %= P;
}
if(q[y + 1] == s[x+y+1] && y < n) {
ans += dfs(x, y + 1, 1);
ans %= P;
}
return dp[x+y+1][t] = ans;
}
void solve() {
cin >> n;
for(int i = 1; i <= n * 2; i++) dp[i][0] = dp[i][1] = -1;
for(int i = 1;i <= n; i++) cin >> p[i];
for(int i = 1;i <= n; i++) cin >> q[i];
for(int i = 1;i <= n * 2; i++) cin >> s[i];
int ans = 0;
if (p[1] == s[1])ans += dfs(1, 0, 0), ans %= P;
if (q[1] == s[1])ans += dfs(0, 1, 1), ans %= P;
cout << ans << endl;
}