Skip to content
This repository was archived by the owner on Nov 6, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions benchmarks/mppi_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Benchmarks for MPPI."""

# pylint: disable=invalid-name
import numpy as np

import google_benchmark as benchmark
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from trajax import optimizers
from benchmarks import util


@jax.jit
def cartpole(state, action, timestep, params=(10.0, 1.0, 0.5)):
"""Classic cartpole system.

Args:
state: state, (4, ) array
action: control, (1, ) array
timestep: scalar time
params: tuple of (MASS_CART, MASS_POLE, LENGTH_POLE)

Returns:
xdot: state time derivative, (4, )
"""
del timestep # Unused

mc, mp, l = params
g = 9.81

q = state[0:2]
qd = state[2:]
s = jnp.sin(q[1])
c = jnp.cos(q[1])

H = jnp.array([[mc + mp, mp * l * c], [mp * l * c, mp * l * l]])
C = jnp.array([[0.0, -mp * qd[1] * l * s], [0.0, 0.0]])

G = jnp.array([[0.0], [mp * g * l * s]])
B = jnp.array([[1.0], [0.0]])

CqdG = jnp.dot(C, jnp.expand_dims(qd, 1)) + G
f = jnp.concatenate(
(qd, jnp.squeeze(-jsp.linalg.solve(H, CqdG, assume_a='pos'))))

v = jnp.squeeze(jsp.linalg.solve(H, B, assume_a='pos'))
g = jnp.concatenate((jnp.zeros(2), v))
xdot = f + g * action

return xdot


def cartpole_mppi_benchmark_setup():
"""Cartpole MPPI benchmark."""

def angle_wrap(th):
return (th) % (2 * jnp.pi)

def state_wrap(s):
return jnp.array([s[0], angle_wrap(s[1]), s[2], s[3]])

def squish(u):
return 5 * jnp.tanh(u)

horizon = 50
dt = 0.1
eq_point = jnp.array([0, jnp.pi, 0, 0])

def cost(x, u, t):
err = state_wrap(x - eq_point)
stage_cost = 0.1 * jnp.dot(err, err) + 0.01 * jnp.dot(u, u)
final_cost = 1000 * jnp.dot(err, err)
return jnp.where(t == horizon, final_cost, stage_cost)

def dynamics(x, u, t):
return x + dt * cartpole(x, squish(u), t)

x0 = jnp.array([0.0, 0.2, 0.0, -0.1])
def bench(x0):
_, U, _, = optimizers.mppi(
cost,
dynamics,
x0,
jnp.zeros((horizon, 1)),
np.array([-5.0]), np.array([5.0]),
)
return U

return bench, (x0,)


# Workaround: hold refs to benchmark-registered functions, as bindings assume
# these exist during exit cleanup
benchmarks = (util.register_jit_benchmark('cartpole_mppi_benchmark',
cartpole_mppi_benchmark_setup),)


if __name__ == '__main__':
benchmark.main()
145 changes: 144 additions & 1 deletion tests/optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,5 +819,148 @@ def control_limits(u):
np.linalg.norm(X[-1] - goal, ord=np.inf), constraints_threshold)


def testRandomShooting1(self):
"""
test_CEM1
Description:
Attempts to use the Cross Entropy Method to solve the acrobot problem from "testAcrobotSolve"
"""

T = 50
goal = np.array([np.pi, 0.0, 0.0, 0.0])
dynamics = euler(acrobot, dt=0.1)

def cost(x, u, t, params):
delta = x - goal
terminal_cost = 0.5 * params[0] * np.dot(delta, delta)
stagewise_cost = 0.5 * params[1] * np.dot(
delta, delta) + 0.5 * params[2] * np.dot(u, u)
return np.where(t == T, terminal_cost, stagewise_cost)

x0 = np.zeros(4)
U = np.zeros((T, 1))
params = np.array([1000.0, 0.1, 0.01])
true_obj = 4959.476212
self.assertLess(
np.abs(
optimizers.objective(
functools.partial(cost, params=params), dynamics, U, x0) -
true_obj), 1e-6)

# optimal_obj = 51.0
cem_hyperparams = frozendict({
'sampling_smoothing': 0.2,
'evolution_smoothing': 0.1,
'elite_portion': 0.1,
'max_iter': 100,
'num_samples': 20_000
})
X_opt, U_opt, obj = optimizers.random_shooting(
functools.partial(cost, params=params),
dynamics,
x0,
U,
np.array([-10.0]), np.array([10.0]),
hyperparams=cem_hyperparams,
)
self.assertAlmostEqual(obj, true_obj, places=4)


def testCEM1(self):
"""
test_CEM1
Description:
Attempts to use the Cross Entropy Method to solve the acrobot problem from "testAcrobotSolve"
"""

T = 50
goal = np.array([np.pi, 0.0, 0.0, 0.0])
dynamics = euler(acrobot, dt=0.1)

def cost(x, u, t, params):
delta = x - goal
terminal_cost = 0.5 * params[0] * np.dot(delta, delta)
stagewise_cost = 0.5 * params[1] * np.dot(
delta, delta) + 0.5 * params[2] * np.dot(u, u)
return np.where(t == T, terminal_cost, stagewise_cost)

x0 = np.zeros(4)
U = np.zeros((T, 1))
params = np.array([1000.0, 0.1, 0.01])
zero_input_obj = 4959.476212
self.assertLess(
np.abs(
optimizers.objective(
functools.partial(cost, params=params), dynamics, U, x0) -
zero_input_obj), 1e-6)

optimal_obj = 51.0
cem_hyperparams = frozendict({
'sampling_smoothing': 0.2,
'evolution_smoothing': 0.1,
'elite_portion': 0.1,
'max_iter': 500,
'num_samples': 20_000
})
X_opt, U_opt, obj = optimizers.cem(
functools.partial(cost, params=params),
dynamics,
x0,
U,
np.array([-5.0]), np.array([5.0]),
hyperparams=cem_hyperparams,
)
self.assertLessEqual(obj, zero_input_obj)
self.assertLessEqual(obj, 10*optimal_obj)
# Objective is Around 171


def testMPPI(self):
"""
test_MPPI1
Description:
Attempts to use Model Predictive Integral Control to solve the acrobot problem from "testAcrobotSolve"
"""

T = 50
goal = np.array([np.pi, 0.0, 0.0, 0.0])
dynamics = euler(acrobot, dt=0.1)

def cost(x, u, t, params):
delta = x - goal
terminal_cost = 0.5 * params[0] * np.dot(delta, delta)
stagewise_cost = 0.5 * params[1] * np.dot(
delta, delta) + 0.5 * params[2] * np.dot(u, u)
return np.where(t == T, terminal_cost, stagewise_cost)

x0 = np.zeros(4)
U = np.zeros((T, 1))
params = np.array([1000.0, 0.1, 0.01])
zero_input_obj = 4959.476212
self.assertLess(
np.abs(
optimizers.objective(
functools.partial(cost, params=params), dynamics, U, x0) -
zero_input_obj), 1e-6)

optimal_obj = 51.0
mppi_hyperparams = frozendict({
'noise_stdev': 0.1,
'lambda': 0.1,
'max_iter': 500,
'num_samples': 20_000
})
X_opt, U_opt, obj = optimizers.mppi(
functools.partial(cost, params=params),
dynamics,
x0,
U,
np.array([-5.0]), np.array([5.0]),
hyperparams=mppi_hyperparams,
)
self.assertLessEqual(obj, zero_input_obj)
self.assertLessEqual(obj, 2*optimal_obj)
# Objective is Around 59

if __name__ == '__main__':
absltest.main()
absltest.main()
Loading