Code for VeLO 2: Training Versatile Learned Optimizers by Scaling Up

Code for VeLO 2: Training Versatile Learned Optimizers by Scaling Up

上一篇文章已经介绍了怎么训练一个MLP网络,这篇文章将介绍一下怎么用VeLO训练resnets

这篇文章基于https://colab.research.google.com/drive/1-ms12IypE-EdDSNjhFMdRdBbMnH94zpH#scrollTo=RQBACAPQZyB-,将介绍使用learned optimizer in the VeLO family:

  • 一个简单的图片识别人物
  • resetnets
#@title imports, configuration, and model classes

from absl import app
from datetime import datetime

from functools import partial
from typing import Any, Callable, Sequence, Tuple

from flax import linen as nn

import jax
import jax.numpy as jnp

from jaxopt import loss
from jaxopt import OptaxSolver
from jaxopt import tree_util

import optax

import tensorflow_datasets as tfds
import tensorflow as tf


# 可以使用的数据集
dataset_names = [
    "mnist", "kmnist", "emnist", "fashion_mnist", "cifar10", "cifar100"
]


L2REG = 1e-4
LEARNING_RATE = .2
EPOCHS = 10
MOMENTUM = .9
DATASET = 'cifar100' #@param [ "mnist", "kmnist", "emnist", "fashion_mnist", "cifar10", "cifar100"]
MODEL = 'resnet18' #@param ["resnet1", "resnet18", "resnet34"]
TRAIN_BATCH_SIZE = 256
TEST_BATCH_SIZE = 1024


# 加载数据集
def load_dataset(split, *, is_training, batch_size):
  version = 3
  ds, ds_info = tfds.load(
      f"{DATASET}:{version}.*.*",
      as_supervised=True,  # remove useless keys
      split=split,
      with_info=True)
  ds = ds.cache().repeat()
  if is_training:
    ds = ds.shuffle(10 * batch_size, seed=0)
  ds = ds.batch(batch_size)
  return iter(tfds.as_numpy(ds)), ds_info


class ResNetBlock(nn.Module):
  """ResNet block."""
  filters: int
  conv: Any
  norm: Any
  act: Callable
  strides: Tuple[int, int] = (1, 1)

  @nn.compact
  def __call__(self, x,):
    residual = x
    y = self.conv(self.filters, (3, 3), self.strides)(x)
    y = self.norm()(y)
    y = self.act(y)
    y = self.conv(self.filters, (3, 3))(y)
    y = self.norm(scale_init=nn.initializers.zeros)(y)

    if residual.shape != y.shape:
      residual = self.conv(self.filters, (1, 1),
                           self.strides, name='conv_proj')(residual)
      residual = self.norm(name='norm_proj')(residual)

    return self.act(residual + y)


class ResNet(nn.Module):
  """ResNetV1."""
  stage_sizes: Sequence[int]
  block_cls: Any
  num_classes: int
  num_filters: int = 64
  dtype: Any = jnp.float32
  act: Callable = nn.relu

  @nn.compact
  def __call__(self, x, train: bool = True):
    conv = partial(nn.Conv, use_bias=False, dtype=self.dtype)
    norm = partial(nn.BatchNorm,
                   use_running_average=not train,
                   momentum=0.9,
                   epsilon=1e-5,
                   dtype=self.dtype)

    x = conv(self.num_filters, (7, 7), (2, 2),
             padding=[(3, 3), (3, 3)],
             name='conv_init')(x)
    x = norm(name='bn_init')(x)
    x = nn.relu(x)
    x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    for i, block_size in enumerate(self.stage_sizes):
      for j in range(block_size):
        strides = (2, 2) if i > 0 and j == 0 else (1, 1)
        x = self.block_cls(self.num_filters * 2 ** i,
                           strides=strides,
                           conv=conv,
                           norm=norm,
                           act=self.act)(x)
    x = jnp.mean(x, axis=(1, 2))
    x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
    x = jnp.asarray(x, self.dtype)
    return x


# 虽然不太清楚为啥ResNet为啥没有__init__函数,但是估计又是python某个不知名的骚操作吧 emmm 我看它__call__这个函数也写的挺骚的。
ResNet1 = partial(ResNet, stage_sizes=[1], block_cls=ResNetBlock)
ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock)
ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock)



#@title training loop definition (run this cell to launch training)
import functools
from typing import Any
from typing import Callable
from typing import NamedTuple
from typing import Optional

from dataclasses import dataclass

import jax
import jax.numpy as jnp

from jaxopt._src import base
from jaxopt._src import tree_util


# 这个类的目的只是为了保存状态信息的吧
class OptaxState(NamedTuple):
  """Named tuple containing state information."""
  iter_num: int # 迭代的数量
  value: float  # value
  error: float  # 
  internal_state: NamedTuple
  aux: Any

# we need to reimplement optax's OptaxSolver's lopt_update method to properly pass in the loss data that VeLO expects.
def lopt_update(self,
            params: Any,
            state: NamedTuple,
            *args,
            **kwargs) -> base.OptStep:
  """Performs one iteration of the optax solver.

  Args:
    params: pytree containing the parameters.  应该是resnet参数的pytree
    state: named tuple containing the solver state.  
    *args: additional positional arguments to be passed to ``fun``.
    **kwargs: additional keyword arguments to be passed to ``fun``.
  Returns:
    (params, state)
  """
  if self.pre_update:
    params, state = self.pre_update(params, state, *args, **kwargs)

  (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs)

  # note the only difference between this function and the baseline 
  # optax.OptaxSolver.lopt_update is that `extra_args` is now passed.
  # if you would like to use a different optimizer, you will likely need to
  # remove these extra_args.
  delta, opt_state = self.opt.update(grad, state.internal_state, params, extra_args={"loss": value})
  params = self._apply_updates(params, delta)

  # Computes optimality error before update to re-use grad evaluation.
  new_state = OptaxState(iter_num=state.iter_num + 1,
                          error=tree_util.tree_l2_norm(grad),
                          value=value,
                          aux=aux,
                          internal_state=opt_state)
  return base.OptStep(params=params, state=new_state)




def train():

  # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
  # it unavailable to JAX.
  # tf.config.experimental.set_visible_devices([], 'GPU')


  # typical data loading and iterator setup

  train_ds, ds_info = load_dataset("train", is_training=True,
                                    batch_size=TRAIN_BATCH_SIZE)
  test_ds, _ = load_dataset("test", is_training=False,
                            batch_size=TEST_BATCH_SIZE)
  input_shape = (1,) + ds_info.features["image"].shape
  num_classes = ds_info.features["label"].num_classes
  iter_per_epoch_train = ds_info.splits['train'].num_examples // TRAIN_BATCH_SIZE
  iter_per_epoch_test = ds_info.splits['test'].num_examples // TEST_BATCH_SIZE

  # Set up model.
  if MODEL == "resnet1":
    net = ResNet1(num_classes=num_classes)
  elif MODEL == "resnet18":
    net = ResNet18(num_classes=num_classes)
  elif MODEL == "resnet34":
    net = ResNet34(num_classes=num_classes)
  else:
    raise ValueError("Unknown model.")

  def predict(params, inputs, aux, train=False):
    x = inputs.astype(jnp.float32) / 255.
    all_params = {"params": params, "batch_stats": aux}
    if train:
      # Returns logits and net_state (which contains the key "batch_stats").
      return net.apply(all_params, x, train=True, mutable=["batch_stats"])
    else:
      # Returns logits only.
      return net.apply(all_params, x, train=False)

  logistic_loss = jax.vmap(loss.multiclass_logistic_loss)

  def loss_from_logits(params, l2reg, logits, labels):
    mean_loss = jnp.mean(logistic_loss(labels, logits))
    sqnorm = tree_util.tree_l2_norm(params, squared=True)
    return mean_loss + 0.5 * l2reg * sqnorm

  def accuracy_and_loss(params, l2reg, data, aux):
    inputs, labels = data
    logits = predict(params, inputs, aux)
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
    loss = loss_from_logits(params, l2reg, logits, labels)
    return accuracy, loss

  def loss_fun(params, l2reg, data, aux):
    inputs, labels = data
    logits, net_state = predict(params, inputs, aux, train=True)
    loss = loss_from_logits(params, l2reg, logits, labels)
    # batch_stats will be stored in state.aux
    return loss, net_state["batch_stats"]

  # The default optimizer used by jaxopt is commented out here
  # opt = optax.sgd(learning_rate=LEARNING_RATE,
  #                 momentum=MOMENTUM,
  #                 nesterov=True)
  
  NUM_STEPS = EPOCHS * iter_per_epoch_train
  opt = prefab.optax_lopt(NUM_STEPS)
  
  # We need has_aux=True because loss_fun returns batch_stats.
  solver = OptaxSolver(opt=opt, 
                       fun=jax.value_and_grad(loss_fun, has_aux=True), 
                       maxiter=EPOCHS * iter_per_epoch_train, 
                       has_aux=True, 
                       value_and_grad=True)

  # Initialize parameters.
  # 初始化训练的参数
  rng = jax.random.PRNGKey(0)
  init_vars = net.init(rng, jnp.zeros(input_shape), train=True)  # 这里的net是resnet,但是我不清楚这里的 init_vars['params']是个什么东西 emm
  params = init_vars["params"]
  batch_stats = init_vars["batch_stats"]
  start = datetime.now().replace(microsecond=0)

  # Run training loop.
  # 训练的循环
  state = solver.init_state(params, L2REG, next(test_ds), batch_stats)  # 初始化优化器
  jitted_update = jax.jit(functools.partial(lopt_update, self=solver))  # 艹,各种骚操作,这里的jax.jit是什么东西呀?
  print(f'Iterations: {solver.maxiter}')

  for _ in range(solver.maxiter):  # 优化器的最大迭代次数
    train_minibatch = next(train_ds)

    if state.iter_num % iter_per_epoch_train == iter_per_epoch_train - 1:
      # Once per epoch evaluate the model on the train and test sets.
      test_acc, test_loss = 0., 0.
      # make a pass over test set to compute test accuracy
      for _ in range(iter_per_epoch_test):
          tmp = accuracy_and_loss(params, L2REG, next(test_ds), batch_stats)
          test_acc += tmp[0] / iter_per_epoch_test
          test_loss += tmp[1] / iter_per_epoch_test

      train_acc, train_loss = 0., 0.
      # make a pass over train set to compute train accuracy
      for _ in range(iter_per_epoch_train):
          tmp = accuracy_and_loss(params, L2REG, next(train_ds), batch_stats)
          train_acc += tmp[0] / iter_per_epoch_train
          train_loss += tmp[1] / iter_per_epoch_train

      train_acc = jax.device_get(train_acc)
      train_loss = jax.device_get(train_loss)
      test_acc = jax.device_get(test_acc)
      test_loss = jax.device_get(test_loss)
      # time elapsed without microseconds
      time_elapsed = (datetime.now().replace(microsecond=0) - start)

      print(f"[Epoch {(state.iter_num+1) // (iter_per_epoch_train+1)}/{EPOCHS}] "
            f"Train acc: {train_acc:.3f}, train loss: {train_loss:.3f}. "
            f"Test acc: {test_acc:.3f}, test loss: {test_loss:.3f}. "
            f"Time elapsed: {time_elapsed}")

    params, state = jitted_update(params=params,
                                  state=state,
                                  l2reg=L2REG,
                                  data=train_minibatch,
                                  aux=batch_stats)
    batch_stats = state.aux

train()

跟baseline的值比较一下

静态的baseline数据

#@title Comparing a baseline optimizer vs. VeLO on resnet18 cifar100.
baseline_train_acc = [0.235,
0.333,
0.428,
0.430,
0.480,
0.528,
0.591,
0.617,
0.661,
0.709,]

baseline_test_acc = [0.216,
0.298,
0.362,
0.343,
0.359,
0.371,
0.375,
0.377,
0.379,
0.399,]


velo_train_acc = [0.170,
0.270,
0.346,
0.331,
0.466,
0.477,
0.551,
0.749,
0.848,
0.955,]

velo_test_acc = [0.163,
0.255,
0.310,
0.290,
0.377,
0.369,
0.385,
0.458,
0.464,
0.492,]


from matplotlib.pyplot import figure

figure(figsize=(8, 6), dpi=80)

plt.plot(range(10), baseline_train_acc, label="Baseline Train Accuracy", c='b',  linestyle='dashed')
plt.plot(range(10), baseline_test_acc, label = "Baseline Test Accuracy", c='b')
plt.plot(range(10), velo_train_acc, label= "VeLO Train Accuracy", c='r', linestyle='dashed')
plt.plot(range(10), velo_test_acc, label="VeLO Test Accuracy", c='r')
plt.xlabel("Training Epochs")
plt.ylabel("Accuracy")
plt.title("Training Accuracy Curves for Resnet18 on Cifar100")
plt.legend()
plt.show()






#@title Comparing a baseline optimizer vs. VeLO on resnet18 cifar100.
baseline_train_loss = [3.470,
2.979,
2.535,
2.567,
2.351,
2.183,
1.970,
1.925,
1.781,
1.644,]

baseline_test_loss = [3.571,
3.206,
2.899,
3.064,
3.055,
3.107,
3.170,
3.447,
3.530,
3.597,]


velo_train_loss = [3.701,
3.071,
2.771,
2.948,
2.294,
2.287,
2.059,
1.268,
0.948,
0.645,]

velo_test_loss = [3.739,
3.188,
2.974,
3.266,
2.797,
2.925,
3.062,
2.769,
2.950,
2.882]


from matplotlib.pyplot import figure

figure(figsize=(8, 6), dpi=80)

plt.plot(range(10), baseline_train_loss, label="Baseline Train Loss", c='b',  linestyle='dashed')
plt.plot(range(10), baseline_test_loss, label = "Baseline Test Loss", c='b')
plt.plot(range(10), velo_train_loss, label= "VeLO Train Loss", c='r', linestyle='dashed')
plt.plot(range(10), velo_test_loss, label="VeLO Test Loss", c='r')
plt.xlabel("Training Epochs")
plt.ylabel("Loss")
plt.title("Training Loss Curves for Resnet18 on Cifar100 ")
plt.legend()
plt.show()
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: velo2cam_calibration是指Velodyne激光雷达和相机之间的标定,用于将两者的数据进行融合。通过velo2cam_calibration,可以将激光雷达和相机的坐标系进行转换,从而实现激光雷达和相机的数据对齐。这对于自动驾驶等领域的应用非常重要。 ### 回答2: velo2cam_calibration是一种用于激光雷达和相机之间进行校准的方法。在自动驾驶等领域,激光雷达和相机是两个常用的传感器,它们能够提供车辆周围环境的丰富信息。然而,由于它们工作原理的差异,需要将其坐标系统转换为一个共同的参考系,以便进行数据的融合和正确感知。 velo2cam_calibration的目的是确定激光雷达和相机之间的空间变换关系,即相对位置和方向。这个过程需要使用特定的校准板或者特征点来进行标定。先采集激光雷达和相机同时观察同一场景的数据,然后通过计算和优化算法,找到使得激光雷达数据和相机图像匹配的最佳变换关系。 校准的结果可以用于许多应用中,比如将激光雷达点云投影到相机图像上,通过匹配激光雷达的距离信息和相机的视觉信息,可以得到更精确的物体检测和跟踪结果。此外,校准还能够帮助解决激光雷达和相机之间的视角不匹配的问题,提高算法的准确性和稳定性。 综上所述,velo2cam_calibration是一个重要的过程,可以使激光雷达和相机之间实现一致的参考系,以便更好地融合和利用两者的数据,提高自动驾驶等系统的性能。 ### 回答3: velo2cam_calibration是指激光雷达与相机之间的标定过程。在自动驾驶汽车和机器人应用中,激光雷达和相机的标定非常重要,因为这些传感器通常安装在不同的位置和方向上,需要将它们的数据进行校准,以实现精确的环境感知和路径规划。 velo2cam_calibration的主要目的是通过确定激光雷达和相机之间的外部和内部参数,将它们的坐标系进行转换,以便能够将二者的数据进行匹配。外部参数包括位置和方向的偏移量,也就是雷达和相机之间的相对位置关系。内部参数包括相机的焦距、畸变、光心等,需要通过标定过程来获取。 velo2cam_calibration的标定过程主要分为两个步骤。首先,激光雷达和相机需要在同一个场景下进行观测,并且需要有共同的特征点或标定板来进行匹配。然后,通过最小二乘法或其他优化算法,将观测到的特征点进行拟合,从而得到雷达和相机之间的转换矩阵。 通过velo2cam_calibration标定过程得到的结果可以用于后续的数据处理和算法设计。例如,在自动驾驶中,可以将激光雷达的点云数据与相机的图像数据进行融合,从而实现更准确的障碍物检测和跟踪,以及更精确的定位和地图构建。 总之,velo2cam_calibration是激光雷达与相机之间的标定过程,通过确定它们的外部和内部参数,实现数据的匹配和融合,从而提高自动驾驶和机器人应用的感知和规划能力。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值