该程序是介绍了一个孪生神经网络,大致就是给出两张图片,比较两张图片的相似性,比如人脸对比等
这里的数据集是mnist,代码中首先会建立一些图片对,就是pairs,如果是同类的图片,则把y值设置为 1,如果是不同类的图片,则把 y 值设置为 0;y值就是相似度
基础神经网络结构为:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 28, 28) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 784) 0
_________________________________________________________________
dense_1 (Dense) (None, 128) 100480
_________________________________________________________________
dropout_1 (Dropout) (None, 128) 0
_________________________________________________________________
dense_2 (Dense) (None, 128) 16512
_________________________________________________________________
dropout_2 (Dropout) (None, 128) 0
_________________________________________________________________
dense_3 (Dense) (None, 128) 16512
=================================================================
Total params: 133,504
Trainable params: 133,504
Non-trainable params: 0
_________________________________________________________________
None
就是给一个数字图片,经过一堆计算之后,会生成一个128维的向量,
然后计算两张图片的运算结果向量之间的欧氏距离,也就是函数 euclidean_distance,向量平方和再开方,其值肯定是一个整数
我们猜想一下,如果两张图片,送的数据一模一样,那么其距离结果肯定为 0,如果完全不同的两张图片,那么运算结果也应该是一个很大的数
下面是一个自定义的损失函数 contrastive_loss,损失函数的参数为 预期值 和 实际运算结果;
这里再强调一下,y_pred 的值肯定不会小于 0,因为是平方和,再开方,肯定大于等于 0
损失函数的返回值为:
return y_true * square_pred + (1 - y_true) * margin_square
因为 y_true 是我们在数据集预处理中设置的,其值可能为 0 ,或 1,
如果 y_true 为 1,也就是两张图片预期一样时,则返回结果可以简化为 square_pred
square_pred 就是欧氏距离;
如果 y_true 为 0, 也就是两张图片预期不一样时,则返回结果可以简化为 margin_square;
margin_square = K.square(K.maximum(margin - y_pred, 0))
损失函数就是,在结果符合预期时,损失值很小,在不符合预期时,损失值很大;
margin_square 就是 (1 - 欧氏距离)的平方,比如传了一个1,一个2,那么欧氏距离应该是很大,比如 y_pred 值为0.9,那就是符合预期,(1 - y_pred)**2,结果为 0.01,损失就很小;
而如果传了不同的两个值,一个1,一个2,结果欧式距离很小,比如 y_pred 值为 0.01,那么损失函数应该很大,这里确实很大,0.99 ** 2,差不多为 1 了,需要神经网络进行优化;
但如果传入的是两个不同的图片,计算结果 y_pred 为 10,那其实也是符合预期,反正两张不同的图片 y_pred 越大越好,那么这时候计算 K.maximum(margin - y_pred, 0) 就是 K.maximum(-9, 0),也就是 0,损失为0
其他就没有什么难点了