Pytorch复习

1.数据预处理

torch.utils.data — PyTorch 1.7.1 documentation

(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform_MIss-Y的博客-CSDN博客_tusimple数据集

1.1继承torch.utils.data.Dataset,重写__getitem__和__len__,来从硬盘读取数据,在__getitem__中实现transform方法

from torch.utils.data import Dataset,DataLoader
import json
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms

class CIFAR10_IMG(Dataset):

    def __init__(self, root, train=True, transform = None, target_transform=None):
        super(CIFAR10_IMG, self).__init__()
        self.train = train
        self.transform = transform
        self.target_transform = target_transform

        if self.train :
            file_annotation = root + '/annotations/cifar10_train.json'
            img_folder = root + '/train_cifar10/'
        else:
            file_annotation = root + '/annotations/cifar10_test.json'
            img_folder = root + '/test_cifar10/'
        fp = open(file_annotation,'r')
        data_dict = json.load(fp)

        assert len(data_dict['images'])==len(data_dict['categories'])
        num_data = len(data_dict['images'])

        self.filenames = []
        self.labels = []
        self.img_folder = img_folder
        for i in range(num_data):
            self.filenames.append(data_dict['images'][i])
            self.labels.append(data_dict['categories'][i])

    def __getitem__(self, index):
        img_name = self.img_folder + self.filenames[index]
        label = self.labels[index]

        img = plt.imread(img_name)
        img = self.transform(img)   #可以根据指定的转化形式对数据集进行转换

        #return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
        return img, label

    def __len__(self):
        return len(self.filenames)

1.2继承torch.utils.data.DataLoader,加载

train_dataset = datasets.CIFAR10_IMG('./datasets',train=True,transform=transforms.ToTensor())
test_dataset = datasets.CIFAR10_IMG('./datasets',train=False,transform=transforms.ToTensor())

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=6, shuffle=True)

2.模型搭建(以ResNet为例)

2.1ResNet结构

2.1使用torchvision

import torchvision model = torchvision.models.resnet50(pretrained=True)

2.2继承nn.Module类定义自己的模型,重新实现构造函数__init__构造函数和forward这两个方法

pytorch教程之nn.Module类详解——使用Module类来自定义模型_MIss-Y的博客-CSDN博客_nn是什么意思

3.损失函数

使用torch.nn

二分类交叉熵(需要加softmax或sigmoid)

torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean')

多分类交叉熵:

torch.nn.CrossEntropyLoss()

L1范数,平均绝对误差

torch.nn.L1Loss(size_average=None, reduce=None, reduction='mean')

L2范数,平均平方误差

torch.nn.MSELoss()

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值