《动手学深度学习》7.5 批量归一化-从零开始实现
import torch
from torch import nn
import MyFunction as MF
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
if not torch.is_grad_enabled():
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
mean = X.mean(dim=0)
var = ((X - mean)**2).mean(dim=0)
else:
mean = X.mean(dim=(0, 2, 3), keepdim=True)
var = ((X - mean)**2).mean(dim=(0, 2, 3), keepdim=True)
X_hat = (X - mean) / torch.sqrt(var + eps)
moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
moving_var = momentum * moving_var + (1.0 - momentum) * var
Y = gamma * X_hat + beta
return Y, moving_mean.data, moving_var.data
class BatchNorm(nn.Module):
def __init__(self, num_features, num_dims):
super().__init__()
if num_dims == 2:
shape = (1, num_features)
else:
shape = (1, num_features, 1, 1)
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
self.moving_mean = torch.zeros(shape)
self.moving_var = torch.ones(shape)
def forward(self, X):
if self.moving_mean.device != X.device:
self.moving_mean = self.moving_mean.to(X.device)
self.moving_var = self.moving_var.to(X.device)
Y, self.moving_mean, self.moving_var = batch_norm(
X, self.gamma, self.beta, self.moving_mean, self.moving_var,
eps=1e-5, momentum=0.9)
return Y
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4),
nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4),
nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(), nn.Linear(16 * 4 * 4, 120),
BatchNorm(120, num_dims=2), nn.Sigmoid(),
nn.Linear(120, 84), BatchNorm(84, num_dims=2),
nn.Sigmoid(), nn.Linear(84, 10))
lr, num_epochs, batch_size = 1.0, 10, 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_iter, test_iter = MF.load_data_fashion_mnist(batch_size)
MF.train_ch7_BN(net, train_iter, test_iter, num_epochs, lr, device=device)
- 训练结果
预测部分
import torch
from torch import nn
import MyFunction as MF
import os
import matplotlib.pyplot as plt
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
if not torch.is_grad_enabled():
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
mean = X.mean(dim=0)
var = ((X - mean)**2).mean(dim=0)
else:
mean = X.mean(dim=(0, 2, 3), keepdim=True)
var = ((X - mean)**2).mean(dim=(0, 2, 3), keepdim=True)
X_hat = (X - mean) / torch.sqrt(var + eps)
moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
moving_var = momentum * moving_var + (1.0 - momentum) * var
Y = gamma * X_hat + beta
return Y, moving_mean.data, moving_var.data
class BatchNorm(nn.Module):
def __init__(self, num_features, num_dims):
super().__init__()
if num_dims == 2:
shape = (1, num_features)
else:
shape = (1, num_features, 1, 1)
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
self.moving_mean = torch.zeros(shape)
self.moving_var = torch.ones(shape)
def forward(self, X):
if self.moving_mean.device != X.device:
self.moving_mean = self.moving_mean.to(X.device)
self.moving_var = self.moving_var.to(X.device)
Y, self.moving_mean, self.moving_var = batch_norm(
X, self.gamma, self.beta, self.moving_mean, self.moving_var,
eps=1e-5, momentum=0.9)
return Y
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4),
nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16,kernel_size=5), BatchNorm(16, num_dims=4),
nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(), nn.Linear(16 * 4 * 4, 120),
BatchNorm(120, num_dims=2), nn.Sigmoid(),
nn.Linear(120, 84), BatchNorm(84, num_dims=2),
nn.Sigmoid(), nn.Linear(84, 10))
filename = r'./data/BN-9.pth'
if os.path.exists(filename):
net.load_state_dict(torch.load(filename))
else:
print(f"No such file or directory: '{filename}'")
lr, num_epochs, batch_size = 1.0, 10, 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_iter, test_iter = MF.load_data_fashion_mnist(batch_size)
MF.predict_ch7_BN(net, test_iter)
plt.show()
- 预测结果
MyFunction包里定义的相关函数
import datetime
def train_ch7_BN(net, train_iter, test_iter, num_epochs, lr, device):
print("Start Training...")
nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print("==========" * 8 + "%s" % nowtime)
def init_weights(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)
print(f'training on {device}:{torch.cuda.get_device_name()}')
net.to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss()
timer = MF.Timer()
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
net.train()
with tqdm(train_iter) as t:
for X, y in t:
timer.start()
optimizer.zero_grad()
X, y = X.to(device), y.to(device)
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
optimizer.step()
train_l_sum += l.item() * X.shape[0]
train_acc_sum += MF.accuracy(y_hat,y)
n += y.shape[0]
train_l = train_l_sum / n
train_acc = train_acc_sum / n
timer.stop()
t.set_description(f"epoch:{epoch}")
t.set_postfix(loss="%.3f" % train_l, train_acc="%.3f" % train_acc, time="%.3f sec" % timer.stop())
torch.save(net.state_dict(),"./data/BN-%d.pth" %(epoch))
test_acc = MF.evaluate_accuracy_BN(net, test_iter)
print(f'epoch:{epoch+1},loss {train_l:.3f}, train_acc {train_acc:.3f}, test_acc {test_acc:.3f}, {timer.stop()} sec')
print(f'{n* num_epochs / timer.sum():.1f} examples/sec '
f'on {str(device)}')
nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print("==========" * 8 + "%s" % nowtime)
print('Finished Training...')
def evaluate_accuracy_BN(net, data_iter, device=None):
acc_sum, n = 0.0, 0
if isinstance(net, torch.nn.Module):
net.eval()
if not device:
device = next(iter(net.parameters())).device
for X, y in data_iter:
if isinstance(X, list):
X = [x.to(device) for x in X]
else:
X = X.to(device)
y = y.to(device)
acc_sum += MF.accuracy(net(X), y)
n += y.shape[0]
result = acc_sum / n
return result
def predict_ch7_BN(net, test_iter, n=6):
if isinstance(net, torch.nn.Module):
net.eval()
"""预测标签"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for X, y in test_iter:
h, w = X.shape[-2:]
X.to(device)
y.to(device)
trues = MF.get_fashion_mnist_labels(y)
preds = MF.get_fashion_mnist_labels(net(X).argmax(axis=1))
titles = [true + '\n' + pred for true, pred in zip(trues, preds)]
MF.show_images(X[0:n].reshape((n,h,w)), 1, n, titles=titles[0:n])
break