编写一个自己的项目,并不是从零开始。而是在已有的框架下,往里填充东西。这个框架可以是别人的项目,也可以是自己的框架。下面是我自己的PyTorch框架。
1. 变量封装
为了方便管理项目中众多变量,将所有变量封装到一个对象中,需要用到时直接从该对象中获取。使用argparse封装变量。封装变量可以单独创建一个py文件,取名为arg_parser.py。代码如下:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import argparse
parser = argparse.ArgumentParser(description='该项目的一句话简介,可写可不写')
parser.add_argument('--batch_size', type=int, default=16, help='该项的解释,可写可不写')
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--snapshots_folder', type=str, default='snapshots/')
parser.add_argument('--sample_output_folder', type=str, default='samples/')
args = parser.parse_args()
2. 获取训练/测试集
一般情况下,我们训练集都是自己的,也就是说PyTorch并不会集成。所以需要我们自己创建Dataset。自己实现Dataset,只需要继承该类时,实现里面的两个抽象函数——(__getitem__(self, item)和__len__(self))。这两个函数的功能分别是根据传入的参数获取数据,以及获取数据的数量。
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import glob
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
def populate_train_list(orig_images_path, haze_images_path):
data = []
image_list_haze = glob.glob(haze_images_path + '*.jpg')
for image in image_list_haze:
image = image.split('\\')[-1]
gt = image.split('_')[0] + '_' + image.split('_')[1] + '.jpg'
data.append([haze_images_path + image, orig_images_path + gt])
train_dataset = []
test_dataset = []
# 前百分之九十:训练集,后百分之十:验证集
train_dataset.append(data[0:int(len(data) * 0.9)])
test_dataset.append(data[int(len(data) * 0.9):])
train_dataset = np.asarray(train_dataset).squeeze()
test_dataset = np.asarray(test_dataset).squeeze()
return train_dataset, test_dataset
class dehaze_dataset(Dataset):
def __init__(self, orig_images_path, haze_images_path, mode='train'):
super(dehaze_dataset, self).__init__()
self.mode = mode
train_dataset, test_dataset = populate_train_list(orig_images_path, haze_images_path)
if mode == 'train':
self.data_list = train_dataset
else:
self.data_list = test_dataset
def __getitem__(self, item):
haze_image_path, gt_path = self.data_list[item]
haze_image = Image.open(haze_image_path)
gt = Image.open(gt_path)
haze_image = haze_image.resize((480, 640), Image.ANTIALIAS)
gt = gt.resize((480, 640), Image.ANTIALIAS)
haze_image = np.asarray(haze_image) / 255.0
gt = np.asarray(gt) / 255.0
haze_image = torch.from_numpy(haze_image).float()
gt = torch.from_numpy(gt).float()
return haze_image.permute(2, 0, 1), gt.permute(2, 0, 1)
def __len__(self):
return len(self.data_list)
3. 网络框架
此处以AODNet为例。创建一个名为AODNet.py文件,代码如下:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
class AODNet(nn.Module):
def __init__(self):
super(AODNet, self).__init__()
self.conv_1 = nn.Conv2d(3, 3, 1)
self.conv_2 = nn.Conv2d(3, 3, 3, padding=1)
self.conv_3 = nn.Conv2d(6, 3, 5, padding=2)
self.conv_4 = nn.Conv2d(6, 3, 7, padding=3)
self.conv_5 = nn.Conv2d(12, 3, 3, padding=1)
def forward(self, x):
x1 = F.relu(self.conv_1(x))
x2 = F.relu(self.conv_2(x))
concat1 = torch.cat((x1, x2), dim=1)
x3 = F.relu(self.conv_3(concat1))
concat2 = torch.cat((x2, x3), dim=1)
x4 = F.relu(self.conv_4(concat2))
concat3 = torch.cat((x1, x2, x3, x4), dim=1)
k = F.relu(self.conv_5(concat3))
return F.relu((k * x) - k + 1)
网络框架千变万化,但是复现一个已给出参数的网络框架还是比较简单的。
4. 训练
创建一个名为train.py的文件。训练过程有一个固定的模板,只需要记住改模板即可。
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
import torchvision
from arg_parser import args
def train(net, train_loader, test_dataset, optimizer, device):
criterion = torch.nn.MSELoss().to(device)
for epoch in range(args.epochs):
net.train()
train_loss = 0
for batch_idx, (haze, gt) in enumerate(train_loader, 0):
haze, gt = haze.to(device), gt.to(device)
output = net(haze)
loss = criterion(output, gt)
train_loss += loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip_norm)
optimizer.step()
if (batch_idx + 1) % args.display_iter == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:6f}'.format(
epoch, (batch_idx + 1) * len(haze), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), train_loss / args.display_iter
))
train_loss = 0
if (batch_idx + 1) % args.snapshot_iter == 0:
torch.save(net.state_dict(), args.snapshots_folder + 'Epoch' + str(epoch) + '.pt')
net.eval()
for batch_idx, (haze, gt) in enumerate(test_dataset, 0):
haze = haze.to(device)
output = net(haze)
torchvision.utils.save_image(torch.cat((haze, output, gt), dim=0),
args.sample_output_folder + str(batch_idx + 1) + '.jpg')
torch.save(net.state_dict(), args.snapshots_folder + 'epoch_' + str(epoch) + '.pt')
torch.save(net.state_dict(), args.snapshots_folder + 'dehazer.pt')
5. 测试
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
import torchvision
import numpy as np
from PIL import Image
from AODNet import AODNet
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_path = 'D:/...'
dehaze_net = AODNet().to(device)
data_hazy = Image.open(image_path)
data_hazy = (np.asarray(data_hazy) / 255.0)
data_hazy = torch.from_numpy(data_hazy).float()
data_hazy = data_hazy.permute(2, 0, 1)
data_hazy = data_hazy.cuda().unsqueeze(0)
dehaze_net.load_state_dict(torch.load('snapshots/dehazer.pt'))
clean_image = dehaze_net(data_hazy)
torchvision.utils.save_image(torch.cat((data_hazy, clean_image), 0), "results/" + image_path.split("/")[-1])
if __name__ == "__main__":
main()
6. 主函数
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import os
import torch
from torch.utils.data import DataLoader
from train import train
from AODNet import AODNet
from arg_parser import args
from create_dataset import dehaze_dataset
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_dataset = dehaze_dataset(args.orig_images_path, args.hazy_images_path)
test_dataset = dehaze_dataset(args.orig_images_path, args.hazy_images_path, mode='test')
train_loader = DataLoader(dataset=train_dataset, batch_size=args.train_batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=args.val_batch_size, shuffle=True)
net = AODNet().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
train(net, train_loader, test_loader, optimizer, device)
if __name__ == "__main__":
if not os.path.exists(args.snapshots_folder):
os.mkdir(args.snapshots_folder)
if not os.path.exists(args.sample_output_folder):
os.mkdir(args.sample_output_folder)
main()