目录
一、如何获取数据?
1.Dataset
Dataset 功能:提供一种方法去获取数据及其label
- 如何获取每一个数据及其label
- 告诉我们有多少的数据
import os
from torch.utils.data import Dataset
from PIL import Image
class MyData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir=root_dir
self.label_dir=label_dir
self.path=os.path.join(self.root_dir,self.label_dir)
self.img_path=os.listdir(self.path)
def __getitem__(self, idx):
img_name=self.img_path[idx]
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
img=Image.open(img_item_path)
label=self.label_dir
return img,label
def __len__(self):
return len(self.img_path)
root_dir="dataset/train"
ants_label_dir="ants"
ants_dataset=MyData(root_dir,ants_label_dir)
2.DataLoader
DataLoader:为后面的网络提供不同的数据形式
总结
以上就是今天的内容,主要说明如何使用pytorch获取数据及其label