'''记录一些常用的函数'''
import torch
from torch import tensor, nn
import torchvision
import numpy
import math
def accuracy(y_hat, y):
'''
input:
y_hat:预测结果
y:Ground truth
output:预测正确的个数
'''
if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
y_hat = y_hat.argmax(axis=1)
cmp = y_hat.type(y.dtype) == y
return float(cmp.type(y.dtype).sum())
class Accumulator:
'''
定义累加器类
'''
def __init__(self, n):
self.data = [0.0] * n
def add(self, *args):
self.data = [a + float(b) for a, b in zip(self.data, args)]
def reset(self):
self.data = [0.0] * len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def evaluate_accuracy(net, data_iter):
'''
input:
net:网络模型
data_iter:数据迭代器,train_iter or test_iter
output:
将data_iter遍历完后模型的准确率
'''
if isinstance(net, torch.nn.Module):
net.eval()
metric = Accumulator(2)
for X, y in data_iter:
metric.add(accuracy(net(X), y), y.numel())
return metric[0] / metric[1]
def evaluate_accuracy_device(net, data_iter, device=None):
'''
可指定device的评估准确率函数
input:net:model data_iter
device:
output:
accuracy
'''
if isinstance(net, torch.nn.Module):
net.eval()
if not device:
device = next(iter(net.parameters())).device
metric = Accumulator(2)
with torch.no_grad():
for X, y in data_iter:
X = X.to(device)
y = y.to(device)
metric.add(accuracy(net(X), y), y.numel())
return metric[0] / metric[1]
def train_epoch(net, train_iter, loss, updater):
'''
input:
net:模型
train_iter:训练集数据迭代器
loss:损失函数
updater:优化器
output:
loss,acc
'''
if isinstance(net, torch.nn.Module):
net.train()
metirc = Accumulator(3)
for X, y in train_iter:
y_hat = net(X)
l = loss(y_hat, y)
if isinstance(updater, torch.optim.Optimizer):
updater.zero_grad()
l.backward()
updater.step()
metirc.add(float(l) * len(y), accuracy(y_hat, y), y.size().numel())
else:
l.sum().backward()
updater(X.shape[0])
metirc.add(float(l.sum()), accuracy(y_hat, y), y.size().numel())
return metirc[0] / metirc[2], metirc[1] / metirc[2]
def train(net, train_iter, test_iter, loss, num_epochs, updater):
'''
net:模型
train_iter:train dataset iter
test_iter:test dataset iter
loss:loss func
num_epochs:number of epochs
updater:optimizer
'''
for epoch in range(num_epochs):
train_metrics = train_epoch(net, train_iter, loss, updater)
test_acc = evaluate_accuracy(net, test_iter)
train_loss, train_acc = train_metrics
print(f'train_loss:{train_loss};train_acc:{train_acc}')
print(f'test_acc:{test_acc}')
def train_device(net, train_iter, test_iter, loss=None, num_epochs=10, lr=0.1, updater=None, device=None):
'''
可指定device的训练函数,比原来的train多了to_device操作
还可以修改的地方:指定初始化方式
'''
def init_weights(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)
print('training on', device)
net.to(device)
if not updater:
updater = torch.optim.SGD(net.parameters(), lr=lr)
if not loss:
loss = nn.CrossEntropyLoss()
for _ in range(num_epochs):
metirc = Accumulator(3)
for i, (X, y) in enumerate(train_iter):
updater.zero_grad()
X, y = X.to(device), y.to(device)
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
updater.step()
metirc.add(float(l) * len(y), accuracy(y_hat, y), y.size().numel())
train_loss, train_acc = metirc[0] / metirc[2], metirc[1] / metirc[2]
print(f'train_loss:{train_loss};train_acc:{train_acc}')
test_acc = evaluate_accuracy_device(net, test_iter, device)
print(f'test_acc:{test_acc}')
def sgd(params, lr, batch_size):
'''
input:
params:list,elment:torch.Tensor.weights and bias
lr:learning rate
batch_size.
'''
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()
def cross_entropy(y_hat, y):
'''
input:
y_hat:估计类别
y:GT
交叉熵损失:GT类别对应的估计的概率取负对数再求和,求和是对所有样本求和
'''
return -torch.log(y_hat[range(len(y_hat)), y])
def squared_loss(y_hat, y):
'''
均方损失
input:
y_hat:估计类别
y:GT
'''
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
def sequence_mask(X, valid_len, value=0):
'''
X:sequence,2D matrix
valid_len:有效值的个数
value:将无效的值赋成value
'''
max_len = X.size(1)
mask = torch.arange(0, max_len, dtype=torch.float32)[None, :] < valid_len[:, None]
X[~mask] = value
return X
def mask_softmax(X, valid_len=None):
'''
X:sequence,3D tensor,(batch_size,h,w)
valid_len:有效值的个数 1D or 2D tensor
'''
if valid_len is None:
return nn.functional.softmax(X, dim=1)
else:
shape = X.shape
if valid_len.dim() == 1:
valid_len = torch.repeat_interleave(valid_len, shape[1])
else:
valid_len = valid_len.reshape(-1)
X = sequence_mask(X.reshape(-1, shape[-1]), valid_len,
value=-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)