Source code for optymus.methods.adaptive._adagrad

import time

import jax
import jax.numpy as jnp
from tqdm.auto import tqdm

from optymus.methods.utils import BaseOptimizer
from optymus.methods.utils._result import OptimizeResult

jax.config.update("jax_enable_x64", True)


class AdaGrad(BaseOptimizer):
    r"""Adagrad optimizer

    Adagrad is an adaptive learning rate optimization algorithm that adapts the learning rate
    for each parameter based on the historical gradients. It is particularly useful for sparse
    data and non-stationary objectives.

    We can write the update rule for Adagrad as follows:

    .. math::
        g_{t} = \nabla f(x_t)

        G_{t} = G_{t-1} + g_{t}^2

        x_{t+1} = x_t - \frac{\eta}{\sqrt{G_{t} + \epsilon}} g_{t}

    where :math:`g_{t}` is the gradient, :math:`G_{t}` is the sum of the squares of the gradients,
    :math:`\eta` is the learning rate, :math:`\epsilon` is a small constant to avoid division by zero,
    and :math:`t` is the current iteration.

    References
    ----------
    [1] Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive subgradient methods for online learning and stochastic optimization. Journal of Machine Learning Research, 12(Jul), 2121-2159.

    Parameters
    ----------
    f_obj : callable
        Objective function to minimize
    f_cons : list of callables
        List of constraint functions to minimize
    args : tuple
        Arguments to pass to the objective function
    args_cons : tuple
        Arguments to pass to the constraint functions
    x0 : ndarray
        Initial guess
    eps : float
        Small constant to avoid division by zero
    tol : float
        Tolerance for the norm of the gradient
    learning_rate : float
        Learning rate
    max_iter : int
        Maximum number of iterations
    verbose : bool
        Whether to display a progress bar
    maximize : bool
        Whether to maximize the objective function

    Returns
    -------
    method_name : str
        Method name
    xopt : ndarray
        Optimal point
    fmin : float
        Minimum value
    num_iter : int
        Number of iterations
    path : ndarray
        Path taken
    g_sum : ndarray
        Sum of the squares of the gradients
    """

    def optimize(self):
        start_time = time.time()
        x = self.x0.astype(float)  # Ensure x0 is of a floating-point type

        grad = jax.grad(self.penalized_obj)

        g_sq_sum = jnp.zeros_like(x)
        path = [x]
        g_sum_list = []
        f_history = [float(self.penalized_obj(x))]
        grad_norms = []
        num_iter = 0
        termination_reason = "max_iter_reached"

        progress_bar = tqdm(range(1, self.max_iter + 1), desc="Adagrad", disable=not self.verbose)

        for _ in progress_bar:
            g = grad(x)
            grad_norms.append(float(jnp.linalg.norm(g)))
            if jnp.linalg.norm(g) < self.tol:
                termination_reason = "gradient_norm_below_tol"
                break
            g = grad(x)
            g_sq_sum += g**2
            x -= self.learning_rate * g / (jnp.sqrt(g_sq_sum) + self.eps)
            x = self.project(x)

            path.append(x)
            g_sum_list.append(g_sq_sum)
            f_history.append(float(self.penalized_obj(x)))
            num_iter += 1

        end_time = time.time()
        elapsed_time = end_time - start_time
        return OptimizeResult({
            "method_name": "Adagrad" if not self.f_cons else "Adagrad with Penalty",
            "x0": self.x0,
            "xopt": x,
            "fmin": self.f_obj(x),
            "num_iter": num_iter,
            "path": jnp.array(path),
            "g_sum": jnp.array(g_sum_list),
            "f_history": jnp.array(f_history),
            "grad_norms": jnp.array(grad_norms),
            "termination_reason": termination_reason,
            "time": elapsed_time,
        })


[docs] def adagrad(**kwargs): optimizer = AdaGrad(**kwargs) return optimizer.optimize()