【pytorch使用tfrecord数据】

pytorch读取tfrecord文件

import os
import sys

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import transforms, datasets
from tqdm import tqdm
from tfrecord.torch.dataset import TFRecordDataset

from model import resnet18


def parse_tfrecord(single_record):  #解码tfrecord文件并将其转化成训练时的张量,这个其实是对数据transform
    image, label = single_record    #和tensorflow解码类似
    image = torch.tensor(single_record["image/encoded"]) #先将其转化为向量
    
    label = torch.tensor(single_record["image/label"]).squeeze() 
    #label读出[[label1], [label2],...],如果不降维,你每次取label就直接是一个元组[label1],无法进行训练
    #降维之后就是[label1, label2,...]
    
    image = torchvision.io.decode_jpeg(image).float()
    #将image读出来重组成jpeg,操作和tensorflow类似,.float()是因为网络权重是float类型,两者必须相同
    
    return (image, label)
    

def main():
    device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    batch_size = 128

    data_root = os.path.abspath(os.path.join(os.getcwd()))  # get data root path
    image_path = os.path.join(data_root, "dataset","webface")  # web_face data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    
    train_tfrecord_path = os.path.join(image_path, "webface_train.tfrecord")
    index_path = None
    description = {"image/encoded": "byte", "image/label": "int"} 
    #description是解码tfrecord,一定要知道转化成tfrecord时它用的格式string,就是"image/encoded" "image/label",不然会报错
    
    dataset = TFRecordDataset(train_tfrecord_path, index_path, description, transform=parse_tfrecord)
    #这一步就把tfrecord加载到dataset中了,后续操作基本上没有特殊的地方了
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
    #这个loader也可以继续用

    
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) 
    print('Using {} dataloader workers every process'.format(nw))


    net = resnet18() #加载你自己写的网络
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 10575)
    net.to(device)

    # define loss function
    loss_function = nn.CrossEntropyLoss()

    # construct an optimizer
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.001)

    epochs = 10
    best_acc = 0.0
    save_path = './resNet18_webface.pth'
    train_steps = 408072
    val_num = 45341
    
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(loader, file=sys.stdout)
        for step, data in enumerate(loader):
            images, labels = data
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()

            torch.save(net.state_dict(), save_path)
            # print statistics
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
    print('Finished Training')


if __name__ == '__main__':
    main()

pytorch的dataset类

最简单的dataset:

torch.utils.data.Dataset()
#有len和getitem两个函数,自己可以重写

torch.utils.data.TensorDataset(x,y)
#继承自Dataset,把数据x,y作为参数加载进dataset

pytorch加载图片:

from torchvision.datatsets import ImageFolder
dataset = ImageFolder(root='root_path', transform=None, loader=default_loader)
#要求root是根目录,在这个目录下有几个文件夹,每个文件夹表示一个类别
#transform对图片做预处理,然后通过loader将图片转换成我们需要的图片类型进入神经网络

pytorch使用自带的数据集:

torchvision.datasets.MINST(root,transform,train,download)
#train=Ture就是训练集,download=False是无需下载数据集,直接下载很慢,可以先准备好

#其余还有下列数据集
dset.CocoCaptions(root="dir where images are", annFile="json annotation file", 			 
                  [transform, target_transform])
dset.CocoDetection(root="dir where images are", annFile="json annotation file", 
                   [transform, target_transform])

dset.LSUN(db_path, classes='train', [transform, target_transform])


Imagenet-12
#https://github.com/pytorch/examples/blob/27e2a46c1d1505324032b1d94fc6ce24d5b67e97/imagenet/main.py#L48-L62

dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)

dset.STL10(root, split='train', transform=None, target_transform=None, download=False)

pytorch的dataloader类

类似于迭代器,还可以预处理数据

torch.utils.data.DataLoader(Dataset,batch_size,shuffle,collate_fn)
#collate_fn表示如何取样本的,可以定义自己的函数来实现自己想要的功能
#Dataset只是一个打包工具,DataLoader把数据读入内存

TFRecord文件

tfrecord文件包含了类型为tf.train.Example的协议内存块,而协议内存块中又包含了字段features,features中包含了若干个feature,每个feature是一个map,key-value键值对,key是string类型,valuew是feature类型的消息体,包含三种类型:byte,float,int;key和value都是列表;

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
Python读取TFRecord文件的方法如下: ```python import tensorflow as tf # 创建一个TFRecordDataset对象 dataset = tf.data.TFRecordDataset('data.tfrecord') # 定义读取函数 def parser(record): features = { 'image': tf.io.FixedLenFeature([], dtype=tf.string), 'label': tf.io.FixedLenFeature([], dtype=tf.int64) } parsed = tf.io.parse_single_example(record, features) image = tf.io.decode_jpeg(parsed['image'], channels=3) label = parsed['label'] return image, label # 应用读取函数到每个record dataset = dataset.map(parser) # 创建迭代器 iterator = dataset.make_one_shot_iterator() # 获取数据 image, label = iterator.get_next() ``` 以上代码演示了如何读取名为`data.tfrecord`的TFRecord文件,并解析其中的图像和标签信息。在解析函数`parser`中,我们先定义了TFRecord文件中包含的特征信息,然后使用`tf.io.parse_single_example`函数解析单个record,并对图像数据进行解码。最后,我们使用`map`函数将解析函数应用到每个record上。 当然,如果您使用的是PyTorch,也可以使用以下代码读取TFRecord文件: ```python import torch import torchvision.datasets as datasets import torchvision.transforms as transforms # 定义解析函数 def parser(record): features = { 'image': tf.io.FixedLenFeature([], dtype=tf.string), 'label': tf.io.FixedLenFeature([], dtype=tf.int64) } parsed = tf.io.parse_single_example(record, features) image = tf.io.decode_jpeg(parsed['image'], channels=3) label = parsed['label'] return image, label # 创建数据集对象 dataset = datasets.DatasetFolder( 'data.tfrecord', loader=lambda x: torch.load(x), extensions=('tfrecord') ) # 应用解析函数到每个record dataset.transform = transforms.Compose([ parser ]) # 创建数据加载器 dataloader = torch.utils.data.DataLoader( dataset, batch_size=32, shuffle=True ) # 获取数据 for images, labels in dataloader: # 使用数据进行训练或预测 pass ``` 以上代码演示了如何使用PyTorch的`DatasetFolder`读取TFRecord文件,并使用解析函数`parser`解析图像和标签信息。最后,我们创建了一个数据加载器,并使用其中的数据进行训练或预测。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小橘AI

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值