pytorch(08)数据模型的读取(2)

import numpy as np
import torch
import os
import random
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
BASE_PATH = os.path.abspath(__file__)
# print(BASE_PATH)
base_path = os.path.abspath(os.path.join(BASE_PATH, '..', 'TestDir'))
# print(base_path)
data_dir = os.path.abspath(os.path.join(BASE_PATH, '..', 'RMB_data'))
random.seed(1)
# print(data_dir)
test_label = {"1": 0, "100": 1}
data_info = list()
for path, dirs, files in os.walk(base_path):
    for sub_dir in dirs:
        # print(sub_dir)
        sub_dirlist = os.listdir(os.path.join(base_path, sub_dir))
        pynames = list(filter(lambda y: y.endswith('.jpg'), sub_dirlist))
        # print(pynames)
        # print(test_label[sub_dir])
        for pyname in pynames:
            datainfo_dir = os.path.join(base_path, sub_dir, pyname)
            t_label=test_label[sub_dir]
            t_label = int(t_label)
            data_info.append((datainfo_dir, t_label))
# print(data_info)
new_data_info = list()
for data_info_e in data_info:
    x_dir, x_label = data_info_e
    x_img = Image.open(x_dir).convert('RGB')
    ok_transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
    ])
    x_img = ok_transform(x_img)
    new_data_info.append((x_img,x_label))

# print(len(new_data_info[0][0]))
print(len(new_data_info))
newdataLoader = DataLoader(new_data_info,batch_size=14, shuffle=True)
for ids, data in enumerate(newdataLoader):
    print(ids)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值