Source code for optymus.methods.adaptive._yogi
import time
import jax
import jax.numpy as jnp
from tqdm.auto import tqdm
from optymus.methods.utils import BaseOptimizer
from optymus.methods.utils._result import OptimizeResult
jax.config.update("jax_enable_x64", True)
class Yogi(BaseOptimizer):
r"""Yogi optimizer
Yogi is an adaptive learning rate optimization algorithm that combines the advantages of
the Adam and RMSprop optimization algorithms. It uses the sign of the gradient to adapt
the learning rate.
We can write the update rule for Yogi as follows:
.. math::
m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t
v_t = v_{t-1} - (1 - \beta_2) (g_t^2) \text{sign}(v_{t-1} - g_t^2)
\hat{m}_t = \frac{m_t}{1 - \beta_1^t}
\hat{v}_t = \frac{v_t}{1 - \beta_2^t}
x_{t+1} = x_t - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
where :math:`m_t` and :math:`v_t` are the first and second moment estimates, respectively,
:math:`g_t` is the gradient, :math:`\beta_1` and :math:`\beta_2` are the exponential decay rates
for the first and second moment estimates, respectively, :math:`\alpha` is the learning rate,
:math:`\epsilon` is a small constant to avoid division by zero, and :math:`t` is the current iteration.
References
----------
[1] Zaheer, M., Reddi, S. J., Sachan, D. S., Kale, S., Kumar, S., & Hovy, E. (2018). Adaptive methods for nonconvex optimization. In Advances in Neural Information Processing Systems (pp. 8779-8788).
Parameters
----------
f_obj : callable
Objective function to minimize
f_cons : list of callables
List of constraint functions to minimize
args : tuple
Arguments to pass to the objective function
args_cons : tuple
Arguments to pass to the constraint functions
x0 : ndarray
Initial guess
beta1 : float
Exponential decay rate for the first moment estimates
beta2 : float
Exponential decay rate for the second moment estimates
eps : float
Small constant to avoid division by zero
tol : float
Tolerance for the norm of the gradient
learning_rate : float
Learning rate
max_iter : int
Maximum number of iterations
verbose : bool
Whether to display a progress bar
maximize : bool
Whether to maximize the objective function
Returns
-------
method_name : str
Method name
xopt : ndarray
Optimal point
fmin : float
Minimum value
num_iter : int
Number of iterations
path : ndarray
Path taken
v : ndarray
Second moment estimates
"""
def optimize(self):
start_time = time.time()
x = self.x0.astype(float) # Ensure x0 is of a floating-point type
grad = jax.grad(self.penalized_obj)
m = jnp.zeros_like(x)
v = jnp.zeros_like(x)
path = [x]
v_list = []
f_history = [float(self.penalized_obj(x))]
grad_norms = []
num_iter = 0
termination_reason = "max_iter_reached"
progress_bar = tqdm(range(1, self.max_iter + 1), desc="Yogi", disable=not self.verbose)
for t in progress_bar:
g = grad(x)
grad_norms.append(float(jnp.linalg.norm(g)))
if jnp.linalg.norm(g) < self.tol:
termination_reason = "gradient_norm_below_tol"
break
g = grad(x)
m = self.beta1 * m + (1 - self.beta1) * g
v = v - (1 - self.beta2) * (g**2) * jnp.sign(v - g**2)
m_hat = m / (1 - self.beta1**t)
v_hat = v / (1 - self.beta2**t)
x = x - self.learning_rate * m_hat / (jnp.sqrt(v_hat) + self.eps)
x = self.project(x)
path.append(x)
v_list.append(v)
f_history.append(float(self.penalized_obj(x)))
num_iter += 1
end_time = time.time()
elapsed_time = end_time - start_time
return OptimizeResult({
"method_name": "Yogi" if not self.f_cons else "Yogi with Penalty",
"x0": self.x0,
"xopt": x,
"fmin": self.f_obj(x),
"num_iter": num_iter,
"path": jnp.array(path),
"v": jnp.array(v_list),
"f_history": jnp.array(f_history),
"grad_norms": jnp.array(grad_norms),
"termination_reason": termination_reason,
"time": elapsed_time,
})
[docs]
def yogi(**kwargs):
optimizer = Yogi(**kwargs)
return optimizer.optimize()