pytorch应用

三大步骤:数据读取、网络构建、其他辅助
数据读取
常见的数据例如mnist就用torchvision的datasets方法来进行傻瓜式读取就行
对于分类问题,可以采用torchvision.datasets.ImageFolder读取image和label信息:

data_dir = '/data' 
image_datasets = {x: datasets.ImageFolder( 
os.path.join(data_dir, x), data_transforms[x]
), for x in ['train', 'val']}

torchvision.datasets.ImageFolder只是返回list,list是不能作为模型输入的,因此在PyTorch中需要用另一个类来封装list,那就是:torch.utils.data.DataLoader。torch.utils.data.DataLoader类可以将list类型的输入数据封装成Tensor数据格式,以备模型使用。注意,这里是对图像和标签分别封装成一个Tensor。

当你的数据不是按照一个类别一个文件夹这种方式存储时,你就要自定义一个类来读取数据,自定义的这个类必须继承自torch.utils.data.Dataset这个基类,最后同样用torch.utils.data.DataLoader封装成Tensor。

利用torchvision.model中的模型就可以满足条件(还有pretrain的参数),如果最后分类classes数目不相同那么可以提取前一层fc:

# coding=UTF-8
import torchvision.models as models
 
#调用模型
model = models.resnet50(pretrained=True)
#提取fc层中固定的参数
fc_features = model.fc.in_features
#修改类别为9
model.fc = nn.Linear(fc_features, 9)

一般我们都把分类的前一层FC叫features

对于自己读取数据要继承datasets并重写__len__()和__getitem__()两个方法
def getitem(self, index):可见需要index对所有的数据进行遍历读取,最终返回image&label也是所有数据的!

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值