loss和validation曲线对比

码代码

import os
import argparse
import numpy as np
import matplotlib.pyplot as plt

def mean_list(valid_ids, valid_values, stride=10, iters=1000):
    length = len(valid_ids)
    nums = length // stride
    used_ids = []
    mean_values = []
    std_vallues = []
    valid_values = np.asarray(valid_values)
    for k in range(nums):
        used_ids.append(int(k*stride*iters))
        temp_list = valid_values[k:(k+1)*stride]
        mean = np.mean(temp_list)
        std = np.std(temp_list)
        mean_values.append(mean)
        std_vallues.append(std)
    return used_ids, mean_values, std_vallues

parser = argparse.ArgumentParser()
parser.add_argument('-s', '--server', action='store_true', default=False)
parser.add_argument('-a', '--show_all', action='store_true', default=False)
args = parser.parse_args()

path = [
    '2021-02-07--13-39-13_seg_consist_suhu_snemi3d_d5_u200_pre400k_w1',
    '2021-03-08--09-23-59_seg_consist_suhu_snemi3d_d5_u200_pre400k_w1_550',
    '2021-03-02--13-41-48_seg_consist_suhu_snemi3d_d5_u50_pre400k_w1',
    '2021-03-02--13-41-48_seg_consist_suhu_snemi3d_d5_u100_pre400k_w1',

    'placeholder'
]

valid_stride = 1
start = 0
train_mode = 1
valid_mode = 2
train_mse_ylim = [0, 0]
valid_mse_ylim = [0, 0]

path = path[:-1]
split_sign=True
f_train = ['loss.txt'] * len(path)
f_valid = ['valid.txt'] * len(path)
# f_valid = ['valid_ac3_50.txt'] * len(path)
f_valid = ['waterz_ac3_50.txt'] * len(path); split_sign=False; valid_mode = 4
# f_valid = ['waterz_ac4_50.txt'] * len(path); split_sign=False; valid_mode = 4
# f_valid = ['lmc_ac3_50.txt'] * len(path); split_sign=False; valid_mode = 4
# f_valid = ['waterz_cremia_50.txt'] * len(path); split_sign=False; valid_mode = 4
# f_valid = ['waterz_cremib_50.txt'] * len(path); split_sign=False; valid_mode = 4
# f_valid = ['waterz_cremic_50.txt'] * len(path); split_sign=False; valid_mode = 4
# f_valid = ['waterz_fib_50.txt'] * len(path); split_sign=False; valid_mode = 4
out_path = './loss_curves'
all_title = ['train', 'Validation']
img_name = ['train', 'valid']

base_path = '../models'
if not os.path.exists(out_path):
    os.makedirs(out_path)

train_ids = []
train_mse = []
valid_ids = []
valid_mse = []
all_best_value = []
for iters, p in enumerate(path):
    if args.server:
        valid_file = open(os.path.join(base_path, p, p, f_train[iters]), 'r')
    else:
        valid_file = open(os.path.join(base_path, p, f_train[iters]), 'r')
    valid_list = [x[:-1] for x in valid_file.readlines()]
    ids_train = []
    values_mse_train = []
    id_val_mse_train = {}
    for f in valid_list:
        name = f.split(',')
        id_tmp = int(name[0].split(' ')[-1])
        if 'consistency' in p or 'affine' in p:
            val_tmp = float(name[2].split('=')[-1])
        else:
            val_tmp = float(name[train_mode].split('=')[-1])
        if id_tmp != 1:
            id_val_mse_train.setdefault(id_tmp, val_tmp)
    
    length = len(id_val_mse_train)
    # length = 2000
    best_train_mse_id = 100
    best_train_mse_value = 10000
    for k in range(100, 100*length, 100):
        ids_train.append(k)
        try:
            temp_value = id_val_mse_train[k]
        except:
            temp_value = 0.2
        values_mse_train.append(temp_value)
        if temp_value < best_train_mse_value:
            best_train_mse_id = k
            best_train_mse_value = temp_value
    
    if args.server:
        valid_file = open(os.path.join(base_path, p, p, f_valid[iters]), 'r')
    else:
        valid_file = open(os.path.join(base_path, p, f_valid[iters]), 'r')
    valid_list = [x[:-1] for x in valid_file.readlines()]
    ids_valid = []
    values_mse_valid = []
    id_val_mse_valid = {}
    max_id = 0
    for f in valid_list:
        name = f.split(',')
        if split_sign:
            id_tmp = int(name[0].split('-')[-1])
        else:
            id_tmp = int(name[0].split('=')[-1])
        val_tmp = float(name[valid_mode].split('=')[-1])
        if id_tmp != 1:
            id_val_mse_valid.setdefault(id_tmp, val_tmp)
        max_id = id_tmp
    
    # length = len(id_val_mse_valid)
    length = max_id // 1000
    # length = 210
    best_valid_mse_id = 1000
    best_valid_mse_value = 10000
    for k in range(1000, 1000*length+1, 1000):
        ids_valid.append(k)
        try:
            temp_value = id_val_mse_valid[k]
        except:
            temp_value = 6
        values_mse_valid.append(temp_value)
        if temp_value < best_valid_mse_value:
            best_valid_mse_id = k
            best_valid_mse_value = temp_value
    
    train_ids.append(ids_train)
    train_mse.append(values_mse_train)
    valid_ids.append(ids_valid)
    valid_mse.append(values_mse_valid)
    all_best_value.append(best_train_mse_id)
    all_best_value.append(best_train_mse_value)
    all_best_value.append(best_valid_mse_id)
    all_best_value.append(best_valid_mse_value)

if args.server:
    plt.switch_backend('agg')

if args.show_all:
    plt.figure(figsize=(10,10),dpi=100)
    for i in range(len(path)):
        plt.plot(train_ids[i][start:], train_mse[i][start:], label='%s_%d_%.6f' % (path[i][21:], all_best_value[4*i], all_best_value[4*i+1]))
        plt.plot(valid_ids[i][start:], valid_mse[i][start:], label='%s_%d_%.6f' % (path[i][21:], all_best_value[4*i+2], all_best_value[4*i+3]))
    plt.title('train+valid')
    plt.legend()
    if train_mse_ylim[1] > 0:
        plt.ylim(train_mse_ylim[0], train_mse_ylim[1])
    plt.grid()
    plt.savefig(os.path.join(out_path, img_name[0] + '.png'), bbox_inches = 'tight')
    if not args.server:
        plt.show()
    plt.close('all')
else:
    plt.figure(figsize=(10,10),dpi=100)
    for i in range(len(path)):
        plt.plot(train_ids[i][start:], train_mse[i][start:], label='%s_%d_%.6f' % (path[i][21:], all_best_value[4*i], all_best_value[4*i+1]))
    plt.title(all_title[0])
    if train_mse_ylim[1] > 0:
        plt.ylim(train_mse_ylim[0], train_mse_ylim[1])
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(out_path, img_name[0] + '.png'), bbox_inches = 'tight')
    if not args.server:
        plt.show()
    plt.close('all')

    if args.server:
        plt.switch_backend('agg')
    plt.figure(figsize=(10,10),dpi=100)
    for i in range(len(path)):
        used_ids, mean_values, std_vallues = mean_list(valid_ids[i], valid_mse[i], stride=valid_stride, iters=1000)
        std_down = [mean_values[x]-std_vallues[x] for x in range(len(mean_values))]
        std_up = [mean_values[x]+std_vallues[x] for x in range(len(mean_values))]
        plt.plot(used_ids, mean_values, label='%s_%d_%.6f' % (path[i][21:], all_best_value[4*i+2], all_best_value[4*i+3]))
        plt.fill_between(used_ids, std_down, std_up, alpha=0.3)
        # plt.plot(valid_ids[i][start:], valid_mse[i][start:], label='%s_%d_%.6f' % (path[i][21:], all_best_value[4*i+2], all_best_value[4*i+3]))
    plt.title(all_title[1])
    if valid_mse_ylim[1] > 0:
        plt.ylim(valid_mse_ylim[0], valid_mse_ylim[1])
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(out_path, img_name[1] + '.png'), bbox_inches = 'tight')
    if not args.server:
        plt.show()
    plt.close('all')

结果展示

在这里插入图片描述
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

深山里的小白羊

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值