pytorch实现带标签格式数据的模型训练

1.训练数据读入

注:以下模拟数据,主要讲解方法。

标签数据


下面函数即为实现标签数据的读入

def reader(txt):

    fh = open(txt)  
    c=0  
    imgs=[]  
    class_names=[]  
    for line in  fh.readlines():  
        if c==0:  
            class_names=[n.strip() for n in line.rstrip().split('   ')]  
        else:  
            cls = line.split()   
            fn = cls.pop(0)
            imgs.append((fn, tuple([float(v) for v in cls])))  
        c=c+1

    return class_names,imgs

其中,返回imgs是标签元组,即[1,0,0,1],class_names为属性名,即sex。

如人脸特征数据,也可以通过reader()读入。

2.简单模型设计(以全连层为例)

cmodel=nn.Linear(100, 2) ,(或者nn.Sequential(nn.Linear(100, 2))

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.classify=cmodel
    def forward(self, x):
        x=self.classify(x)
        return x,

3.模型训练

训练集读入

train_data_loader = torch.utils.data.DataLoader(  \
         ImageFloder(root = "./fea.txt", label = "./label.txt"), batch_size= 2, shuffle= False, num_workers= 4)

其中,root,label分别是特征与标签文件地址, ImageFloder类定义如下:

class ImageFloder(data.Dataset):  
    def __init__(self, root, label):

self.classes1,self.imgs1 = reader(label)
        self.classes2,self.imgs2 = reader(root)

    def __getitem__(self, index):  
        fn1, label1 = self.imgs1[index]
        fn2, label2 = self.imgs2[index]

return torch.Tensor(label1),torch.Tensor(label2)

    def __len__(self):  
        return len(self.imgs1)

训练代码详见项目:

https://github.com/eeric/pytorch-model-training-label

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值