本专栏内容是学习 深度学习麋了鹿 的《图像分割UNet硬核讲解》(带你手撸unet代码)部分笔记。
内容包括从数据集→网络结构→训练→测试。(附代码)
本节是 UNet 训练及代码实现 笔记
上节内容: UNet 网络结构及代码实现 。
下面开始本节内容:
import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from 图像分割UNet硬核讲解.Data01 import MyDataset
from 图像分割UNet硬核讲解.Net02 import UNet
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 如果 cuda 可以用先用 cuda
# 权重地址,需要新建一个名为 params 的文件
weight_path = 'params/unet.pth'
data_path = 'D:\\STUDY1\\MILELU\\图像分割UNet硬核讲解\\VOC\\VOCdevkit\\VOC2007'
sava_path = 'train_image'
if __name__ == '__main__':
data_loader = DataLoader(MyDataset(data_path), batch_size=10, shuffle=True)
net = UNet().to(device)
if os.path.exists(weight_path):
net.load_state_dict(torch.load(weight_path))
print('successful load weight !')
else:
print('not successful load weight')
opt = optim.Adam(net.parameters())
loss_fun = nn.BCELoss()
epoch = 1
while True:
for i, (image, segment_image) in enumerate(data_loader):
image = image.to(device)
segment_image = segment_image.to(device)
out_image = net(image)
train_loss = loss_fun(out_image, segment_image)
opt.zero_grad()
train_loss.backward()
opt.step()
if i % 5 == 0:
print(f'{epoch}-{i}-train_loss===>>{train_loss.item()}')
if i % 50 == 0:
torch.save(net.state_dict(), weight_path)
_image = image[0]
_segment_image = segment_image[0]
_out_image = out_image[0]
img = torch.stack([_image, _segment_image, _out_image], dim=0)
save_image(img, f'{sava_path}/{i}.png')
epoch += 1
输出结果:
生成的图像:
随意选择一张:
本节内容没什么可讲的,重点就在于代码上面。
另外,如果运行后内存不够,可以选择减小代码中的 batch_size 。当然也可以选择回到上上节中的 utils01.py ,将里面的照片尺寸改小点。
下节内容:UNet 测试及代码实现