1.训练数据读入
注:以下模拟数据,主要讲解方法。
标签数据
下面函数即为实现标签数据的读入
def reader(txt):
fh = open(txt)c=0
imgs=[]
class_names=[]
for line in fh.readlines():
if c==0:
class_names=[n.strip() for n in line.rstrip().split(' ')]
else:
cls = line.split()
fn = cls.pop(0)
imgs.append((fn, tuple([float(v) for v in cls])))
c=c+1
return class_names,imgs
其中,返回imgs是标签元组,即[1,0,0,1],class_names为属性名,即sex。
如人脸特征数据,也可以通过reader()读入。
2.简单模型设计(以全连层为例)
cmodel=nn.Linear(100, 2) ,(或者nn.Sequential(nn.Linear(100, 2))
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.classify=cmodel
def forward(self, x):
x=self.classify(x)
return x,
3.模型训练
训练集读入
train_data_loader = torch.utils.data.DataLoader( \
ImageFloder(root = "./fea.txt", label = "./label.txt"), batch_size= 2, shuffle= False, num_workers= 4)
其中,root,label分别是特征与标签文件地址, ImageFloder类定义如下:
class ImageFloder(data.Dataset):
def __init__(self, root, label):
self.classes1,self.imgs1 = reader(label)
self.classes2,self.imgs2 = reader(root)
def __getitem__(self, index):
fn1, label1 = self.imgs1[index]
fn2, label2 = self.imgs2[index]
return torch.Tensor(label1),torch.Tensor(label2)
def __len__(self):
return len(self.imgs1)
训练代码详见项目:
https://github.com/eeric/pytorch-model-training-label