import os
import torch
import imageio
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from natsort import natsorted
Vector = [torch.Tensor, torch.Tensor]
def load_diabetes_data(csv_file_path: str, delim: str, data_type=np.float32) -> Vector:
if not os.path.exists(csv_file_path):
print('csv file not exists!')
x_y_data = np.loadtxt(csv_file_path, dtype=data_type, delimiter=delim)
x_data = torch.from_numpy(x_y_data[:, : -1])
y_data = torch.from_numpy(x_y_data[:, [-1]])
return [x_data, y_data]
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
self.ac_func = torch.nn.Sigmoid()
def forward(self, x):
x = self.ac_func(self.linear1(x))
x = self.ac_func(self.linear2(x))
x = self.ac_func(self.linear3(x))
return x
def train(x_data: torch.Tensor, y_data: torch.Tensor, epoch_num: int) -> None:
model = Model()
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
epoch_num = epoch_num
""" create dynamic figure """
loss_list = []
epoch_list = []
# open interactive
plt.ion()
for epoch in range(epoch_num):
""" forward """
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
""" backward """
optimizer.zero_grad()
loss.backward()
""" update """
optimizer.step()
loss_list.append(loss.item())
epoch_list.append(epoch)
""" dynamic show image. """
plt.clf() # clear figure axis
plt.plot(epoch_list, loss_list, 'r-')
plt.title("loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.pause(0.1) # pause 100ms
""" save img file """
save_img_path = "./img/{:0>4d}.jpg".format(epoch)
plt.savefig(save_img_path)
print('\r Epoch: {:>3.0f}%[{}->{}], loss: {}'.format(epoch * 100 / (epoch_num - 1),
int(epoch / 10) * '*',
(int(epoch_num / 10) - 1 - int(epoch / 10)) * '.',
loss.item()), end='')
# close interactive
plt.ioff()
def save_loss_line_to_gif(loss_img_path: str, gif_img_path: str) -> None:
if not os.path.exists(loss_img_path):
print('no data to merge.')
return
""" get all image by nature order """
img_list = natsorted(os.listdir(loss_img_path))
gif_buffer = []
for img_name in tqdm(img_list):
""" because plt.savefig() save image's suffix is jpg """
if img_name.split('.')[-1] != 'jpg':
continue
img_path = os.path.join(loss_img_path, img_name)
gif_buffer.append(imageio.imread(img_path))
imageio.mimsave(gif_img_path, gif_buffer, 'GIF', duration=0.1)
if __name__ == "__main__":
csv_path = 'diabetes.csv'
delimiter = ','
[i_data, o_data] = load_diabetes_data(csv_path, delim=delimiter)
train(i_data, o_data, epoch_num=100)
loss_image_path = "./img"
save_gif_path = "./img/res.gif"
save_loss_line_to_gif(loss_image_path, save_gif_path)
测试结果: