概述
本题目对数学基本知识带有一定的考察,首先,需要知道矩阵的乘法怎么算、一个向量点乘矩阵怎么算,具体的知识在线性代数有详细讲述,编码实现这俩功能。然后,在算
Q
×
K
T
)
)
×
V
Q \times K^T)) \times V
Q×KT))×V的过程中,通过对元素的预处理、交换运算次序的方式,对程序进行改进以减少时间使用、空间使用。题目版权归CCF所有,真题跳转官网查看。
真题来源: 矩阵运算
官网地址:www.cspro.org(模拟考试入口)
题目分析
有大小为 n 行 d 列( n × d n \times d n×d)的三个矩阵 Q、K、V ,以及长度为n的一维行向量W,在依次输入 Q、K、V ,W 的每个元素后,计算 ( W ⋅ ( Q × K T ) ) × V (W \cdot (Q \times K^T)) \times V (W⋅(Q×KT))×V并输出计算结果,其中 K T K^T KT为 K 的转置。
暴力题解
- 第一步,接受参数 n 行 d 列,构造并初始化矩阵 Q、K、V ,由于是无序的,可统一存在矩阵字典 matrix 中,键为矩阵名,值是二维列表。
- 第二步,构造并初始化 W 和 K T K^T KT 。W 直接存在一维列表; K T K^T KT需要对 K 遍历,将 K [ i ] [ j ] K[i][j] K[i][j]存储在 K T [ j ] [ i ] K^T[j][i] KT[j][i]处。
- 第三步,实现 一维向量 W · 点乘的功能 _dot。参数 _a 表示 W ,_b 表示 ( Q × K T ) (Q \times K^T) (Q×KT)部分,两个for循环,使得 ( Q × K T ) (Q \times K^T) (Q×KT)第i行每一个值✖️W第i个值。
- 第四步,实现矩阵叉乘的功能 _cro。参数 _a 和 _b 分别乘号代表左、右两边的矩阵,设 _a 为 n × d n \times d n×d 的矩阵、 _b 为 d × m d \times m d×m 的矩阵。按照矩阵乘法运算公式:新矩阵 _c 的行数 = 左边矩阵 _a 的行数,新矩阵 _c 的列数 = 右边矩阵 ** b** 的列数;新矩阵的元素 $ _c[i][j] = \displaystyle \sum^{d-1}{y = 0}a{ik}b_{kj} = a_{i0}b_{0j} + a_{i1}b_{1j} + …+a_{i(d-1)} b_{(d-1)j}$ ,i和j都是从0开始。
- 第五步,按照既定公式,从左到右,由内而外,依次进行计算。然后,将计算的结果按照给定格式进行输出,这里使用了字符串拼接的方法。
具体实现:
def _dot(n, _a, _b): # 点乘
for i in range(n): # n == len(_a)
for j in range(len(_b[i])):
_b[i][j] *= _a[i]
return _b
def _cro(_a, _b): # 叉乘
_c = [] # 相乘后的矩阵
for i in range(len(_a)): # _a的行数决定了_c的行数
t1 = []
for j in range(len(_b[0])): # _b的列数决定了_c的列数
t2 = 0
for k in range(len(_b)):
t2 += _a[i][k] * _b[k][j]
t1.append(t2)
_c.append(t1)
return _c
def main():
n, d = map(int, input().split())
matrix = {}
matrix_key = ["Q", "K", "V"]
for i in range(n * 3): # 生成Q,K,V矩阵
site = int(i / n)
line = list(map(int, input().split()))
if matrix_key[site] not in matrix.keys():
matrix[matrix_key[site]] = []
matrix[matrix_key[site]].append(line)
W = list(map(int, input().split())) # 算W向量
KT = [] # K的转置
for i in range(d): # 算K的转置
KT.append([matrix["K"][j][i] for j in range(n)])
# 直接按照公式依次计算
res1 = _cro(matrix["Q"], KT)
res2 = _dot(n, W, res1)
res3 = _cro(res2, matrix["V"])
for item in res3:
st = ""
for it in item:
st += "{} ".format(it)
s = st[:-1]
print(s)
if __name__ == "__main__":
main()
提交结果:
满分题解
利用矩阵的特性,减少循环次数,进而降低时间复杂度。 ( W ⋅ ( Q × K T ) ) × V (W \cdot (Q \times K^T)) \times V (W⋅(Q×KT))×V实际上就是 1 ∗ n 的矩阵 × n ∗ d 的矩阵 × d ∗ n 的矩阵 × n ∗ d 的矩阵 1 *n的矩阵\times n*d的矩阵 \times d*n的矩阵 \times n*d的矩阵 1∗n的矩阵×n∗d的矩阵×d∗n的矩阵×n∗d的矩阵,可结合矩阵的运算法则 (AB)C = A(BC) 。使得问题整体从 O ( d n 2 ) O(dn^2) O(dn2)转为 O ( n d 2 ) O(nd^2) O(nd2)的。
- 先算 K T × V K^T \times V KT×V得到 d ∗ d d*d d∗d 的矩阵 res1 。
- 再算 W ⋅ Q W \cdot Q W⋅Q 得到 1 ∗ d 1*d 1∗d 的矩阵 res2。
- 然后算 ( W ⋅ ( Q × K T ) ) × V (W \cdot (Q \times K^T)) \times V (W⋅(Q×KT))×V 即 r e s 1 × r e s 2 res1 \times res2 res1×res2,得到 1 ∗ d 1*d 1∗d 的结果。
- 最后按行解包并输出。
具体实现:
def _dot(n, _a, _b):
for i in range(n):
for j in range(len(_b[i])):
_b[i][j] *= _a[i]
return _b
def _cro(_a, _b):
_c = []
for i in range(len(_a)):
t1 = []
for j in range(len(_b[0])):
t2 = 0
for k in range(len(_b)):
t2 += _a[i][k] * _b[k][j]
t1.append(t2)
_c.append(t1)
return _c
def main():
n, d = map(int, input().split())
matrix = {}
matrix_key = ["Q", "K", "V"]
for i in range(n * 3):
site = int(i / n)
line = list(map(int, input().split()))
if matrix_key[site] not in matrix.keys():
matrix[matrix_key[site]] = []
matrix[matrix_key[site]].append(line)
W = list(map(int, input().split()))
KT = []
for i in range(d):
KT.append([matrix["K"][j][i] for j in range(n)])
# 调换计算的顺序
res1 = _cro(KT, matrix["V"])
res2 = _cro(matrix["Q"], res1)
res3 = _dot(n, W, res2)
for item in res3:
print(*item) # *具有解包的作用
if __name__ == "__main__":
main()
提交结果:
简化编码
在考场上,尽可能以最短的时间内拿到最高的分为目的,可利用高阶语法和函数、使用给定的已知参数,节省编码时间。
- 自右向左利用结合率先算矩阵乘法,最后算 一维向量 × 矩阵 一维向量 \times 矩阵 一维向量×矩阵
- 嵌套使用列表表达式,结合***map()***记录矩阵 Q、K、V、W。
- 利用题目明确给出的n行、d列,只针对本题的 Q、K、V、W 实现矩阵的乘法。
具体实现:
n, d = map(int, input().split())
Q = [[i for i in map(int, input().split())] for j in range(n)]
K = [[i for i in map(int, input().split())] for j in range(n)]
V = [[i for i in map(int, input().split())] for j in range(n)]
W = [i for i in map(int, input().split())]
tmp = []
res = []
# K的转置 * V => tmp
for i in range(d):
tmp.append([])
for j in range(d):
tmp[i].append(0)
for k in range(n):
tmp[i][j] += K[k][i]*V[k][j]
# Q * tmp => res
for i in range(n):
res.append([])
for j in range(d):
res[i].append(0)
for k in range(d):
res[i][j] += Q[i][k]*tmp[k][j]
res[i][j] *= W[i] # 对第i行所有元素乘w[i]
# 按格式输出
for i in range(n):
print(*res[i])