# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
import math
import torch
from torch.optim.optimizer import Optimizer
import time
import numpy as np
torch.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class MetaBalance(Optimizer):
r"""Implements MetaBalance algorithm.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
relax factor: the hyper-parameter to control the magnitude proximity
beta: the hyper-parameter to control the moving averages of magnitudes, set as 0.9 empirically
"""
def __init__(self, params, relax_factor=0.7, beta=0.9):
if not 0.0 <= relax_factor < 1.0:
raise ValueError("Invalid relax factor: {}".format(relax_factor))
if not 0.0 <= beta < 1.0:
raise ValueError("Invalid beta: {}".format(beta))
defaults = dict(relax_factor=relax_factor, beta=beta)
super(MetaBalance, self).__init__(params, defaults)
@torch.no_grad()
def step(self, loss_array):#, closure=None
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
# loss = None
# if closure is not None:
# with torch.enable_grad():
# loss = closure()
self.balance_GradMagnitudes(loss_array)
#return loss
def balance_GradMagnitudes(self, loss_array):
for loss_index, loss in enumerate(loss_array):
loss.backward(retain_graph=True
metabalance源码解析
最新推荐文章于 2024-11-01 20:08:05 发布