PyTorch学习(一)

两大指南函数

  1. dir()

    功能:返回包含查询对象的所有属性和方法名称的列表

    dir(torch)
    # 输出: torch中所包含的全部方法
    dir(torch.cuda.is_available)
    # 输出:该函数所有属性
    
  2. help()

    功能:查看函数或模块用途的详细说明

使用dir()和help()查询函数仅输入函数名,不输入括号。

数据加载和预处理

  1. Dataset类

    Dataset是一个抽象类,它的主要作用是封装数据集。你可以将Dataset看作是数据的集合,

    定义了如何获取数据集中的单个样本。Dataset类需要实现__init__,__len__和__getitem__三个方法。

    一些内置的Dataset类,比如ImageFolderCIFAR10等,它们已经实现了这些方法,可以直接使用。

  2. DataLoader类

    DataLoader是Dataset的包装器,它的主要作用是提供批量加载数据的能力。

    使用DataLoader时,需要指定以下参数:

    • dataset:你创建的Dataset实例。
    • batch_size:每个批次的样本数量。
    • shuffle:是否在每个epoch开始时打乱数据。
    • num_workers:加载数据时使用的进程数量。
    from torch.utils.data import DataLoader
    
    train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
    

    iter函数将train_dataloader变成一个迭代器,使用next函数可以以此从迭代器中生成一个一个的批次

    train_features, train_labels = next(iter(train_dataloader))
    print(f"Feature batch shape: {train_features.size()}")
    print(f"Labels batch shape: {train_labels.size()}")
    img = train_features[0].squeeze()
    label = train_labels[0]
    

    可视化中squeeze()的作用

    squeeze函数用于去除张量中所有大小为1的维度

    1. 简化可视化:在可视化图像时,我们通常只需要[channels, height, width]的三维张量。squeeze函数可以帮助去除批次维度(batch_size),使得我们可以直接处理单个图像。
    2. 去除冗余维度:如果图像数据在加载或预处理过程中被错误地增加了额外的维度,squeeze可以帮助去除这些不必要的维度。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值