"""JAX replacements for scipy.optimize functions using optimistix and optax.
This module provides JAX-compatible implementations of scipy.optimize
functions, using optimistix for root-finding and minimization, and
optax for gradient-based optimization.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import numpy as np
if TYPE_CHECKING:
from numpy.typing import ArrayLike
[docs]
@dataclass
class OptimizeResult:
"""Container for optimization result.
Compatible with scipy.optimize.OptimizeResult.
Attributes
----------
x : ndarray
Solution vector
fun : float
Function value at solution
success : bool
Whether optimization converged
message : str
Description of termination
nfev : int
Number of function evaluations
nit : int
Number of iterations
jac : ndarray, optional
Jacobian at solution (if available)
"""
x: np.ndarray
fun: float
success: bool
message: str
nfev: int = 0
nit: int = 0
jac: np.ndarray | None = None
[docs]
def minimize(
fun: Callable,
x0: ArrayLike,
args: tuple = (),
method: str | None = None,
jac: Callable | bool | None = None,
bounds: list[tuple] | None = None,
tol: float | None = None,
options: dict | None = None,
) -> OptimizeResult:
"""Minimize a scalar function using optimistix/optax.
Parameters
----------
fun : callable
Objective function to minimize
x0 : array-like
Initial guess
args : tuple
Extra arguments passed to fun
method : str, optional
Optimization method: 'BFGS', 'L-BFGS-B', 'CG', 'Adam', 'SGD'
jac : callable or bool, optional
Jacobian function or True for auto-diff
bounds : list of tuple, optional
Bounds for each variable
tol : float, optional
Tolerance for termination
options : dict, optional
Solver-specific options
Returns
-------
OptimizeResult
Optimization result container
"""
options = options or {}
tol = tol or 1e-8
maxiter = options.get("maxiter", 1000)
learning_rate = options.get("learning_rate", 0.01)
# Try optimistix first, fall back to scipy
try:
return _minimize_optimistix(
fun, x0, args, method, jac, bounds, tol, maxiter, learning_rate
)
except ImportError:
return _minimize_scipy(fun, x0, args, method, jac, bounds, tol, options)
def _minimize_optimistix(
fun: Callable,
x0: ArrayLike,
args: tuple,
method: str | None,
jac: Callable | bool | None,
bounds: list[tuple] | None,
tol: float,
maxiter: int,
learning_rate: float = 0.01,
) -> OptimizeResult:
"""Minimization using optimistix."""
import jax.numpy as jnp
import optimistix as optx
x0 = jnp.asarray(x0)
# Wrap function with args
def objective(x, _):
if args:
return fun(x, *args)
return fun(x)
# Select solver based on method
if method is None or method.upper() in ("BFGS", "L-BFGS-B"):
solver: Any = optx.BFGS(rtol=tol, atol=tol)
elif method.upper() == "CG":
solver = optx.NonlinearCG(rtol=tol, atol=tol)
elif method.upper() in ("ADAM", "SGD"):
# Use gradient descent with optax
return _minimize_optax(fun, x0, args, method, tol, maxiter, learning_rate)
else:
solver = optx.BFGS(rtol=tol, atol=tol)
# Run optimization
try:
sol: Any = optx.minimise(
objective,
solver,
x0,
max_steps=maxiter,
)
return OptimizeResult(
x=np.asarray(sol.value),
fun=float(objective(sol.value, None)),
success=True,
message="Optimization converged",
nit=int(sol.stats.get("num_steps", 0)) if hasattr(sol, "stats") else 0,
)
except Exception as e:
return OptimizeResult(
x=np.asarray(x0),
fun=float(objective(x0, None)),
success=False,
message=str(e),
)
def _minimize_optax(
fun: Callable,
x0: ArrayLike,
args: tuple,
method: str,
tol: float,
maxiter: int,
learning_rate: float = 0.01,
) -> OptimizeResult:
"""Minimization using optax gradient descent.
The inner loop uses ``jax.lax.while_loop`` so the entire update sequence
is JIT-compiled rather than interpreted step-by-step in Python. This
avoids the O(maxiter) Python overhead that occurred with the previous
``for i in range(maxiter)`` implementation. (BUG-038)
"""
import jax
import jax.numpy as jnp
import optax
x0 = jnp.asarray(x0)
# Wrap function with args
def objective(x):
if args:
return fun(x, *args)
return fun(x)
# Select optimizer
if method.upper() == "ADAM":
optimizer = optax.adam(learning_rate=learning_rate)
elif method.upper() == "SGD":
optimizer = optax.sgd(learning_rate=learning_rate)
else:
optimizer = optax.adam(learning_rate=learning_rate)
# Pre-compute gradient function (will be JIT-compiled via while_loop)
grad_fn = jax.grad(objective)
# ---------------------------------------------------------------------------
# JIT-compiled loop via jax.lax.while_loop
# State: (x, opt_state, iteration, converged)
# ---------------------------------------------------------------------------
init_opt_state = optimizer.init(x0)
def cond_fn(state):
_x, _opt_state, i, converged = state
return jnp.logical_and(~converged, i < maxiter)
def body_fn(state):
x, opt_state, i, _converged = state
grads = grad_fn(x)
updates, new_opt_state = optimizer.update(grads, opt_state, x)
new_x = optax.apply_updates(x, updates)
grad_norm = jnp.linalg.norm(grads)
new_converged = grad_norm < tol
return new_x, new_opt_state, i + 1, new_converged
init_state = (x0, init_opt_state, jnp.zeros((), dtype=jnp.int32), jnp.array(False))
final_x, _final_opt_state, nit, converged = jax.lax.while_loop(
cond_fn, body_fn, init_state
)
success = bool(converged)
n_iterations = int(nit)
message = (
f"Converged after {n_iterations} iterations"
if success
else f"Maximum iterations ({maxiter}) reached"
)
return OptimizeResult(
x=np.asarray(final_x),
fun=float(objective(final_x)),
success=success,
message=message,
nit=n_iterations,
)
def _minimize_scipy(
fun: Callable,
x0: ArrayLike,
args: tuple,
method: str | None,
jac: Callable | bool | None,
bounds: list[tuple] | None,
tol: float,
options: dict,
) -> OptimizeResult:
"""Fallback to scipy.optimize.minimize."""
from scipy.optimize import minimize as scipy_minimize
result = scipy_minimize(
fun,
x0,
args=args,
method=method,
jac=jac,
bounds=bounds,
tol=tol,
options=options,
)
return OptimizeResult(
x=result.x,
fun=result.fun,
success=result.success,
message=result.message,
nfev=result.nfev if hasattr(result, "nfev") else 0,
nit=result.nit if hasattr(result, "nit") else 0,
jac=result.jac if hasattr(result, "jac") else None,
)
[docs]
def curve_fit(
f: Callable,
xdata: ArrayLike,
ydata: ArrayLike,
p0: ArrayLike | None = None,
sigma: ArrayLike | None = None,
absolute_sigma: bool = False,
bounds: tuple | None = None,
maxfev: int = 1000,
**kwargs,
) -> tuple[np.ndarray, np.ndarray]:
"""Nonlinear least squares curve fitting using optimistix.
Parameters
----------
f : callable
Model function: ``f(x, *params)``
xdata : array-like
Independent variable
ydata : array-like
Dependent variable (data to fit)
p0 : array-like, optional
Initial guess for parameters
sigma : array-like, optional
Uncertainties in ydata
absolute_sigma : bool
If True, sigma is used in absolute sense
bounds : tuple, optional
Bounds (lower, upper) for parameters
maxfev : int
Maximum function evaluations
Returns
-------
popt : ndarray
Optimal parameters
pcov : ndarray
Covariance matrix of parameters
"""
try:
return _curve_fit_optimistix(
f, xdata, ydata, p0, sigma, absolute_sigma, bounds, maxfev
)
except ImportError:
return _curve_fit_scipy(
f, xdata, ydata, p0, sigma, absolute_sigma, bounds, maxfev, **kwargs
)
def _curve_fit_optimistix(
f: Callable,
xdata: ArrayLike,
ydata: ArrayLike,
p0: ArrayLike | None,
sigma: ArrayLike | None,
absolute_sigma: bool,
bounds: tuple | None,
maxfev: int,
) -> tuple[np.ndarray, np.ndarray]:
"""Curve fitting using optimistix least squares."""
import jax
import jax.numpy as jnp
import optimistix as optx
xdata = jnp.asarray(xdata)
ydata = jnp.asarray(ydata)
if p0 is None:
# Estimate number of parameters from function signature
import inspect
sig = inspect.signature(f)
n_params = len(sig.parameters) - 1 # Exclude x
p0 = jnp.ones(n_params)
else:
p0 = jnp.asarray(p0)
if sigma is not None:
sigma = jnp.asarray(sigma)
weights = 1.0 / sigma
else:
weights = None
# Define residual function for least squares
def residual_fn(params, _):
model = f(xdata, *params)
residual = model - ydata
if weights is not None:
residual = residual * weights
return residual
# Use Levenberg-Marquardt for least squares
solver: Any = optx.LevenbergMarquardt(rtol=1e-8, atol=1e-8)
try:
sol: Any = optx.least_squares(
residual_fn,
solver,
p0,
max_steps=maxfev,
)
popt = np.asarray(sol.value)
# Compute covariance matrix via Jacobian
# pcov = (J^T @ J)^-1 * s^2
# where s^2 is the residual variance
jacobian = jax.jacobian(lambda p: residual_fn(p, None))(sol.value)
residuals = residual_fn(sol.value, None)
n_data = len(ydata)
n_params = len(popt)
dof = max(0, n_data - n_params)
# Initialize s_sq to 1.0 so it is always defined even when dof == 0
# (n_data <= n_params). Without this initialization, the unconditional
# use of s_sq below raises NameError when dof is 0. (BUG-005)
s_sq = 1.0
if dof > 0:
chi_sq = float(jnp.sum(residuals**2))
s_sq = chi_sq / dof
try:
jtj = jacobian.T @ jacobian
pcov = np.asarray(jnp.linalg.inv(jtj))
except Exception:
pcov = np.full((n_params, n_params), np.inf)
else:
pcov = np.full((n_params, n_params), np.inf)
if not absolute_sigma:
pcov = pcov * s_sq
return popt, pcov
except Exception:
# Fall back to scipy
return _curve_fit_scipy(
f,
np.asarray(xdata),
np.asarray(ydata),
np.asarray(p0) if p0 is not None else None,
np.asarray(sigma) if sigma is not None else None,
absolute_sigma,
bounds,
maxfev,
)
def _curve_fit_scipy(
f: Callable,
xdata: ArrayLike,
ydata: ArrayLike,
p0: ArrayLike | None,
sigma: ArrayLike | None,
absolute_sigma: bool,
bounds: tuple | None,
maxfev: int,
**kwargs,
) -> tuple[np.ndarray, np.ndarray]:
"""Fallback to scipy.optimize.curve_fit."""
from scipy.optimize import curve_fit as scipy_curve_fit
bounds_scipy = bounds if bounds is not None else (-np.inf, np.inf)
return scipy_curve_fit(
f,
xdata,
ydata,
p0=p0,
sigma=sigma,
absolute_sigma=absolute_sigma,
bounds=bounds_scipy,
maxfev=maxfev,
**kwargs,
)
[docs]
def least_squares(
fun: Callable,
x0: ArrayLike,
bounds: tuple = (-np.inf, np.inf),
method: str = "trf",
ftol: float = 1e-8,
xtol: float = 1e-8,
gtol: float = 1e-8,
max_nfev: int | None = None,
**kwargs,
) -> OptimizeResult:
"""Solve nonlinear least squares using optimistix.
Parameters
----------
fun : callable
Function returning residuals: fun(x) -> residuals
x0 : array-like
Initial guess
bounds : tuple
Lower and upper bounds
method : str
Method: 'trf', 'dogbox', 'lm'
ftol, xtol, gtol : float
Tolerances
max_nfev : int, optional
Maximum function evaluations
Returns
-------
OptimizeResult
Optimization result
"""
max_nfev = max_nfev or 1000
try:
return _least_squares_optimistix(fun, x0, bounds, ftol, xtol, max_nfev)
except ImportError:
return _least_squares_scipy(
fun, x0, bounds, method, ftol, xtol, gtol, max_nfev, **kwargs
)
def _least_squares_optimistix(
fun: Callable,
x0: ArrayLike,
bounds: tuple,
ftol: float,
xtol: float,
max_nfev: int,
) -> OptimizeResult:
"""Least squares using optimistix."""
import jax.numpy as jnp
import optimistix as optx
x0 = jnp.asarray(x0)
def residual_fn(x, _):
return fun(x)
solver: Any = optx.LevenbergMarquardt(rtol=xtol, atol=ftol)
try:
sol: Any = optx.least_squares(
residual_fn,
solver,
x0,
max_steps=max_nfev,
)
residuals = fun(sol.value)
cost = float(0.5 * jnp.sum(residuals**2))
return OptimizeResult(
x=np.asarray(sol.value),
fun=cost,
success=True,
message="Optimization converged",
nit=int(sol.stats.get("num_steps", 0)) if hasattr(sol, "stats") else 0,
)
except Exception as e:
residuals = fun(x0)
cost = float(0.5 * np.sum(np.asarray(residuals) ** 2))
return OptimizeResult(
x=np.asarray(x0),
fun=cost,
success=False,
message=str(e),
)
def _least_squares_scipy(
fun: Callable,
x0: ArrayLike,
bounds: tuple,
method: str,
ftol: float,
xtol: float,
gtol: float,
max_nfev: int,
**kwargs,
) -> OptimizeResult:
"""Fallback to scipy.optimize.least_squares."""
from scipy.optimize import least_squares as scipy_least_squares
result = scipy_least_squares(
fun,
x0,
bounds=bounds,
method=method,
ftol=ftol,
xtol=xtol,
gtol=gtol,
max_nfev=max_nfev,
**kwargs,
)
return OptimizeResult(
x=result.x,
fun=result.cost,
success=result.success,
message=result.message,
nfev=result.nfev,
nit=result.njev if hasattr(result, "njev") else 0,
jac=result.jac if hasattr(result, "jac") else None,
)
[docs]
def root(
fun: Callable,
x0: ArrayLike,
args: tuple = (),
method: str = "hybr",
jac: Callable | bool | None = None,
tol: float | None = None,
options: dict | None = None,
) -> OptimizeResult:
"""Find roots of a function using optimistix.
Parameters
----------
fun : callable
Function returning residuals
x0 : array-like
Initial guess
args : tuple
Extra arguments
method : str
Method (ignored, always uses Newton)
jac : callable, optional
Jacobian function
tol : float, optional
Tolerance
options : dict, optional
Solver options
Returns
-------
OptimizeResult
Root finding result
"""
tol = tol or 1e-8
options = options or {}
maxiter = options.get("maxiter", 1000)
try:
return _root_optimistix(fun, x0, args, tol, maxiter)
except ImportError:
return _root_scipy(fun, x0, args, method, jac, tol, options)
def _root_optimistix(
fun: Callable,
x0: ArrayLike,
args: tuple,
tol: float,
maxiter: int,
) -> OptimizeResult:
"""Root finding using optimistix."""
import jax.numpy as jnp
import optimistix as optx
x0 = jnp.asarray(x0)
def residual_fn(x, _):
if args:
return fun(x, *args)
return fun(x)
solver: Any = optx.Newton(rtol=tol, atol=tol)
try:
sol: Any = optx.root_find(
residual_fn,
solver,
x0,
max_steps=maxiter,
)
return OptimizeResult(
x=np.asarray(sol.value),
fun=float(jnp.sum(residual_fn(sol.value, None) ** 2)),
success=True,
message="Root finding converged",
nit=int(sol.stats.get("num_steps", 0)) if hasattr(sol, "stats") else 0,
)
except Exception as e:
return OptimizeResult(
x=np.asarray(x0),
fun=float(np.sum(np.asarray(fun(x0, *args) if args else fun(x0)) ** 2)),
success=False,
message=str(e),
)
def _root_scipy(
fun: Callable,
x0: ArrayLike,
args: tuple,
method: str,
jac: Callable | bool | None,
tol: float,
options: dict,
) -> OptimizeResult:
"""Fallback to scipy.optimize.root."""
from scipy.optimize import root as scipy_root
result = scipy_root(
fun,
x0,
args=args,
method=method,
jac=jac,
tol=tol,
options=options,
)
return OptimizeResult(
x=result.x,
fun=float(np.sum(result.fun**2)) if result.fun is not None else 0.0,
success=result.success,
message=result.message,
nfev=result.nfev if hasattr(result, "nfev") else 0,
)