Source code for optymus.methods.adaptive._rmsprop
import time
import jax
import jax.numpy as jnp
from tqdm.auto import tqdm
from optymus.methods.utils import BaseOptimizer
from optymus.methods.utils._result import OptimizeResult
jax.config.update("jax_enable_x64", True)
class RMSProp(BaseOptimizer):
r"""RMSprop optimizer
RMSprop is an adaptive learning rate optimization algorithm that divides the learning rate
by a running average of the squared gradients. It is particularly useful for non-stationary
objectives.
We can write the update rule for RMSprop as follows:
.. math::
E[g^2]_t = \beta E[g^2]_{t-1} + (1 - \beta) g_t^2
x_{t+1} = x_t - \frac{\eta}{\sqrt{E[g^2]_t + \epsilon}} g_t
where :math:`E[g^2]_t` is the running average of the squared gradients, :math:`g_t` is the gradient,
:math:`\beta` is the decay rate, :math:`\eta` is the learning rate, :math:`\epsilon` is a small constant
to avoid division by zero, and :math:`t` is the current iteration.
References
----------
[1] Tieleman, T., & Hinton, G. (2012). Lecture 6.5-rmsprop: Divide the gradient by a running average of its recent magnitude. COURSERA: Neural Networks for Machine Learning, 4(2), 26-31.
Parameters
----------
f_obj : callable
Objective function to minimize
f_cons : list of callables
List of constraint functions to minimize
args : tuple
Arguments to pass to the objective function
args_cons : tuple
Arguments to pass to the constraint functions
x0 : ndarray
Initial guess
beta : float
Decay rate
eps : float
Small constant to avoid division by zero
tol : float
Tolerance for the norm of the gradient
learning_rate : float
Learning rate
max_iter : int
Maximum number of iterations
verbose : bool
Whether to display a progress bar
maximize : bool
Whether to maximize the objective function
Returns
-------
method_name : str
Method name
xopt : ndarray
Optimal point
fmin : float
Minimum value
num_iter : int
Number of iterations
path : ndarray
Path taken
eg2 : ndarray
Running average of the squared gradients
"""
def optimize(self):
start_time = time.time()
x = self.x0.astype(float) # Ensure x0 is of a floating-point type
grad = jax.grad(self.penalized_obj)
Eg2 = jnp.zeros_like(x)
path = [x]
eg2_list = []
f_history = [float(self.penalized_obj(x))]
grad_norms = []
num_iter = 0
termination_reason = "max_iter_reached"
progress_bar = tqdm(range(1, self.max_iter + 1), desc="RMSProp", disable=not self.verbose)
for _ in progress_bar:
g = grad(x)
grad_norms.append(float(jnp.linalg.norm(g)))
if jnp.linalg.norm(g) < self.tol:
termination_reason = "gradient_norm_below_tol"
break
g = grad(x)
if jnp.linalg.norm(g) < self.tol:
termination_reason = "gradient_norm_below_tol"
break
Eg2 = self.beta1 * Eg2 + (1 - self.beta1) * g**2
x = self.learning_rate * g / (jnp.sqrt(Eg2) + self.eps)
x = self.project(x)
path.append(x)
eg2_list.append(Eg2)
f_history.append(float(self.penalized_obj(x)))
num_iter += 1
end_time = time.time()
elapsed_time = end_time - start_time
return OptimizeResult({
"method_name": "RMSprop" if not self.f_cons else "RMSprop with Penalty",
"x0": self.x0,
"xopt": x,
"fmin": self.f_obj(x),
"num_iter": num_iter,
"path": jnp.array(path),
"eg2": jnp.array(eg2_list),
"f_history": jnp.array(f_history),
"grad_norms": jnp.array(grad_norms),
"termination_reason": termination_reason,
"time": elapsed_time,
})
[docs]
def rmsprop(**kwargs):
optimizer = RMSProp(**kwargs)
return optimizer.optimize()