1. 完整训练代码:
import jax
import numpy as np
from jax import numpy as jnp
from jax import vmap, value_and_grad, jit
import jax.random as random
from jax.scipy.special import logsumexp
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
mnist_img_size = (28, 28)
ln = 0.01
BATCH_SIZE = 128
training_epochs = 100
batch_size = 128
def init_model(layer_widths, scale=0.1):
params = []
key = random.PRNGKey(0)
weights_key, bias_key = random.split(key, num = 2)
for input_layer, output_layer in zip(layer_widths[:-1], layer_widths[1:]):
weights = random.normal(weights_key, shape=(output_layer, input_layer), dtype=jnp.float32)
bias = random.normal(bias_key, shape=(output_layer,), dtype=jnp.float32)
params.append((scale * weights,scale * bias))
return params
def predict_fn(params, x):
hidden_layers = params[:-1]
input_data = x
for weights, bias in hidden_layers:
input_data = jax.nn.relu(jnp.dot(weights, input_data) + bias)
weights_last, bias_last = params[-1]
logits = jnp.dot(weights_last, input_data) + bias_last
return logits - logsumexp(logits)
batched_MLP_predict = vmap(predict_fn, in_axes=(None, 0))
def loss_fn(params, images, gt_labels):
predictions = batched_MLP_predict(params, images)
return -jnp.mean(predictions * gt_labels)
@jit
def update_weights(params, images, gt_labels):
loss, grads = value_and_grad(loss_fn)(params, images, gt_labels)
return loss, jax.tree_util.tree_map(lambda p ,g:p - ln * g, params, grads)
def custom_transform(x):
return np.ravel(np.array(x, dtype=np.float32))
def custom_collate_fn(batch):
data = list(zip(*batch))
images = np.array(data[0])
labels = np.array(data[1])
return images, labels
def load_data():
train_mnist = MNIST(root = "data/train_mnist", train = True, download = True, transform= custom_transform)
test_mnist = MNIST(root = "data/test_mnist", train = False, download = True, transform= custom_transform)
train_loader = DataLoader(train_mnist, batch_size, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)
test_loader = DataLoader(test_mnist, batch_size, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)
return train_loader, train_loader
def accuracy(params, data):
acc = 0
for images, labels in data:
pre_classes = jnp.argmax(batched_MLP_predict(params, images), axis = 1)
acc += np.sum(pre_classes == labels)
return acc / (len(data) * batch_size)
def training():
train_data, test_data = load_data()
layer_widths = [784, 512, 10]
params = init_model(layer_widths)
for epoch in range(0, training_epochs):
for cnt, (images, labels) in enumerate(train_data):
print(cnt)
gt_labels = jax.nn.one_hot(labels, len(MNIST.classes))
loss, params = update_weights(params, images, gt_labels)
if cnt % 50 == 0:
print("loss is :", loss)
print(f'Epoch{epoch}, train_acc = {accuracy(params, train_data)} test_acc = {accuracy(params, test_data)}')
return params
params = training()
2. 可视化代码