动态DP——广义矩阵乘法—狭义版
广义矩阵乘法 (抄来的结论
矩阵乘法中存在两个运算符:乘法( o p t 1 opt_1 opt1),加法( o p t 2 opt_2 opt2)。
我们也可以将矩阵乘法中的这两个运算符替换为别的算子(乘法,加法,除法,减法,max,min等)。
但是由于矩阵乘法满足结合律,如果我们要用矩阵乘法加速维护某些信息,那么我们当我们替换算子的时候需要满足 o p t 1 opt_1 opt1 对 o p t 2 opt_2 opt2 满足分配律。
举个例子:在朴素矩阵乘法中,乘法对加法满足分配律: a ∗ ( b + c ) = a ∗ b + a ∗ c a*(b + c) = a*b+a*c a∗(b+c)=a∗b+a∗c
再举个例子:广义矩阵乘法一般用来维护动态DP,更具体的,将 o p t 1 opt_1 opt1 替换为加法,将 o p t 2 opt_2 opt2 替换为取min/max。
我们很容易知道, o p t 1 opt_1 opt1 对 o p t 2 opt_2 opt2 满足分配律: a + m i n ( b , c ) = m i n ( a + b , a + c ) a+min(b, c) = min(a+b,a+c) a+min(b,c)=min(a+b,a+c)
那么在第二个例子中,就是我们常用的广义形式。
第二个例子的单位矩阵形式为:
[
0
inf
inf
inf
0
inf
inf
inf
0
]
\left[ \begin{matrix} 0 & \inf & \inf\\ \inf & 0 & \inf \\ \inf & \inf & 0 \end{matrix} \right]
⎣⎡0infinfinf0infinfinf0⎦⎤
乘法代码:
c
=
a
∗
b
c=a*b
c=a∗b
for (int i = 0; i < 5; ++i)
for (int j = 0; j < 5; ++j)
for (int k = 0; k < 5; ++k)
c[i][j] = min(c[i][j], a[i][k] + b[k][j]);
要是取max,把矩阵的 i n f inf inf 改为 − i n f -inf −inf 就好了
例题,CF750E New Year and Old Subsequence 题目链接
题目:给定一个字符串,每次询问一个子串,问最少删除多少个字符后,可以保证存在子序列 “ 2017 2017 2017” 但不存在子序列 “ 2016 2016 2016”。
考虑对于每一个子串怎么做:
显然直接做子序列匹配即可,定义 dp[i][0/1/2/3/4]
为到第i个字符匹配到了 空/2/20/201/2017
的最小需要删除的字符数量。
转移方式显然:c
是当前位置的字符
d
p
[
i
]
[
0
]
=
d
p
[
i
−
1
]
[
0
]
+
[
c
=
=
2
]
d
p
[
i
]
[
1
]
=
m
i
n
(
d
p
[
i
−
1
]
[
1
]
+
[
c
=
=
0
]
,
d
p
[
i
−
1
]
[
0
]
+
[
c
=
=
2
?
0
:
i
n
f
]
)
d
p
[
i
]
[
2
]
=
m
i
n
(
d
p
[
i
−
1
]
[
2
]
+
[
c
=
=
1
]
,
d
p
[
i
−
1
]
[
1
]
+
[
c
=
=
0
?
0
:
i
n
f
]
)
d
p
[
i
]
[
3
]
=
m
i
n
(
d
p
[
i
−
1
]
[
3
]
+
[
c
=
=
6
∣
∣
c
=
=
7
]
,
d
p
[
i
−
1
]
[
2
]
+
[
c
=
=
1
?
0
:
i
n
f
]
)
d
p
[
i
]
[
4
]
=
m
i
n
(
d
p
[
i
−
1
]
[
4
]
+
[
c
=
=
6
]
,
d
p
[
i
−
1
]
[
3
]
+
[
c
=
=
7
?
0
:
i
n
f
]
)
\begin{aligned} dp[i][0] &= dp[i-1][0] + [c==2] \\ dp[i][1] &= min(dp[i-1][1] + [c==0], dp[i-1][0] + [c==2?0:inf]) \\ dp[i][2] &= min(dp[i-1][2] + [c==1], dp[i-1][1] + [c==0?0:inf]) \\ dp[i][3] &= min(dp[i-1][3] + [c==6||c==7], dp[i-1][2] + [c==1?0:inf]) \\ dp[i][4] &= min(dp[i-1][4] + [c==6], dp[i-1][3] + [c==7?0:inf]) \\ \end{aligned}
dp[i][0]dp[i][1]dp[i][2]dp[i][3]dp[i][4]=dp[i−1][0]+[c==2]=min(dp[i−1][1]+[c==0],dp[i−1][0]+[c==2?0:inf])=min(dp[i−1][2]+[c==1],dp[i−1][1]+[c==0?0:inf])=min(dp[i−1][3]+[c==6∣∣c==7],dp[i−1][2]+[c==1?0:inf])=min(dp[i−1][4]+[c==6],dp[i−1][3]+[c==7?0:inf])
现在我们考虑快速查询一个子串的dp值。
我们观察这个式子,只存在取min和加法,那么我们可以考虑用广义矩阵转移。
然后用线段树维护一个矩阵乘法即可。
初始矩阵为 [ 0 , i n f , i n f , i n f , i n f ] [0,inf,inf,inf,inf] [0,inf,inf,inf,inf] 然后乘上这段区间矩阵。
有矩阵乘法的基础的就不用多说。
关于每个值的矩阵,构造也很显然,根据转移构造即可,例如2的矩阵:
[
1
0
inf
inf
inf
inf
0
inf
inf
inf
inf
inf
0
inf
inf
inf
inf
inf
0
inf
inf
inf
inf
inf
0
]
\left[ \begin{matrix} 1 & 0 & \inf & \inf & \inf\\ \inf & 0 & \inf & \inf & \inf\\ \inf & \inf & 0 & \inf & \inf\\ \inf & \inf & \inf & 0 & \inf\\ \inf & \inf & \inf & \inf & 0\\ \end{matrix} \right]
⎣⎢⎢⎢⎢⎡1infinfinfinf00infinfinfinfinf0infinfinfinfinf0infinfinfinfinf0⎦⎥⎥⎥⎥⎤
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
typedef long long ll;
struct maxtr
{
int a[5][5];
maxtr()
{
memset(a, 0x3f, sizeof(a));
}
void init()
{
for (int i = 0; i < 5; ++i)
a[i][i] = 0;
}
maxtr operator * (const maxtr &x) const
{
maxtr res;
for (int i = 0; i < 5; ++i)
for (int j = 0; j < 5; ++j)
for (int k = 0; k < 5; ++k)
res.a[i][j] = min(res.a[i][j], a[i][k] + x.a[k][j]);
return res;
}
void show()
{
for (int i = 0; i < 4; ++i)
{
for (int j = 0; j < 4; ++j)
cout << a[i][j] << ' ';
cout << '\n';
}
}
};
const int N = 2e5+100;
char c[N];
maxtr val[N << 2], ST;
int n, m;
void exchange(maxtr &x, char c)
{
x.init();
if (c == '2') x.a[0][0] = 1, x.a[0][1] = 0;
else if (c == '0') x.a[1][1] = 1, x.a[1][2] = 0;
else if (c == '1') x.a[2][2] = 1, x.a[2][3] = 0;
else if (c == '6') x.a[3][3] = 1, x.a[4][4] = 1;
else if (c == '7') x.a[4][4] = 1, x.a[3][4] = 0;
}
void build(int id = 1, int l = 1, int r = n)
{
if (l == r)
{
exchange(val[id], c[l]);
return;
}
int mid = (l + r) >> 1;
build(id << 1, l, mid);
build(id << 1 | 1, mid + 1, r);
val[id] = val[id << 1] * val[id << 1 | 1];
}
maxtr query(int L, int R, int id = 1, int l = 1, int r = n)
{
if (L > r || R < l) return ST;
if (L <= l && r <= R) return val[id];
int mid = (l + r) >> 1;
maxtr a = query(L, R, id << 1, l, mid);
maxtr b = query(L, R, id << 1 | 1, mid + 1, r);
return a * b;
}
void sol()
{
ST.init();
cin >> n >> m;
cin >> c + 1;
build();
int l, r;
while (m--)
{
cin >> l >> r;
auto t = query(l, r);
if (t.a[0][4] > 1e6) t.a[0][4] = -1;
cout << t.a[0][4] << '\n';
}
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int T = 1;
while (T--) sol();
return 0;
}