diff --git a/benchmarks/mppi_benchmark.py b/benchmarks/mppi_benchmark.py new file mode 100644 index 0000000..5aeb13d --- /dev/null +++ b/benchmarks/mppi_benchmark.py @@ -0,0 +1,114 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarks for MPPI.""" + +# pylint: disable=invalid-name +import numpy as np + +import google_benchmark as benchmark +import jax +import jax.numpy as jnp +import jax.scipy as jsp +from trajax import optimizers +from benchmarks import util + + +@jax.jit +def cartpole(state, action, timestep, params=(10.0, 1.0, 0.5)): + """Classic cartpole system. + + Args: + state: state, (4, ) array + action: control, (1, ) array + timestep: scalar time + params: tuple of (MASS_CART, MASS_POLE, LENGTH_POLE) + + Returns: + xdot: state time derivative, (4, ) + """ + del timestep # Unused + + mc, mp, l = params + g = 9.81 + + q = state[0:2] + qd = state[2:] + s = jnp.sin(q[1]) + c = jnp.cos(q[1]) + + H = jnp.array([[mc + mp, mp * l * c], [mp * l * c, mp * l * l]]) + C = jnp.array([[0.0, -mp * qd[1] * l * s], [0.0, 0.0]]) + + G = jnp.array([[0.0], [mp * g * l * s]]) + B = jnp.array([[1.0], [0.0]]) + + CqdG = jnp.dot(C, jnp.expand_dims(qd, 1)) + G + f = jnp.concatenate( + (qd, jnp.squeeze(-jsp.linalg.solve(H, CqdG, assume_a='pos')))) + + v = jnp.squeeze(jsp.linalg.solve(H, B, assume_a='pos')) + g = jnp.concatenate((jnp.zeros(2), v)) + xdot = f + g * action + + return xdot + + +def cartpole_mppi_benchmark_setup(): + """Cartpole MPPI benchmark.""" + + def angle_wrap(th): + return (th) % (2 * jnp.pi) + + def state_wrap(s): + return jnp.array([s[0], angle_wrap(s[1]), s[2], s[3]]) + + def squish(u): + return 5 * jnp.tanh(u) + + horizon = 50 + dt = 0.1 + eq_point = jnp.array([0, jnp.pi, 0, 0]) + + def cost(x, u, t): + err = state_wrap(x - eq_point) + stage_cost = 0.1 * jnp.dot(err, err) + 0.01 * jnp.dot(u, u) + final_cost = 1000 * jnp.dot(err, err) + return jnp.where(t == horizon, final_cost, stage_cost) + + def dynamics(x, u, t): + return x + dt * cartpole(x, squish(u), t) + + x0 = jnp.array([0.0, 0.2, 0.0, -0.1]) + def bench(x0): + _, U, _, = optimizers.mppi( + cost, + dynamics, + x0, + jnp.zeros((horizon, 1)), + np.array([-5.0]), np.array([5.0]), + ) + return U + + return bench, (x0,) + + +# Workaround: hold refs to benchmark-registered functions, as bindings assume +# these exist during exit cleanup +benchmarks = (util.register_jit_benchmark('cartpole_mppi_benchmark', + cartpole_mppi_benchmark_setup),) + + +if __name__ == '__main__': + benchmark.main() \ No newline at end of file diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py index ed3f621..c316fea 100644 --- a/tests/optimizers_test.py +++ b/tests/optimizers_test.py @@ -819,5 +819,148 @@ def control_limits(u): np.linalg.norm(X[-1] - goal, ord=np.inf), constraints_threshold) + def testRandomShooting1(self): + """ + test_CEM1 + Description: + Attempts to use the Cross Entropy Method to solve the acrobot problem from "testAcrobotSolve" + """ + + T = 50 + goal = np.array([np.pi, 0.0, 0.0, 0.0]) + dynamics = euler(acrobot, dt=0.1) + + def cost(x, u, t, params): + delta = x - goal + terminal_cost = 0.5 * params[0] * np.dot(delta, delta) + stagewise_cost = 0.5 * params[1] * np.dot( + delta, delta) + 0.5 * params[2] * np.dot(u, u) + return np.where(t == T, terminal_cost, stagewise_cost) + + x0 = np.zeros(4) + U = np.zeros((T, 1)) + params = np.array([1000.0, 0.1, 0.01]) + true_obj = 4959.476212 + self.assertLess( + np.abs( + optimizers.objective( + functools.partial(cost, params=params), dynamics, U, x0) - + true_obj), 1e-6) + + # optimal_obj = 51.0 + cem_hyperparams = frozendict({ + 'sampling_smoothing': 0.2, + 'evolution_smoothing': 0.1, + 'elite_portion': 0.1, + 'max_iter': 100, + 'num_samples': 20_000 + }) + X_opt, U_opt, obj = optimizers.random_shooting( + functools.partial(cost, params=params), + dynamics, + x0, + U, + np.array([-10.0]), np.array([10.0]), + hyperparams=cem_hyperparams, + ) + self.assertAlmostEqual(obj, true_obj, places=4) + + + def testCEM1(self): + """ + test_CEM1 + Description: + Attempts to use the Cross Entropy Method to solve the acrobot problem from "testAcrobotSolve" + """ + + T = 50 + goal = np.array([np.pi, 0.0, 0.0, 0.0]) + dynamics = euler(acrobot, dt=0.1) + + def cost(x, u, t, params): + delta = x - goal + terminal_cost = 0.5 * params[0] * np.dot(delta, delta) + stagewise_cost = 0.5 * params[1] * np.dot( + delta, delta) + 0.5 * params[2] * np.dot(u, u) + return np.where(t == T, terminal_cost, stagewise_cost) + + x0 = np.zeros(4) + U = np.zeros((T, 1)) + params = np.array([1000.0, 0.1, 0.01]) + zero_input_obj = 4959.476212 + self.assertLess( + np.abs( + optimizers.objective( + functools.partial(cost, params=params), dynamics, U, x0) - + zero_input_obj), 1e-6) + + optimal_obj = 51.0 + cem_hyperparams = frozendict({ + 'sampling_smoothing': 0.2, + 'evolution_smoothing': 0.1, + 'elite_portion': 0.1, + 'max_iter': 500, + 'num_samples': 20_000 + }) + X_opt, U_opt, obj = optimizers.cem( + functools.partial(cost, params=params), + dynamics, + x0, + U, + np.array([-5.0]), np.array([5.0]), + hyperparams=cem_hyperparams, + ) + self.assertLessEqual(obj, zero_input_obj) + self.assertLessEqual(obj, 10*optimal_obj) + # Objective is Around 171 + + + def testMPPI(self): + """ + test_MPPI1 + Description: + Attempts to use Model Predictive Integral Control to solve the acrobot problem from "testAcrobotSolve" + """ + + T = 50 + goal = np.array([np.pi, 0.0, 0.0, 0.0]) + dynamics = euler(acrobot, dt=0.1) + + def cost(x, u, t, params): + delta = x - goal + terminal_cost = 0.5 * params[0] * np.dot(delta, delta) + stagewise_cost = 0.5 * params[1] * np.dot( + delta, delta) + 0.5 * params[2] * np.dot(u, u) + return np.where(t == T, terminal_cost, stagewise_cost) + + x0 = np.zeros(4) + U = np.zeros((T, 1)) + params = np.array([1000.0, 0.1, 0.01]) + zero_input_obj = 4959.476212 + self.assertLess( + np.abs( + optimizers.objective( + functools.partial(cost, params=params), dynamics, U, x0) - + zero_input_obj), 1e-6) + + optimal_obj = 51.0 + mppi_hyperparams = frozendict({ + 'noise_stdev': 0.1, + 'lambda': 0.1, + 'max_iter': 500, + 'num_samples': 20_000 + }) + X_opt, U_opt, obj = optimizers.mppi( + functools.partial(cost, params=params), + dynamics, + x0, + U, + np.array([-5.0]), np.array([5.0]), + hyperparams=mppi_hyperparams, + ) + self.assertLessEqual(obj, zero_input_obj) + self.assertLessEqual(obj, 2*optimal_obj) + # Objective is Around 59 + if __name__ == '__main__': - absltest.main() + absltest.main() \ No newline at end of file diff --git a/trajax/optimizers.py b/trajax/optimizers.py index 23fc40a..6a9a61e 100644 --- a/trajax/optimizers.py +++ b/trajax/optimizers.py @@ -48,6 +48,7 @@ from __future__ import print_function from functools import partial # pylint: disable=g-importing-member +from frozendict import frozendict import jax from jax import custom_derivatives @@ -985,13 +986,13 @@ def hess_vec_prod(u, v): def default_cem_hyperparams(): - return { + return frozendict({ 'sampling_smoothing': 0., 'evolution_smoothing': 0.1, 'elite_portion': 0.1, 'max_iter': 10, 'num_samples': 400 - } + }) @partial(jit, static_argnums=(4,)) @@ -1051,7 +1052,7 @@ def body_fun(t, noises): return samples -@partial(jit, static_argnums=(0, 1)) +@partial(jit, static_argnums=(0, 1, 6, 7)) def cem(cost, dynamics, init_state, @@ -1108,7 +1109,7 @@ def loop_body(_, args): return mean, stdev, random_key # TODO(sindhwani): swap with lax.scan to make this optimizer differentiable. - mean, stdev, random_key = lax.fori_loop(0, hyperparams['max_iter'], loop_body, + mean, stdev, random_key = lax.fori_loop(0, hyperparams["max_iter"], loop_body, (mean, stdev, random_key)) X = rollout(dynamics, mean, init_state) @@ -1116,7 +1117,129 @@ def loop_body(_, args): return X, mean, obj -@partial(jit, static_argnums=(0, 1)) + +# Sampling based Zeroth Order Optimization via Model Predictive Path Integral (MPPI) + +def default_mppi_hyperparams(): + return frozendict({ + 'noise_stdev': 0., + 'lambda': 0.1, + 'max_iter': 10, + 'num_samples': 400 + }) + + +@partial(jit, static_argnums=(2,)) +def noise_samples(random_key, controls, hyperparams): + """Samples a batch of noise sequences based on Gaussian distribution. + + Args: + random_key: a jax.random.PRNGKey() random seed + controls: control sequence, has dimension (horizion, dim_control). + hyperparams: dictionary of hyperparameters with following keys: + num_samples-- number of noise sequences to sample + + Returns: + Array of sampled controls, with dimension (num_samples, horizon, + dim_control). + """ + num_samples = hyperparams['num_samples'] + horizon = controls.shape[0] + dim_control = controls.shape[1] + noises = jax.random.normal(random_key, shape=(num_samples, horizon, dim_control)) + return noises + + +@partial(jit, static_argnums=(3,)) +def mppi_update(old_controls, noise_seq, costs, hyperparams): + lam = hyperparams['lambda'] + + # importance sampling + beta = np.min(costs) + eta = np.sum(np.exp(- 1. / lam * (costs - beta)), axis=0) + 1e-10 # numerical stability + weights = np.exp(- 1. / lam * (costs - beta)) / eta + + # update inputs + controls = old_controls + np.sum(weights[:, np.newaxis, np.newaxis] * noise_seq, axis=0) + return controls + + +@jit +def clip_controls(controls, control_low, control_high): + control_low = jax.lax.broadcast(control_low, controls.shape[:-1]) + control_high = jax.lax.broadcast(control_high, controls.shape[:-1]) + controls = np.clip(controls, control_low, control_high) + return controls + + +@partial(jit, static_argnums=(0, 1, 6, 7)) +def mppi(cost, + dynamics, + init_state, + init_controls, + control_low, + control_high, + random_key=None, + hyperparams=None): + """Model Predictive Path Integral (MPPI) Control. + + MPPI is a sampling-based optimization algorithm. At each iteration, MPPI samples + a batch of Gaussian noise sequences and perturbs actions with it. It then uses + importance sampling to weight control inputs based on their cost and compute the + actions, which are used as initial solution in the next iteration. + + The implementation follows Algorithm 2 in https://ieeexplore.ieee.org/document/7989202 + allowing also for approximate models s.a. neural network dynamic models. + + Args: + cost: cost(x, u, t) returns a scalar + dynamics: dynamics(x, u, t) returns next state + init_state: initial state + init_controls: initial controls, of the shape (horizon, dim_control) + control_low: lower bound of control space + control_high: upper bound of control space + random_key: jax.random.PRNGKey() that serves as a random seed + hyperparams: a dictionary of algorithm hyperparameters with following keys + noise_stdev -- standard deviation of the zero mean Gaussian from which the noise sequence is sampled. + lambda -- temperature parameter in MPPI defining the free energy of the control system + max_iter -- maximum number of iterations + num_samples -- number of action sequences + + Returns: + X: Optimal state trajectory. + U: Optimized control sequence, an array of shape (horizon, dim_control) + obj: scalar objective achieved. + """ + if random_key is None: + random_key = random.PRNGKey(0) + if hyperparams is None: + hyperparams = default_mppi_hyperparams() + U0 = np.array(init_controls) + obj_fn = partial(objective, cost, dynamics) + + def body_fn(carry, iter_num): + initial_controls, random_key = carry + random_key, rng = random.split(random_key) + + # get random noise sequence and evaluate costs + noise_seq = noise_samples(rng, initial_controls, hyperparams) + batch_controls = clip_controls(initial_controls + noise_seq, control_low, control_high) + costs = vmap(obj_fn, in_axes=(0, None))(batch_controls, init_state) + + # update the solution using MPPI + updated_controls = mppi_update(initial_controls, noise_seq, costs, hyperparams) + updated_controls = clip_controls(updated_controls, control_low, control_high) + return (updated_controls, random_key), None + + (controls, random_key), _ = lax.scan(body_fn, (U0, random_key), np.arange(hyperparams['max_iter'])) + + X = rollout(dynamics, controls, init_state) + obj = objective(cost, dynamics, controls, init_state) + return X, controls, obj + + + +@partial(jit, static_argnums=(0, 1, 6, 7)) def random_shooting(cost, dynamics, init_state, @@ -1164,8 +1287,8 @@ def random_shooting(cost, best_idx = np.argmin(costs) U = controls[best_idx] - X = rollout(dynamics, mean, init_state) - obj = objective(cost, dynamics, mean, init_state) + X = rollout(dynamics, U, init_state) + obj = objective(cost, dynamics, U, init_state) return X, U, obj @@ -1380,4 +1503,4 @@ def continuation_criteria(inputs): continuation_criteria, body, (X, U, dual_equality, dual_inequality, penalty, equality_constraints, inequality_constraints, np.inf, np.inf, - np.full(U.shape, np.inf), 0, 0)) + np.full(U.shape, np.inf), 0, 0)) \ No newline at end of file