diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..b46e77c Binary files /dev/null and b/.DS_Store differ diff --git a/.libero/config.yaml b/.libero/config.yaml new file mode 100644 index 0000000..4fa7ba6 --- /dev/null +++ b/.libero/config.yaml @@ -0,0 +1,5 @@ +assets: /home/xsuper/dsrl_pi0/LIBERO/libero/libero/assets +bddl_files: /home/xsuper/dsrl_pi0/LIBERO/libero/libero/bddl_files +benchmark_root: /home/xsuper/dsrl_pi0/LIBERO/libero/libero +datasets: /home/xsuper/dsrl_pi0/LIBERO/libero/datasets +init_states: /home/xsuper/dsrl_pi0/LIBERO/libero/libero/init_files diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..4b5a294 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "python-envs.defaultEnvManager": "ms-python.python:conda", + "python-envs.defaultPackageManager": "ms-python.python:conda" +} \ No newline at end of file diff --git a/examples/scripts/run_libero.sh b/examples/scripts/run_libero.sh index 7e0f2b5..3bb3bd2 100644 --- a/examples/scripts/run_libero.sh +++ b/examples/scripts/run_libero.sh @@ -1,16 +1,21 @@ #!/bin/bash proj_name=DSRL_pi0_Libero device_id=0 - +#fff export DISPLAY=:0 export MUJOCO_GL=egl export PYOPENGL_PLATFORM=egl export MUJOCO_EGL_DEVICE_ID=$device_id - export OPENPI_DATA_HOME=./openpi export EXP=./logs/$proj_name; export CUDA_VISIBLE_DEVICES=$device_id export XLA_PYTHON_CLIENT_PREALLOCATE=false +export PIP_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple +export PYTHONPATH=$PWD:$PWD/LIBERO:$PWD/openpi/src:$PYTHONPATH +export LIBERO_CONFIG_PATH=$PWD/.libero +export NUMBA_DISABLE_JIT=1 +export MPLCONFIGDIR=$PWD/.cache/matplotlib +mkdir -p "$MPLCONFIGDIR" pip install mujoco==3.3.1 @@ -31,4 +36,4 @@ python3 examples/launch_train_sim.py \ --resize_image 64 \ --action_magnitude 1.0 \ --query_freq 20 \ ---hidden_dims 128 \ \ No newline at end of file +--hidden_dims 128 \ diff --git a/examples/train_sim.py b/examples/train_sim.py index a9d99c9..a992137 100755 --- a/examples/train_sim.py +++ b/examples/train_sim.py @@ -1,9 +1,11 @@ #! /usr/bin/env python import os -# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs from https://github.com/huggingface/gym-aloha/tree/main?tab=readme-ov-file#-gpu-rendering-egl -xla_flags = os.environ.get('XLA_FLAGS', '') -xla_flags += ' --xla_gpu_triton_gemm_any=True' -os.environ['XLA_FLAGS'] = xla_flags +# Triton GEMM can improve throughput on some GPUs, but older JAX/XLA builds may +# crash while compiling half-precision kernels on newer architectures. +if os.environ.get('DSRL_ENABLE_TRITON_GEMM', '0') == '1': + xla_flags = os.environ.get('XLA_FLAGS', '') + xla_flags += ' --xla_gpu_triton_gemm_any=True' + os.environ['XLA_FLAGS'] = xla_flags import pathlib, copy @@ -161,4 +163,4 @@ def main(variant): replay_buffer = online_replay_buffer replay_buffer.seed(variant.seed) trajwise_alternating_training_loop(variant, agent, env, eval_env, online_replay_buffer, replay_buffer, wandb_logger, shard_fn=shard_fn, agent_dp=agent_dp) - \ No newline at end of file +