原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列4

在这里插入图片描述


原型网络进行分类的基本流程

利用原型网络进行分类,基本流程如下:

1.对于每一个样本使用编码的方式fφ (),学习到每一个样本的编码表示(信息抽取)。
2.学习到每一个样本的编码表示之后,对于每一个分类下的所有的样本编码进行求和求取平均的操作,将结果作为分类的原型表示。
3.当一个新的数据样本被输入到网络中的时候,对于这个样本使用fφ(),生成其编码表示。
4.计算新的样本的编码表示和每一个分类的原型表示之间的距离情况,通过最下距离来确定查询样本属于哪一个分类。
5.在计算出所有的分类之间的距离之后,使用softmax的方式将距离转换成概率的形式。

一、原始代码—计算欧氏距离,设计原型网络(计算原型+开始训练)

def eucli_tensor(x,y):	#计算两个tensor的欧氏距离,用于loss的计算
	return -1*torch.sqrt(torch.sum((x-y)*(x-y))).view(1)

class Protonets(object):
	def __init__(self,input_shape,outDim,Ns,Nq,Nc,log_data,step,trainval=False):
		#Ns:支持集数量,Nq:查询集数量,Nc:每次迭代所选类数,log_data:模型和类对应的中心所要储存的位置,step:若trainval==True则读取已训练的第step步的模型和中心,trainval:是否从新开始训练模型
		self.input_shape = input_shape
		self.outDim = outDim
		self.batchSize = 1
		self.Ns = Ns
		self.Nq = Nq
		self.Nc = Nc
		if trainval == False:
			#若训练一个新的模型,初始化CNN和中心点
			self.center = {}
			self.model = CNNnet(input_shape,outDim)
		else:
			#否则加载CNN模型和中心点
			self.center = {}
			self.model = torch.load(log_data+'model_net_'+str(step)+'.pkl')		#'''修改,存储模型的文件名'''
			self.load_center(log_data+'model_center_'+str(step)+'.csv')	#'''修改,存储中心的文件名'''
	
	def compute_center(self,data_set):	#data_set是一个numpy对象,是某一个支持集,计算支持集对应的中心的点
		center = 0
		for i in range(self.Ns):
			data = np.reshape(data_set[i], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
			data = Variable(torch.from_numpy(data))
			data = self.model(data)[0]	#将查询点嵌入另一个空间
			if i == 0:
				center = data
			else:
				center += data
		center /= self.Ns
		return center
	
	def train(self,labels_data,class_number):	#网络的训练
		#Select class indices for episode
		class_index = list(range(class_number))
		random.shuffle(class_index)
		choss_class_index = class_index[:self.Nc]#选20个类
		sample = {'xc':[],'xq':[]}
		for label in choss_class_index:
			D_set = labels_data[label]
			#从D_set随机取支持集和查询集
			support_set,query_set = self.randomSample(D_set)
			#计算中心点
			self.center[label] = self.compute_center(support_set)
			#将中心和查询集存储在list中
			sample['xc'].append(self.center[label])	#list
			sample['xq'].append(query_set)
		#优化器
		optimizer = torch.optim.Adam(self.model.parameters(),lr=0.001)
		optimizer.zero_grad()
		protonets_loss = self.loss(sample)
		protonets_loss.backward()
		optimizer.step()

二、每一行代码的详细解释

def eucli_tensor(x, y):
    return -1 * torch.sqrt(torch.sum((x - y) * (x - y))).view(1)

这是一个函数,用于计算两个张量(tensor)之间的欧氏距离(Euclidean Distance)。它通过计算两个张量差的平方和的平方根,并乘以-1。最后通过 view(1) 将结果转换成一个形状为 (1,) 的张量。

class Protonets(object):
    def __init__(self, input_shape, outDim, Ns, Nq, Nc, log_data, step, trainval=False):
        self.input_shape = input_shape
        self.outDim = outDim
        self.batchSize = 1
        self.Ns = Ns
        self.Nq = Nq
        self.Nc = Nc
        if trainval == False:
            self.center = {}
            self.model = CNNnet(input_shape, outDim)
        else:
            self.center = {}
            self.model = torch.load(log_data + 'model_net_' + str(step) + '.pkl')
            self.load_center(log_data + 'model_center_' + str(step) + '.csv')

这是一个 Protonets 类的定义,它有一个构造函数 __init__,用于初始化类的属性。其中的参数含义如下:

  • input_shape:输入数据的形状。
  • outDim:输出维度。
  • Ns:支持集(support set)的数量。
  • Nq:查询集(query set)的数量。
  • Nc:每次迭代所选类别数。
  • log_data:模型和中心的存储位置。
  • step:训练的步数。
  • trainval:是否重新开始训练模型。

根据 trainval 的取值,分为两种情况进行初始化:

  1. trainval=False:表示训练一个新的模型。此时,初始化一个空的中心字典 self.center,并创建一个名为 CNNnet 的模型对象 self.model,其输入形状为 input_shape,输出维度为 outDim
  2. trainval=True:表示加载已经训练好的模型和中心。同样,初始化一个空的中心字典 self.center。然后通过 torch.load 加载之前训练保存的模型文件 log_data + 'model_net_' + str(step) + '.pkl',并将其赋给 self.model。接着调用 load_center 方法加载之前训练保存的中心文件 log_data + 'model_center_' + str(step) + '.csv'

总结

这段代码是一个用于实现 Protonets 算法的类。

原型网络是一种基于神经网络的机器学习模型,可以在PyTorch框架中实现。它也被称为卷积神经网络(Convolutional Neural Networks,CNNs)的一种形式。原型网络的设计灵感来源于生物视觉系统,能够对图像进行高效的特征提取和图像识别。 原型网络的基本结构包括卷积层、池化层和全连接层。卷积层利用卷积操作从输入图像中提取特征,每个卷积核都负责检测图像中的不同特征。池化层则用于减少特征图的尺寸,并且提取最显著的特征。全连接层将特征映射到不同的类别,用于分类。 在PyTorch中,我们可以使用torch.nn模块来构建原型网络。首先,我们需要定义一个继承自torch.nn.Module的网络类,并在其中定义网络的组件,如卷积层和全连接层。然后,我们可以通过重写forward方法来定义网络的前向传播过程。在前向传播过程中,我们可以使用PyTorch提供的各种函数来实现卷积、池化和全连接操作。 为了训练原型网络,我们还需要定义一个损失函数和优化器。常用的损失函数包括交叉熵损失函数和均方差损失函数。我们可以使用torch.optim模块中的优化器来更新网络的权重,常用的优化器有随机梯度下降(SGD)和Adam。 在训练过程中,我们首先将输入数据传入网络中进行前向传播,然后计算损失函数的值。接着,通过反向传播计算损失函数对网络权重的梯度,并使用优化器更新网络的权重参数。重复这个过程直到达到设定的训练迭代次数。最后,我们可以使用训练好的网络对新的图像进行分类预测。 总之,原型网络是一种在PyTorch框架中实现的神经网络模型,它通过卷积、池化和全连接层来提取和分类图像特征。使用PyTorch的torch.nn模块和torch.optim模块,我们可以方便地构建、训练和利用原型网络进行图像分类任务。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值