Source code for optymus.methods.adaptive._adamax

import time

import jax
import jax.numpy as jnp
from tqdm.auto import tqdm

from optymus.methods.utils import BaseOptimizer
from optymus.methods.utils._result import OptimizeResult

jax.config.update("jax_enable_x64", True)


class Adamax(BaseOptimizer):
    r"""Adamax optimizer

    Adamax is an extension of the Adam optimization algorithm that uses the infinity norm
    of the gradients instead of the L2 norm. It is particularly useful for non-stationary
    objectives.

    We can write the update rule for Adamax as follows:

    .. math::
        m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t

        u_t = \max(\beta_2 u_{t-1}, |g_t|)

        x_{t+1} = x_t - \frac{\eta}{u_t + \epsilon} m_t

    where :math:`m_t` and :math:`u_t` are the first and infinity moment estimates, respectively,
    :math:`g_t` is the gradient, :math:`\beta_1` and :math:`\beta_2` are the exponential decay rates
    for the first and second moment estimates, respectively, :math:`\eta` is the learning rate,
    :math:`\epsilon` is a small constant to avoid division by zero, and :math:`t` is the current iteration.

    References
    ----------
    [1] Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.

    Parameters
    ----------
    f_obj : callable
        Objective function to minimize
    f_cons : list of callables
        List of constraint functions to minimize
    args : tuple
        Arguments to pass to the objective function
    args_cons : tuple
        Arguments to pass to the constraint functions
    x0 : ndarray
        Initial guess
    beta1 : float
        Exponential decay rate for the first moment estimates
    beta2 : float
        Exponential decay rate for the second moment estimates
    eps : float
        Small constant to avoid division by zero
    tol : float
        Tolerance for the norm of the gradient
    learning_rate : float
        Learning rate
    max_iter : int
        Maximum number of iterations
    verbose : bool
        Whether to display a progress bar
    maximize : bool
        Whether to maximize the objective function

    Returns
    -------
    method_name : str
        Method name
    xopt : ndarray
        Optimal point
    fmin : float
        Minimum value
    num_iter : int
        Number of iterations
    path : ndarray
        Path taken
    u : ndarray
        Infinity moment estimates
    """

    def optimize(self):
        start_time = time.time()
        x = self.x0.astype(float)  # Ensure x0 is of a floating-point type

        grad = jax.grad(self.penalized_obj)
        m = jnp.zeros_like(x)
        u = jnp.zeros_like(x)
        path = [x]
        u_list = []
        f_history = [float(self.penalized_obj(x))]
        grad_norms = []
        num_iter = 0
        termination_reason = "max_iter_reached"

        progress_bar = tqdm(range(1, self.max_iter + 1), desc="Adamax", disable=not self.verbose)

        for _ in progress_bar:
            g = grad(x)
            grad_norms.append(float(jnp.linalg.norm(g)))
            if jnp.linalg.norm(g) < self.tol:
                termination_reason = "gradient_norm_below_tol"
                break
            g = grad(x)
            m = self.beta1 * m + (1 - self.beta1) * g
            u = jnp.maximum(self.beta2 * u, jnp.abs(g))
            x -= self.learning_rate * m / (u + self.eps)
            x = self.project(x)

            path.append(x)
            u_list.append(u)
            f_history.append(float(self.penalized_obj(x)))
            num_iter += 1

        end_time = time.time()
        elapsed_time = end_time - start_time
        return OptimizeResult({
            "method_name": "Adamax" if not self.f_cons else "Adamax with Penalty",
            "x0": self.x0,
            "xopt": x,
            "fmin": self.f_obj(x),
            "num_iter": num_iter,
            "path": jnp.array(path),
            "u": jnp.array(u_list),
            "f_history": jnp.array(f_history),
            "grad_norms": jnp.array(grad_norms),
            "termination_reason": termination_reason,
            "time": elapsed_time,
        })


[docs] def adamax(**kwargs): optmizer = Adamax(**kwargs) return optmizer.optimize()