定义自己的数据类
import torch.utils.data as data
import os
import PIL.Image as Image
import torch
class myDataset(torch.utils.data.Dataset):
def __init__(self, root,augment = None):
self.dataSource = root
file_is_pic=['.jpg','.png','.JPG']
self.image_files = np.array([x.path for x in os.scandir(root) if
x.name.endswith(file_is_pic[0]) or x.name.endswith(file_is_pic[1]) or x.name.endswith(file_is_pic[2])])
self.augment = augment
def __getitem__(self, index):
data_of_pic = self.image_files[index]
data_of_label = self.image_files[index].replace('rgb_origin', 'rgb_origin_mask')
if self.augment:
image = open_image(data_of_pic)
image = self.augment(image)
image_label= open_image(data_of_label)
image_label = self.augment(image_label)
return image, image_label
else:
return open_image(data_of_pic), open_image(data_of_label)
def __len__(self):
return len(self.image_files)
加载使用
img_path = 'rgb_origin'
finger_datasets= myDataset(img_path)
train_loader = torch.utils.data.DataLoader(dataset=finger_datasets,
batch_size=16,
shuffle=True,
num_workers=0)
for epoch in range(100):
for data,label in train_loader:
outputs = UNet(data)
outputs = outputs.permute(0, 2, 3, 1)
width_out = 128
height_out = 128
m = outputs.shape[0]
outputs = outputs.resize(m * width_out * height_out, 2)
labels = labels.resize(m * width_out * height_out)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(UNet.parameters(), lr=0.01, momentum=0.99)
optimizer.zero_grad()
loss = criterion(outputs, labels)
print(loss)
loss.backward()
optimizer.step()