Source code for optymus.methods.second_order._newton_raphson

import time

import jax
import jax.numpy as jnp
from tqdm.auto import tqdm

from optymus.methods.utils import BaseOptimizer
from optymus.methods.utils._result import OptimizeResult


class NewtonRaphson(BaseOptimizer):
    _default_line_search = "armijo"
    r"""Newton-Raphson method with different matrix types

    The Newton-Raphson method is a second-order optimization algorithm that uses
    different matrices (Hessian, Fisher Information, Identity) to compute the step direction.

    We can minimize the objective function :math:`f` by solving the following
    equation:

    .. math::
        M(x) p = -\nabla f(x)

    where :math:`M(x)` is the chosen matrix (Hessian, Fisher Information, Identity) of :math:`f`
    evaluated at point :math:`x`, :math:`\nabla f(x)` is the gradient of :math:`f` evaluated
    at point :math:`x`, and :math:`p` is the step direction.

    Parameters
    ----------
    f_obj : callable
        Objective function to minimize
    f_cons : callable
        Constraint function
    args : tuple
        Arguments for the objective function
    args_cons : tuple
        Arguments for the constraint function
    x0 : ndarray
        Initial guess
    tol : float
        Tolerance for stopping criteria
    learning_rate : float
        Learning rate for line search
    max_iter : int
        Maximum number of iterations
    h_type : str
        Type of matrix to use ('hessian', 'fisher', 'identity', 'bfgs').
        Use 'bfgs' for functions with custom_vjp (e.g., topology optimization).
    verbose : bool
        If True, prints progress
    maximize : bool
        If True, maximize the objective function

    Returns
    -------
    method_name : str
        Method name
    xopt : ndarray
        Optimal point
    fmin : float
        Minimum value
    num_iter : int
        Number of iterations
    path : ndarray
        Path taken
    alphas : ndarray
        Step sizes
    """

    def optimize(self, h_type="hessian"):
        start_time = time.time()
        x = self.x0.astype(float)  # Ensure x0 is of a floating-point type

        grad = jax.grad(self.penalized_obj)

        # Only compute exact Hessian if needed (avoids error with custom_vjp functions)
        hess = None
        if h_type in ("hessian", "fisher"):
            hess = jax.hessian(self.penalized_obj)

        # Initialize inverse Hessian approximation for BFGS
        B_inv = None
        if h_type == "bfgs":
            B_inv = jnp.eye(len(x))

        path = [x]
        alphas = []
        f_history = [float(self.penalized_obj(x))]
        grad_norms = []
        num_iter = 0
        termination_reason = "max_iter_reached"

        progres_bar = tqdm(range(self.max_iter), desc="Newton-Raphson", disable=not self.verbose)

        for _ in progres_bar:
            g = grad(x)
            grad_norms.append(float(jnp.linalg.norm(g)))

            if jnp.linalg.norm(g) < self.tol:
                termination_reason = "gradient_norm_below_tol"
                break

            # Compute search direction based on h_type
            if h_type == "hessian":
                M = hess(x)
                d = jnp.linalg.solve(M, -g)
            elif h_type == "fisher":
                M = -hess(x)
                d = jnp.linalg.solve(M, -g)
            elif h_type == "identity":
                d = -g  # No solve needed with identity matrix
            elif h_type == "bfgs":
                d = -jnp.dot(B_inv, g)  # No solve needed with inverse Hessian
            else:
                msg = f"Unknown h_type: {h_type}"
                raise ValueError(msg)

            r = self._do_line_search(x, d, g)
            x_new = self.project(r["xopt"])

            # BFGS inverse Hessian update
            if h_type == "bfgs":
                delta = x_new - x
                g_new = grad(x_new)
                gamma = g_new - g

                denom = jnp.dot(delta, gamma)
                if denom > 1e-10:  # Curvature condition check
                    rho = 1.0 / denom
                    I = jnp.eye(len(x))
                    B_inv = (I - rho * jnp.outer(delta, gamma)) @ B_inv
                    B_inv = B_inv @ (I - rho * jnp.outer(gamma, delta))
                    B_inv = B_inv + rho * jnp.outer(delta, delta)

            x = x_new

            alphas.append(r["alpha"])
            path.append(x)
            f_history.append(float(self.penalized_obj(x)))
            num_iter += 1

        end_time = time.time()
        elapsed_time = end_time - start_time

        method_suffix = f" ({h_type})" if h_type != "hessian" else ""
        penalty_suffix = " with Penalty" if self.f_cons else ""
        return OptimizeResult({
            "method_name": f"Newton-Raphson{method_suffix}{penalty_suffix}",
            "x0": self.x0,
            "xopt": x,
            "fmin": self.f_obj(x, *self.args),
            "num_iter": num_iter,
            "path": jnp.array(path),
            "alphas": jnp.array(alphas),
            "f_history": jnp.array(f_history),
            "grad_norms": jnp.array(grad_norms),
            "termination_reason": termination_reason,
            "time": elapsed_time,
        })


[docs] def newton_raphson(htype='hessian', **kwargs): optimizer = NewtonRaphson(**kwargs) return optimizer.optimize(htype)