pytorch读取tfrecord文件
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import transforms, datasets
from tqdm import tqdm
from tfrecord.torch.dataset import TFRecordDataset
from model import resnet18
def parse_tfrecord(single_record): #解码tfrecord文件并将其转化成训练时的张量,这个其实是对数据transform
image, label = single_record #和tensorflow解码类似
image = torch.tensor(single_record["image/encoded"]) #先将其转化为向量
label = torch.tensor(single_record["image/label"]).squeeze()
#label读出[[label1], [label2],...],如果不降维,你每次取label就直接是一个元组[label1],无法进行训练
#降维之后就是[label1, label2,...]
image = torchvision.io.decode_jpeg(image).float()
#将image读出来重组成jpeg,操作和tensorflow类似,.float()是因为网络权重是float类型,两者必须相同
return (image, label)
def main():
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
batch_size = 128
data_root = os.path.abspath(os.path.join(os.getcwd())) # get data root path
image_path = os.path.join(data_root, "dataset","webface") # web_face data set path
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
train_tfrecord_path = os.path.join(image_path, "webface_train.tfrecord")
index_path = None
description = {"image/encoded": "byte", "image/label": "int"}
#description是解码tfrecord,一定要知道转化成tfrecord时它用的格式string,就是"image/encoded" "image/label",不然会报错
dataset = TFRecordDataset(train_tfrecord_path, index_path, description, transform=parse_tfrecord)
#这一步就把tfrecord加载到dataset中了,后续操作基本上没有特殊的地方了
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
#这个loader也可以继续用
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
print('Using {} dataloader workers every process'.format(nw))
net = resnet18() #加载你自己写的网络
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 10575)
net.to(device)
# define loss function
loss_function = nn.CrossEntropyLoss()
# construct an optimizer
params = [p for p in net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.001)
epochs = 10
best_acc = 0.0
save_path = './resNet18_webface.pth'
train_steps = 408072
val_num = 45341
for epoch in range(epochs):
# train
net.train()
running_loss = 0.0
train_bar = tqdm(loader, file=sys.stdout)
for step, data in enumerate(loader):
images, labels = data
optimizer.zero_grad()
logits = net(images.to(device))
loss = loss_function(logits, labels.to(device))
loss.backward()
optimizer.step()
torch.save(net.state_dict(), save_path)
# print statistics
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
epochs,
loss)
print('Finished Training')
if __name__ == '__main__':
main()
pytorch的dataset类
最简单的dataset:
torch.utils.data.Dataset()
#有len和getitem两个函数,自己可以重写
torch.utils.data.TensorDataset(x,y)
#继承自Dataset,把数据x,y作为参数加载进dataset
pytorch加载图片:
from torchvision.datatsets import ImageFolder
dataset = ImageFolder(root='root_path', transform=None, loader=default_loader)
#要求root是根目录,在这个目录下有几个文件夹,每个文件夹表示一个类别
#transform对图片做预处理,然后通过loader将图片转换成我们需要的图片类型进入神经网络
pytorch使用自带的数据集:
torchvision.datasets.MINST(root,transform,train,download)
#train=Ture就是训练集,download=False是无需下载数据集,直接下载很慢,可以先准备好
#其余还有下列数据集
dset.CocoCaptions(root="dir where images are", annFile="json annotation file",
[transform, target_transform])
dset.CocoDetection(root="dir where images are", annFile="json annotation file",
[transform, target_transform])
dset.LSUN(db_path, classes='train', [transform, target_transform])
Imagenet-12
#https://github.com/pytorch/examples/blob/27e2a46c1d1505324032b1d94fc6ce24d5b67e97/imagenet/main.py#L48-L62
dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)
dset.STL10(root, split='train', transform=None, target_transform=None, download=False)
pytorch的dataloader类
类似于迭代器,还可以预处理数据
torch.utils.data.DataLoader(Dataset,batch_size,shuffle,collate_fn)
#collate_fn表示如何取样本的,可以定义自己的函数来实现自己想要的功能
#Dataset只是一个打包工具,DataLoader把数据读入内存
TFRecord文件
tfrecord文件包含了类型为tf.train.Example的协议内存块,而协议内存块中又包含了字段features,features中包含了若干个feature,每个feature是一个map,key-value键值对,key是string类型,valuew是feature类型的消息体,包含三种类型:byte,float,int;key和value都是列表;