pytorch实现带标签格式数据的模型训练

1.训练数据读入

注:以下模拟数据,主要讲解方法。

标签数据


下面函数即为实现标签数据的读入

def reader(txt):

    fh = open(txt)  
    c=0  
    imgs=[]  
    class_names=[]  
    for line in  fh.readlines():  
        if c==0:  
            class_names=[n.strip() for n in line.rstrip().split('   ')]  
        else:  
            cls = line.split()   
            fn = cls.pop(0)
            imgs.append((fn, tuple([float(v) for v in cls])))  
        c=c+1

    return class_names,imgs

其中,返回imgs是标签元组,即[1,0,0,1],class_names为属性名,即sex。

如人脸特征数据,也可以通过reader()读入。

2.简单模型设计(以全连层为例)

cmodel=nn.Linear(100, 2) ,(或者nn.Sequential(nn.Linear(100, 2))

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.classify=cmodel
    def forward(self, x):
        x=self.classify(x)
        return x,

3.模型训练

训练集读入

train_data_loader = torch.utils.data.DataLoader(  \
         ImageFloder(root = "./fea.txt", label = "./label.txt"), batch_size= 2, shuffle= False, num_workers= 4)

其中,root,label分别是特征与标签文件地址, ImageFloder类定义如下:

class ImageFloder(data.Dataset):  
    def __init__(self, root, label):

self.classes1,self.imgs1 = reader(label)
        self.classes2,self.imgs2 = reader(root)

    def __getitem__(self, index):  
        fn1, label1 = self.imgs1[index]
        fn2, label2 = self.imgs2[index]

return torch.Tensor(label1),torch.Tensor(label2)

    def __len__(self):  
        return len(self.imgs1)

训练代码详见项目:

https://github.com/eeric/pytorch-model-training-label

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
PyTorch中,训练模型通常分为以下几个步骤: 1. 数据准备:首先,你需要准备好用于训练模型数据集。这可能涉及到数据的加载、预处理、划分等操作。 2. 模型定义:接下来,你需要定义模型的结构。在PyTorch中,你可以使用`torch.nn.Module`类来创建自定义的神经网络模型,并定义模型的前向传播过程。 3. 损失函数定义:在训练模型过程中,你需要定义一个损失函数来衡量模型的预测结果与真实标签之间的差异。PyTorch提供了各种损失函数,如均方误差(MSE)、交叉熵损失等。 4. 优化器定义:为了更新模型的参数,你需要选择一个优化器算法。在PyTorch中,你可以使用`torch.optim`模块中的优化器,如随机梯度下降(SGD)、Adam等。 5. 训练循环:接下来,你需要编写一个循环来迭代训练模型。在每个迭代步骤中,你需要执行以下操作: - 前向传播:将输入数据传入模型,并获得模型的预测结果。 - 计算损失:将预测结果与真实标签进行比较,并计算损失值。 - 反向传播:根据损失值,计算梯度并反向传播到模型的参数。 - 参数更新:使用优化器来更新模型的参数。 6. 模型评估:在训练过程中,你可以定期评估模型在验证集或测试集上的性能。这可以帮助你监控和调整模型训练过程。 7. 模型保存和加载:在训练完成后,你可以将模型保存到硬盘上,以便以后使用。同样地,你也可以从保存的模型文件中加载模型

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值