都是通用代码~~ 给自己学习用 一个整体的使用流程 学习用哦 仅是学习啊 本人还是用y5的
一、导入包以及设置随机种子
import numpy as np
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import random
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
二、以类的方式定义超参数
三、定义自己的模型
四、定义早停类(此步骤可以省略)
class EarlyStopping():
def __init__(self,patience=7,verbose=False,delta=0):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
def __call__(self,val_loss,model,path):
print("val_loss={}".format(val_loss))
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss,model,path)
elif score < self.best_score+self.delta:
self.counter+=1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter>=self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss,model,path)
self.counter = 0
def save_checkpoint(self,val_loss,model,path):
if self.verbose:
print(
f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), path+'/'+'model_checkpoint.pth')
self.val_loss_min = val_loss
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
五、定义自己的数据集Dataset,DataLoader
class Dataset_name(Dataset):
def __init__(self, flag='train'):
assert flag in ['train', 'test', 'valid']
self.flag = flag
self.__load_data__()
def __getitem__(self, index):
pass
def __len__(self):
pass
def __load_data__(self, csv_paths: list):
pass
print(
"train_X.shape:{}\ntrain_Y.shape:{}\nvalid_X.shape:{}\nvalid_Y.shape:{}\n"
.format(self.train_X.shape, self.train_Y.shape, self.valid_X.shape, self.valid_Y.shape))
train_dataset = Dataset_name(flag='train')
train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
valid_dataset = Dataset_name(flag='valid')
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=64, shuffle=True)
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
六、实例化模型,设置loss,优化器等
model = Your_model().to(args.device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(Your_model.parameters(),lr=args.learning_rate)
train_loss = []
valid_loss = []
train_epochs_loss = []
valid_epochs_loss = []
early_stopping = EarlyStopping(patience=args.patience,verbose=True)
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
七、开始训练以及调整lr
for epoch in range(args.epochs):
Your_model.train()
train_epoch_loss = []
for idx,(data_x,data_y) in enumerate(train_dataloader,0):
data_x = data_x.to(torch.float32).to(args.device)
data_y = data_y.to(torch.float32).to(args.device)
outputs = Your_model(data_x)
optimizer.zero_grad()
loss = criterion(data_y,outputs)
loss.backward()
optimizer.step()
train_epoch_loss.append(loss.item())
train_loss.append(loss.item())
if idx%(len(train_dataloader)//2)==0:
print("epoch={}/{},{}/{}of train, loss={}".format(
epoch, args.epochs, idx, len(train_dataloader),loss.item()))
train_epochs_loss.append(np.average(train_epoch_loss))
#=====================valid============================
Your_model.eval()
valid_epoch_loss = []
for idx,(data_x,data_y) in enumerate(valid_dataloader,0):
data_x = data_x.to(torch.float32).to(args.device)
data_y = data_y.to(torch.float32).to(args.device)
outputs = Your_model(data_x)
loss = criterion(outputs,data_y)
valid_epoch_loss.append(loss.item())
valid_loss.append(loss.item())
valid_epochs_loss.append(np.average(valid_epoch_loss))
#==================early stopping======================
early_stopping(valid_epochs_loss[-1],model=Your_model,path=r'c:\\your_model_to_save')
if early_stopping.early_stop:
print("Early stopping")
break
#====================adjust lr========================
lr_adjust = {
2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
10: 5e-7, 15: 1e-7, 20: 5e-8
}
if epoch in lr_adjust.keys():
lr = lr_adjust[epoch]
for param_group in optimizer.param_groups:
param_group['lr'] = lr
print('Updating learning rate to {}'.format(lr))
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
八、绘图
九、预测
完事