import time
import tracemalloc
import jax
import jax.numpy as jnp
import numpy as np
from tqdm.auto import tqdm
from optymus.methods.utils import BaseOptimizer
from optymus.methods.utils._result import OptimizeResult
class ParticleSwarmOptimization(BaseOptimizer):
def optimize(self, divergence, bounds, alpha, w, c1, c2, n_particles):
start_time = time.time()
tracemalloc.start() # Start memory tracking
num_iter = 0
dimensions = len(bounds)
lb, ub = jnp.array(bounds).T
if n_particles is None:
n_particles = max(30, int(10 * np.sqrt(dimensions)))
# Initialize random number generator keys
key = jax.random.PRNGKey(42)
# Initialize particles' positions and velocities randomly within bounds
def generate_particles(key):
key_pos, key_vel = jax.random.split(key)
X = jax.random.uniform(key_pos, shape=(n_particles, dimensions), minval=lb, maxval=ub)
# More reasonable velocity initialization
V = jax.random.uniform(
key_vel, shape=(n_particles, dimensions), minval=-jnp.abs(ub - lb), maxval=jnp.abs(ub - lb)
)
return X, V
X, V = generate_particles(key)
path = [[] for _ in range(self.max_iter)]
path_gbest = []
particle_diversity = []
velocity_diversity = []
# Initialize particle best and global best
pbest = X.copy()
pbest_val = jnp.array([self.penalized_obj(x) for x in X])
gbest = pbest[jnp.argmin(pbest_val)].copy()
gbest_val = jnp.min(pbest_val)
f_history = [float(gbest_val)]
progres_bar = tqdm(range(self.max_iter), desc="Particle Swarm Optimization", disable=not self.verbose)
for i in progres_bar:
# Generate new keys for each iteration
key, key_r1, key_r2 = jax.random.split(key, 3)
# Generate r1 and r2
r1 = jax.random.uniform(key_r1, shape=(n_particles, dimensions))
r2 = jax.random.uniform(key_r2, shape=(n_particles, dimensions))
# Compute divergence term
if divergence == "baseline":
divergence_term = jnp.zeros_like(X)
else:
# Normalize particle positions and global best to a probability distribution
X_prob = jax.nn.softmax(pbest, axis=1) # Softmax over dimensions for each particle
gbest_prob = jax.nn.softmax(gbest) # Softmax for the global best
div = divergence(X_prob, gbest_prob)
divergence_term = alpha * div[:, None] * jnp.sign(X - gbest)
# Update velocities and positions
V = w * V + c1 * r1 * (pbest - X) + c2 * r2 * (gbest - X) - divergence_term
X = X + V
# Ensure particles stay within bounds
X = jnp.clip(X, lb, ub)
# Update individual best and global best
current_val = jnp.array([self.penalized_obj(x) for x in X])
improved = current_val < pbest_val
# Update individual best values
pbest = pbest.at[improved].set(X[improved])
pbest_val = pbest_val.at[improved].set(current_val[improved])
# Update global best values
new_gbest_val = jnp.min(pbest_val)
if new_gbest_val < gbest_val:
gbest = pbest[jnp.argmin(pbest_val)].copy()
gbest_val = new_gbest_val
# Store the updated particles positions
path[i].append(X.copy())
path_gbest.append(gbest.copy())
p_diversity = jnp.linalg.norm(jnp.std(X, axis=0))
v_diversity = jnp.linalg.norm(jnp.std(V, axis=0))
particle_diversity.append(p_diversity)
velocity_diversity.append(v_diversity)
f_history.append(float(gbest_val))
num_iter += 1
end_time = time.time()
elapsed_time = end_time - start_time
_, peak = tracemalloc.get_traced_memory()
tracemalloc.stop() # Stop memory tracking
return OptimizeResult({
"method_name": "Particle Swarm Optimization"
if not self.f_cons
else "Particle Swarm Optimization with Penalty",
"x0": jnp.mean(X, axis=0),
"xopt": gbest,
"fmin": gbest_val,
"num_iter": num_iter,
"path_particles": path,
"path": jnp.array(path_gbest),
"f_history": jnp.array(f_history),
"termination_reason": "max_iter_reached",
"time": elapsed_time,
"memory_peak": peak / 1e6,
"particle_diversity": particle_diversity,
"velocity_diversity": velocity_diversity,
})
[docs]
def particle_swarm(
divergence="baseline",
bounds=[(-5, 5), (-5, 5)], # noqa
alpha=0.5,
w=0.1,
c1=0.25,
c2=2,
n_particles=30,
**kwargs,
):
"""Particle Swarm Optimization algorithm."""
optimizer = ParticleSwarmOptimization(**kwargs)
result = optimizer.optimize(divergence, bounds, alpha, w, c1, c2, n_particles)
return result