torch读取数据

  使用torchvision.datasets.ImageFolder来读取图片

torchvision.datasets.ImageFolder(root="root folder path", [transform, target_transform])
  • root : 指定图片存储的路径,在下面的例子中是’./data/dogcat_2’
  • transform: 一个函数,原始图片作为输入,返回一个转换后的图片。
  • target_transform - 一个函数,输入为target,输出对其的转换。例子,输入的是图片标注的string,输出为word的索引。
    有几个变量
  • self.classes - 用一个list保存 类名
  • self.class_to_idx - 类名对应的 索引
  • self.imgs - 保存(img-path, class) tuple的list

  定义一个torch.utils.data.Dataset数据类
Dataset有两个函数

  • _getitem_
  • _len_
class Dataset_1(torch.utils.data.Dataset):
  def __init__(self,root,is_resize=False,is_transfrom=False):
    self.root=root
    self.is_resize=is_resize
    self.is_transfrom=is_transfrom
    self.imgs_list=...#保存图片路径节省内存
    self.labs_list=...
  def __getitem__(self, index):
    img_path,lab=self.imgs_list[index],self.labs_list[index]
    img_data = Image.open(img_path)
    if self.is_transfrom:
      img_data=self.is_transfrom(img_data)
    return img_data,lab
  def __len__(self):
    return len(self.imgs_list)

  定义好Dataset数据类,之后使用DataLoader导入

torch.utils.data.DataLoader(dataset=Dataset_1, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)

  有时需要对数据进行处理

train_transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor()
])
img = Image.open('test.png')
train_transforms(img)
训练数据
train_loader = torch.utils.data.DataLoader(dataset=Dataset_1, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)
model = torchvision.models.__dict__['resnet101'](pretrained=True)
model.load_state_dict(torch.load('...pth'))
model.to(DEVICE)
# 训练
optimizer = torch.optim.SGD(model.parameters(),lr = 0.01)
loss_fn = torch.nn.MSELoss()
model.train()
for batch_index,(data,target) in enumerate(train_loader):
    data, target = data.to(DEVICE), target.to(DEVICE)
    output = model(data)
    loss = loss_fn(output, target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
# 推理
model.eval()
output = model(data)
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值