基于JAX搭建分类模型

1. 完整训练代码:

import jax
import numpy as np
from jax import numpy as jnp
from jax import vmap, value_and_grad, jit
import jax.random as random
from jax.scipy.special import logsumexp
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

mnist_img_size = (28, 28)
ln = 0.01 # learning rate
BATCH_SIZE = 128
training_epochs = 100
batch_size = 128


def init_model(layer_widths, scale=0.1):
    params = []
    key = random.PRNGKey(0)
    weights_key, bias_key = random.split(key, num = 2)
    for input_layer, output_layer in zip(layer_widths[:-1], layer_widths[1:]):
        weights = random.normal(weights_key, shape=(output_layer, input_layer), dtype=jnp.float32)
        bias = random.normal(bias_key, shape=(output_layer,), dtype=jnp.float32)
        params.append((scale * weights,scale *  bias))

    return params


def predict_fn(params, x):
    hidden_layers = params[:-1]
    input_data = x
    for weights, bias in hidden_layers:
        input_data = jax.nn.relu(jnp.dot(weights, input_data) + bias)
    weights_last, bias_last = params[-1]
    logits = jnp.dot(weights_last, input_data) + bias_last
    
    return logits - logsumexp(logits)

batched_MLP_predict = vmap(predict_fn, in_axes=(None, 0))

def loss_fn(params, images, gt_labels):
    predictions = batched_MLP_predict(params, images)

    return -jnp.mean(predictions * gt_labels)


@jit
def update_weights(params, images, gt_labels):
    loss, grads = value_and_grad(loss_fn)(params, images, gt_labels)
    
    return loss, jax.tree_util.tree_map(lambda p ,g:p - ln * g, params, grads)


def custom_transform(x):
    return np.ravel(np.array(x, dtype=np.float32))


def custom_collate_fn(batch):
    data = list(zip(*batch))
    images = np.array(data[0])
    labels = np.array(data[1])

    return images, labels


def load_data():
    train_mnist = MNIST(root = "data/train_mnist", train = True, download = True, transform= custom_transform)
    test_mnist = MNIST(root = "data/test_mnist", train = False, download = True, transform= custom_transform)

    train_loader = DataLoader(train_mnist, batch_size, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)
    test_loader = DataLoader(test_mnist, batch_size, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)

    return train_loader, train_loader


def accuracy(params, data):
    acc = 0
    for images, labels in data:
        pre_classes = jnp.argmax(batched_MLP_predict(params, images), axis = 1)
        acc += np.sum(pre_classes == labels)

    return acc / (len(data) * batch_size)


def training():
    train_data, test_data = load_data()
    layer_widths = [784, 512, 10]
    params = init_model(layer_widths)
    for epoch in range(0, training_epochs):
        for cnt, (images, labels) in enumerate(train_data):
            print(cnt)
            gt_labels = jax.nn.one_hot(labels, len(MNIST.classes))
            loss, params = update_weights(params, images, gt_labels)

            if cnt % 50 == 0:
                print("loss is :", loss)
        
        print(f'Epoch{epoch}, train_acc = {accuracy(params, train_data)} test_acc = {accuracy(params, test_data)}')
	return params
params = training()

2. 可视化代码

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值