记录下代码,分享给大家。主要目的:以后和别人相亲的时候能有底气地多给对方一篇博客作为彩礼。代码比较短,但还是花了一段时间理解,还没测试,不知道有没有bug。
参考链接:PyTorch 实现孪生网络识别面部相似度-PyTorch 中文网 https://www.pytorchtutorial.com/pytorch-one-shot-learning/#i-4
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 24 10:00:24 2018
Paper: Siamese Neural Networks for One-shot Image Recognition
links: https://www.cnblogs.com/denny402/p/7520063.html
"""
import torch
from torch.autograd import Variable
import os
import random
import linecache
import numpy as np
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import PIL.ImageOps
import matplotlib.pyplot as plt
root = r'E:/siamese/faces'
train_txt_root = r'E:/siamese/faces/train.txt'
test_txt_root = r'E:/siamese/faces/test.txt'
train_batch_size = 32
train_number_epochs = 100
# 图片可视化函数
def imshow(img, text=None, should_save=False):
npimg = img.numpy()
plt.axis("off")
if text:
plt.text(75, 8, text, style='italic', fontweight='bold',
bbox={'facecolor': 'white', 'alpha': 0.8, 'pad': 10})
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
def show_plot(iteration, loss):
plt.plot(iteration, loss)
plt.show()
def convert(train=True):
if (train):
f = open(train_txt_root, 'w') #将图片地址写在/data_faces/train.txt中
data_path = root
if (not os.path.exists(data_path)):
os.makedirs(data_path)
for i in range(40):
for j in range(8):
img_path = data_path + '/s' + str(i + 1) + '/' + str(j + 1) &#