import pickle
import matplotlib.pyplot as plt
f1 = open('/media/xm/0ABA09D10ABA09D1/交通标志识别/Capsule-master/cnn/traffic_cnn_show.txt','rb')
cnn = pickle.load(f1)
f2 = open('/media/xm/0ABA09D10ABA09D1/交通标志识别/Capsule-master/capsule/traffic_capsule_show.txt','rb')
capsule = pickle.load(f2)
def loss_plot(loss_type):
iters = range(len(cnn['self.losses']))
plt.figure()
# acc
#plt.plot(iters, self.accuracy[loss_type], 'r', label='train acc')
# loss
plt.plot(iters, cnn['self.val_loss'], 'g', label='cnn_val_loss')
if loss_type == 'epoch':
# val_acc
#plt.plot(iters, self.val_acc[loss_type], 'b', label='val acc')
# val_loss
plt.plot(iters, capsule['self.val_loss'], 'k', label='capsule_val_loss')
plt.grid(False)
plt.xlabel(loss_type)
plt.ylabel('val_loss')
plt.legend(loc="upper right")
plt.show()
plt.figure()
# acc
m = []
n = []
for x in cnn['self.val_acc']:
m.append(x)
for y in capsule['self.val_acc']:
n.append(y)
plt.plot(iters, m, 'r', label='cnn_val_acc')
# loss
#plt.plot(iters, self.losses[loss_type], 'g', label='train loss')
if loss_type == 'epoch':
# val_acc
plt.plot(iters, n, 'b', label='capsule_val_acc')
# val_loss
#plt.plot(iters, self.val_loss[loss_type], 'k', label='val loss')
plt.grid(False)
plt.xlabel(loss_type)
plt.ylabel('val_accuracy')
plt.legend(loc="lower right")
plt.show()
loss_plot('epoch')