使用pytorch导入自建数据集

使用pytorch导入自建数据集
以mini_imagenet为例
其实是关键需要数据集的结构为

data
	train
		类别1
			image1
			image2
			……
		类别2
			image1
			image2
			……
	test
		类别1
			image1
			image2
			……
		类别2
			image1
			image2
			……
	val(可选)
		类别1
			image1
			image2
			……
		类别2
			image1
			image2
			……
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from wideresnet import WideResNet

BATCH_SIZE = 4
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 归一化处理
 # 需要更多数据预处理,自己查
])
transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 归一化处理
 # 需要更多数据预处理,自己查
])

#读取数据
dataset_train = datasets.ImageFolder('./train', transform_train)
dataset_test = datasets.ImageFolder('./test', transform)
#dataset_val = datasets.ImageFolder('data/val', transform)

# 上面这一段是加载测试集的
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True) # 训练集
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True) # 测试集
#val_loader = torch.utils.data.DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True) # 验证集
# 对应文件夹的label
print(dataset_train.class_to_idx)   # 这是一个字典,可以查看每个标签对应的文件夹,也就是你的类别。
                                    # 训练好模型后输入一张图片测试,比如输出是99,就可以用字典查询找到你的类别名称
print(dataset_test.class_to_idx)
#print(dataset_val.class_to_idx)


if __name__ == '__main__':
    model = WideResNet(40, 100, 4, 0.0)
    for batch_idx, (images, labels) in enumerate(train_loader):
        # compute output
        outputs = model(images)
        print(data.shape)
        print(target)
  • 1
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
以下是一个使用PyTorch导入3D数据集的示例代码: ```python import numpy as np import torch from torch.utils.data import Dataset class MyDataset(Dataset): def __init__(self, data_file): self.data = np.load(data_file) self.data = np.transpose(self.data, (0, 4, 1, 2, 3)) # 将数据的维度顺序转换为(batch_size, channel, depth, height, width) def __len__(self): return self.data.shape[0] def __getitem__(self, index): x = self.data[index] y = np.random.randint(0, 2) # 假设数据集是二分类问题,随机生成标签 return torch.from_numpy(x), torch.tensor(y) # 使用示例 data_file = 'data.npy' dataset = MyDataset(data_file) dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True) for x, y in dataloader: print(x.shape, y.shape) # 输出(batch_size, channel, depth, height, width)和(batch_size,) ``` 在上述示例中,我们首先定义了一个`MyDataset`类,该类继承自`torch.utils.data.Dataset`,并实现了`__init__`、`__len__`和`__getitem__`方法。`__init__`方法从文件中加载数据,`__len__`方法返回数据集的大小,`__getitem__`方法返回指定索引的数据和标签。 在`__getitem__`方法中,我们使用`numpy`将数据转换为`torch.Tensor`类型,并返回它们。在使用时,我们可以使用`torch.utils.data.DataLoader`类将数据集加载到内存中,并迭代访问。在上述示例中,我们使用了一个简单的循环,每次获取一个批次的数据,打印它们的形状。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值