如题
import matplotlib.pyplot as plt
import numpy as np
fig, axs = plt.subplots(1, 5, figsize=(16, 8))
epoch_data = [1, 10, 100, 1000]
loss_data1 = np.array([[4.431531e-01, 3.673583e-02, 6.285390e-06, 4.1090e-06],
[1.623631e+00, 2.330988e-01, 1.241400e-04, 9.0449e-06],
[3.134755e-01, 1.535817e-01, 3.926213e-04, 5.283528e-06],
[1.510560e-01, 8.463229e-02, 3.146114e-04, 4.022342e-06],
[2.341002e-01, 1.572589e-01, 1.801148e-03, 5.845174e-06]])
loss_data2 = np.array([[8.786691e-01, 1.692781e-01, 1.646515e-05, 1.020533e-05],
[2.459055e-01, 2.234710e-02, 1.430798e-05, 8.401059e-07],
[5.773173e-02, 3.107118e-02, 2.589364e-05, 2.196231e-06],
[7.178970e-02, 2.909906e-02, 1.070074e-05, 2.422505e-06],
[1.595414e-01, 6.577221e-02, 5.846239e-05, 3.661418e-06]])
x_uniform = np.arange(len(epoch_data))
for i in range(5):
if i == 0:
# 画图
axs[i].plot(x_uniform, loss_data1[i], 'b*--', alpha=0.5, linewidth=1, label='Stokes')
axs[i].plot(x_uniform, loss_data2[i], 'rs--', alpha=0.5, linewidth=1, label='Darcy')
axs[i].legend() # 显示上面的label
axs[i].set_xlabel('Epoch')
axs[i].set_ylabel('loss_number')
axs[i].set_xticklabels([0, 1, 10, 100, 1000])
else:
axs[i].plot(x_uniform, loss_data1[i], 'b*--', alpha=0.5, linewidth=1, label='Stokes')
axs[i].plot(x_uniform, loss_data2[i], 'rs--', alpha=0.5, linewidth=1, label='Darcy')
axs[i].legend() # 显示上面的label
axs[i].set_xlabel('Epoch')
axs[i].set_xticklabels([0, 1, 10, 100, 1000])
plt.tight_layout()
plt.show()