原形网络(Prototypical Networks)基于PyTorch的实现

我在Jay2coomzz的基础上修改了数据处理方式和evaluation_model()方法,作为初学者若有不当之处恳请大家批评指正,原文链接如下:
https://blog.csdn.net/weixin_38471579/article/details/102631018

数据集和源码链接将放到评论区

1 数据集

一个简单的数据集介绍

Omniglot数据集一共包含1623 类手写体,每一类中包含20 个样本。其中这 1623 个手写体类来自 50 个不同地区(或文明)的 alphabets,如:Latin 文明包含 26 个alphabets,Greek 包含 24 个alphabets。如images_background/Greek文件夹下的 24个 希腊字母,代表 Greek 文明下的 24 个字母类,每个字母只有 20 个样本。
一般用于训练的是 images_background文件夹下的964 类(30个地区的字母),用于测试的是images_evaluation文件夹下 659 类 (20个地区的字母)。
训练的目的是,用 964 个类来训练模型,识别 659 个新的类。测试集与训练集完全分开,问题相似但从未遇见,这正是元学习learning to learn的含义。

数据集处理

先按照文件夹名称将测试集和训练集导入

import os 
import matplotlib.image as mpimg
import numpy as np
import csv

#将图片数据转化为numpy,每一个类得数据被为训练集和测试集,并存储在字典中

os.chdir('E:/pytorch/prototypical_network/scripts')
def load_data():
	#验证集
	labels_trainData = {}
	label = 0
	for file in os.listdir('../data/images_background'):
		for dir in os.listdir('../data/images_background/' + file):
			labels_trainData[label] = []
			data = []
			for png in os.listdir('../data/images_background/' + file +'/' + dir):
				image_np = mpimg.imread('../data/images_background/' + file +'/' + dir+'/' +png)
				image_np.astype(np.float64)
				data.append(image_np)
			labels_trainData[label] = np.array(data)
			label += 1
	#测试集
	labels_testData = {}
	label = 0
	for file in os.listdir('../data/images_evaluation'):
		for dir in os.listdir('../data/images_evaluation/' + file):
			labels_testData[label] = []
			data = []
			for png in os.listdir('../data/images_evaluation/' + file +'/' + dir):
				image_np = mpimg.imread('../data/images_evaluation/' + file +'/' + dir+'/' +png)
				image_np.astype(np.float64)
				data.append(image_np)
			labels_testData[label] = np.array(data)
			label += 1
	return labels_trainData,labels_testData

图片会以字典的形式存于labels_trainData和labels_testData中。其种类(lable,也就是字典的键值)为0,1,2,3,…,962存储。(训练集963个类别,测试集658个类别)

labels_trainData ,labels_testData = load_data()

不妨查看一下数据的格式(以测试集第0类为例)

print(labels_testData[0].shape)

输出(20, 105, 105),可以看到这类文字共有20个样例,每个样例时105*105的矩阵,如果你再打印其中的一个图片,你将会看到这个图片被读取成了0,1这样的二值矩阵。(它甚至不是灰度图,省事儿多了)
打印一个图片看看:

import matplotlib.pyplot as plt
plt.imshow(labels_testData[0][3])

在这里插入图片描述
下来给数据集增加一个通道,以便之后送入神经网络

wide = labels_trainData[0][0].shape[0]
length = labels_trainData[0][0].shape[1]
	
for label in labels_trainData.keys():
	labels_trainData[label] = np.reshape(labels_trainData[label], [-1,1,wide, length])

for label in labels_testData.keys():
	labels_testData[label] = np.reshape(labels_testData[label], [-1,1,wide, length])

2 搭建网络

一些引用:

import os
import numpy as np
import h5py
import random
import csv

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
from torch.autograd import Variable

CNN

class CNNnet(torch.nn.Module):
	def __init__(self,input_shape,outDim):	
		super(CNNnet,self).__init__()
		self.conv1 = torch.nn.Sequential(
			torch.nn.Conv2d(in_channels=input_shape[0],
							out_channels=16,
							kernel_size=3,
							stride=1,
							padding=1),
			torch.nn.BatchNorm2d(16),
			torch.nn.MaxPool2d(2),
			torch.nn.ReLU()
		)
		self.conv2 = torch.nn.Sequential(
			torch.nn.Conv2d(16,32,3,1,1),
			torch.nn.BatchNorm2d(32),
			nn.MaxPool2d(2),
			torch.nn.ReLU()
		)
		self.conv3 = torch.nn.Sequential(
			torch.nn.Conv2d(32,64,3,1,1),
			torch.nn.BatchNorm2d(64),
			nn.MaxPool2d(2),
			torch.nn.ReLU()
		)
		self.conv4 = torch.nn.Sequential(
			torch.nn.Conv2d(64,64,3,1,1),
			torch.nn.BatchNorm2d(64),
			#nn.MaxPool2d(2)
			torch.nn.ReLU()
		)
		self.conv5 = torch.nn.Sequential(
			torch.nn.Conv2d(64,64,3,1,1),
			torch.nn.BatchNorm2d(64),
			#nn.MaxPool2d(2)
			torch.nn.ReLU()
		)
		self.mlp1 = torch.nn.Linear(10816,125)		#'''此处修改torch.nn.Linear(x,125)中的x位置'''
		self.mlp2 = torch.nn.Linear(125,outDim)
		
	def forward(self, x):	#'''根据__init__做相应修改'''
		x = self.conv1(x)
		x = self.conv2(x)
		x = self.conv3(x)
		x = self.conv4(x)
		x = self.conv5(x)
		x = self.mlp1(x.view(x.size(0),-1))
		x = self.mlp2(x)
		return x

这是一个5层的CNN,是训练的重点,我们的目的就是得到一个具有泛化能力的CNN,将相似类型的问题映射到一个out_dim维度的空间,这个空间应该将同类图片聚的尽可能的近,将不同类的图片聚的尽可能的远.
所以损失函数和预测函数需要我们自定义.

Protonets

CNN只是Protonets中的一部分,下面给出Protonets类的定义:

def eucli_tensor(x,y):	#计算两个tensor的欧氏距离,用于loss的计算
	return -1*torch.sqrt(torch.sum((x-y)*(x-y))).view(1)

class Protonets(object):
	def __init__(self,input_shape,outDim,Ns,Nq,Nc,log_data,step,trainval=False):
		#Ns:支持集数量,Nq:查询集数量,Nc:每次迭代所选类数,log_data:模型和类对应的中心所要储存的位置,step:若trainval==True则读取已训练的第step步的模型和中心,trainval:是否从新开始训练模型
		self.input_shape = input_shape
		self.outDim = outDim
		self.batchSize = 1
		self.Ns = Ns
		self.Nq = Nq
		self.Nc = Nc
		if trainval == False:
			#若训练一个新的模型,初始化CNN和中心点
			self.center = {}
			self.model = CNNnet(input_shape,outDim).cuda()
		else:
			#否则加载CNN模型和中心点
			self.center = {}
			self.model = torch.load(log_data+'model_net_'+str(step)+'.pkl')		#'''修改,存储模型的文件名'''
			self.load_center(log_data+'model_center_'+str(step)+'.csv')	#'''修改,存储中心的文件名'''
	
	def compute_center(self,data_set):	#data_set是一个numpy对象,是某一个支持集,计算支持集对应的中心的点
		center = 0
		for i in range(self.Ns):
			data = np.reshape(data_set[i], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
			data = Variable(torch.from_numpy(data)).cuda()
			data = self.model(data)[0]	#将查询点嵌入另一个空间
			if i == 0:
				center = data
			else:
				center += data
		center /= self.Ns
		return center
	
	def train(self,labels_data,class_number):	#网络的训练
		#Select class indices for episode
		class_index = list(range(class_number))
		random.shuffle(class_index)
		choss_class_index = class_index[:self.Nc]#选20个类
		sample = {'xc':[],'xq':[]}
		for label in choss_class_index:
			D_set = labels_data[label]
			#从D_set随机取支持集和查询集
			support_set,query_set = self.randomSample(D_set)
			#计算中心点
			self.center[label] = self.compute_center(support_set)
			#将中心和查询集存储在list中
			sample['xc'].append(self.center[label])	#list
			sample['xq'].append(query_set)
		#优化器
		optimizer = torch.optim.Adam(self.model.parameters(),lr=0.001)
		optimizer.zero_grad()
		protonets_loss = self.loss(sample)
		protonets_loss.backward()
		optimizer.step()
	
	def loss(self,sample):	#自定义loss
		loss_1 = autograd.Variable(torch.FloatTensor([0])).cuda()
		for i in range(self.Nc):
			query_dataSet = sample['xq'][i]
			for n in range(self.Nq):
				data = np.reshape(query_dataSet[n], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
				data = Variable(torch.from_numpy(data)).cuda()
				data = self.model(data)[0]	#将查询点嵌入另一个空间
				#查询点与每个中心点逐个计算欧氏距离
				predict = 0
				for j in range(self.Nc):
					center_j = sample['xc'][j]
					if j == 0:
						predict = eucli_tensor(data,center_j)
					else:
						predict = torch.cat((predict, eucli_tensor(data,center_j)), 0)
				#为loss叠加
				loss_1 += -1*F.log_softmax(predict,dim=0)[i]
		loss_1 /= self.Nq*self.Nc
		return loss_1
	
	def randomSample(self,D_set): #从D_set随机取支持集和查询集(20个类中的其中一个类,shape为[20,105,105])
		index_list = list(range(D_set.shape[0]))#20个图片中选5个
		random.shuffle(index_list)
		support_data_index = index_list[:self.Ns]
		query_data_index = index_list[self.Ns:self.Ns + self.Nq]
		support_set = []
		query_set = []
		for i in support_data_index:
			support_set.append(D_set[i])
		for i in query_data_index:
			query_set.append(D_set[i])
		return support_set,query_set
	
	def evaluation_model(self,labels_data,class_number):
		test_accury = []
		center_for_test={}
		class_index = list(range(class_number))#600多类
		random.shuffle(class_index)
		choss_class_index = class_index[:self.Nc]#选20个类
		sample = {'xc':[],'xq':[]}
		for label in choss_class_index:
			D_set = labels_data[label]
			#从D_set随机取支持集和查询集
			support_set,query_set = self.randomSample(D_set)
			#计算中心点
			center_for_test[label] = self.compute_center(support_set)
			#将中心和查询集存储在list中
			sample['xc'].append(center_for_test[label])	#list
			sample['xq'].append(query_set)
		
		for i in range(self.Nc):
			query_dataSet = sample['xq'][i]
			for n in range(self.Nq):
				data = np.reshape(query_dataSet[n], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
				data = torch.from_numpy(data).cuda()
				data = self.model(data)[0]	#将查询点嵌入另一个空间
				#查询点与每个中心点逐个计算欧氏距离
				predict = 0
				for j in range(self.Nc):
					center_j = sample['xc'][j]
					if j == 0:
						predict = eucli_tensor(data,center_j)
					else:
						predict = torch.cat((predict, eucli_tensor(data,center_j)), 0)
				y_pre_j = int(torch.argmax(F.log_softmax(predict,dim=0)))	#离第j个中心最近
				test_accury.append(1 if y_pre_j == i else 0)
		return sum(test_accury)/len(test_accury)
	
	def save_center(self,path):
		datas = []
		for label in self.center.keys():
			datas.append([label] + list(self.center[label].cpu().detach().numpy()))
		with open(path,"w", newline="") as datacsv:
			csvwriter = csv.writer(datacsv,dialect = ("excel"))
			csvwriter.writerows(datas)
	
	def load_center(self,path):
		csvReader = csv.reader(open(path))
		for line in csvReader:
			label = int(line[0])
			center = [ float(line[i]) for i in range(1,len(line))]
			center = np.array(center)
			center = Variable(torch.from_numpy(center)).cuda()
			self.center[label] = center

大体思路是:
1.只用训练集训练网络,从训练集labels_trainData中选出Nc个类,这其中每个类有20个样本,再在这20个样本中抽取Ns个支持集,Nq个查寻集,将支持集通过CNN,取其结果计算均值,获得中心点,然后用验证集计算loss更新梯度.再进行下一轮学习
2.每经过50轮学习,使用验证集计算模型的准确度,评估方案如下:在labels_testData中中选出Nc个类,这其中每个类有20个样本,再在这20个样本中抽取Ns个支持集,Nq个查寻集,将支持集通过CNN,取其结果计算均值,获得中心点,然后用验证集计算其到每个中心点的距离,取最近的点为预测值,由此评估网络性能
训练部分如下:
以20way-5shot为例:

protonets = Protonets((1,wide,length),10,5,5,20,'../log/',50)
for n in range(10000):	 ##随机选取x个类进行一个episode的训练
		protonets.train(labels_trainData,class_number_train)                                                                                                                                                                                                                                                                                                                                                                                
		if n % 50 == 0 and n != 0:	#每50次存储一次模型,并测试模型的准确率,训练集的准确率和测试集的准确率被存储在model_step_eval.txt中
			torch.save(protonets.model, '../log/model_net_'+str(n)+'.pkl')
			protonets.save_center('../log/model_center_'+str(n)+'.csv')
			test_accury = protonets.evaluation_model(labels_testData,class_number_test)
			print(test_accury)
			str_data = str(n) + ',' + str('       test_accury     ') + str(test_accury) + '\n'
			with open('../log/model_step_eval.txt', "a") as f:
				f.write(str_data)
		print(n)

训练收敛极快,下面列出部分结果:(保存在model_step_eval.txt中)

50,       test_accury     0.72
100,       test_accury     0.86
150,       test_accury     0.88
200,       test_accury     0.76
250,       test_accury     0.77
300,       test_accury     0.88
350,       test_accury     0.86
400,       test_accury     0.89
450,       test_accury     0.95
500,       test_accury     0.96
550,       test_accury     0.87
600,       test_accury     0.93
650,       test_accury     0.9
700,       test_accury     0.98
750,       test_accury     0.98
...........
...........
3400,       test_accury     0.9
3450,       test_accury     0.99
3500,       test_accury     0.92
3550,       test_accury     1.0
  • 12
    点赞
  • 67
    收藏
    觉得还不错? 一键收藏
  • 24
    评论
评论 24
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值