本周需要学习如何使用scipy。scipy中包含了许多跟numpy一样的函数,因此使用起来会有许多相似之处。
Exercise 10.1: Least squares
生成一个m行n列的矩阵,并要求m>n。同时生成一个m维向量。
求解x = arg minx ||Ax - b||2.
该题可以使用lstsq来求解,scipy.linalg.lstsq使用参数跟numpy.linalg.lstsq一致。
import numpy
import numpy.matlib
import scipy.linalg
m = 5
n = 3
A = numpy.matlib.rand((m, n))
b = numpy.matlib.rand((m, 1))
x = scipy.linalg.lstsq(A, b)[0]
print("A")
print(A)
print("b")
print(b)
print("x")
print(x)
结果如下
A
[[0.92962385 0.05363313 0.50548348]
[0.95964889 0.56054093 0.58568573]
[0.32413643 0.20750581 0.26737541]
[0.77044965 0.53036492 0.04242422]
[0.30828577 0.24769184 0.88993455]]
b
[[0.7871204 ]
[0.87042653]
[0.7409384 ]
[0.490697 ]
[0.87699647]]
x
[[0.3817007 ]
[0.28595985]
[0.80332776]]
Exercise 10.2: Optimization
寻找函数f(x) = sin^2(x - 2)e^-x2的最大值。
scipy.optimize.minimize可以求解出某个函数的最小值,然而本题需要求解的是最大值,因此需要将其取反,此时最小值就是原函数的最大值了。并且该方法需要使用一个可调用函数来作为参数,一般情况下函数需要使用def来定义,然而也可以使用lambda来实现matlab中的@(x)f(x)一样的匿名函数。
import math
import scipy.optimize
f = lambda x: - pow(math.sin(x - 2), 2) * math.exp(-pow(x, 2))
y = scipy.optimize.minimize(f, 0)
print(y)
结果如下
fun: -0.9116854118471545
hess_inv: array([[0.2680098]])
jac: array([-1.49011612e-08])
message: 'Optimization terminated successfully.'
nfev: 18
nit: 4
njev: 6
status: 0
success: True
x: array([0.21624132])
由于我们对函数进行了取反,因此fun显示的最小值取反后为原函数的最大值。
可以看到,当x=0.21624132时,f取最大值,为0.9116854118471545。
Exercise 10.3: Pairwise distances
生成一个n行m列的矩阵,计算每两行之间的距离。
给出一个应用例子就是,由n座城市,用两列的方式给定他们的坐标,使用一个表格来记录每两座城市之间的距离。
计算距离通常使用几何距离,也就是欧几里得距离。
import numpy
import numpy.matlib
import scipy.spatial.distance
n = 5
m = 6
X = numpy.matlib.rand((n, m))
dist = numpy.zeros((n, n))
for i in range(0, n):
for j in range(0, n):
dist[i][j] = scipy.spatial.distance.euclidean(X[i], X[j])
print("X")
print(X)
print("")
print("dist")
print(dist)
结果如下
X
[[0.20273567 0.25231566 0.44017816 0.70797027 0.1422286 0.42647038]
[0.1075678 0.498664 0.14429584 0.76133804 0.48465502 0.37378304]
[0.40225683 0.97926513 0.68925144 0.85595932 0.10464522 0.44246778]
[0.29966126 0.54135749 0.66086263 0.78189069 0.07700883 0.78110347]
[0.7007273 0.4551945 0.69823013 0.48250564 0.27664712 0.14748214]]
dist
[[0. 0.52931154 0.80862285 0.52642555 0.70886063]
[0.52931154 0. 0.87913248 0.79879386 0.91257761]
[0.80862285 0.87913248 0. 0.56922661 0.78732718]
[0.52642555 0.79879386 0.56922661 0. 0.83704073]
[0.70886063 0.91257761 0.78732718 0.83704073 0. ]]