【无标题】

这篇文章主要介绍了如何在PyTorch中实现一个用于训练图像超分辨率模型的Trainer类,包括使用HRNet、D_Net、损失函数(如SSIM和自定义histogramloss)、数据预处理、多尺度数据加载以及训练和测试流程。同时,还涉及到了模型的初始化、优化器设置和性能监控等关键步骤。
摘要由CSDN通过智能技术生成

"""
Name : trainer_seg_hist_fuse.py
Author : xxxxxx
Time : 2022/9/12 21:08
"""
#训练图像超分辨率模型的训练器类Trainer
import os
from decimal import Decimal

import utility
import torch.nn as nn

import torch
from tqdm import tqdm
import pytorch_ssim
import torchvision
from PIL import Image
from torch.utils.data.dataloader import DataLoader

from hrseg.hrseg_model import create_hrnet
from loss.myloss import hist_loss

import torch.optim as optim

from model.D_Net import Discriminator as D_Net
from model.D_Net import calculate_loss_D, calculate_loss_G





def tensor_save_rgbimage(tensor, filename, cuda=False):
if cuda:
img = tensor.clone().cpu().clamp(0, 255).numpy()
else:
img = tensor.clone().clamp(0, 255).numpy()
img = img.transpose(1, 2, 0).astype('uint8')
img = Image.fromarray(img)
img.save(filename)


class vgg_v2(nn.Module):
def __init__(self, vgg_model):
super(vgg_v2, self).__init__()
self.vgg_layers = vgg_model.features
self.layer_name_mapping = {
'1': "relu1_1",
'3': "relu1_2",
'6': "relu2_1",
'8': "relu2_2"
}

def forward(self, x):
output = []
for name, module in self.vgg_layers._modules.items():
x = module(x)
if name in self.layer_name_mapping:
output.append(x)
return output


def vgg_loss(vgg, img, gt):
mse = nn.MSELoss(size_average=True)
img_vgg = vgg(img)
gt_vgg = vgg(gt)

# return 0.4*mse(img_vgg[2], gt_vgg[2]) + 0.2*mse(img_vgg[3], gt_vgg[3])
return mse(img_vgg[0], gt_vgg[0]) + 0.6 * mse(img_vgg[1], gt_vgg[1]) + 0.4 * mse(img_vgg[2], gt_vgg[2]) + 0.2 * mse(
img_vgg[3], gt_vgg[3])


def vgg_init(vgg_loc):
vgg_model = torchvision.models.vgg16(pretrained=False).cuda()
vgg_model.load_state_dict(torch.load(vgg_loc))
trainable(vgg_model, False)

return vgg_model


def trainable(net, trainable):
for para in net.parameters():
para.requires_grad = trainable


class Trainer():
def __init__(self, args, loader, my_model, my_loss, ckp, adv=False):
self.args = args
self.scale = args.scale

self.ckp = ckp
self.loader_train = loader.loader_train
self.loader_test = loader.loader_test

self.model = my_model
self.loss = my_loss

"adv"
self.adv =adv

self.optimizer = utility.make_optimizer(args, self.model)
self.scheduler = utility.make_scheduler(args, self.optimizer)

if self.args.load != '.':
self.optimizer.load_state_dict(
torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
)
for _ in range(len(ckp.log)): self.scheduler.step()

self.error_last = 1e8

def train(self):
self.scheduler.step()
self.loss.step()
epoch = self.scheduler.last_epoch + 1
lr = self.scheduler.get_lr()[0]

self.ckp.write_log(
'[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
)
self.loss.start_log()

self.model.train()
self.model.eval()

timer_data, timer_model = utility.timer(), utility.timer()
criterion_ssim = pytorch_ssim.SSIM(window_size=11)
criterion_mse = nn.MSELoss(size_average=True)

# vgg_model = vgg_init('./pretrained/vgg16-397923af.pth')
# vgg = vgg_v2(vgg_model)
# vgg.eval()

'define seg model'
seg_model = create_hrnet().cuda()
seg_model.eval()

"define discriminator and its optimizer"
if self.adv:
model_D = D_Net().cuda()
optimizer_D = optim.Adam(model_D.parameters(), lr=float(lr/1000), betas=(0.9, 0.999), eps=1e-8)

for batch, (lr, hr, _, idx_scale) in enumerate(self.loader_train):
lr, hr = self.prepare(lr, hr)

timer_data.hold()
timer_model.tic()

self.optimizer.zero_grad()

lr = lr / 255.0
hr = hr / 255.0

[b, c, h, w] = hr.shape

# phr1, phr2, phr4 = self.model(lr, 3)
res_g3_s1, res_g3_s2, res_g3_s4, feat_g3_s1, feat_g3_s2, feat_g3_s4 = self.model.forward_1(lr, 3)

'use seg_model'
seg_map, seg_orin, seg_fea = seg_model(res_g3_s1)

phr1, phr2, phr4 = self.model.forward_2(lr, res_g3_s1, res_g3_s2, res_g3_s4, feat_g3_s1, feat_g3_s2,
feat_g3_s4,seg_orin, seg_fea)

hr4 = hr[:, :, 0::4, 0::4]
hr2 = hr[:, :, 0::2, 0::2]
hr1 = hr

'use seg_model'
seg_map, seg_orin, seg_fea = seg_model(phr1)

if self.adv:
loss_D = calculate_loss_D(model_D, hr1, phr1, seg_map)

optimizer_D.zero_grad()
loss_D.backward(retain_graph=True)
optimizer_D.step()

hist_loss_ = hist_loss(seg_map, phr1, hr1, gpu_id='cuda')
rect_loss = criterion_ssim(phr1, hr1) + criterion_ssim(phr2, hr2) + criterion_ssim(phr4, hr4)
loss_G = calculate_loss_G(model_D, hr1, phr1, seg_map)

full_loss = rect_loss + hist_loss_ + 0.1 * loss_G
self.optimizer.zero_grad()
full_loss.backward()
self.optimizer.step()


else:
'use hist loss'
hist_loss_ = hist_loss(seg_map, phr1, hr1, gpu_id='cuda')

rect_loss = criterion_ssim(phr1, hr1) + criterion_ssim(phr2, hr2) + criterion_ssim(phr4, hr4)

full_loss = rect_loss + hist_loss_

if full_loss.item() < self.args.skip_threshold * self.error_last:
full_loss.backward()
self.optimizer.step()
else:
print('Skip this batch {}! (Loss: {})'.format(
batch + 1, rect_loss.item()
))

timer_model.hold()


if (batch + 1) % self.args.print_every == 0:
if self.adv:
self.ckp.write_log('[{}/{}]\t{}\t{}\t{}\tD: {}\tG: {}\t{:.1f}+{:.1f}s'.format(
(batch + 1) * self.args.batch_size,
len(self.loader_train.dataset),
full_loss.item(),
rect_loss.item(),
hist_loss_.item(),
loss_D.item(),
loss_G.item(),
# percept_loss.item(),
timer_model.release(),
timer_data.release()))
else:
self.ckp.write_log('[{}/{}]\t{}\t{}\t{}\t{:.1f}+{:.1f}s'.format(
(batch + 1) * self.args.batch_size,
len(self.loader_train.dataset),
full_loss.item(),
rect_loss.item(),
hist_loss_.item(),
# percept_loss.item(),
timer_model.release(),
timer_data.release()))

timer_data.tic()

# print(rect_loss.item())

self.loss.end_log(len(self.loader_train))
self.error_last = self.loss.log[-1, -1]

def test(self):
epoch = self.scheduler.last_epoch + 1
self.ckp.write_log('\nEvaluation:')
self.ckp.add_log(torch.zeros(1, len(self.scale)))
self.model.eval()

'define seg model'
seg_model = create_hrnet().cuda()
seg_model.eval()

timer_test = utility.timer()
with torch.no_grad():
for idx_scale, scale in enumerate(self.scale):
eval_acc = 0
self.loader_test.dataset.set_scale(idx_scale)

# 使用 DataLoader 对象进行包装
loader_test_standard = DataLoader(self.loader_test.dataset, batch_size=self.loader_test.batch_size,
shuffle=False, num_workers=self.loader_test.num_workers)

tqdm_test = tqdm(self.loader_test, ncols=80)
eval_acc=0
for idx_img, (lr, hr, filename) in enumerate(loader_test_standard):
filename = filename[0]
no_eval = (hr.nelement() == 1)
if not no_eval:
lr, hr = self.prepare(lr, hr)
else:
lr, = self.prepare(lr)

lr = lr / 255.0
hr = hr / 255.0

[b, c, h, w] = hr.shape
n_map = torch.zeros(b, c, h, w).cuda()

# phr1, phr2, phr4 = self.model(lr, 3)
res_g3_s1, res_g3_s2, res_g3_s4, feat_g3_s1, feat_g3_s2, feat_g3_s4 = self.model.forward_1(lr, 3)

'use seg_model'
seg_map, seg_orin, seg_fea = seg_model(res_g3_s1)

phr1, phr2, phr4 = self.model.forward_2(lr, res_g3_s1, res_g3_s2, res_g3_s4, feat_g3_s1, feat_g3_s2,
feat_g3_s4, seg_orin, seg_fea)


phr = utility.quantize(phr1 * 255, self.args.rgb_range)
lr = utility.quantize(lr * 255, self.args.rgb_range)
hr = utility.quantize(hr * 255, self.args.rgb_range)

save_list = [hr, lr, phr, lr]

if not no_eval:
eval_acc += utility.calc_psnr(
phr, hr, scale, self.args.rgb_range,
benchmark=self.loader_test.dataset.benchmark
)

if self.args.save_results:
self.ckp.save_results(filename, save_list, scale, epoch)

self.ckp.log[-1, idx_scale] = eval_acc / len(self.loader_test)
best = self.ckp.log.max(0)
self.ckp.write_log(
'[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
self.args.data_test,
scale,
self.ckp.log[-1, idx_scale],
best[0][idx_scale],
best[1][idx_scale] + 1
)
)

self.ckp.write_log(
'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True
)
if not self.args.test_only:
self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch))

def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda')

def _prepare(tensor):
if self.args.precision == 'half': tensor = tensor.half()
return tensor.to(device)

return [_prepare(a) for a in args]

def terminate(self):
if self.args.test_only:
self.test()
return True
else:
epoch = self.scheduler.last_epoch + 1
return epoch >= self.args.epochs

Traceback (most recent call last):
  File "main_train.py", line 34, in <module>
    t.train()
  File "/home/amax/SALLIE/DRBN_SKF/src/trainer_seg_hist_fuse.py", line 142, in train
    for batch, (lr, hr, _, idx_scale) in enumerate(self.loader_train):
TypeError: iter() returned non-iterator of type '_MSDataLoaderIter'
Traceback (most recent call last):
  File "/home/amax/.conda/envs/xxxx/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/home/amax/.conda/envs/xxxx/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/amax/.conda/envs/xxxx/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/home/amax/.conda/envs/xxxx/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe

#实现多尺度数据加载的PyTorch数据加载器
import sys
import threading
import queue
import random
import collections

import torch
import torch.multiprocessing as multiprocessing

from torch._C import _set_worker_signal_handlers
from torch.utils.data import _utils
from torch.utils.data.dataloader import DataLoader


_use_shared_memory = False

if sys.version_info[0] == 2:
import Queue as queue
else:
import queue


def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):
global _use_shared_memory
_use_shared_memory = True
_set_worker_signal_handlers()

torch.set_num_threads(1)
torch.manual_seed(seed)
while True:
r = index_queue.get()
if r is None:
break
idx, batch_indices = r
try:
idx_scale = 0
if len(scale) > 1 and dataset.train:
idx_scale = random.randrange(0, len(scale))
dataset.set_scale(idx_scale)

samples = collate_fn([dataset[i] for i in batch_indices])
samples.append(idx_scale)

except Exception:
data_queue.put((idx, _utils.ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))


class _MSDataLoaderIter(DataLoader):
def __init__(self, loader):
self.loader = loader
self.scale = loader.scale
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
self.timeout = loader.timeout
self.done_event = threading.Event()
self.sample_iter = iter(self.batch_sampler)

if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn
self.worker_queue_idx = 0
self.worker_result_queue = multiprocessing.Queue()

self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}

self.data_loader = DataLoader(
self.loader.dataset,
batch_sampler=self.batch_sampler,
collate_fn=self.collate_fn,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
timeout=self.timeout,
worker_init_fn=self.worker_init_fn
)

if self.pin_memory or self.timeout > 0:
self.data_queue = queue.Queue()
if self.pin_memory:
maybe_device_id = torch.cuda.current_device()
else:
maybe_device_id = None
self.pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
maybe_device_id))
self.pin_memory_thread.daemon = True
self.pin_memory_thread.start()

else:
self.data_queue = self.worker_result_queue

# Prime the prefetch loop
for _ in range(2 * self.num_workers):
self._put_indices()

def _put_indices(self):
try:
index = next(self.sample_iter)
except StopIteration:
return
self.worker_result_queue.put((self.worker_queue_idx, index))
self.worker_queue_idx += 1
self.batches_outstanding += 1



 

class _MSDataLoaderIter(DataLoader):
    # 其他代码...

    def __iter__(self):
        return self

    def __next__(self):
        if self.batches_outstanding == 0:
            self._shutdown_workers()
            raise StopIteration

        while True:
            try:
                data = self.data_queue.get(timeout=self.timeout)
                idx, samples = data
                self.batches_outstanding -= 1
                self._put_indices()
                return samples
            except Exception as e:
                print(f"Error while getting data: {e}")

Traceback (most recent call last):
  File "main_train.py", line 34, in <module>
    t.train()
  File "/home/amax/SALLIE/DRBN_SKF/src/trainer_seg_hist_fuse.py", line 143, in train
    lr, hr = self.prepare(lr, hr)
  File "/home/amax/SALLIE/DRBN_SKF/src/trainer_seg_hist_fuse.py", line 324, in prepare
    return [_prepare(a) for a in args]
  File "/home/amax/SALLIE/DRBN_SKF/src/trainer_seg_hist_fuse.py", line 324, in <listcomp>
    return [_prepare(a) for a in args]
  File "/home/amax/SALLIE/DRBN_SKF/src/trainer_seg_hist_fuse.py", line 322, in _prepare
    return tensor.to(device)
AttributeError: 'int' object has no attribute 'to'

class YourClass:
    def prepare(self, *args):
        device = torch.device('cpu' if self.args.cpu else 'cuda')

        def _prepare(tensor):
            if isinstance(tensor, torch.Tensor):
                if self.args.precision == 'half':
                    tensor = tensor.half()
                return tensor.to(device)
            else:
                return tensor

        return [_prepare(a) for a in args]

  • 29
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值