很多小伙伴在刚刚结束深度学习算法的时候,肯定想用自己的数据来进行训练网络,但是不知到怎么写代码,下面这个代码就会为你解惑,自己可以根据实际情况来更改代码,训练自己的图片数据集。
下面我用蚂蚁和蜜蜂数据集为例,我的数据格式是这样的,如下图:
每个类别都会有相应的图片
from torch.utils.data import Dataset,DataLoader
import os
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
transform = transforms.Compose([transforms.Resize([500, 500]), # 图像预处理
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
class_list={0:"ants",1:"bees"}#用于后续预测的时候可以使用,用预测到的标签来直接获取相应的类别
class Mydataset(Dataset):
def __init__(self,file_path="D:/PycharmProjects/pythonProject/classification-pytorch-main1/datasets",formate="train",transform=False):
self.transforms=transform
self.file_path=file_path
self.formate=formate
self.file_train=os.path.join(self.file_path,self.formate)
print(self.file_train)
files_class = os.listdir(self.file_train)
self.imgs=[]
for i, j in enumerate(files_class):
data = os.path.join(self.file_train, j)
print(data)
data_1 = os.listdir(data)
data_all = [[os.path.join(data, k), i] for k in data_1]
self.imgs += data_all
print(self.imgs)
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
img_path, label = self.imgs[index] # 选择文件路径
pil_img = Image.open(img_path).convert('RGB') # 利用PIL打开文件路径
if self.transforms:
img=transform(pil_img)
else:
pil_img = np.asarray(pil_img)
img = torch.from_numpy(pil_img)
return img, label
if __name__ == '__main__':
train_data= Mydataset(transform=True)
print(train_data.__getitem__(0)[0])
print(train_data.__getitem__(0)[1])
#验证能否传进模型中
train_dataloder=DataLoader(train_data,batch_size=8,shuffle=True)
for data in train_dataloder:
print(data[0].shape)
print(data[1])
#结果不唯一,其中的结果如下:
# torch.Size([8, 3, 500, 500])
# tensor([1, 0, 0, 1, 1, 0, 1, 1])
break
如果觉得有帮助,就点个赞吧,祝大家学业有成!