码代码
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
def mean_list(valid_ids, valid_values, stride=10, iters=1000):
length = len(valid_ids)
nums = length // stride
used_ids = []
mean_values = []
std_vallues = []
valid_values = np.asarray(valid_values)
for k in range(nums):
used_ids.append(int(k*stride*iters))
temp_list = valid_values[k:(k+1)*stride]
mean = np.mean(temp_list)
std = np.std(temp_list)
mean_values.append(mean)
std_vallues.append(std)
return used_ids, mean_values, std_vallues
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--server', action='store_true', default=False)
parser.add_argument('-a', '--show_all', action='store_true', default=False)
args = parser.parse_args()
path = [
'2021-02-07--13-39-13_seg_consist_suhu_snemi3d_d5_u200_pre400k_w1',
'2021-03-08--09-23-59_seg_consist_suhu_snemi3d_d5_u200_pre400k_w1_550',
'2021-03-02--13-41-48_seg_consist_suhu_snemi3d_d5_u50_pre400k_w1',
'2021-03-02--13-41-48_seg_consist_suhu_snemi3d_d5_u100_pre400k_w1',
'placeholder'
]
valid_stride = 1
start = 0
train_mode = 1
valid_mode = 2
train_mse_ylim = [0, 0]
valid_mse_ylim = [0, 0]
path = path[:-1]
split_sign=True
f_train = ['loss.txt'] * len(path)
f_valid = ['valid.txt'] * len(path)
# f_valid = ['valid_ac3_50.txt'] * len(path)
f_valid = ['waterz_ac3_50.txt'] * len(path); split_sign=False; valid_mode = 4
# f_valid = ['waterz_ac4_50.txt'] * len(path); split_sign=False; valid_mode = 4
# f_valid = ['lmc_ac3_50.txt'] * len(path); split_sign=False; valid_mode = 4
# f_valid = ['waterz_cremia_50.txt'] * len(path); split_sign=False; valid_mode = 4
# f_valid = ['waterz_cremib_50.txt'] * len(path); split_sign=False; valid_mode = 4
# f_valid = ['waterz_cremic_50.txt'] * len(path); split_sign=False; valid_mode = 4
# f_valid = ['waterz_fib_50.txt'] * len(path); split_sign=False; valid_mode = 4
out_path = './loss_curves'
all_title = ['train', 'Validation']
img_name = ['train', 'valid']
base_path = '../models'
if not os.path.exists(out_path):
os.makedirs(out_path)
train_ids = []
train_mse = []
valid_ids = []
valid_mse = []
all_best_value = []
for iters, p in enumerate(path):
if args.server:
valid_file = open(os.path.join(base_path, p, p, f_train[iters]), 'r')
else:
valid_file = open(os.path.join(base_path, p, f_train[iters]), 'r')
valid_list = [x[:-1] for x in valid_file.readlines()]
ids_train = []
values_mse_train = []
id_val_mse_train = {}
for f in valid_list:
name = f.split(',')
id_tmp = int(name[0].split(' ')[-1])
if 'consistency' in p or 'affine' in p:
val_tmp = float(name[2].split('=')[-1])
else:
val_tmp = float(name[train_mode].split('=')[-1])
if id_tmp != 1:
id_val_mse_train.setdefault(id_tmp, val_tmp)
length = len(id_val_mse_train)
# length = 2000
best_train_mse_id = 100
best_train_mse_value = 10000
for k in range(100, 100*length, 100):
ids_train.append(k)
try:
temp_value = id_val_mse_train[k]
except:
temp_value = 0.2
values_mse_train.append(temp_value)
if temp_value < best_train_mse_value:
best_train_mse_id = k
best_train_mse_value = temp_value
if args.server:
valid_file = open(os.path.join(base_path, p, p, f_valid[iters]), 'r')
else:
valid_file = open(os.path.join(base_path, p, f_valid[iters]), 'r')
valid_list = [x[:-1] for x in valid_file.readlines()]
ids_valid = []
values_mse_valid = []
id_val_mse_valid = {}
max_id = 0
for f in valid_list:
name = f.split(',')
if split_sign:
id_tmp = int(name[0].split('-')[-1])
else:
id_tmp = int(name[0].split('=')[-1])
val_tmp = float(name[valid_mode].split('=')[-1])
if id_tmp != 1:
id_val_mse_valid.setdefault(id_tmp, val_tmp)
max_id = id_tmp
# length = len(id_val_mse_valid)
length = max_id // 1000
# length = 210
best_valid_mse_id = 1000
best_valid_mse_value = 10000
for k in range(1000, 1000*length+1, 1000):
ids_valid.append(k)
try:
temp_value = id_val_mse_valid[k]
except:
temp_value = 6
values_mse_valid.append(temp_value)
if temp_value < best_valid_mse_value:
best_valid_mse_id = k
best_valid_mse_value = temp_value
train_ids.append(ids_train)
train_mse.append(values_mse_train)
valid_ids.append(ids_valid)
valid_mse.append(values_mse_valid)
all_best_value.append(best_train_mse_id)
all_best_value.append(best_train_mse_value)
all_best_value.append(best_valid_mse_id)
all_best_value.append(best_valid_mse_value)
if args.server:
plt.switch_backend('agg')
if args.show_all:
plt.figure(figsize=(10,10),dpi=100)
for i in range(len(path)):
plt.plot(train_ids[i][start:], train_mse[i][start:], label='%s_%d_%.6f' % (path[i][21:], all_best_value[4*i], all_best_value[4*i+1]))
plt.plot(valid_ids[i][start:], valid_mse[i][start:], label='%s_%d_%.6f' % (path[i][21:], all_best_value[4*i+2], all_best_value[4*i+3]))
plt.title('train+valid')
plt.legend()
if train_mse_ylim[1] > 0:
plt.ylim(train_mse_ylim[0], train_mse_ylim[1])
plt.grid()
plt.savefig(os.path.join(out_path, img_name[0] + '.png'), bbox_inches = 'tight')
if not args.server:
plt.show()
plt.close('all')
else:
plt.figure(figsize=(10,10),dpi=100)
for i in range(len(path)):
plt.plot(train_ids[i][start:], train_mse[i][start:], label='%s_%d_%.6f' % (path[i][21:], all_best_value[4*i], all_best_value[4*i+1]))
plt.title(all_title[0])
if train_mse_ylim[1] > 0:
plt.ylim(train_mse_ylim[0], train_mse_ylim[1])
plt.legend()
plt.grid()
plt.savefig(os.path.join(out_path, img_name[0] + '.png'), bbox_inches = 'tight')
if not args.server:
plt.show()
plt.close('all')
if args.server:
plt.switch_backend('agg')
plt.figure(figsize=(10,10),dpi=100)
for i in range(len(path)):
used_ids, mean_values, std_vallues = mean_list(valid_ids[i], valid_mse[i], stride=valid_stride, iters=1000)
std_down = [mean_values[x]-std_vallues[x] for x in range(len(mean_values))]
std_up = [mean_values[x]+std_vallues[x] for x in range(len(mean_values))]
plt.plot(used_ids, mean_values, label='%s_%d_%.6f' % (path[i][21:], all_best_value[4*i+2], all_best_value[4*i+3]))
plt.fill_between(used_ids, std_down, std_up, alpha=0.3)
# plt.plot(valid_ids[i][start:], valid_mse[i][start:], label='%s_%d_%.6f' % (path[i][21:], all_best_value[4*i+2], all_best_value[4*i+3]))
plt.title(all_title[1])
if valid_mse_ylim[1] > 0:
plt.ylim(valid_mse_ylim[0], valid_mse_ylim[1])
plt.legend()
plt.grid()
plt.savefig(os.path.join(out_path, img_name[1] + '.png'), bbox_inches = 'tight')
if not args.server:
plt.show()
plt.close('all')
结果展示