pytorch深度学习总结1

最近使用pytorch踩过的一些坑,记录一下,偏应用。

1.图片加载

pytorch中的datasets.ImageFolder函数直接可以读取自己的图片的数据集。
数据集存放:
把每一类的图片放到一个文件夹里面,加载时地址只用写到类别文件夹的上一级目录。例如下图中dataset文件夹存放了4个类别的图片,那么图片加载时写入的地址就是** ‘F:\dataset’** 。datasets.ImageFolder会自动根据文件夹类别给数据打上标签。
在这里插入图片描述

from torchvision import datasets, transforms
import torch
import os
ef load_data(root_path, dir, batch_size, phase):
    transform_dict = {
        'tar':transforms.Compose(
            [transforms.Resize(224),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225]),])}
    data = datasets.ImageFolder(root=os.path.join(root_path,dir), transform=transform_dict[phase])  ##即各类别文件夹所在目录的上一级目录
    data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True,drop_last=False, num_workers=4)
    #设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。drop_last=False不丢弃
    return data_loader

上面这段示例代码可以当作模板使用,但其实我还有一个问题没搞懂,就是transforms.Normalize标准化时输入的平均值和标准差为什么不是0.5,而是一堆奇怪的小数,有哪位大佬路过可以帮忙解答一下。

2.模型搭建

pytorch的模型搭建这里,没有什么特别要记录的,网上很多例子,照着搭自己的模型就行。只有一点,是在我真正开始动手操作的时候才发现的,之前照着书学习的时候没发现或者看到了没有注意就略过了。

在pytorch中搭建一个My_Net类作为自己的模型,在调用时按照下面流程调用传入自己的数据就行,它会直接执行My_Net类中的forward 函数,完成前向传播,不需要单独调用forward函数。

class My_Net(nn.Module):
    def __init__(self,):
        super(My_Net,self).__init__()
        ……
    def forward(data):
        ……
        return result
    def my_loss(  ):
        ……
        return loss
 #调用
 data
 model = My_Net()
 result=model(data)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值