文章目录
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset,DataLoader,Dataset
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torchvision.transforms as transforms
一、torch.utils
1、torch.utils.data
详情介绍 DataLoader、Dataset、TensorDataset
import torch
import torch.utils.data as Data
torch.manual_seed(1)
BATCH_SIZE = 5
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)
torch_dataset = Data.TensorDataset(x,y) #把数据放在数据库中
loader = Data.DataLoader(
# 从dataset数据库中每次抽出batch_size个数据
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,#将数据打乱
num_workers=2, #使用两个线程
)
(1)DataLoader:(构建可迭代的数据装载器)
torch.utils.data.DataLoader(): 构建可迭代的数据装载器, 我们在训练的时候,每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据的。
dataLoader的参数很多,但我们常用的主要有5个:
- dataset: Dataset类, 决定数据从哪读取以及如何读取
- bathsize: 批大小
- num_works: 是否多进程读取机制
- shuffle: 每个epoch是否乱序
- drop_last: 当样本数不能被batchsize整除时, 是否舍弃最后一批数据,默认False(不放弃),True(放弃)
(2)TensorDataset
(3)Dataset
2、torch.utils.tensorboard
3、torch.utils.datasets
二、torchvision,该包主要由3个子包组成(torchvision.datasets、torchvision.models、torchvision.transforms)
1、trochvison的数据集使用 datasets
1、trochvison的数据集使用 transforms
1.torchvision.datasets
①、 下载 CIFAR10和MNIST数据集:方法相同
dset.MNIST(root, train=True, transform=None, target_transform=None, download=False)
dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)
import torchvision
import torchvision.transforms as transforms
#下载训练集
train_set=torchvision.datasets.CIFAR10(root="./dataset" #下载位置
,train=True #只下载训练集
,download=True #本地没有存储就下载,本地已有就不用下载
,transform=transforms.ToTensor() #接收目标并对其进行转换的函数/转换
)
#下载训练集
test_set=torchvision.datasets.CIFAR10(root="./dataset" #下载位置
,train=False #只下载训练集
,download=True #本地没有存储就下载,本地已有就不用下载
,transform=transforms.ToTensor() #接收目标并对其进行转换的函数/转换
)
2.torchvision.models
十一、torchvision.model 现有网络模型的使用和修改、保存
3.torchvision.transforms
(1) transforms.ToTensor()
transforms.ToTensor()函数的作用是将原始的PILImage格式或者numpy.array格式的数据格式化为可被pytorch快速处理的张量类型。
输入模式为(L、LA、P、I、F、RGB、YCbCr、RGBA、CMYK、1)的PIL Image 或 numpy.ndarray (形状为H x W x C)数据范围是[0, 255] 到一个 Torch.FloatTensor,其形状 (C x H x W) 在 [0.0, 1.0] 范围内。
import numpy as np
from torchvision import transforms
a = np.random.random((224,224,3))
transform = transforms.Compose([
transforms.ToTensor()
])
b = transform(a)
print(b.shape)
#torch.Size([3, 224, 224])
三、nn.Module
- 模板
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
总
1. 导入包
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset,DataLoader,Dataset
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torchvision.transforms as transforms
2.加载数据
#下载训练集
train_set=torchvision.datasets.CIFAR10(root="./dataset" #下载位置
,train=True #只下载训练集
,download=True #本地没有存储就下载,本地已有就不用下载
,transform=transforms.ToTensor() #接收目标并对其进行转换的函数/转换
)
#下载测试集
test_set=torchvision.datasets.CIFAR10(root="./dataset" #下载位置
,train=False #只下载训练集
,download=True #本地没有存储就下载,本地已有就不用下载
,transform=transforms.ToTensor() #接收目标并对其进行转换的函数/转换
3.处理数据
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)