持续更新 …
定义自己的数据集 Dataset, Dataloader
class Dataset_name(Dataset):
def __init__(self, flag='train'):
assert flag in ['train', 'test', 'valid']
self.flag = flag
self.__load_data__()
def __getitem__(self, index):
pass
def __len__(self):
pass
def __load_data__(self, csv_paths: list):
pass
print("train_X.shape:{}\ntrain_Y.shape:{}\n" \
"valid_X.shape:{}\nvalid_Y.shape:{}\n".format(
self.train_X.shape, self.train_Y.shape, self.valid_X.shape, self.valid_Y.shape))
train_dataset = Dataset_name(flag='train')
train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
valid_dataset = Dataset_name(flag='valid')
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=64, shuffle=True)
开始训练并调整学习率
for epoch in range(args.epochs):
Your_model.train()
train_epoch_loss = []
for idx,(data_x,data_y) in enumerate(train_dataloader, 0):
data_x = data_x.to(torch.float32).to(args.device)
data_y = data_y.to(torch.float32).to(args.device)
outputs = Your_model(data_x)
optimizer.zero_grad()
loss = criterion(data_y,outputs)
loss.backward()
optimizer.step()
train_epoch_loss.append(loss.item())
train_loss.append(loss.item())
if idx % (len(train_dataloader)//2) == 0:
print("epoch={}/{},{}/{}of train, loss={}".format(
epoch, args.epochs, idx, len(train_dataloader),loss.item()))
train_epochs_loss.append(np.average(train_epoch_loss))
# =====================valid============================
Your_model.eval()
valid_epoch_loss = []
for idx,(data_x,data_y) in enumerate(valid_dataloader, 0):
data_x = data_x.to(torch.float32).to(args.device)
data_y = data_y.to(torch.float32).to(args.device)
outputs = Your_model(data_x)
loss = criterion(outputs,data_y)
valid_epoch_loss.append(loss.item())
valid_loss.append(loss.item())
valid_epochs_loss.append(np.average(valid_epoch_loss))
#==================early stopping======================
early_stopping(valid_epochs_loss[-1],
model=Your_model,path=r'c:\\your_model_to_save')
if early_stopping.early_stop:
print("Early stopping")
break
#====================adjust lr========================
lr_adjust = {
2: 5e-5,
4: 1e-5,
6: 5e-6,
8: 1e-6,
10: 5e-7,
15: 1e-7
}
if epoch in lr_adjust.keys():
lr = lr_adjust[epoch]
for param_group in optimizer.param_groups:
param_group['lr'] = lr
print('Updating learning rate to {}'.format(lr))
load_model
模型的参数中,key值不同,多了 module
state_dict = torch.load('checkpoint.pt')
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module
new_state_dict[name] = v
model.load_state_dict(new_state_dict) # 从新加载这个模型
编写预测类
以下代码用于加载一个预先训练好的模型,进行模型预测,大都可以使用,不排除特定情况!
from ceevee.base import AbstractPredictor
class MySuperPredictor(AbstractPredictor):
def __init__(self, weights_path: str, ):
super().__init__()
self.model = self._load_model(weights_path=weights_path)
def process(self, x, *kw):
with torch.no_grad():
res = self.model(x)
return res
@staticmethod
def _load_model(weights_path):
model = ModelClass()
weights = torch.load(weights_path, map_location='cpu')
model.load_state_dict(weights)
return model