一、mian.py
1. 论文主要内容
(1)网络
(1)数据集的预处理:从数据集中得到mos(y)和mos_std
(2)输入image:32x32的patch,x
(3)局部对比度归一化,局部对比度归一化(local contrast normalization, LCN)。局部对比度归一化确保对比度在每个小窗口上被归一化,而不是作为整体在图像上被归一化;使用可分离卷积来计算特征映射的局部平均值和局部标准差,然后在不同的特征映射上使用逐元素的减法和除法。归一化位置(i,j)处的像素为
(4)卷积层:(7x7,1,50)
(5)Pooling:max和min,每个feature map映射成1个max值和个min值,最终得到2x50个映射值
(6)FC:(50x2,800),(800,800)(800,1)
(7)ReLU:在FC之后用到
(8)Dropout:在第2个FC层,dropout=0.5,
(9)优化函数-Adm
(10)Loss=L1(y,y_pred)
(11)输出:预测的质量分数y_pred
(2)评价标准
SROCC, KROCC, PLCC, RMSE, MAE, OR
(3)实验部分
分别对kernel的大小和数量,batch的大小和数量做实验,通过SROCC和PLCC确定四个参数的最佳数值
(4)Image的局部质量预测:
该模块之前,CNNIQA只在LIVE上完成训练、验证和测试
(1)从TID2008中选择LIVE中没有的参考image,垂直分为4部分,第2,3,4部分分别用不同等级的JPEG,JPEG2000,WN,BLUR失真替换,得到参考和不同失真等级同时存在的合成image。将预测的质量分数通过16x16的patch和stride=8归一化到[0,255],得到质量预测map
(2)从TID2008中选择几种LIVE中没有、对局部影响较大的失真(paper中选择JPEG,JPEG2000,块效应),得到CNNIQA能够定位局部失真。
(5)数据集(训练60%,验证、测试各20%)
(1)LIVE:779失真image和29参考image,5种失真类型(JPEG,JPEG2000,高斯模糊,白噪声,fastfading),5个失真等级,dmos∈[0,100]。
(2)TID2008:1700失真图像和25参考图像,17种失真类型,4个失真等级,mos∈[0,9]
(3)交叉数据集测试:在LIVE数据库上训练网络确定网络参数;把TID2008数据集随机划分100次成80%和20%两个部分,其中80%用于dmos→mos的非线性函数的参数估计(estimating parameters of the logistic function),20%用于最终测试。
2.main
前面与源代码完全一致,link.
下面是使用matplotlib.pyplot绘制loss等数据并保存
(1)训练部分
最终以列表形式存放训练数据,初始化train_loss, validation_SROCC, validation_KROCC, validation_PLCC, validation_RMSE,validation_MAE,validation_OR 等list为空
# 初始化loss,SROCC等list
train_loss, validation_SROCC, validation_KROCC, validation_PLCC, validation_RMSE,validation_MAE,validation_OR = [],[],[],[],[],[],[]
testing_SROCC, testing_KROCC, testing_PLCC, testing_RMSE, testing_MAE, testing_OR = [], [], [], [], [],[]
# 训练程序在ignite定义的train_step函数中
# 对于训练过程,就是加载model和梯度置零+传入batch参数得到x、y + x通过model得到y_pred + 计算y和y_pred的loss+梯度方向传播+更新参数
# 这里x和y分别是image的像素信息和mos
# return loss并将loss保存在state.output中
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
writer.add_scalar("training/loss", engine.state.output, engine.state.iteration)
if engine.state.iteration % len(train_loader) == 0:
global loss
loss = 0
loss += engine.state.output
print("=== > start training")
print("Epoch[{}] Iteration[{}/{}] Loss:{:.4f}"
.format(engine.state.epoch, engine.state.iteration, len(train_loader), engine.state.output))
train_loss.append(loss)
(2)验证部分
绘制曲线并保存数据时会用到metrics_validation,这里申明其为全局变量。
# 每个epoch完成后进行validation
# 对于评估过程,就是加载batch内容+前向计算metrics
# return metrics并将metrics保存在state.output中
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
evaluator.run(val_loader) # 载入验证loader
global metrics_validation
metrics_validation = evaluator.state.metrics # 载入评价标准
SROCC, KROCC, PLCC, RMSE, MAE, OR = metrics_validation['IQA_performance']
print("=== > start Validation")
print("Validation Results - Epoch: {} SROCC: {:.4f} KROCC: {:.4f} PLCC: {:.4f} RMSE: {:.4f} MAE: {:.4f} OR: {:.2f}%"
.format(engine.state.epoch, SROCC, KROCC, PLCC, RMSE, MAE, 100 * OR))
V_SROCC, V_KROCC, V_PLCC, V_RMSE, V_MAE, V_OR = 0,0,0,0,0,0
V_SROCC += SROCC
validation_SROCC.append(SROCC)
V_KROCC += KROCC
validation_KROCC.append(KROCC)
V_PLCC += PLCC
validation_PLCC.append(PLCC)
V_RMSE += RMSE
validation_RMSE.append(RMSE)
V_MAE += MAE
validation_MAE.append(MAE)
V_OR += OR
validation_OR.append(OR)
writer.add_scalar("validation/SROCC", SROCC, engine.state.epoch)
writer.add_scalar("validation/KROCC", KROCC, engine.state.epoch)
writer.add_scalar("validation/PLCC", PLCC, engine.state.epoch)
writer.add_scalar("validation/RMSE", RMSE, engine.state.epoch)
writer.add_scalar("validation/MAE", MAE, engine.state.epoch)
writer.add_scalar("validation/OR", OR, engine.state.epoch)
global best_criterion
global best_epoch
if SROCC > best_criterion:
best_criterion = SROCC
best_epoch = engine.state.epoch
torch.save(model.state_dict(), trained_model_file) # 保存模型
(3)测试部分
@trainer.on(Events.EPOCH_COMPLETED)
def log_testing_results(engine):
if config["test_ratio"] > 0 and config['test_during_training']:
evaluator.run(test_loader)
global metrics_testing
metrics_testing = evaluator.state.metrics
SROCC, KROCC, PLCC, RMSE, MAE, OR = metrics_testing['IQA_performance']
print("=== > start Testing ")
print("Testing Results - Epoch: {} SROCC: {:.4f} KROCC: {:.4f} PLCC: {:.4f} RMSE: {:.4f} MAE: {:.4f} OR: {:.2f}%"
.format(engine.state.epoch, SROCC, KROCC, PLCC, RMSE, MAE, 100 * OR))
T_SROCC, T_KROCC, T_PLCC, T_RMSE, T_MAE, T_OR = 0, 0, 0, 0, 0, 0
T_SROCC += SROCC
testing_SROCC.append(SROCC)
T_KROCC += KROCC
testing_KROCC.append(KROCC)
T_PLCC += PLCC
testing_PLCC.append(PLCC)
T_RMSE += RMSE
testing_RMSE.append(RMSE)
T_MAE += MAE
testing_MAE.append(MAE)
T_OR += OR
testing_OR.append(OR)
writer.add_scalar("testing/SROCC", SROCC, engine.state.epoch)
writer.add_scalar("testing/KROCC", KROCC, engine.state.epoch)
writer.add_scalar("testing/PLCC", PLCC, engine.state.epoch)
writer.add_scalar("testing/RMSE", RMSE, engine.state.epoch)
writer.add_scalar("testing/MAE", MAE, engine.state.epoch)
writer.add_scalar("testing/OR", OR, engine.state.epoch)
(4)绘制曲线并保存
# loss
fig = plt.figure(figsize=(3, 4), facecolor='white', edgecolor='black', dpi=150)
ax = fig.add_subplot(1, 1, 1) # subplot(nrows行数, ncols列数, plot_number当前子图偏编号)
plt.ylim(0, 1)
ax.plot(train_loss)
ax.set_title("train loss")
ax.set_xlabel("iteration")
ax.set_ylabel("loss")
plt.legend()
plt.savefig(args.save_path+"train-loss.png")
plt.show()
# SROCC+KROCC+PLCC+OR
fig = plt.figure(figsize=(6, 4), facecolor='white', edgecolor='black'