diff --git a/docker/README.md b/docker/README.md index e2e354bb8..e9eab249b 100644 --- a/docker/README.md +++ b/docker/README.md @@ -7,7 +7,7 @@ _tl;dr [dockerhub url](https://hub.docker.com/r/justheuristic/practical_rl/)_ We recommend you to use either native docker (recommended for linux) or kitematic(recommended for windows). * Installing [kitematic](https://kitematic.com/), a simple interface to docker (all platforms) * Pure docker: Guide for [windows](https://docs.docker.com/docker-for-windows/), [linux](https://docs.docker.com/engine/installation/), or [macOS](https://docs.docker.com/docker-for-mac/). -* If you want to use your GPU make sure you have [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) and [NVidia driver](https://www.nvidia.com/en-us/drivers/unix/) + [CUDA 10.2](https://developer.nvidia.com/cuda-downloads) installed +* If you want to use your GPU make sure you have [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia-container-toolkit) and [NVidia driver](https://www.nvidia.com/en-us/drivers/unix/) + [CUDA 10.2](https://developer.nvidia.com/cuda-downloads) installed Below are the instructions for both approaches. diff --git a/setup_colab.sh b/setup_colab.sh index 2957520c3..c9b812bf9 100755 --- a/setup_colab.sh +++ b/setup_colab.sh @@ -6,7 +6,7 @@ wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/x # Download & import Atari ROMs (Colab stopped bundling them around the beginning of June 2021) -gdown -q https://drive.google.com/uc?id=1dCLEJcJGDDV4l5ssoexP2TEOVuBfyh7D +gdown -q https://drive.google.com/uc?id=1c6_W2Fig92hm5FRIc2Mpc_ZZyr6o52lF # Alternative download: # wget -q http://www.atarimania.com/roms/Roms.rar diff --git a/week01_intro/README.md b/week01_intro/README.md index 6120dbe34..0c81dc013 100644 --- a/week01_intro/README.md +++ b/week01_intro/README.md @@ -11,7 +11,7 @@ ## More materials: -* __[recommended]__ - awesome openai post about evolution strategies - [blog post](https://blog.openai.com/evolution-strategies/), [article](https://arxiv.org/abs/1703.03864) +* __[recommended]__ - awesome openai post about evolution strategies - [blog post](https://openai.com/research/evolution-strategies), [article](https://arxiv.org/abs/1703.03864) * __[recommended]__ - formal explanation of crossentropy method in [general](https://people.smp.uq.edu.au/DirkKroese/ps/CEEncycl.pdf) and for [optimization](https://people.smp.uq.edu.au/DirkKroese/ps/CEopt.pdf) * Deep learning course (if you want to learn in parallel) - https://github.com/yandexdataschool/HSE_deeplearning * Video on genetic algorithms (english) - [video](https://www.youtube.com/watch?v=ejxfTy4lI6I) @@ -23,10 +23,10 @@ ## Practice assignment -Instant dive in: [__seminar_gym_interface__](https://colab.research.google.com/github/yandexdataschool/Practical_RL/blob/master/week01_intro/seminar_gym_interface.ipynb), [__crossentropy_method__](https://colab.research.google.com/github/yandexdataschool/Practical_RL/blob/master/week01_intro/crossentropy_method.ipynb), +Instant dive in: [__seminar_gymnasium_interface__](https://colab.research.google.com/github/yandexdataschool/Practical_RL/blob/master/week01_intro/seminar_gymnasium_interface.ipynb), [__crossentropy_method__](https://colab.research.google.com/github/yandexdataschool/Practical_RL/blob/master/week01_intro/crossentropy_method.ipynb), [__deep_crossentropy_method__](https://colab.research.google.com/github/yandexdataschool/Practical_RL/blob/master/week01_intro/deep_crossentropy_method.ipynb) -* Open `gym_interface.ipynb` and follow instructions from there +* Open `seminar_gymnasium_interface.ipynb` and follow instructions from there * After you're done there, proceed to `crossentropy_method.ipynb` * You can find homework and bonus assignment descriptions at the end of that notebook. -* Note: so far it's enough to say `pip install gym` on top of any data-science-stuffed python, but we'd appreciate if you gradually switch to [full installation](https://github.com/openai/gym#installing-everything). +* Note: so far it's enough to say `pip install gymnasium` on top of any data-science-stuffed python, but we'd appreciate if you gradually switch to [full installation](https://github.com/Farama-Foundation/Gymnasium). diff --git a/week01_intro/crossentropy_method.ipynb b/week01_intro/crossentropy_method.ipynb index 64f0da89c..40ae90548 100644 --- a/week01_intro/crossentropy_method.ipynb +++ b/week01_intro/crossentropy_method.ipynb @@ -1,422 +1,508 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Crossentropy method\n", - "\n", - "This notebook will teach you to solve reinforcement learning problems with crossentropy method. We'll follow-up by scaling everything up and using neural network policy." - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "7XGyc-FCG35I" + }, + "source": [ + "# Crossentropy method\n", + "\n", + "This notebook will teach you to solve reinforcement learning problems with crossentropy method. We'll follow-up by scaling everything up and using neural network policy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2jwz8moTG35K" + }, + "outputs": [], + "source": [ + "import sys, os\n", + "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", + " !touch .setup_complete\n", + "\n", + "# This code creates a virtual display to draw game images on.\n", + "# It will have no effect if your machine has a monitor.\n", + "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", + " !bash ../xvfb start\n", + " os.environ['DISPLAY'] = ':1'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oe7EKolvLC67" + }, + "outputs": [], + "source": [ + "!pip install gymnasium[toy_text]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ltjzx5AFG35K" + }, + "outputs": [], + "source": [ + "import gymnasium as gym\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "env = gym.make(\"Taxi-v3\", render_mode=\"rgb_array\")\n", + "print(env.reset(seed=0))\n", + "plt.imshow(env.render())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "O6waF01eG35L", + "outputId": "8ca46444-7a7f-4091-c50c-0cab96a995e9" + }, + "outputs": [], + "source": [ + "n_states = env.observation_space.n\n", + "n_actions = env.action_space.n\n", + "\n", + "print(f\"n_states={n_states}, n_actions={n_actions}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IHXnU2QWG35L" + }, + "source": [ + "# Create stochastic policy\n", + "\n", + "This time our policy should be a probability distribution.\n", + "\n", + "```policy[s,a] = P(take action a | in state s)```\n", + "\n", + "Since we still use integer state and action representations, you can use a 2-dimensional array to represent the policy.\n", + "\n", + "Please initialize the policy __uniformly__, that is, probabililities of all actions should be equal." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "9qL5eW-rG35L" + }, + "outputs": [], + "source": [ + "def initialize_policy(n_states, n_actions):\n", + " \n", + "\n", + " return policy\n", + "\n", + "\n", + "policy = initialize_policy(n_states, n_actions)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "G1SeRRGgG35L" + }, + "outputs": [], + "source": [ + "assert type(policy) in (np.ndarray, np.matrix)\n", + "assert np.allclose(policy, 1.0 / n_actions)\n", + "assert np.allclose(np.sum(policy, axis=1), 1)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9zR1fCUrG35L" + }, + "source": [ + "# Play the game\n", + "\n", + "Just like before, but we also record all states and actions we took." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "d4v8WmvlG35L" + }, + "outputs": [], + "source": [ + "def generate_session(env, policy, t_max=10**4):\n", + " \"\"\"\n", + " Play game until end or for t_max ticks.\n", + " :param policy: an array of shape [n_states,n_actions] with action probabilities\n", + " :returns: list of states, list of actions and sum of rewards\n", + " \"\"\"\n", + " states, actions = [], []\n", + " total_reward = 0.0\n", + "\n", + " s, _ = env.reset()\n", + "\n", + " for t in range(t_max):\n", + " # Hint: you can use np.random.choice for sampling action\n", + " # https://numpy.org/doc/stable/reference/random/generated/numpy.random.choice.html\n", + "\n", + " a = \n", + "\n", + " new_s, r, terminated, truncated, _ = env.step(a)\n", + "\n", + " # Record information we just got from the environment.\n", + " states.append(s)\n", + " actions.append(a)\n", + " total_reward += r\n", + "\n", + " s = new_s\n", + " if terminated or truncated:\n", + " break\n", + "\n", + " return states, actions, total_reward\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "a1EUUZ29G35M" + }, + "outputs": [], + "source": [ + "s, a, r = generate_session(env, policy)\n", + "assert type(s) == type(a) == list\n", + "assert len(s) == len(a)\n", + "assert type(r) in [float, np.float64]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_5YEDTKnG35M" + }, + "outputs": [], + "source": [ + "# let's see the initial reward distribution\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "sample_rewards = [generate_session(env, policy, t_max=1000)[-1] for _ in range(200)]\n", + "\n", + "plt.hist(sample_rewards, bins=20)\n", + "plt.vlines([np.percentile(sample_rewards, 50)], [0], [100], label=\"50'th percentile\", color='green')\n", + "plt.vlines([np.percentile(sample_rewards, 90)], [0], [100], label=\"90'th percentile\", color='red')\n", + "plt.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EeWtL3F5G35M" + }, + "source": [ + "### Crossentropy method steps" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "00WWzr0KG35N" + }, + "outputs": [], + "source": [ + "def select_elites(states_batch, actions_batch, rewards_batch, percentile):\n", + " \"\"\"\n", + " Select states and actions from games that have rewards >= percentile\n", + " :param states_batch: list of lists of states, states_batch[session_i][t]\n", + " :param actions_batch: list of lists of actions, actions_batch[session_i][t]\n", + " :param rewards_batch: list of rewards, rewards_batch[session_i]\n", + "\n", + " :returns: elite_states,elite_actions, both 1D lists of states and respective actions from elite sessions\n", + "\n", + " Please return elite states and actions in their original order\n", + " [i.e. sorted by session number and timestep within session]\n", + "\n", + " If you are confused, see examples below. Please don't assume that states are integers\n", + " (they will become different later).\n", + " \"\"\"\n", + "\n", + " reward_threshold = \n", + "\n", + " elite_states = \n", + " elite_actions = \n", + "\n", + " return elite_states, elite_actions\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "teOLBGojG35N" + }, + "outputs": [], + "source": [ + "states_batch = [\n", + " [1, 2, 3], # game1\n", + " [4, 2, 0, 2], # game2\n", + " [3, 1], # game3\n", + "]\n", + "\n", + "actions_batch = [\n", + " [0, 2, 4], # game1\n", + " [3, 2, 0, 1], # game2\n", + " [3, 3], # game3\n", + "]\n", + "rewards_batch = [\n", + " 3, # game1\n", + " 4, # game2\n", + " 5, # game3\n", + "]\n", + "\n", + "test_result_0 = select_elites(states_batch, actions_batch, rewards_batch, percentile=0)\n", + "test_result_30 = select_elites(\n", + " states_batch, actions_batch, rewards_batch, percentile=30\n", + ")\n", + "test_result_90 = select_elites(\n", + " states_batch, actions_batch, rewards_batch, percentile=90\n", + ")\n", + "test_result_100 = select_elites(\n", + " states_batch, actions_batch, rewards_batch, percentile=100\n", + ")\n", + "\n", + "assert np.all(test_result_0[0] == [1, 2, 3, 4, 2, 0, 2, 3, 1]) and np.all(\n", + " test_result_0[1] == [0, 2, 4, 3, 2, 0, 1, 3, 3]\n", + "), \"For percentile 0 you should return all states and actions in chronological order\"\n", + "assert np.all(test_result_30[0] == [4, 2, 0, 2, 3, 1]) and np.all(\n", + " test_result_30[1] == [3, 2, 0, 1, 3, 3]\n", + "), \"For percentile 30 you should only select states/actions from two first\"\n", + "assert np.all(test_result_90[0] == [3, 1]) and np.all(\n", + " test_result_90[1] == [3, 3]\n", + "), \"For percentile 90 you should only select states/actions from one game\"\n", + "assert np.all(test_result_100[0] == [3, 1]) and np.all(\n", + " test_result_100[1] == [3, 3]\n", + "), \"Please make sure you use >=, not >. Also double-check how you compute percentile.\"\n", + "\n", + "print(\"Ok!\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wYLBHFFwG35N" + }, + "outputs": [], + "source": [ + "def get_new_policy(elite_states, elite_actions):\n", + " \"\"\"\n", + " Given a list of elite states/actions from select_elites,\n", + " return a new policy where each action probability is proportional to\n", + "\n", + " policy[s_i,a_i] ~ #[occurrences of s_i and a_i in elite states/actions]\n", + "\n", + " Don't forget to normalize the policy to get valid probabilities and handle the 0/0 case.\n", + " For states that you never visited, use a uniform distribution (1/n_actions for all states).\n", + "\n", + " :param elite_states: 1D list of states from elite sessions\n", + " :param elite_actions: 1D list of actions from elite sessions\n", + "\n", + " \"\"\"\n", + "\n", + " new_policy = np.zeros([n_states, n_actions])\n", + "\n", + " \n", + " # Don't forget to set 1/n_actions for all actions in unvisited states.\n", + "\n", + " return new_policy\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "I1VbNcpoG35O" + }, + "outputs": [], + "source": [ + "elite_states = [1, 2, 3, 4, 2, 0, 2, 3, 1]\n", + "elite_actions = [0, 2, 4, 3, 2, 0, 1, 3, 3]\n", + "\n", + "new_policy = get_new_policy(elite_states, elite_actions)\n", + "\n", + "assert np.isfinite(\n", + " new_policy\n", + ").all(), \"Your new policy contains NaNs or +-inf. Make sure you don't divide by zero.\"\n", + "assert np.all(\n", + " new_policy >= 0\n", + "), \"Your new policy can't have negative action probabilities\"\n", + "assert np.allclose(\n", + " new_policy.sum(axis=-1), 1\n", + "), \"Your new policy should be a valid probability distribution over actions\"\n", + "\n", + "reference_answer = np.array(\n", + " [\n", + " [1.0, 0.0, 0.0, 0.0, 0.0],\n", + " [0.5, 0.0, 0.0, 0.5, 0.0],\n", + " [0.0, 0.33333333, 0.66666667, 0.0, 0.0],\n", + " [0.0, 0.0, 0.0, 0.5, 0.5],\n", + " ]\n", + ")\n", + "assert np.allclose(new_policy[:4, :5], reference_answer)\n", + "\n", + "print(\"Ok!\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WVvZq1fSG35O" + }, + "source": [ + "# Training loop\n", + "Generate sessions, select N best and fit to those." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3CmH7Aj4G35O" + }, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "\n", + "def show_progress(rewards_batch, log, percentile, reward_range=[-990, +10]):\n", + " \"\"\"\n", + " A convenience function that displays training progress.\n", + " No cool math here, just charts.\n", + " \"\"\"\n", + "\n", + " mean_reward = np.mean(rewards_batch)\n", + " threshold = np.percentile(rewards_batch, percentile)\n", + " log.append([mean_reward, threshold])\n", + "\n", + " plt.figure(figsize=[8, 4])\n", + " plt.subplot(1, 2, 1)\n", + " plt.plot(list(zip(*log))[0], label=\"Mean rewards\")\n", + " plt.plot(list(zip(*log))[1], label=\"Reward thresholds\")\n", + " plt.legend()\n", + " plt.grid()\n", + "\n", + " plt.subplot(1, 2, 2)\n", + " plt.hist(rewards_batch, range=reward_range)\n", + " plt.vlines(\n", + " [np.percentile(rewards_batch, percentile)],\n", + " [0],\n", + " [100],\n", + " label=\"percentile\",\n", + " color=\"red\",\n", + " )\n", + " plt.legend()\n", + " plt.grid()\n", + " clear_output(True)\n", + " print(\"mean reward = %.3f, threshold=%.3f\" % (mean_reward, threshold))\n", + " plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tz0Yd964G35O" + }, + "outputs": [], + "source": [ + "# reset policy just in case\n", + "policy = initialize_policy(n_states, n_actions)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-cNq5hndG35O" + }, + "outputs": [], + "source": [ + "n_sessions = 250 # sample this many sessions\n", + "percentile = 50 # discard this percentage of sessions with lowest rewards\n", + "learning_rate = 0.5 # how quickly the policy is updated, on a scale from 0 to 1\n", + "\n", + "log = []\n", + "\n", + "for i in range(100):\n", + " %time sessions = [ ]\n", + "\n", + " states_batch, actions_batch, rewards_batch = zip(*sessions)\n", + "\n", + " elite_states, elite_actions = \n", + "\n", + " new_policy = \n", + "\n", + " policy = learning_rate * new_policy + (1 - learning_rate) * policy\n", + "\n", + " # display results on chart\n", + " show_progress(rewards_batch, log, percentile)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "K5LIoVTuG35O" + }, + "source": [ + "### Reflecting on results\n", + "\n", + "You may have noticed that the taxi problem quickly converges from less than -1000 to a near-optimal score and then descends back into -50/-100. This is in part because the environment has some innate randomness. Namely, the starting points of passenger/driver change from episode to episode.\n", + "\n", + "In case CEM failed to learn how to win from one distinct starting point, it will simply discard it because no sessions from that starting point will make it into the \"elites\".\n", + "\n", + "To mitigate that problem, you can either reduce the threshold for elite sessions (duct tape way) or change the way you evaluate strategy (theoretically correct way). For each starting state, you can sample an action randomly, and then evaluate this action by running _several_ games starting from it and averaging the total reward. Choosing elite sessions with this kind of sampling (where each session's reward is counted as the average of the rewards of all sessions with the same starting state and action) should improve the performance of your policy." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ubIHBgQ-G35O" + }, + "source": [ + "\n", + "### You're not done yet!\n", + "\n", + "Go to [`./deep_crossentropy_method.ipynb`](./deep_crossentropy_method.ipynb) for a more serious task" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys, os\n", - "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", - " !touch .setup_complete\n", - "\n", - "# This code creates a virtual display to draw game images on.\n", - "# It will have no effect if your machine has a monitor.\n", - "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", - " !bash ../xvfb start\n", - " os.environ['DISPLAY'] = ':1'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import gym\n", - "import numpy as np\n", - "\n", - "env = gym.make(\"Taxi-v3\")\n", - "env.reset()\n", - "env.render()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "n_states = env.observation_space.n\n", - "n_actions = env.action_space.n\n", - "\n", - "print(\"n_states=%i, n_actions=%i\" % (n_states, n_actions))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Create stochastic policy\n", - "\n", - "This time our policy should be a probability distribution.\n", - "\n", - "```policy[s,a] = P(take action a | in state s)```\n", - "\n", - "Since we still use integer state and action representations, you can use a 2-dimensional array to represent the policy.\n", - "\n", - "Please initialize the policy __uniformly__, that is, probabililities of all actions should be equal." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def initialize_policy(n_states, n_actions):\n", - " \n", - " \n", - " return policy\n", - "\n", - "policy = initialize_policy(n_states, n_actions)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "assert type(policy) in (np.ndarray, np.matrix)\n", - "assert np.allclose(policy, 1./n_actions)\n", - "assert np.allclose(np.sum(policy, axis=1), 1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Play the game\n", - "\n", - "Just like before, but we also record all states and actions we took." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def generate_session(env, policy, t_max=10**4):\n", - " \"\"\"\n", - " Play game until end or for t_max ticks.\n", - " :param policy: an array of shape [n_states,n_actions] with action probabilities\n", - " :returns: list of states, list of actions and sum of rewards\n", - " \"\"\"\n", - " states, actions = [], []\n", - " total_reward = 0.\n", - "\n", - " s = env.reset()\n", - "\n", - " for t in range(t_max):\n", - " # Hint: you can use np.random.choice for sampling action\n", - " # https://numpy.org/doc/stable/reference/random/generated/numpy.random.choice.html\n", - " a = \n", - "\n", - " new_s, r, done, info = env.step(a)\n", - "\n", - " # Record information we just got from the environment.\n", - " states.append(s)\n", - " actions.append(a)\n", - " total_reward += r\n", - "\n", - " s = new_s\n", - " if done:\n", - " break\n", - "\n", - " return states, actions, total_reward" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "s, a, r = generate_session(env, policy)\n", - "assert type(s) == type(a) == list\n", - "assert len(s) == len(a)\n", - "assert type(r) in [float, np.float]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# let's see the initial reward distribution\n", - "import matplotlib.pyplot as plt\n", - "%matplotlib inline\n", - "\n", - "sample_rewards = [generate_session(env, policy, t_max=1000)[-1] for _ in range(200)]\n", - "\n", - "plt.hist(sample_rewards, bins=20)\n", - "plt.vlines([np.percentile(sample_rewards, 50)], [0], [100], label=\"50'th percentile\", color='green')\n", - "plt.vlines([np.percentile(sample_rewards, 90)], [0], [100], label=\"90'th percentile\", color='red')\n", - "plt.legend()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Crossentropy method steps" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def select_elites(states_batch, actions_batch, rewards_batch, percentile):\n", - " \"\"\"\n", - " Select states and actions from games that have rewards >= percentile\n", - " :param states_batch: list of lists of states, states_batch[session_i][t]\n", - " :param actions_batch: list of lists of actions, actions_batch[session_i][t]\n", - " :param rewards_batch: list of rewards, rewards_batch[session_i]\n", - "\n", - " :returns: elite_states,elite_actions, both 1D lists of states and respective actions from elite sessions\n", - "\n", - " Please return elite states and actions in their original order \n", - " [i.e. sorted by session number and timestep within session]\n", - "\n", - " If you are confused, see examples below. Please don't assume that states are integers\n", - " (they will become different later).\n", - " \"\"\"\n", - "\n", - " reward_threshold = \n", - "\n", - " elite_states = \n", - " elite_actions = \n", - "\n", - " return elite_states, elite_actions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "states_batch = [\n", - " [1, 2, 3], # game1\n", - " [4, 2, 0, 2], # game2\n", - " [3, 1], # game3\n", - "]\n", - "\n", - "actions_batch = [\n", - " [0, 2, 4], # game1\n", - " [3, 2, 0, 1], # game2\n", - " [3, 3], # game3\n", - "]\n", - "rewards_batch = [\n", - " 3, # game1\n", - " 4, # game2\n", - " 5, # game3\n", - "]\n", - "\n", - "test_result_0 = select_elites(states_batch, actions_batch, rewards_batch, percentile=0)\n", - "test_result_30 = select_elites(states_batch, actions_batch, rewards_batch, percentile=30)\n", - "test_result_90 = select_elites(states_batch, actions_batch, rewards_batch, percentile=90)\n", - "test_result_100 = select_elites(states_batch, actions_batch, rewards_batch, percentile=100)\n", - "\n", - "assert np.all(test_result_0[0] == [1, 2, 3, 4, 2, 0, 2, 3, 1]) \\\n", - " and np.all(test_result_0[1] == [0, 2, 4, 3, 2, 0, 1, 3, 3]), \\\n", - " \"For percentile 0 you should return all states and actions in chronological order\"\n", - "assert np.all(test_result_30[0] == [4, 2, 0, 2, 3, 1]) and \\\n", - " np.all(test_result_30[1] == [3, 2, 0, 1, 3, 3]), \\\n", - " \"For percentile 30 you should only select states/actions from two first\"\n", - "assert np.all(test_result_90[0] == [3, 1]) and \\\n", - " np.all(test_result_90[1] == [3, 3]), \\\n", - " \"For percentile 90 you should only select states/actions from one game\"\n", - "assert np.all(test_result_100[0] == [3, 1]) and\\\n", - " np.all(test_result_100[1] == [3, 3]), \\\n", - " \"Please make sure you use >=, not >. Also double-check how you compute percentile.\"\n", - "\n", - "print(\"Ok!\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def get_new_policy(elite_states, elite_actions):\n", - " \"\"\"\n", - " Given a list of elite states/actions from select_elites,\n", - " return a new policy where each action probability is proportional to\n", - "\n", - " policy[s_i,a_i] ~ #[occurrences of s_i and a_i in elite states/actions]\n", - "\n", - " Don't forget to normalize the policy to get valid probabilities and handle the 0/0 case.\n", - " For states that you never visited, use a uniform distribution (1/n_actions for all states).\n", - "\n", - " :param elite_states: 1D list of states from elite sessions\n", - " :param elite_actions: 1D list of actions from elite sessions\n", - "\n", - " \"\"\"\n", - "\n", - " new_policy = np.zeros([n_states, n_actions])\n", - "\n", - " \n", - " # Don't forget to set 1/n_actions for all actions in unvisited states.\n", - "\n", - " return new_policy" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "elite_states = [1, 2, 3, 4, 2, 0, 2, 3, 1]\n", - "elite_actions = [0, 2, 4, 3, 2, 0, 1, 3, 3]\n", - "\n", - "new_policy = get_new_policy(elite_states, elite_actions)\n", - "\n", - "assert np.isfinite(new_policy).all(), \\\n", - " \"Your new policy contains NaNs or +-inf. Make sure you don't divide by zero.\"\n", - "assert np.all(new_policy >= 0), \\\n", - " \"Your new policy can't have negative action probabilities\"\n", - "assert np.allclose(new_policy.sum(axis=-1), 1), \\\n", - " \"Your new policy should be a valid probability distribution over actions\"\n", - "\n", - "reference_answer = np.array([\n", - " [1., 0., 0., 0., 0.],\n", - " [0.5, 0., 0., 0.5, 0.],\n", - " [0., 0.33333333, 0.66666667, 0., 0.],\n", - " [0., 0., 0., 0.5, 0.5]])\n", - "assert np.allclose(new_policy[:4, :5], reference_answer)\n", - "\n", - "print(\"Ok!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Training loop\n", - "Generate sessions, select N best and fit to those." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from IPython.display import clear_output\n", - "\n", - "def show_progress(rewards_batch, log, percentile, reward_range=[-990, +10]):\n", - " \"\"\"\n", - " A convenience function that displays training progress. \n", - " No cool math here, just charts.\n", - " \"\"\"\n", - "\n", - " mean_reward = np.mean(rewards_batch)\n", - " threshold = np.percentile(rewards_batch, percentile)\n", - " log.append([mean_reward, threshold])\n", - " \n", - " plt.figure(figsize=[8, 4])\n", - " plt.subplot(1, 2, 1)\n", - " plt.plot(list(zip(*log))[0], label='Mean rewards')\n", - " plt.plot(list(zip(*log))[1], label='Reward thresholds')\n", - " plt.legend()\n", - " plt.grid()\n", - "\n", - " plt.subplot(1, 2, 2)\n", - " plt.hist(rewards_batch, range=reward_range)\n", - " plt.vlines([np.percentile(rewards_batch, percentile)],\n", - " [0], [100], label=\"percentile\", color='red')\n", - " plt.legend()\n", - " plt.grid()\n", - " clear_output(True)\n", - " print(\"mean reward = %.3f, threshold=%.3f\" % (mean_reward, threshold))\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# reset policy just in case\n", - "policy = initialize_policy(n_states, n_actions)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "n_sessions = 250 # sample this many sessions\n", - "percentile = 50 # take this percent of session with highest rewards\n", - "learning_rate = 0.5 # how quickly the policy is updated, on a scale from 0 to 1\n", - "\n", - "log = []\n", - "\n", - "for i in range(100):\n", - " %time sessions = [ ]\n", - "\n", - " states_batch, actions_batch, rewards_batch = zip(*sessions)\n", - "\n", - " elite_states, elite_actions = \n", - "\n", - " new_policy = \n", - "\n", - " policy = learning_rate * new_policy + (1 - learning_rate) * policy\n", - "\n", - " # display results on chart\n", - " show_progress(rewards_batch, log, percentile)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Reflecting on results\n", - "\n", - "You may have noticed that the taxi problem quickly converges from less than -1000 to a near-optimal score and then descends back into -50/-100. This is in part because the environment has some innate randomness. Namely, the starting points of passenger/driver change from episode to episode.\n", - "\n", - "In case CEM failed to learn how to win from one distinct starting point, it will simply discard it because no sessions from that starting point will make it into the \"elites\".\n", - "\n", - "To mitigate that problem, you can either reduce the threshold for elite sessions (duct tape way) or change the way you evaluate strategy (theoretically correct way). For each starting state, you can sample an action randomly, and then evaluate this action by running _several_ games starting from it and averaging the total reward. Choosing elite sessions with this kind of sampling (where each session's reward is counted as the average of the rewards of all sessions with the same starting state and action) should improve the performance of your policy." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "### You're not done yet!\n", - "\n", - "Go to [`./deep_crossentropy_method.ipynb`](./deep_crossentropy_method.ipynb) for a more serious task" - ] - } - ], - "metadata": { - "language_info": { - "name": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 1 + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/week01_intro/deep_crossentropy_method.ipynb b/week01_intro/deep_crossentropy_method.ipynb index 38eff5276..6cf7aa349 100644 --- a/week01_intro/deep_crossentropy_method.ipynb +++ b/week01_intro/deep_crossentropy_method.ipynb @@ -1,455 +1,524 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Deep Crossentropy method\n", - "\n", - "In this section we'll extend your CEM implementation with neural networks! You will train a multi-layer neural network to solve simple continuous state space games. __Please make sure you're done with tabular crossentropy method from the previous notebook.__\n", - "\n", - "![img](https://watanimg.elwatannews.com/old_news_images/large/249765_Large_20140709045740_11.jpg)\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys, os\n", - "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", - " !touch .setup_complete\n", - "\n", - "# This code creates a virtual display to draw game images on.\n", - "# It will have no effect if your machine has a monitor.\n", - "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", - " !bash ../xvfb start\n", - " os.environ['DISPLAY'] = ':1'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import gym\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "%matplotlib inline\n", - "\n", - "# if you see \" has no attribute .env\", remove .env or update gym\n", - "env = gym.make(\"CartPole-v0\").env\n", - "\n", - "env.reset()\n", - "n_actions = env.action_space.n\n", - "state_dim = env.observation_space.shape[0]\n", - "\n", - "plt.imshow(env.render(\"rgb_array\"))\n", - "print(\"state vector dim =\", state_dim)\n", - "print(\"n_actions =\", n_actions)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Neural Network Policy\n", - "\n", - "For this assignment we'll utilize the simplified neural network implementation from __[Scikit-learn](https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPClassifier.html)__. Here's what you'll need:\n", - "\n", - "* `agent.partial_fit(states, actions)` - make a single training pass over the data. Maximize the probability of :actions: from :states:\n", - "* `agent.predict_proba(states)` - predict probabilities of all actions, a matrix of shape __[len(states), n_actions]__\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn.neural_network import MLPClassifier\n", - "\n", - "agent = MLPClassifier(\n", - " hidden_layer_sizes=(20, 20),\n", - " activation='tanh',\n", - ")\n", - "\n", - "# initialize agent to the dimension of state space and number of actions\n", - "agent.partial_fit([env.reset()] * n_actions, range(n_actions), range(n_actions))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def generate_session(env, agent, t_max=1000):\n", - " \"\"\"\n", - " Play a single game using agent neural network.\n", - " Terminate when game finishes or after :t_max: steps\n", - " \"\"\"\n", - " states, actions = [], []\n", - " total_reward = 0\n", - "\n", - " s = env.reset()\n", - "\n", - " for t in range(t_max):\n", - " \n", - " # use agent to predict a vector of action probabilities for state :s:\n", - " probs = \n", - "\n", - " assert probs.shape == (env.action_space.n,), \"make sure probabilities are a vector (hint: np.reshape)\"\n", - " \n", - " # use the probabilities you predicted to pick an action\n", - " # sample proportionally to the probabilities, don't just take the most likely action\n", - " a = \n", - " # ^-- hint: try np.random.choice\n", - "\n", - " new_s, r, done, info = env.step(a)\n", - "\n", - " # record sessions like you did before\n", - " states.append(s)\n", - " actions.append(a)\n", - " total_reward += r\n", - "\n", - " s = new_s\n", - " if done:\n", - " break\n", - " return states, actions, total_reward" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dummy_states, dummy_actions, dummy_reward = generate_session(env, agent, t_max=5)\n", - "print(\"states:\", np.stack(dummy_states))\n", - "print(\"actions:\", dummy_actions)\n", - "print(\"reward:\", dummy_reward)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### CEM steps\n", - "Deep CEM uses exactly the same strategy as the regular CEM, so you can copy your function code from previous notebook.\n", - "\n", - "The only difference is that now each observation is not a number but a `float32` vector." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def select_elites(states_batch, actions_batch, rewards_batch, percentile=50):\n", - " \"\"\"\n", - " Select states and actions from games that have rewards >= percentile\n", - " :param states_batch: list of lists of states, states_batch[session_i][t]\n", - " :param actions_batch: list of lists of actions, actions_batch[session_i][t]\n", - " :param rewards_batch: list of rewards, rewards_batch[session_i]\n", - "\n", - " :returns: elite_states,elite_actions, both 1D lists of states and respective actions from elite sessions\n", - "\n", - " Please return elite states and actions in their original order \n", - " [i.e. sorted by session number and timestep within session]\n", - "\n", - " If you are confused, see examples below. Please don't assume that states are integers\n", - " (they will become different later).\n", - " \"\"\"\n", - "\n", - " \n", - " \n", - " return elite_states, elite_actions" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Training loop\n", - "Generate sessions, select N best and fit to those." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from IPython.display import clear_output\n", - "\n", - "def show_progress(rewards_batch, log, percentile, reward_range=[-990, +10]):\n", - " \"\"\"\n", - " A convenience function that displays training progress. \n", - " No cool math here, just charts.\n", - " \"\"\"\n", - "\n", - " mean_reward = np.mean(rewards_batch)\n", - " threshold = np.percentile(rewards_batch, percentile)\n", - " log.append([mean_reward, threshold])\n", - "\n", - " clear_output(True)\n", - " print(\"mean reward = %.3f, threshold=%.3f\" % (mean_reward, threshold))\n", - " plt.figure(figsize=[8, 4])\n", - " plt.subplot(1, 2, 1)\n", - " plt.plot(list(zip(*log))[0], label='Mean rewards')\n", - " plt.plot(list(zip(*log))[1], label='Reward thresholds')\n", - " plt.legend()\n", - " plt.grid()\n", - "\n", - " plt.subplot(1, 2, 2)\n", - " plt.hist(rewards_batch, range=reward_range)\n", - " plt.vlines([np.percentile(rewards_batch, percentile)],\n", - " [0], [100], label=\"percentile\", color='red')\n", - " plt.legend()\n", - " plt.grid()\n", - "\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "n_sessions = 100\n", - "percentile = 70\n", - "log = []\n", - "\n", - "for i in range(100):\n", - " # generate new sessions\n", - " sessions = [ ]\n", - "\n", - " states_batch, actions_batch, rewards_batch = map(np.array, zip(*sessions))\n", - "\n", - " elite_states, elite_actions = \n", - "\n", - " \n", - "\n", - " show_progress(rewards_batch, log, percentile, reward_range=[0, np.max(rewards_batch)])\n", - "\n", - " if np.mean(rewards_batch) > 190:\n", - " print(\"You Win! You may stop training now via KeyboardInterrupt.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Record sessions\n", - "\n", - "import gym.wrappers\n", - "\n", - "with gym.wrappers.Monitor(gym.make(\"CartPole-v0\"), directory=\"videos\", force=True) as env_monitor:\n", - " sessions = [generate_session(env_monitor, agent) for _ in range(100)]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Show video. This may not work in some setups. If it doesn't\n", - "# work for you, you can download the videos and view them locally.\n", - "\n", - "from pathlib import Path\n", - "from base64 import b64encode\n", - "from IPython.display import HTML\n", - "\n", - "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n", - "video_path = video_paths[-1] # You can also try other indices\n", - "\n", - "if 'google.colab' in sys.modules:\n", - " # https://stackoverflow.com/a/57378660/1214547\n", - " with video_path.open('rb') as fp:\n", - " mp4 = fp.read()\n", - " data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", - "else:\n", - " data_url = str(video_path)\n", - "\n", - "HTML(\"\"\"\n", - "\n", - "\"\"\".format(data_url))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Homework part I\n", - "\n", - "### Tabular crossentropy method\n", - "\n", - "You may have noticed that the taxi problem quickly converges from -100 to a near-optimal score and then descends back into -50/-100. This is in part because the environment has some innate randomness. Namely, the starting points of passenger/driver change from episode to episode.\n", - "\n", - "### Tasks\n", - "- __1.1__ (2 pts) Find out how the algorithm performance changes if you use a different `percentile` and/or `n_sessions`. Provide here some figures so we can see how the hyperparameters influence the performance.\n", - "- __1.2__ (1 pts) Tune the algorithm to end up with positive average score.\n", - "\n", - "It's okay to modify the existing code.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "``````" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Homework part II\n", - "\n", - "### Deep crossentropy method\n", - "\n", - "By this moment, you should have got enough score on [CartPole-v0](https://gym.openai.com/envs/CartPole-v0) to consider it solved (see the link). It's time to try something harder.\n", - "\n", - "* if you have any trouble with CartPole-v0 and feel stuck, feel free to ask us or your peers for help.\n", - "\n", - "### Tasks\n", - "\n", - "* __2.1__ (3 pts) Pick one of environments: `MountainCar-v0` or `LunarLander-v2`.\n", - " * For MountainCar, get average reward of __at least -150__\n", - " * For LunarLander, get average reward of __at least +50__\n", - "\n", - "See the tips section below, it's kinda important.\n", - "__Note:__ If your agent is below the target score, you'll still get some of the points depending on the result, so don't be afraid to submit it.\n", - " \n", - " \n", - "* __2.2__ (up to 6 pts) Devise a way to speed up training against the default version\n", - " * Obvious improvement: use [`joblib`](https://joblib.readthedocs.io/en/latest/). However, note that you will probably need to spawn a new environment in each of the workers instead of passing it via pickling. (2 pts)\n", - " * Try re-using samples from 3-5 last iterations when computing threshold and training. (2 pts)\n", - " * Obtain __-100__ at `MountainCar-v0` or __+200__ at `LunarLander-v2` (2 pts). Feel free to experiment with hyperparameters, architectures, schedules etc.\n", - " \n", - "__Please list what you did in Anytask submission form__. This reduces probability that somebody misses something.\n", - " \n", - " \n", - "### Tips\n", - "* Gym page: [MountainCar](https://gym.openai.com/envs/MountainCar-v0), [LunarLander](https://gym.openai.com/envs/LunarLander-v2)\n", - "* Sessions for MountainCar may last for 10k+ ticks. Make sure ```t_max``` param is at least 10k.\n", - " * Also it may be a good idea to cut rewards via \">\" and not \">=\". If 90% of your sessions get reward of -10k and 10% are better, than if you use percentile 20% as threshold, R >= threshold __fails to cut off bad sessions__ while R > threshold works alright.\n", - "* _issue with gym_: Some versions of gym limit game time by 200 ticks. This will prevent cem training in most cases. Make sure your agent is able to play for the specified __t_max__, and if it isn't, try `env = gym.make(\"MountainCar-v0\").env` or otherwise get rid of TimeLimit wrapper.\n", - "* If you use old _swig_ lib for LunarLander-v2, you may get an error. See this [issue](https://github.com/openai/gym/issues/100) for solution.\n", - "* If it doesn't train, it's a good idea to plot reward distribution and record sessions: they may give you some clue. If they don't, call course staff :)\n", - "* 20-neuron network is probably not enough, feel free to experiment.\n", - "\n", - "You may find the following snippet useful:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "I_i1q1TWG9zH" + }, + "source": [ + "# Deep Crossentropy method\n", + "\n", + "In this section we'll extend your CEM implementation with neural networks! You will train a multi-layer neural network to solve simple continuous state space games. __Please make sure you're done with tabular crossentropy method from the previous notebook.__\n", + "\n", + "![img](https://watanimg.elwatannews.com/old_news_images/large/249765_Large_20140709045740_11.jpg)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "t4CJ1sRyG9zJ" + }, + "outputs": [], + "source": [ + "import sys, os\n", + "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", + " !touch .setup_complete\n", + "\n", + "# This code creates a virtual display to draw game images on.\n", + "# It will have no effect if your machine has a monitor.\n", + "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", + " !bash ../xvfb start\n", + " os.environ['DISPLAY'] = ':1'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "C2xd5vPwPVCb" + }, + "outputs": [], + "source": [ + "# Install gymnasium if you didn't\n", + "!pip install gymnasium[toy_text,classic_control]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_2zbc7ahG9zK" + }, + "outputs": [], + "source": [ + "import gymnasium as gym\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "# if you see \" has no attribute .env\", remove .env or update gym\n", + "env = gym.make(\"CartPole-v0\", render_mode=\"rgb_array\").env\n", + "\n", + "env.reset()\n", + "n_actions = env.action_space.n\n", + "state_dim = env.observation_space.shape[0]\n", + "\n", + "plt.imshow(env.render())\n", + "print(\"state vector dim =\", state_dim)\n", + "print(\"n_actions =\", n_actions)\n", + "\n", + "env.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z72_alhdG9zK" + }, + "source": [ + "# Neural Network Policy\n", + "\n", + "For this assignment we'll utilize the simplified neural network implementation from __[Scikit-learn](https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPClassifier.html)__. Here's what you'll need:\n", + "\n", + "* `agent.partial_fit(states, actions)` - make a single training pass over the data. Maximize the probability of :actions: from :states:\n", + "* `agent.predict_proba(states)` - predict probabilities of all actions, a matrix of shape __[len(states), n_actions]__\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wLItY4unG9zL" + }, + "outputs": [], + "source": [ + "from sklearn.neural_network import MLPClassifier\n", + "\n", + "agent = MLPClassifier(\n", + " hidden_layer_sizes=(20, 20),\n", + " activation=\"tanh\",\n", + ")\n", + "\n", + "# initialize agent to the dimension of state space and number of actions\n", + "agent.partial_fit([env.reset()[0]] * n_actions, range(n_actions), range(n_actions))\n" + ] + }, { - "data": { - "image/png": "\n", - "text/plain": [ - "
" + "cell_type": "code", + "execution_count": 45, + "metadata": { + "id": "eyFS3oUmG9zL" + }, + "outputs": [], + "source": [ + "def generate_session(env, agent, t_max=1000):\n", + " \"\"\"\n", + " Play a single game using agent neural network.\n", + " Terminate when game finishes or after :t_max: steps\n", + " \"\"\"\n", + " states, actions = [], []\n", + " total_reward = 0\n", + "\n", + " s, _ = env.reset()\n", + "\n", + " for t in range(t_max):\n", + "\n", + " # use agent to predict a vector of action probabilities for state :s:\n", + " probs = \n", + "\n", + " assert probs.shape == (env.action_space.n,), \"make sure probabilities are a vector (hint: np.reshape)\"\n", + "\n", + " # use the probabilities you predicted to pick an action\n", + " # sample proportionally to the probabilities, don't just take the most likely action\n", + " a = \n", + " # ^-- hint: try np.random.choice\n", + "\n", + " new_s, r, terminated, truncated, _ = env.step(a)\n", + "\n", + " # record sessions like you did before\n", + " states.append(s)\n", + " actions.append(a)\n", + " total_reward += r\n", + "\n", + " s = new_s\n", + " if terminated or truncated:\n", + " break\n", + " return states, actions, total_reward\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4xgrTCgJG9zL" + }, + "outputs": [], + "source": [ + "dummy_states, dummy_actions, dummy_reward = generate_session(env, agent, t_max=5)\n", + "print(\"states:\", np.stack(dummy_states))\n", + "print(\"actions:\", dummy_actions)\n", + "print(\"reward:\", dummy_reward)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "p85lt16qG9zL" + }, + "source": [ + "### CEM steps\n", + "Deep CEM uses exactly the same strategy as the regular CEM, so you can copy your function code from previous notebook.\n", + "\n", + "The only difference is that now each observation is not a number but a `float32` vector." + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": { + "id": "4On-p7p4G9zL" + }, + "outputs": [], + "source": [ + "def select_elites(states_batch, actions_batch, rewards_batch, percentile=50):\n", + " \"\"\"\n", + " Select states and actions from games that have rewards >= percentile\n", + " :param states_batch: list of lists of states, states_batch[session_i][t]\n", + " :param actions_batch: list of lists of actions, actions_batch[session_i][t]\n", + " :param rewards_batch: list of rewards, rewards_batch[session_i]\n", + "\n", + " :returns: elite_states,elite_actions, both 1D lists of states and respective actions from elite sessions\n", + "\n", + " Please return elite states and actions in their original order\n", + " [i.e. sorted by session number and timestep within session]\n", + "\n", + " If you are confused, see examples below. Please don't assume that states are integers\n", + " (they will become different later).\n", + " \"\"\"\n", + "\n", + " \n", + "\n", + " return elite_states, elite_actions\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xc40V4DaG9zM" + }, + "source": [ + "# Training loop\n", + "Generate sessions, select N best and fit to those." + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": { + "id": "PPwVKwF7G9zM" + }, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "\n", + "def show_progress(rewards_batch, log, percentile, reward_range=[-990, +10]):\n", + " \"\"\"\n", + " A convenience function that displays training progress.\n", + " No cool math here, just charts.\n", + " \"\"\"\n", + "\n", + " mean_reward = np.mean(rewards_batch)\n", + " threshold = np.percentile(rewards_batch, percentile)\n", + " log.append([mean_reward, threshold])\n", + "\n", + " clear_output(True)\n", + " print(\"mean reward = %.3f, threshold=%.3f\" % (mean_reward, threshold))\n", + " plt.figure(figsize=[8, 4])\n", + " plt.subplot(1, 2, 1)\n", + " plt.plot(list(zip(*log))[0], label=\"Mean rewards\")\n", + " plt.plot(list(zip(*log))[1], label=\"Reward thresholds\")\n", + " plt.legend()\n", + " plt.grid()\n", + "\n", + " plt.subplot(1, 2, 2)\n", + " plt.hist(rewards_batch, range=reward_range)\n", + " plt.vlines(\n", + " [np.percentile(rewards_batch, percentile)],\n", + " [0],\n", + " [100],\n", + " label=\"percentile\",\n", + " color=\"red\",\n", + " )\n", + " plt.legend()\n", + " plt.grid()\n", + "\n", + " plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "euK7WRQiG9zM" + }, + "outputs": [], + "source": [ + "n_sessions = 100\n", + "percentile = 70\n", + "log = []\n", + "\n", + "for i in range(100):\n", + " # generate new sessions\n", + " sessions = [ ]\n", + "\n", + " states_batch, actions_batch, rewards_batch = map(np.array, zip(*sessions))\n", + "\n", + " elite_states, elite_actions = \n", + "\n", + " \n", + "\n", + " show_progress(\n", + " rewards_batch, log, percentile, reward_range=[0, np.max(rewards_batch)]\n", + " )\n", + "\n", + " if np.mean(rewards_batch) > 190:\n", + " print(\"You Win! You may stop training now via KeyboardInterrupt.\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yeNWKjtsG9zM" + }, + "source": [ + "# Results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RJwsWl4kG9zM" + }, + "outputs": [], + "source": [ + "# Record sessions\n", + "\n", + "from gymnasium.wrappers import RecordVideo\n", + "\n", + "with RecordVideo(\n", + " env=gym.make(\"CartPole-v0\", render_mode=\"rgb_array\"),\n", + " video_folder=\"./videos\",\n", + " episode_trigger=lambda episode_number: True,\n", + ") as env_monitor:\n", + " sessions = [generate_session(env_monitor, agent) for _ in range(100)]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kLPXdME7G9zN" + }, + "outputs": [], + "source": [ + "# Show video. This may not work in some setups. If it doesn't\n", + "# work for you, you can download the videos and view them locally.\n", + "\n", + "from pathlib import Path\n", + "from base64 import b64encode\n", + "from IPython.display import HTML\n", + "\n", + "video_paths = sorted([s for s in Path(\"videos\").iterdir() if s.suffix == \".mp4\"])\n", + "video_path = video_paths[-1] # You can also try other indices\n", + "\n", + "if \"google.colab\" in sys.modules:\n", + " # https://stackoverflow.com/a/57378660/1214547\n", + " with video_path.open(\"rb\") as fp:\n", + " mp4 = fp.read()\n", + " data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", + "else:\n", + " data_url = str(video_path)\n", + "\n", + "HTML(\n", + " \"\"\"\n", + "\n", + "\"\"\".format(\n", + " data_url\n", + " )\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6d_3oOQ1G9zN" + }, + "source": [ + "# Homework part I\n", + "\n", + "### Tabular crossentropy method\n", + "\n", + "You may have noticed that the taxi problem quickly converges from -100 to a near-optimal score and then descends back into -50/-100. This is in part because the environment has some innate randomness. Namely, the starting points of passenger/driver change from episode to episode.\n", + "\n", + "### Tasks\n", + "- __1.1__ (2 pts) Find out how the algorithm performance changes if you use a different `percentile` and/or `n_sessions`. Provide here some figures so we can see how the hyperparameters influence the performance.\n", + "- __1.2__ (1 pts) Tune the algorithm to end up with positive average score.\n", + "\n", + "It's okay to modify the existing code.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "L88LySiVG9zN" + }, + "source": [ + "``````" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7LpAJc4rG9zN" + }, + "source": [ + "# Homework part II\n", + "\n", + "### Deep crossentropy method\n", + "\n", + "By this moment, you should have got enough score on [CartPole-v0](https://gymnasium.farama.org/environments/classic_control/cart_pole/) to consider it solved (see the link). It's time to try something harder.\n", + "\n", + "* if you have any trouble with CartPole-v0 and feel stuck, feel free to ask us or your peers for help.\n", + "\n", + "### Tasks\n", + "\n", + "* __2.1__ (3 pts) Pick one of environments: `MountainCar-v0` or `LunarLander-v2`.\n", + " * For MountainCar, get average reward of __at least -150__\n", + " * For LunarLander, get average reward of __at least +50__\n", + "\n", + "See the tips section below, it's kinda important.\n", + "__Note:__ If your agent is below the target score, you'll still get some of the points depending on the result, so don't be afraid to submit it.\n", + " \n", + " \n", + "* __2.2__ (up to 6 pts) Devise a way to speed up training against the default version\n", + " * Obvious improvement: use [`joblib`](https://joblib.readthedocs.io/en/latest/). However, note that you will probably need to spawn a new environment in each of the workers instead of passing it via pickling. (2 pts)\n", + " * Try re-using samples from 3-5 last iterations when computing threshold and training. (2 pts)\n", + " * Obtain __-100__ at `MountainCar-v0` or __+200__ at `LunarLander-v2` (2 pts). Feel free to experiment with hyperparameters, architectures, schedules etc.\n", + " \n", + "__Please list what you did in Anytask submission form__. This reduces probability that somebody misses something.\n", + " \n", + " \n", + "### Tips\n", + "* Gymnasium pages: [MountainCar](https://gymnasium.farama.org/environments/classic_control/mountain_car/), [LunarLander](https://gymnasium.farama.org/environments/box2d/lunar_lander/)\n", + "* Sessions for MountainCar may last for 10k+ ticks. Make sure ```t_max``` param is at least 10k.\n", + " * Also it may be a good idea to cut rewards via \">\" and not \">=\". If 90% of your sessions get reward of -10k and 10% are better, than if you use percentile 20% as threshold, R >= threshold __fails to cut off bad sessions__ while R > threshold works alright.\n", + "* _issue with gym_: Some versions of gym limit game time by 200 ticks. This will prevent cem training in most cases. Make sure your agent is able to play for the specified __t_max__, and if it isn't, try `env = gym.make(\"MountainCar-v0\").env` or otherwise get rid of TimeLimit wrapper.\n", + "* If you use old _swig_ lib for LunarLander-v2, you may get an error. See this [issue](https://github.com/openai/gym/issues/100) for solution.\n", + "* If it doesn't train, it's a good idea to plot reward distribution and record sessions: they may give you some clue. If they don't, call course staff :)\n", + "* 20-neuron network is probably not enough, feel free to experiment.\n", + "\n", + "You may find the following snippet useful:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qcjz-nm_G9zN" + }, + "outputs": [], + "source": [ + "def visualize_mountain_car(env, agent):\n", + " # Compute policy for all possible x and v (with discretization)\n", + " xs = np.linspace(env.min_position, env.max_position, 100)\n", + " vs = np.linspace(-env.max_speed, env.max_speed, 100)\n", + "\n", + " grid = np.dstack(np.meshgrid(xs, vs[::-1])).transpose(1, 0, 2)\n", + " grid_flat = grid.reshape(len(xs) * len(vs), 2)\n", + " probs = (\n", + " agent.predict_proba(grid_flat).reshape(len(xs), len(vs), 3).transpose(1, 0, 2)\n", + " )\n", + "\n", + " # # The above code is equivalent to the following:\n", + " # probs = np.empty((len(vs), len(xs), 3))\n", + " # for i, v in enumerate(vs[::-1]):\n", + " # for j, x in enumerate(xs):\n", + " # probs[i, j, :] = agent.predict_proba([[x, v]])[0]\n", + "\n", + " # Draw policy\n", + " f, ax = plt.subplots(figsize=(7, 7))\n", + " ax.imshow(\n", + " probs,\n", + " extent=(env.min_position, env.max_position, -env.max_speed, env.max_speed),\n", + " aspect=\"auto\",\n", + " )\n", + " ax.set_title(\"Learned policy: red=left, green=nothing, blue=right\")\n", + " ax.set_xlabel(\"position (x)\")\n", + " ax.set_ylabel(\"velocity (v)\")\n", + "\n", + " # Sample a trajectory and draw it\n", + " states, actions, _ = generate_session(env, agent)\n", + " states = np.array(states)\n", + " ax.plot(states[:, 0], states[:, 1], color=\"white\")\n", + "\n", + " # Draw every 3rd action from the trajectory\n", + " for (x, v), a in zip(states[::3], actions[::3]):\n", + " if a == 0:\n", + " plt.arrow(x, v, -0.1, 0, color=\"white\", head_length=0.02)\n", + " elif a == 2:\n", + " plt.arrow(x, v, 0.1, 0, color=\"white\", head_length=0.02)\n", + "\n", + "\n", + "with gym.make(\"MountainCar-v0\", render_mode=\"rgb_arrary\").env as env:\n", + " visualize_mountain_car(env, agent)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Dzk41lDPG9zO" + }, + "source": [ + "### Bonus tasks\n", + "\n", + "* __2.3 bonus__ (2 pts) Try to find a network architecture and training params that solve __both__ environments above (_Points depend on implementation. If you attempted this task, please mention it in Anytask submission._)\n", + "\n", + "* __2.4 bonus__ (4 pts) Solve continuous action space task with `MLPRegressor` or similar.\n", + " * Since your agent only predicts the \"expected\" action, you will have to add noise to ensure exploration.\n", + " * Choose one of [MountainCarContinuous-v0](https://gymnasium.farama.org/environments/classic_control/mountain_car_continuous/) (90+ pts to solve), [LunarLanderContinuous-v2](https://gymnasium.farama.org/environments/box2d/lunar_lander/) (`env = gym.make(\"LunarLander-v2\", continuous=True)`)(200+ pts to solve)\n", + " * 4 points for solving. Slightly less for getting some results below solution threshold. Note that discrete and continuous environments may have slightly different rules, aside from action spaces." ] - }, - "metadata": {}, - "output_type": "display_data" } - ], - "source": [ - "def visualize_mountain_car(env, agent):\n", - " # Compute policy for all possible x and v (with discretization)\n", - " xs = np.linspace(env.min_position, env.max_position, 100)\n", - " vs = np.linspace(-env.max_speed, env.max_speed, 100)\n", - " \n", - " grid = np.dstack(np.meshgrid(xs, vs[::-1])).transpose(1, 0, 2)\n", - " grid_flat = grid.reshape(len(xs) * len(vs), 2)\n", - " probs = agent.predict_proba(grid_flat).reshape(len(xs), len(vs), 3).transpose(1, 0, 2)\n", - "\n", - " # # The above code is equivalent to the following:\n", - " # probs = np.empty((len(vs), len(xs), 3))\n", - " # for i, v in enumerate(vs[::-1]):\n", - " # for j, x in enumerate(xs):\n", - " # probs[i, j, :] = agent.predict_proba([[x, v]])[0]\n", - "\n", - " # Draw policy\n", - " f, ax = plt.subplots(figsize=(7, 7))\n", - " ax.imshow(probs, extent=(env.min_position, env.max_position, -env.max_speed, env.max_speed), aspect='auto')\n", - " ax.set_title('Learned policy: red=left, green=nothing, blue=right')\n", - " ax.set_xlabel('position (x)')\n", - " ax.set_ylabel('velocity (v)')\n", - " \n", - " # Sample a trajectory and draw it\n", - " states, actions, _ = generate_session(env, agent)\n", - " states = np.array(states)\n", - " ax.plot(states[:, 0], states[:, 1], color='white')\n", - " \n", - " # Draw every 3rd action from the trajectory\n", - " for (x, v), a in zip(states[::3], actions[::3]):\n", - " if a == 0:\n", - " plt.arrow(x, v, -0.1, 0, color='white', head_length=0.02)\n", - " elif a == 2:\n", - " plt.arrow(x, v, 0.1, 0, color='white', head_length=0.02)\n", - "\n", - "with gym.make('MountainCar-v0').env as env:\n", - " visualize_mountain_car(env, agent_mountain_car)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Bonus tasks\n", - "\n", - "* __2.3 bonus__ (2 pts) Try to find a network architecture and training params that solve __both__ environments above (_Points depend on implementation. If you attempted this task, please mention it in Anytask submission._)\n", - "\n", - "* __2.4 bonus__ (4 pts) Solve continuous action space task with `MLPRegressor` or similar.\n", - " * Since your agent only predicts the \"expected\" action, you will have to add noise to ensure exploration.\n", - " * Choose one of [MountainCarContinuous-v0](https://gym.openai.com/envs/MountainCarContinuous-v0) (90+ pts to solve), [LunarLanderContinuous-v2](https://gym.openai.com/envs/LunarLanderContinuous-v2) (200+ pts to solve) \n", - " * 4 points for solving. Slightly less for getting some results below solution threshold. Note that discrete and continuous environments may have slightly different rules, aside from action spaces." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, - "nbformat": 4, - "nbformat_minor": 1 + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/week01_intro/project_starter_evolution_strategies.ipynb b/week01_intro/project_starter_evolution_strategies.ipynb index 68f869cdf..d21475cbf 100644 --- a/week01_intro/project_starter_evolution_strategies.ipynb +++ b/week01_intro/project_starter_evolution_strategies.ipynb @@ -8,7 +8,7 @@ "\n", "![img](https://t4.ftcdn.net/jpg/00/17/46/81/240_F_17468143_wY3hsHyfNYoMdG9BlC56HI4JA7pNu63h.jpg)\n", "\n", - "Remember the idea behind Evolution Strategies? Here's a neat [blog post](https://blog.openai.com/evolution-strategies/) about 'em.\n", + "Remember the idea behind Evolution Strategies? Here's a neat [blog post](https://openai.com/research/evolution-strategies) about 'em.\n", "\n", "Can you reproduce their success? You will have to implement evolutionary strategies and see how they work.\n", "\n", diff --git a/week01_intro/seminar_gym_interface.ipynb b/week01_intro/seminar_gymnasium_interface.ipynb similarity index 69% rename from week01_intro/seminar_gym_interface.ipynb rename to week01_intro/seminar_gymnasium_interface.ipynb index 7865c8db8..f47c6fe07 100644 --- a/week01_intro/seminar_gym_interface.ipynb +++ b/week01_intro/seminar_gymnasium_interface.ipynb @@ -34,11 +34,19 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### OpenAI Gym\n", + "### OpenAI Gym --> Farama Gymnasium\n", "\n", "We're gonna spend several next weeks learning algorithms that solve decision processes. We are then in need of some interesting decision problems to test our algorithms.\n", "\n", - "That's where OpenAI Gym comes into play. It's a Python library that wraps many classical decision problems including robot control, videogames and board games.\n", + "That's where Gymnasium comes into play. It's a Python library that wraps many classical decision problems including robot control, videogames and board games.\n", + "\n", + "The library Gym by OpenAi has been replaced by Gymnsasium while saving all functionality comparable with the latest version of Gym.\n", + "\n", + "Announce: https://farama.org/Announcing-The-Farama-Foundation\n", + "\n", + "Github: https://github.com/Farama-Foundation/Gymnasium\n", + "\n", + "Documentation: https://gymnasium.farama.org/\n", "\n", "So here's how it works:" ] @@ -49,14 +57,14 @@ "metadata": {}, "outputs": [], "source": [ - "import gym\n", + "import gymnasium as gym\n", "\n", - "env = gym.make(\"MountainCar-v0\")\n", + "env = gym.make(\"MountainCar-v0\", render_mode=\"rgb_array\")\n", "env.reset()\n", "\n", - "plt.imshow(env.render('rgb_array'))\n", + "plt.imshow(env.render())\n", "print(\"Observation space:\", env.observation_space)\n", - "print(\"Action space:\", env.action_space)" + "print(\"Action space:\", env.action_space)\n" ] }, { @@ -70,16 +78,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Gym interface\n", + "### Gymnasium interface\n", "\n", "The three main methods of an environment are\n", - "* `reset()`: reset environment to the initial state, _return first observation_\n", + "* `reset()`: reset environment to the initial state, return first observation and dict with auxiliary info\n", "* `render()`: show current environment state (a more colorful version :) )\n", - "* `step(a)`: commit action `a` and return `(new_observation, reward, is_done, info)`\n", + "* `step(a)`: commit action `a` and return `(new_observation, reward, terminated, truncated, info)`\n", " * `new_observation`: an observation right after committing the action `a`\n", " * `reward`: a number representing your reward for committing action `a`\n", - " * `is_done`: True if the MDP has just finished, False if still in progress\n", - " * `info`: some auxiliary stuff about what just happened. For now, ignore it." + " * `terminated`: True if the MDP has just finished, False if still in progress\n", + " * `truncated`: True if the number of steps elapsed >= max episode steps\n", + " * `info`: some auxiliary stuff about what just happened. For now, ignore it.\n", + "\n", + "A detailed explanation of the difference between `terminated` and `truncated` and how it should be used:\n", + "1. https://gymnasium.farama.org/tutorials/gymnasium_basics/handling_time_limits/\n", + "2. https://gymnasium.farama.org/content/migration-guide/" ] }, { @@ -88,10 +101,14 @@ "metadata": {}, "outputs": [], "source": [ - "obs0 = env.reset()\n", + "# Set seed to reproduce initial state in stochastic environment\n", + "obs0, info = env.reset(seed=0)\n", "print(\"initial observation code:\", obs0)\n", "\n", - "# Note: in MountainCar, observation is just two numbers: car position and velocity" + "obs0, info = env.reset(seed=1)\n", + "print(\"initial observation code:\", obs0)\n", + "\n", + "# Note: in MountainCar, observation is just two numbers: car position and velocity\n" ] }, { @@ -101,13 +118,14 @@ "outputs": [], "source": [ "print(\"taking action 2 (right)\")\n", - "new_obs, reward, is_done, _ = env.step(2)\n", + "new_obs, reward, terminated, truncated, _ = env.step(2)\n", "\n", "print(\"new observation code:\", new_obs)\n", "print(\"reward:\", reward)\n", - "print(\"is game over?:\", is_done)\n", + "print(\"is game over?:\", terminated)\n", + "print(\"is game truncated due to time limit?:\", truncated)\n", "\n", - "# Note: as you can see, the car has moved to the right slightly (around 0.0005)" + "# Note: as you can see, the car has moved to the right slightly (around 0.0005)\n" ] }, { @@ -134,10 +152,10 @@ "# Create env manually to set time limit. Please don't change this.\n", "TIME_LIMIT = 250\n", "env = gym.wrappers.TimeLimit(\n", - " gym.envs.classic_control.MountainCarEnv(),\n", + " gym.make(\"MountainCar-v0\", render_mode=\"rgb_array\"),\n", " max_episode_steps=TIME_LIMIT + 1,\n", ")\n", - "actions = {'left': 0, 'stop': 1, 'right': 2}" + "actions = {\"left\": 0, \"stop\": 1, \"right\": 2}\n" ] }, { @@ -151,12 +169,12 @@ " # (a tuple of position and velocity), the current time step, or both,\n", " # if you want.\n", " position, velocity = obs\n", - " \n", + "\n", " # This is an example policy. You can try running it, but it will not work.\n", " # Your goal is to fix that. You don't need anything sophisticated here,\n", " # and you can hard-code any policy that seems to work.\n", " # Hint: think how you would make a swing go farther and faster.\n", - " return actions['right']" + " return actions[\"right\"]\n" ] }, { @@ -168,29 +186,31 @@ "plt.figure(figsize=(4, 3))\n", "display.clear_output(wait=True)\n", "\n", - "obs = env.reset()\n", + "obs, _ = env.reset()\n", "for t in range(TIME_LIMIT):\n", " plt.gca().clear()\n", - " \n", + "\n", " action = policy(obs, t) # Call your policy\n", - " obs, reward, done, _ = env.step(action) # Pass the action chosen by the policy to the environment\n", - " \n", + " obs, reward, terminated, truncated, _ = env.step(\n", + " action\n", + " ) # Pass the action chosen by the policy to the environment\n", + "\n", " # We don't do anything with reward here because MountainCar is a very simple environment,\n", " # and reward is a constant -1. Therefore, your goal is to end the episode as quickly as possible.\n", "\n", " # Draw game image on display.\n", - " plt.imshow(env.render('rgb_array'))\n", - " \n", + " plt.imshow(env.render())\n", + "\n", " display.display(plt.gcf())\n", " display.clear_output(wait=True)\n", "\n", - " if done:\n", + " if terminated or truncated:\n", " print(\"Well done!\")\n", " break\n", "else:\n", " print(\"Time limit exceeded. Try again.\")\n", "\n", - "display.clear_output(wait=True)" + "display.clear_output(wait=True)\n" ] }, { @@ -200,7 +220,7 @@ "outputs": [], "source": [ "assert obs[0] > 0.47\n", - "print(\"You solved it!\")" + "print(\"You solved it!\")\n" ] } ], diff --git a/week02_value_based/mdp.py b/week02_value_based/mdp.py index 371a66691..1d148f22f 100644 --- a/week02_value_based/mdp.py +++ b/week02_value_based/mdp.py @@ -249,7 +249,7 @@ def render(self): print('\n'.join(map(''.join, desc_copy)), end='\n\n') -def plot_graph(mdp, graph_size='10,10', s_node_size='1,5', +def plot_graph(mdp, s_node_size='1,5', a_node_size='0,5', rankdir='LR', ): """ Function for pretty drawing MDP graph with graphviz library. @@ -259,7 +259,6 @@ def plot_graph(mdp, graph_size='10,10', s_node_size='1,5', python library for graphviz for pip users: pip install graphviz :param mdp: - :param graph_size: size of graph plot :param s_node_size: size of state nodes :param a_node_size: size of action nodes :param rankdir: order for drawing @@ -292,7 +291,7 @@ def plot_graph(mdp, graph_size='10,10', s_node_size='1,5', 'fontsize': '16'} graph = Digraph(name='MDP') - graph.attr(rankdir=rankdir, size=graph_size) + graph.attr(rankdir=rankdir) for state_node in mdp._transition_probs: graph.node(state_node, **s_node_attrs) diff --git a/week03_model_free/README.md b/week03_model_free/README.md index fb6d865bf..b2271d15b 100644 --- a/week03_model_free/README.md +++ b/week03_model_free/README.md @@ -18,9 +18,4 @@ ### Assignments -Just as usual, start with -- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yandexdataschool/Practical_RL/blob/master/week03_model_free/seminar_qlearning.ipynb) -`seminar_qlearning.ipynb` _Implement q-learning agent and test it on Taxi and CartPole with binarizer_ - -and then proceed to -- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yandexdataschool/Practical_RL/blob/master/week03_model_free/homework.ipynb) `homework.ipynb` _Implement EV-SARSA agent, experience replay + bonus tasks_ +Just as usual, start with `homework.ipynb` [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yandexdataschool/Practical_RL/blob/master/week03_model_free/homework.ipynb) For seminar, implement q-learning agent and test it on Taxi and CartPole with binarizer. And then, implement EV-SARSA agent, experience replay + bonus tasks for homework. diff --git a/week03_model_free/homework.ipynb b/week03_model_free/homework.ipynb index 036728ef3..6eb18ea33 100644 --- a/week03_model_free/homework.ipynb +++ b/week03_model_free/homework.ipynb @@ -1,29 +1,5 @@ { "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "[Part I: On-policy learning and SARSA (3 points)](#Part-I:-On-policy-learning-and-SARSA-(3-points))\n", - "\n", - "[Part II: Experience replay (4 points)](#Part-II:-experience-replay-(4-points))\n", - "\n", - "[Bonus I: TD($ \\lambda $) (5+ points)](#Bonus-I:-TD($\\lambda$)-(5+-points))\n", - "\n", - "[Bonus II: More pacman (5+ points)](#Bonus-II:-More-pacman-(5+-points))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Part I: On-policy learning and SARSA (3 points)\n", - "\n", - "_This notebook builds upon `qlearning.ipynb`, or to be exact your implementation of QLearningAgent._\n", - "\n", - "The policy we're gonna use is epsilon-greedy policy, where agent takes optimal action with probability $(1-\\epsilon)$, otherwise samples action at random. Note that agent __can__ occasionally sample optimal action during random sampling by pure chance." - ] - }, { "cell_type": "code", "execution_count": null, @@ -33,6 +9,7 @@ "import sys, os\n", "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", + "\n", " !touch .setup_complete\n", "\n", "# This code creates a virtual display to draw game images on.\n", @@ -53,11 +30,24 @@ "%matplotlib inline" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip3 install -q gymnasium[classic-control]" + ] + }, { "cell_type": "markdown", "metadata": {}, "source": [ - "You can copy your `QLearningAgent` implementation from previous notebook." + "## Seminar: Q-learning (1.5 points)\n", + "\n", + "This notebook will guide you through implementation of vanilla Q-learning algorithm.\n", + "\n", + "You need to implement QLearningAgent (follow instructions for each method) and use it on a number of tests below." ] }, { @@ -89,7 +79,6 @@ " which returns Q(state,action)\n", " - self.set_qvalue(state,action,value)\n", " which sets Q(state,action) := value\n", - "\n", " !!!Important!!!\n", " Note: please avoid using self._qValues directly. \n", " There's a special self.get_qvalue/set_qvalue for that.\n", @@ -182,6 +171,335 @@ " return chosen_action" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Try it on taxi\n", + "\n", + "Here we use the Q-Learning agent on the Taxi-v3 environment from OpenAI gym.\n", + "You will need to complete a few of its functions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import gymnasium as gym\n", + "env = gym.make(\"Taxi-v3\", render_mode='rgb_array')\n", + "\n", + "n_actions = env.action_space.n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "s, _ = env.reset()\n", + "plt.imshow(env.render())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent = QLearningAgent(\n", + " alpha=0.5, epsilon=0.25, discount=0.99,\n", + " get_legal_actions=lambda s: range(n_actions))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def play_and_train(env, agent, t_max=10**4):\n", + " \"\"\"\n", + " This function should \n", + " - run a full game, actions given by agent's e-greedy policy\n", + " - train agent using agent.update(...) whenever it is possible\n", + " - return total reward\n", + " \"\"\"\n", + " total_reward = 0.0\n", + " s, _ = env.reset()\n", + "\n", + " for t in range(t_max):\n", + " # get agent to pick action given state s.\n", + " a = \n", + "\n", + " next_s, r, done, _, _ = env.step(a)\n", + "\n", + " # train (update) agent for state s\n", + " \n", + "\n", + " s = next_s\n", + " total_reward += r\n", + " if done:\n", + " break\n", + "\n", + " return total_reward" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "rewards = []\n", + "for i in range(1000):\n", + " rewards.append(play_and_train(env, agent))\n", + " agent.epsilon *= 0.99\n", + "\n", + " if i % 100 == 0:\n", + " clear_output(True)\n", + " plt.title('eps = {:e}, mean reward = {:.1f}'.format(agent.epsilon, np.mean(rewards[-10:])))\n", + " plt.plot(rewards)\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Seminar: Discretized state spaces (1.5 points)\n", + "\n", + "Use agent to train efficiently on `CartPole-v0`. This environment has a continuous set of possible states, so you will have to group them into bins somehow.\n", + "\n", + "The simplest way is to use `round(x, n_digits)` (or `np.round`) to round a real number to a given amount of digits. The tricky part is to get the `n_digits` right for each state to train effectively.\n", + "\n", + "Note that you don't need to convert state to integers, but to __tuples__ of any kind of values." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def make_env():\n", + " return gym.make('CartPole-v0', render_mode='rgb_array').env # .env unwraps the TimeLimit wrapper\n", + "\n", + "env = make_env()\n", + "n_actions = env.action_space.n\n", + "\n", + "print(\"first state: %s\" % (env.reset()[0]))\n", + "plt.imshow(env.render())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Play a few games\n", + "\n", + "We need to estimate observation distributions. To do so, we'll play a few games and record all states." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def visualize_cartpole_observation_distribution(seen_observations):\n", + " seen_observations = np.array(seen_observations)\n", + " \n", + " # The meaning of the observations is documented in\n", + " # https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py\n", + "\n", + " f, axarr = plt.subplots(2, 2, figsize=(16, 9), sharey=True)\n", + " for i, title in enumerate(['Cart Position', 'Cart Velocity', 'Pole Angle', 'Pole Velocity At Tip']):\n", + " ax = axarr[i // 2, i % 2]\n", + " ax.hist(seen_observations[:, i], bins=20)\n", + " ax.set_title(title)\n", + " xmin, xmax = ax.get_xlim()\n", + " ax.set_xlim(min(xmin, -xmax), max(-xmin, xmax))\n", + " ax.grid()\n", + " f.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "seen_observations = []\n", + "for _ in range(1000):\n", + " s, _ = env.reset()\n", + " seen_observations.append(s)\n", + " done = False\n", + " while not done:\n", + " s, r, done, _, _ = env.step(env.action_space.sample())\n", + " seen_observations.append(s)\n", + " \n", + "visualize_cartpole_observation_distribution(seen_observations)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Discretize environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from gymnasium.core import ObservationWrapper\n", + "\n", + "\n", + "class Discretizer(ObservationWrapper):\n", + " def observation(self, state):\n", + " # Hint: you can do that with round(x, n_digits).\n", + " # You may pick a different n_digits for each dimension.\n", + " state = \n", + "\n", + " return tuple(state)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env = Discretizer(make_env())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "seen_observations = []\n", + "for _ in range(1000):\n", + " s, _ = env.reset()\n", + " seen_observations.append(s)\n", + " done = False\n", + " while not done:\n", + " s, r, done, _, _ = env.step(env.action_space.sample())\n", + " seen_observations.append(s)\n", + " if done:\n", + " break\n", + " \n", + "visualize_cartpole_observation_distribution(seen_observations)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Learn discretized policy\n", + "\n", + "Now let's train a policy that uses discretized state space.\n", + "\n", + "__Tips:__\n", + "\n", + "* Note that increasing the number of digits for one dimension of the observations increases your state space by a factor of $10$.\n", + "* If your discretization is too fine-grained, your agent will take much longer than 10000 steps to converge. You can either increase the number of iterations and reduce epsilon decay or change discretization. In practice we found that this kind of mistake is rather frequent.\n", + "* If your discretization is too coarse, your agent may fail to find the optimal policy. In practice we found that on this particular environment this kind of mistake is rare.\n", + "* **Start with a coarse discretization** and make it more fine-grained if that seems necessary.\n", + "* Having $10^3$–$10^4$ distinct states is recommended (`len(agent._qvalues)`), but not required.\n", + "* If things don't work without annealing $\\varepsilon$, consider adding that, but make sure that it doesn't go to zero too quickly.\n", + "\n", + "A reasonable agent should attain an average reward of at least 50." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "def moving_average(x, span=100):\n", + " return pd.DataFrame({'x': np.asarray(x)}).x.ewm(span=span).mean().values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent = QLearningAgent(\n", + " alpha=0.5, epsilon=0.25, discount=0.99,\n", + " get_legal_actions=lambda s: range(n_actions))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rewards = []\n", + "epsilons = []\n", + "\n", + "for i in range(10000):\n", + " reward = play_and_train(env, agent)\n", + " rewards.append(reward)\n", + " epsilons.append(agent.epsilon)\n", + " \n", + " # OPTIONAL: \n", + "\n", + " if i % 100 == 0:\n", + " rewards_ewma = moving_average(rewards)\n", + " \n", + " clear_output(True)\n", + " plt.plot(rewards, label='rewards')\n", + " plt.plot(rewards_ewma, label='rewards ewma@100')\n", + " plt.legend()\n", + " plt.grid()\n", + " plt.title('eps = {:e}, rewards ewma@100 = {:.1f}'.format(agent.epsilon, rewards_ewma[-1]))\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print('Your agent has learned {} Q-values.'.format(len(agent._qvalues)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Homework Part I: On-policy learning and SARSA (3 points)\n", + "\n", + "The policy we're gonna use is epsilon-greedy policy, where agent takes the optimal action with probability $(1-\\epsilon)$, otherwise samples action at random. Note that agent __can__ occasionally sample optimal action during random sampling by pure chance." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -232,7 +550,7 @@ "Let's now see how our algorithm compares against q-learning in case where we force agent to explore all the time.\n", "\n", "\n", - "
image by cs188
" + "
Image from CS188
" ] }, { @@ -241,23 +559,10 @@ "metadata": {}, "outputs": [], "source": [ - "import gym\n", - "import gym.envs.toy_text\n", - "env = gym.envs.toy_text.CliffWalkingEnv()\n", - "n_actions = env.action_space.n\n", + "import gymnasium as gym\n", "\n", - "print(env.__doc__)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Our cliffworld has one difference from what's on the image: there is no wall.\n", - "# Agent can choose to go as close to the cliff as it wishes. x:start, T:exit, C:cliff, o: flat ground\n", - "env.render()" + "env = gym.make('CliffWalking-v0', render_mode='rgb_array')\n", + "n_actions = env.action_space.n" ] }, { @@ -266,26 +571,12 @@ "metadata": {}, "outputs": [], "source": [ - "def play_and_train(env, agent, t_max=10**4):\n", - " \"\"\"This function should \n", - " - run a full game, actions given by agent.get_action(s)\n", - " - train agent using agent.update(...) whenever possible\n", - " - return total reward\"\"\"\n", - " total_reward = 0.0\n", - " s = env.reset()\n", - "\n", - " for t in range(t_max):\n", - " a = agent.get_action(s)\n", + "# Our cliffworld has one difference from what's in the image: there is no wall.\n", + "# Agent can choose to go as close to the cliff as it wishes.\n", + "# x:start, T:exit, C:cliff, o: flat ground\n", "\n", - " next_s, r, done, _ = env.step(a)\n", - " agent.update(s, a, r, next_s)\n", - "\n", - " s = next_s\n", - " total_reward += r\n", - " if done:\n", - " break\n", - "\n", - " return total_reward" + "env.reset()\n", + "plt.imshow(env.render())" ] }, { @@ -346,17 +637,22 @@ "metadata": {}, "outputs": [], "source": [ - "def draw_policy(env, agent):\n", + "def draw_policy(agent):\n", " \"\"\" Prints CliffWalkingEnv policy with arrows. Hard-coded. \"\"\"\n", - " n_rows, n_cols = env._cliff.shape\n", + " \n", + " env = gym.make('CliffWalking-v0', render_mode='ansi')\n", + " env.reset()\n", + " grid = [x.split(' ') for x in env.render().split('\\n')[:4]]\n", "\n", + " n_rows, n_cols = 4, 12\n", + " start_state_index = 36\n", " actions = '^>v<'\n", "\n", " for yi in range(n_rows):\n", " for xi in range(n_cols):\n", - " if env._cliff[yi, xi]:\n", + " if grid[yi][xi] == 'C':\n", " print(\" C \", end='')\n", - " elif (yi * n_cols + xi) == env.start_state_index:\n", + " elif (yi * n_cols + xi) == start_state_index:\n", " print(\" X \", end='')\n", " elif (yi * n_cols + xi) == n_rows * n_cols - 1:\n", " print(\" T \", end='')\n", @@ -373,10 +669,69 @@ "outputs": [], "source": [ "print(\"Q-Learning\")\n", - "draw_policy(env, agent_ql)\n", + "draw_policy(agent_ql)\n", "\n", "print(\"SARSA\")\n", - "draw_policy(env, agent_sarsa)" + "draw_policy(agent_sarsa)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Expected Value SARSA for softmax policy (2 points)\n", + "\n", + "Implement an agent that would use a softmax policy for getting an action. Do not forget to also use softmax when calculating the expected value for value estimation. Draw the policy of the agent and see if the result is different compared to the previous approaches. Also, try using different temperatures ($\\tau$) and compare the results.\n", + "\n", + "$$ \\pi(a_i \\mid s) = \\operatorname{softmax} \\left( \\left\\{ {Q(s, a_j) \\over \\tau} \\right\\}_{j=1}^n \\right)_i = {\\operatorname{exp} \\left( Q(s,a_i) / \\tau \\right) \\over {\\sum_{j} \\operatorname{exp} \\left( Q(s,a_j) / \\tau \\right)}} $$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class SoftmaxEVSarsaAgent(EVSarsaAgent):\n", + " def __init__(self, alpha, tau, discount, get_legal_actions):\n", + " super().__init__(alpha, None, discount, get_legal_actions)\n", + " assert tau > 0\n", + " self.tau = tau\n", + " \n", + " def get_value(self, state):\n", + " \"\"\" \n", + " Returns V_{pi} for current state under softmax policy:\n", + " V_{pi}(s) = sum _{over a_i} {pi(a_i | s) * Q(s, a_i)}\n", + "\n", + " Hint: all other methods from QLearningAgent are still accessible.\n", + " \"\"\"\n", + " possible_actions = self.get_legal_actions(state)\n", + "\n", + " # If there are no legal actions, return 0.0\n", + " if len(possible_actions) == 0:\n", + " return 0.0\n", + "\n", + " \n", + "\n", + " return value\n", + " \n", + " def get_action(self, state):\n", + " \"\"\"\n", + " Compute the action to take in the current state, including exploration. \n", + " We should take a random action with probability equaled softmax of q values.\n", + " \"\"\"\n", + " # Pick Action\n", + " possible_actions = self.get_legal_actions(state)\n", + " action = None\n", + "\n", + " # If there are no legal actions, return None\n", + " if len(possible_actions) == 0:\n", + " return None\n", + "\n", + " \n", + " \n", + "\n", + " return chosen_action" ] }, { @@ -388,8 +743,6 @@ "Here are some of the things you can do if you feel like it:\n", "\n", "* Play with epsilon. See learned how policies change if you set epsilon to higher/lower values (e.g. 0.75).\n", - "* Expected Value SARSA for softmax policy __(2pts)__:\n", - "$$ \\pi(a_i \\mid s) = \\operatorname{softmax} \\left( \\left\\{ {Q(s, a_j) \\over \\tau} \\right\\}_{j=1}^n \\right)_i = {\\operatorname{exp} \\left( Q(s,a_i) / \\tau \\right) \\over {\\sum_{j} \\operatorname{exp} \\left( Q(s,a_j) / \\tau \\right)}} $$\n", "* Implement N-step algorithms and TD($\\lambda$): see [Sutton's book](http://incompleteideas.net/book/RLbook2020.pdf) chapter 7 and chapter 12.\n", "* Use those algorithms to train on CartPole in previous / next assignment for this week." ] @@ -398,7 +751,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Part II: experience replay (4 points)\n", + "## Part II: experience replay (2 points)\n", "\n", "There's a powerful technique that you can use to improve sample efficiency for off-policy algorithms: [spoiler] Experience replay :)\n", "\n", @@ -423,28 +776,6 @@ "metadata": {}, "outputs": [], "source": [ - "import sys, os\n", - "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", - " !touch .setup_complete\n", - "\n", - "# This code creates a virtual display to draw game images on.\n", - "# It will have no effect if your machine has a monitor.\n", - "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", - " !bash ../xvfb start\n", - " os.environ['DISPLAY'] = ':1'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "%matplotlib inline\n", - "\n", "from IPython.display import clear_output" ] }, @@ -574,7 +905,7 @@ "metadata": {}, "outputs": [], "source": [ - "import gym\n", + "import gymnasium as gym\n", "env = gym.make(\"Taxi-v3\")\n", "n_actions = env.action_space.n" ] @@ -596,13 +927,13 @@ " If None, do not use experience replay\n", " \"\"\"\n", " total_reward = 0.0\n", - " s = env.reset()\n", + " s, _ = env.reset()\n", "\n", " for t in range(t_max):\n", " # get agent to pick action given state s\n", " a = \n", "\n", - " next_s, r, done, _ = env.step(a)\n", + " next_s, r, done, _, _ = env.step(a)\n", "\n", " # update agent on current transition. Use agent.update\n", " \n", @@ -690,7 +1021,7 @@ "\n", "### Outro\n", "\n", - "We will use the code you just wrote extensively in the next week of our course. If you're feeling that you need more examples to understand how experience replay works, try using it for binarized state spaces (CartPole or other __[classic control envs](https://gym.openai.com/envs/#classic_control)__).\n", + "We will use the code you just wrote extensively in the next week of our course. If you're feeling that you need more examples to understand how experience replay works, try using it for discretized state spaces (CartPole or other __[classic control envs](https://gym.openai.com/envs/#classic_control)__).\n", "\n", "__Next week__ we're gonna explore how q-learning and similar algorithms can be applied for large state spaces, with deep learning models to approximate the Q function.\n", "\n", diff --git a/week03_model_free/seminar_qlearning.ipynb b/week03_model_free/seminar_qlearning.ipynb deleted file mode 100644 index 74f4e8ee4..000000000 --- a/week03_model_free/seminar_qlearning.ipynb +++ /dev/null @@ -1,495 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Q-learning (3 points)\n", - "\n", - "This notebook will guide you through implementation of vanilla Q-learning algorithm.\n", - "\n", - "You need to implement QLearningAgent (follow instructions for each method) and use it on a number of tests below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys, os\n", - "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", - "\n", - " !touch .setup_complete\n", - "\n", - "# This code creates a virtual display to draw game images on.\n", - "# It will have no effect if your machine has a monitor.\n", - "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", - " !bash ../xvfb start\n", - " os.environ['DISPLAY'] = ':1'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from collections import defaultdict\n", - "import random\n", - "import math\n", - "import numpy as np\n", - "\n", - "\n", - "class QLearningAgent:\n", - " def __init__(self, alpha, epsilon, discount, get_legal_actions):\n", - " \"\"\"\n", - " Q-Learning Agent\n", - " based on https://inst.eecs.berkeley.edu/~cs188/sp19/projects.html\n", - " Instance variables you have access to\n", - " - self.epsilon (exploration prob)\n", - " - self.alpha (learning rate)\n", - " - self.discount (discount rate aka gamma)\n", - "\n", - " Functions you should use\n", - " - self.get_legal_actions(state) {state, hashable -> list of actions, each is hashable}\n", - " which returns legal actions for a state\n", - " - self.get_qvalue(state,action)\n", - " which returns Q(state,action)\n", - " - self.set_qvalue(state,action,value)\n", - " which sets Q(state,action) := value\n", - " !!!Important!!!\n", - " Note: please avoid using self._qValues directly. \n", - " There's a special self.get_qvalue/set_qvalue for that.\n", - " \"\"\"\n", - "\n", - " self.get_legal_actions = get_legal_actions\n", - " self._qvalues = defaultdict(lambda: defaultdict(lambda: 0))\n", - " self.alpha = alpha\n", - " self.epsilon = epsilon\n", - " self.discount = discount\n", - "\n", - " def get_qvalue(self, state, action):\n", - " \"\"\" Returns Q(state,action) \"\"\"\n", - " return self._qvalues[state][action]\n", - "\n", - " def set_qvalue(self, state, action, value):\n", - " \"\"\" Sets the Qvalue for [state,action] to the given value \"\"\"\n", - " self._qvalues[state][action] = value\n", - "\n", - " #---------------------START OF YOUR CODE---------------------#\n", - "\n", - " def get_value(self, state):\n", - " \"\"\"\n", - " Compute your agent's estimate of V(s) using current q-values\n", - " V(s) = max_over_action Q(state,action) over possible actions.\n", - " Note: please take into account that q-values can be negative.\n", - " \"\"\"\n", - " possible_actions = self.get_legal_actions(state)\n", - "\n", - " # If there are no legal actions, return 0.0\n", - " if len(possible_actions) == 0:\n", - " return 0.0\n", - "\n", - " \n", - "\n", - " return value\n", - "\n", - " def update(self, state, action, reward, next_state):\n", - " \"\"\"\n", - " You should do your Q-Value update here:\n", - " Q(s,a) := (1 - alpha) * Q(s,a) + alpha * (r + gamma * V(s'))\n", - " \"\"\"\n", - "\n", - " # agent parameters\n", - " gamma = self.discount\n", - " learning_rate = self.alpha\n", - "\n", - " \n", - "\n", - " self.set_qvalue(state, action, )\n", - "\n", - " def get_best_action(self, state):\n", - " \"\"\"\n", - " Compute the best action to take in a state (using current q-values). \n", - " \"\"\"\n", - " possible_actions = self.get_legal_actions(state)\n", - "\n", - " # If there are no legal actions, return None\n", - " if len(possible_actions) == 0:\n", - " return None\n", - "\n", - " \n", - "\n", - " return best_action\n", - "\n", - " def get_action(self, state):\n", - " \"\"\"\n", - " Compute the action to take in the current state, including exploration. \n", - " With probability self.epsilon, we should take a random action.\n", - " otherwise - the best policy action (self.get_best_action).\n", - "\n", - " Note: To pick randomly from a list, use random.choice(list). \n", - " To pick True or False with a given probablity, generate uniform number in [0, 1]\n", - " and compare it with your probability\n", - " \"\"\"\n", - "\n", - " # Pick Action\n", - " possible_actions = self.get_legal_actions(state)\n", - " action = None\n", - "\n", - " # If there are no legal actions, return None\n", - " if len(possible_actions) == 0:\n", - " return None\n", - "\n", - " # agent parameters:\n", - " epsilon = self.epsilon\n", - "\n", - " \n", - "\n", - " return chosen_action" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Try it on taxi\n", - "\n", - "Here we use the qlearning agent on taxi env from openai gym.\n", - "You will need to insert a few agent functions here." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import gym\n", - "env = gym.make(\"Taxi-v3\")\n", - "\n", - "n_actions = env.action_space.n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "agent = QLearningAgent(\n", - " alpha=0.5, epsilon=0.25, discount=0.99,\n", - " get_legal_actions=lambda s: range(n_actions))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def play_and_train(env, agent, t_max=10**4):\n", - " \"\"\"\n", - " This function should \n", - " - run a full game, actions given by agent's e-greedy policy\n", - " - train agent using agent.update(...) whenever it is possible\n", - " - return total reward\n", - " \"\"\"\n", - " total_reward = 0.0\n", - " s = env.reset()\n", - "\n", - " for t in range(t_max):\n", - " # get agent to pick action given state s.\n", - " a = \n", - "\n", - " next_s, r, done, _ = env.step(a)\n", - "\n", - " # train (update) agent for state s\n", - " \n", - "\n", - " s = next_s\n", - " total_reward += r\n", - " if done:\n", - " break\n", - "\n", - " return total_reward" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from IPython.display import clear_output\n", - "\n", - "rewards = []\n", - "for i in range(1000):\n", - " rewards.append(play_and_train(env, agent))\n", - " agent.epsilon *= 0.99\n", - "\n", - " if i % 100 == 0:\n", - " clear_output(True)\n", - " plt.title('eps = {:e}, mean reward = {:.1f}'.format(agent.epsilon, np.mean(rewards[-10:])))\n", - " plt.plot(rewards)\n", - " plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Binarized state spaces\n", - "\n", - "Use agent to train efficiently on `CartPole-v0`. This environment has a continuous set of possible states, so you will have to group them into bins somehow.\n", - "\n", - "The simplest way is to use `round(x, n_digits)` (or `np.round`) to round a real number to a given amount of digits. The tricky part is to get the `n_digits` right for each state to train effectively.\n", - "\n", - "Note that you don't need to convert state to integers, but to __tuples__ of any kind of values." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def make_env():\n", - " return gym.make('CartPole-v0').env # .env unwraps the TimeLimit wrapper\n", - "\n", - "env = make_env()\n", - "n_actions = env.action_space.n\n", - "\n", - "print(\"first state: %s\" % (env.reset()))\n", - "plt.imshow(env.render('rgb_array'))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Play a few games\n", - "\n", - "We need to estimate observation distributions. To do so, we'll play a few games and record all states." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def visualize_cartpole_observation_distribution(seen_observations):\n", - " seen_observations = np.array(seen_observations)\n", - " \n", - " # The meaning of the observations is documented in\n", - " # https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py\n", - "\n", - " f, axarr = plt.subplots(2, 2, figsize=(16, 9), sharey=True)\n", - " for i, title in enumerate(['Cart Position', 'Cart Velocity', 'Pole Angle', 'Pole Velocity At Tip']):\n", - " ax = axarr[i // 2, i % 2]\n", - " ax.hist(seen_observations[:, i], bins=20)\n", - " ax.set_title(title)\n", - " xmin, xmax = ax.get_xlim()\n", - " ax.set_xlim(min(xmin, -xmax), max(-xmin, xmax))\n", - " ax.grid()\n", - " f.tight_layout()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "seen_observations = []\n", - "for _ in range(1000):\n", - " seen_observations.append(env.reset())\n", - " done = False\n", - " while not done:\n", - " s, r, done, _ = env.step(env.action_space.sample())\n", - " seen_observations.append(s)\n", - "\n", - "visualize_cartpole_observation_distribution(seen_observations)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Binarize environment" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from gym.core import ObservationWrapper\n", - "\n", - "\n", - "class Binarizer(ObservationWrapper):\n", - " def observation(self, state):\n", - " # Hint: you can do that with round(x, n_digits).\n", - " # You may pick a different n_digits for each dimension.\n", - " state = \n", - "\n", - " return tuple(state)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "env = Binarizer(make_env())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "seen_observations = []\n", - "for _ in range(1000):\n", - " seen_observations.append(env.reset())\n", - " done = False\n", - " while not done:\n", - " s, r, done, _ = env.step(env.action_space.sample())\n", - " seen_observations.append(s)\n", - " if done:\n", - " break\n", - "\n", - "visualize_cartpole_observation_distribution(seen_observations)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Learn binarized policy\n", - "\n", - "Now let's train a policy that uses binarized state space.\n", - "\n", - "__Tips:__\n", - "\n", - "* Note that increasing the number of digits for one dimension of the observations increases your state space by a factor of $10$.\n", - "* If your binarization is too fine-grained, your agent will take much longer than 10000 steps to converge. You can either increase the number of iterations and reduce epsilon decay or change binarization. In practice we found that this kind of mistake is rather frequent.\n", - "* If your binarization is too coarse, your agent may fail to find the optimal policy. In practice we found that on this particular environment this kind of mistake is rare.\n", - "* **Start with a coarse binarization** and make it more fine-grained if that seems necessary.\n", - "* Having $10^3$–$10^4$ distinct states is recommended (`len(agent._qvalues)`), but not required.\n", - "* If things don't work without annealing $\\varepsilon$, consider adding that, but make sure that it doesn't go to zero too quickly.\n", - "\n", - "A reasonable agent should attain an average reward of at least 50." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "\n", - "def moving_average(x, span=100):\n", - " return pd.DataFrame({'x': np.asarray(x)}).x.ewm(span=span).mean().values" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "agent = QLearningAgent(\n", - " alpha=0.5, epsilon=0.25, discount=0.99,\n", - " get_legal_actions=lambda s: range(n_actions))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "rewards = []\n", - "epsilons = []\n", - "\n", - "for i in range(10000):\n", - " reward = play_and_train(env, agent)\n", - " rewards.append(reward)\n", - " epsilons.append(agent.epsilon)\n", - " \n", - " # OPTIONAL: \n", - "\n", - " if i % 100 == 0:\n", - " rewards_ewma = moving_average(rewards)\n", - " \n", - " clear_output(True)\n", - " plt.plot(rewards, label='rewards')\n", - " plt.plot(rewards_ewma, label='rewards ewma@100')\n", - " plt.legend()\n", - " plt.grid()\n", - " plt.title('eps = {:e}, rewards ewma@100 = {:.1f}'.format(agent.epsilon, rewards_ewma[-1]))\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print('Your agent has learned {} Q-values.'.format(len(agent._qvalues)))" - ] - } - ], - "metadata": { - "language_info": { - "name": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 1 -} diff --git a/week04_approx_rl/README.md b/week04_approx_rl/README.md index 5d399508f..480b9af0f 100644 --- a/week04_approx_rl/README.md +++ b/week04_approx_rl/README.md @@ -46,4 +46,4 @@ From now on, we have two tracks, for pytorch and tensorflow. However, pytorch tr Begin with `seminar_.ipynb` and then proceed with `homework_.ipynb`. -__Note: you're not required to submit assignments in all three frameworks. Pick one and go with it. Maybe switch it occasionally if you want more challenge. __ +__Note: you're not required to submit assignments in all three frameworks. Pick one and go with it. Maybe switch it occasionally if you want more challenge.__ diff --git a/week04_approx_rl/atari_wrappers.py b/week04_approx_rl/atari_wrappers.py index ea19dd880..260e8e994 100644 --- a/week04_approx_rl/atari_wrappers.py +++ b/week04_approx_rl/atari_wrappers.py @@ -1,13 +1,14 @@ -# taken from OpenAI baselines. +# taken from stable_baselines3. import numpy as np -import gym +from gymnasium import Wrapper, RewardWrapper, ObservationWrapper +from gymnasium.spaces import Box -class MaxAndSkipEnv(gym.Wrapper): +class MaxAndSkipEnv(Wrapper): def __init__(self, env, skip=4): """Return only every `skip`-th frame""" - gym.Wrapper.__init__(self, env) + super().__init__(env) # most recent raw observations (for max pooling across time steps) self._obs_buffer = np.zeros( (2,) + env.observation_space.shape, dtype=np.uint8) @@ -16,68 +17,62 @@ def __init__(self, env, skip=4): def step(self, action): """Repeat action, sum reward, and max over last observations.""" total_reward = 0.0 - done = None + terminated = truncated = False for i in range(self._skip): - obs, reward, done, info = self.env.step(action) + obs, reward, terminated, truncated, info = self.env.step(action) if i == self._skip - 2: self._obs_buffer[0] = obs if i == self._skip - 1: self._obs_buffer[1] = obs total_reward += reward - if done: + if terminated or truncated: break - # Note that the observation on the done=True frame + # Note that the observation on the terminated=True frame # doesn't matter max_frame = self._obs_buffer.max(axis=0) - return max_frame, total_reward, done, info - - def reset(self, **kwargs): - return self.env.reset(**kwargs) + return max_frame, total_reward, terminated, truncated, info -class ClipRewardEnv(gym.RewardWrapper): +class ClipRewardEnv(RewardWrapper): def __init__(self, env): - gym.RewardWrapper.__init__(self, env) + super().__init__(env) def reward(self, reward): """Bin reward to {+1, 0, -1} by its sign.""" return np.sign(reward) -class FireResetEnv(gym.Wrapper): +class FireResetEnv(Wrapper): def __init__(self, env): """Take action on reset for environments that are fixed until firing.""" - gym.Wrapper.__init__(self, env) + super().__init__(env) assert env.unwrapped.get_action_meanings()[1] == 'FIRE' assert len(env.unwrapped.get_action_meanings()) >= 3 def reset(self, **kwargs): self.env.reset(**kwargs) - obs, _, done, _ = self.env.step(1) - if done: + obs, _, terminated, truncated, _ = self.env.step(1) + if terminated or truncated: self.env.reset(**kwargs) - obs, _, done, _ = self.env.step(2) - if done: + obs, _, terminated, truncated, _ = self.env.step(2) + if terminated or truncated: self.env.reset(**kwargs) - return obs - - def step(self, ac): - return self.env.step(ac) + return obs, {} -class EpisodicLifeEnv(gym.Wrapper): +class EpisodicLifeEnv(Wrapper): def __init__(self, env): """Make end-of-life == end-of-episode, but only reset on true game over. Done by DeepMind for the DQN and co. since it helps value estimation. """ - gym.Wrapper.__init__(self, env) + super().__init__(env) self.lives = 0 self.was_real_done = True def step(self, action): - obs, reward, done, info = self.env.step(action) - self.was_real_done = done + obs, reward, terminated, truncated, info = self.env.step(action) + self.was_real_done = terminated or truncated # check current lives, make loss of life terminal, # then update lives to handle bonus lives lives = self.env.unwrapped.ale.lives() @@ -85,9 +80,9 @@ def step(self, action): # for Qbert sometimes we stay in lives == 0 condition for a few frames # so it's important to keep lives > 0, so that we only reset once # the environment advertises done. - done = True + terminated = True self.lives = lives - return obs, reward, done, info + return obs, reward, terminated, truncated, info def reset(self, **kwargs): """Reset only when lives are exhausted. @@ -95,25 +90,31 @@ def reset(self, **kwargs): and the learner need not know about any of this behind-the-scenes. """ if self.was_real_done: - obs = self.env.reset(**kwargs) + obs, info = self.env.reset(**kwargs) else: # no-op step to advance from terminal/lost life state - obs, _, _, _ = self.env.step(0) + obs, _, terminated, truncated, info = self.env.step(0) + + # The no-op step can lead to a game over, so we need to check it again + # to see if we should reset the environment and avoid the + # monitor.py `RuntimeError: Tried to step environment that needs reset` + if terminated or truncated: + obs, info = self.env.reset(**kwargs) self.lives = self.env.unwrapped.ale.lives() - return obs + return obs, info # in torch imgs have shape [c, h, w] instead of common [h, w, c] -class AntiTorchWrapper(gym.ObservationWrapper): +class AntiTorchWrapper(ObservationWrapper): def __init__(self, env): - gym.ObservationWrapper.__init__(self, env) + super().__init__(env) self.img_size = [env.observation_space.shape[i] for i in [1, 2, 0] ] - self.observation_space = gym.spaces.Box(0.0, 1.0, self.img_size) + self.observation_space = Box(0.0, 1.0, self.img_size) def observation(self, img): """what happens to each observation""" img = img.transpose(1, 2, 0) - return img + return img \ No newline at end of file diff --git a/week04_approx_rl/framebuffer.py b/week04_approx_rl/framebuffer.py index fa8805d24..7e9b74313 100644 --- a/week04_approx_rl/framebuffer.py +++ b/week04_approx_rl/framebuffer.py @@ -1,12 +1,12 @@ import numpy as np -from gym.spaces.box import Box -from gym.core import Wrapper +from gymnasium.spaces import Box +from gymnasium import Wrapper class FrameBuffer(Wrapper): - def __init__(self, env, n_frames=4, dim_order='tensorflow'): - """A gym wrapper that reshapes, crops and scales image into the desired shapes""" - super(FrameBuffer, self).__init__(env) + def __init__(self, env, n_frames=4, dim_order='pytorch'): + """A gymnasium wrapper that reshapes, crops and scales image into the desired shapes""" + super().__init__(env) self.dim_order = dim_order if dim_order == 'tensorflow': height, width, n_channels = env.observation_space.shape @@ -20,17 +20,17 @@ def __init__(self, env, n_frames=4, dim_order='tensorflow'): self.observation_space = Box(0.0, 1.0, obs_shape) self.framebuffer = np.zeros(obs_shape, 'float32') - def reset(self): + def reset(self, **kwargs): """resets breakout, returns initial frames""" self.framebuffer = np.zeros_like(self.framebuffer) - self.update_buffer(self.env.reset()) - return self.framebuffer + self.update_buffer(self.env.reset(**kwargs)[0]) + return self.framebuffer, {} def step(self, action): """plays breakout for 1 step, returns frame buffer""" - new_img, reward, done, info = self.env.step(action) + new_img, reward, terminated, truncated, info = self.env.step(action) self.update_buffer(new_img) - return self.framebuffer, reward, done, info + return self.framebuffer, reward, terminated, truncated, info def update_buffer(self, img): if self.dim_order == 'tensorflow': @@ -42,4 +42,4 @@ def update_buffer(self, img): axis = 0 cropped_framebuffer = self.framebuffer[:-offset] self.framebuffer = np.concatenate( - [img, cropped_framebuffer], axis=axis) + [img, cropped_framebuffer], axis=axis) \ No newline at end of file diff --git a/week04_approx_rl/homework_pytorch_debug.ipynb b/week04_approx_rl/homework_pytorch_debug.ipynb index 2765a508f..fe3b77cda 100644 --- a/week04_approx_rl/homework_pytorch_debug.ipynb +++ b/week04_approx_rl/homework_pytorch_debug.ipynb @@ -1,818 +1,905 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Deep Q-Network implementation.\n", - "\n", - "This homework shamelessly demands you to implement DQN — an approximate Q-learning algorithm with experience replay and target networks — and see if it works any better this way.\n", - "\n", - "Original paper:\n", - "https://arxiv.org/pdf/1312.5602.pdf" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**This notebook is given for debug.** The main task is in the other notebook (**homework_pytorch_main**). The tasks are similar and share most of the code. The main difference is in environments. In main notebook it can take some 2 hours for the agent to start improving so it seems reasonable to launch the algorithm on a simpler env first. Here it is CartPole and it will train in several minutes.\n", - "\n", - "**We suggest the following pipeline:** First implement debug notebook then implement the main one.\n", - "\n", - "**About evaluation:** All points are given for the main notebook with one exception: if agent fails to beat the threshold in main notebook you can get 1 pt (instead of 3 pts) for beating the threshold in debug notebook." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys, os\n", - "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", - " \n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/atari_wrappers.py\n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/utils.py\n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/replay_buffer.py\n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/framebuffer.py\n", - "\n", - " !pip install gym[box2d]\n", - "\n", - " !touch .setup_complete\n", - "\n", - "# This code creates a virtual display to draw game images on.\n", - "# It will have no effect if your machine has a monitor.\n", - "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", - " !bash ../xvfb start\n", - " os.environ['DISPLAY'] = ':1'" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "__Frameworks__ - we'll accept this homework in any deep learning framework. This particular notebook was designed for PyTorch, but you find it easy to adapt it to almost any Python-based deep learning framework." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import random\n", - "import numpy as np\n", - "import torch\n", - "import utils" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import gym\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### CartPole again\n", - "\n", - "Another env can be used without any modification of the code. State space should be a single vector, actions should be discrete.\n", - "\n", - "CartPole is the simplest one. It should take several minutes to solve it.\n", - "\n", - "For LunarLander it can take 1-2 hours to get 200 points (a good score) on Colab and training progress does not look informative." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ENV_NAME = 'CartPole-v1'\n", - "\n", - "def make_env(seed=None):\n", - " # some envs are wrapped with a time limit wrapper by default\n", - " env = gym.make(ENV_NAME).unwrapped\n", - " if seed is not None:\n", - " env.seed(seed)\n", - " return env" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.6/dist-packages/gym/logger.py:30: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n", - " warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAARSUlEQVR4nO3df6zddX3H8edLQHRqBsi16fpjRe1i\ncJnF3SFG/0CMCsSsmjgDW6QxJJclmGBitoFLpiYj0WTKZuaINTDr4kTmj9AQNsVKYvxDsMVaWxC5\nagltKi0KqDFjK773x/0Uz+ot99wfh9vPPc9HcnK+3/f38z3n/YmHl99++j09qSokSf14znI3IEma\nH4NbkjpjcEtSZwxuSeqMwS1JnTG4JakzIwvuJBcneSDJdJJrR/U+kjRuMor7uJOcAvwAeBNwAPg2\ncHlV3bfkbyZJY2ZUV9znA9NV9aOq+h/gFmDziN5LksbKqSN63TXAwwP7B4DXnGjw2WefXRs2bBhR\nK5LUn/379/Poo49mtmOjCu45JZkCpgDWr1/Pzp07l6sVSTrpTE5OnvDYqJZKDgLrBvbXttrTqmpr\nVU1W1eTExMSI2pCklWdUwf1tYGOSc5I8F7gM2D6i95KksTKSpZKqOprkPcBXgFOAm6tq3yjeS5LG\nzcjWuKvqDuCOUb2+JI0rvzkpSZ0xuCWpMwa3JHXG4JakzhjcktQZg1uSOmNwS1JnDG5J6ozBLUmd\nMbglqTMGtyR1xuCWpM4Y3JLUGYNbkjpjcEtSZwxuSeqMwS1JnTG4Jakzi/rpsiT7gV8ATwFHq2oy\nyVnA54ENwH7gnVX12OLalCQdsxRX3G+oqk1VNdn2rwV2VNVGYEfblyQtkVEslWwGtrXtbcDbRvAe\nkjS2FhvcBXw1ya4kU622qqoOte2fAKsW+R6SpAGLWuMGXl9VB5O8BLgzyfcHD1ZVJanZTmxBPwWw\nfv36RbYhSeNjUVfcVXWwPR8GvgycDzySZDVAez58gnO3VtVkVU1OTEwspg1JGisLDu4kL0jyomPb\nwJuBvcB2YEsbtgW4bbFNSpJ+YzFLJauALyc59jr/XlX/leTbwK1JrgQeAt65+DYlSccsOLir6kfA\nq2ap/xR442KakiSdmN+clKTOGNyS1BmDW5I6Y3BLUmcMbknqjMEtSZ0xuCWpMwa3JHXG4Jakzhjc\nktQZg1uSOmNwS1JnDG5J6ozBLUmdMbglqTMGtyR1xuCWpM4Y3JLUGYNbkjozZ3AnuTnJ4SR7B2pn\nJbkzyYPt+cxWT5KPJ5lOsifJq0fZvCSNo2GuuD8NXHxc7VpgR1VtBHa0fYBLgI3tMQXcuDRtSpKO\nmTO4q+obwM+OK28GtrXtbcDbBuqfqRnfAs5IsnqpmpUkLXyNe1VVHWrbPwFWte01wMMD4w602m9J\nMpVkZ5KdR44cWWAbkjR+Fv2Xk1VVQC3gvK1VNVlVkxMTE4ttQ5LGxkKD+5FjSyDt+XCrHwTWDYxb\n22qSpCWy0ODeDmxp21uA2wbqV7S7Sy4AnhhYUpEkLYFT5xqQ5HPAhcDZSQ4AHwA+DNya5ErgIeCd\nbfgdwKXANPAr4N0j6FmSxtqcwV1Vl5/g0BtnGVvA1YttSpJ0Yn5zUpI6Y3BLUmcMbknqjMEtSZ0x\nuCWpMwa3JHXG4JakzhjcktQZg1uSOmNwS1JnDG5J6ozBLUmdMbglqTMGtyR1xuCWpM4Y3JLUGYNb\nkjpjcEtSZ+YM7iQ3JzmcZO9A7YNJDibZ3R6XDhy7Lsl0kgeSvGVUjUvSuBrmivvTwMWz1G+oqk3t\ncQdAknOBy4BXtnP+JckpS9WsJGmI4K6qbwA/G/L1NgO3VNWTVfVjZn7t/fxF9CdJOs5i1rjfk2RP\nW0o5s9XWAA8PjDnQar8lyVSSnUl2HjlyZBFtSNJ4WWhw3wi8DNgEHAI+Ot8XqKqtVTVZVZMTExML\nbEOSxs+CgruqHqmqp6rq18Cn+M1yyEFg3cDQta0mSVoiCwruJKsHdt8OHLvjZDtwWZLTk5wDbATu\nWVyLkqRBp841IMnngAuBs5McAD4AXJhkE1DAfuAqgKral+RW4D7gKHB1VT01mtYlaTzNGdxVdfks\n5ZueYfz1wPWLaUqSdGJ+c1KSOmNwS1JnDG5J6ozBLUmdMbglqTMGtyR1Zs7bAaWVbNfWq2at//HU\nJ5/lTqThecUtSZ0xuCWpMwa3JHXG4JakzhjcktQZg1uSOmNwS1JnDG5J6ozBLUmdMbglqTMGtyR1\nZs7gTrIuyV1J7kuyL8k1rX5WkjuTPNiez2z1JPl4kukke5K8etSTkKRxMswV91HgfVV1LnABcHWS\nc4FrgR1VtRHY0fYBLmHm1903AlPAjUvetSSNsTmDu6oOVdW9bfsXwP3AGmAzsK0N2wa8rW1vBj5T\nM74FnJFk9ZJ3Lkljal5r3Ek2AOcBdwOrqupQO/QTYFXbXgM8PHDagVY7/rWmkuxMsvPIkSPzbFuS\nxtfQwZ3khcAXgfdW1c8Hj1VVATWfN66qrVU1WVWTExMT8zlVksbaUMGd5DRmQvuzVfWlVn7k2BJI\nez7c6geBdQOnr201SdISGOaukgA3AfdX1ccGDm0HtrTtLcBtA/Ur2t0lFwBPDCypSJIWaZifLnsd\n8C7ge0l2t9r7gQ8Dtya5EngIeGc7dgdwKTAN/Ap495J2LEljbs7grqpvAjnB4TfOMr6AqxfZlyTp\nBPzmpHQcfyhYJzuDW5I6Y3BLUmcMbknqjMEtSZ0xuCWpMwa3JHXG4JakzhjcktQZg1uSOmNwS1Jn\nDG5J6ozBLUmdMbglqTMGtyR1xuCWpM4Y3JLUGYNbkjozzI8Fr0tyV5L7kuxLck2rfzDJwSS72+PS\ngXOuSzKd5IEkbxnlBCRp3AzzY8FHgfdV1b1JXgTsSnJnO3ZDVf3D4OAk5wKXAa8Efg/4WpI/qKqn\nlrJxSRpXc15xV9Whqrq3bf8CuB9Y8wynbAZuqaonq+rHzPza+/lL0awkaZ5r3Ek2AOcBd7fSe5Ls\nSXJzkjNbbQ3w8MBpB3jmoJeWxa6tV/1WzR8KVg+GDu4kLwS+CLy3qn4O3Ai8DNgEHAI+Op83TjKV\nZGeSnUeOHJnPqZI01oYK7iSnMRPan62qLwFU1SNV9VRV/Rr4FL9ZDjkIrBs4fW2r/T9VtbWqJqtq\ncmJiYjFzkKSxMsxdJQFuAu6vqo8N1FcPDHs7sLdtbwcuS3J6knOAjcA9S9eyJI23Ye4qeR3wLuB7\nSXa32vuBy5NsAgrYD1wFUFX7ktwK3MfMHSlXe0eJJC2dOYO7qr4JZJZDdzzDOdcD1y+iL0nSCfjN\nSUnqjMEtSZ0xuCWpMwa3JHXG4JakzhjcktQZg1uSOmNwS1JnDG5J6ozBLUmdMbglqTMGtyR1xuCW\npM4Y3FpRkgz9GMX50rPB4JakzgzzQwrSinX7oamnt9+6eusydiINzytuja3B0J5tXzpZGdyS1Jlh\nfiz4eUnuSfLdJPuSfKjVz0lyd5LpJJ9P8txWP73tT7fjG0Y7BUkaL8NccT8JXFRVrwI2ARcnuQD4\nCHBDVb0ceAy4so2/Enis1W9o46STzvFr2q5xqxfD/FhwAb9su6e1RwEXAX/e6tuADwI3ApvbNsAX\ngH9OkvY60klj8qqtwG/C+oPL1ok0P0PdVZLkFGAX8HLgE8APgcer6mgbcgBY07bXAA8DVNXRJE8A\nLwYePdHr79q1y/ti1R0/s1ouQwV3VT0FbEpyBvBl4BWLfeMkU8AUwPr163nooYcW+5LSsxqm/iFS\nozQ5OXnCY/O6q6SqHgfuAl4LnJHkWPCvBQ627YPAOoB2/HeBn87yWlurarKqJicmJubThiSNtWHu\nKploV9okeT7wJuB+ZgL8HW3YFuC2tr297dOOf931bUlaOsMslawGtrV17ucAt1bV7UnuA25J8vfA\nd4Cb2vibgH9LMg38DLhsBH1L0tga5q6SPcB5s9R/BJw/S/2/gT9bku4kSb/Fb05KUmcMbknqjMEt\nSZ3xn3XViuINTBoHXnFLUmcMbknqjMEtSZ0xuCWpMwa3JHXG4JakzhjcktQZg1uSOmNwS1JnDG5J\n6ozBLUmdMbglqTMGtyR1xuCWpM4M82PBz0tyT5LvJtmX5EOt/ukkP06yuz02tXqSfDzJdJI9SV49\n6klI0jgZ5t/jfhK4qKp+meQ04JtJ/rMd+6uq+sJx4y8BNrbHa4Ab27MkaQnMecVdM37Zdk9rj2f6\n1+o3A59p530LOCPJ6sW3KkmCIde4k5ySZDdwGLizqu5uh65vyyE3JDm91dYADw+cfqDVJElLYKjg\nrqqnqmoTsBY4P8kfAtcBrwD+BDgL+Jv5vHGSqSQ7k+w8cuTIPNuWpPE1r7tKqupx4C7g4qo61JZD\nngT+FTi/DTsIrBs4bW2rHf9aW6tqsqomJyYmFta9JI2hYe4qmUhyRtt+PvAm4PvH1q2TBHgbsLed\nsh24ot1dcgHwRFUdGkn3kjSGhrmrZDWwLckpzAT9rVV1e5KvJ5kAAuwG/rKNvwO4FJgGfgW8e+nb\nlqTxNWdwV9Ue4LxZ6hedYHwBVy++NUnSbPzmpCR1xuCWpM4Y3JLUGYNbkjpjcEtSZwxuSeqMwS1J\nnTG4JakzBrckdcbglqTOGNyS1BmDW5I6Y3BLUmcMbknqjMEtSZ0xuCWpMwa3JHXG4JakzhjcktQZ\ng1uSOmNwS1JnDG5J6kyqarl7IMkvgAeWu48RORt4dLmbGIGVOi9YuXNzXn35/aqamO3Aqc92Jyfw\nQFVNLncTo5Bk50qc20qdF6zcuTmvlcOlEknqjMEtSZ05WYJ763I3MEIrdW4rdV6wcufmvFaIk+Iv\nJyVJwztZrrglSUNa9uBOcnGSB5JMJ7l2ufuZryQ3JzmcZO9A7awkdyZ5sD2f2epJ8vE21z1JXr18\nnT+zJOuS3JXkviT7klzT6l3PLcnzktyT5LttXh9q9XOS3N36/3yS57b66W1/uh3fsJz9zyXJKUm+\nk+T2tr9S5rU/yfeS7E6ys9W6/iwuxrIGd5JTgE8AlwDnApcnOXc5e1qATwMXH1e7FthRVRuBHW0f\nZua5sT2mgBufpR4X4ijwvqo6F7gAuLr9b9P73J4ELqqqVwGbgIuTXAB8BLihql4OPAZc2cZfCTzW\n6je0cSeza4D7B/ZXyrwA3lBVmwZu/ev9s7hwVbVsD+C1wFcG9q8DrlvOnhY4jw3A3oH9B4DVbXs1\nM/epA3wSuHy2cSf7A7gNeNNKmhvwO8C9wGuY+QLHqa3+9OcS+Arw2rZ9ahuX5e79BPNZy0yAXQTc\nDmQlzKv1uB84+7jaivkszvex3Esla4CHB/YPtFrvVlXVobb9E2BV2+5yvu2P0ecBd7MC5taWE3YD\nh4E7gR8Cj1fV0TZksPen59WOPwG8+NnteGj/CPw18Ou2/2JWxrwACvhqkl1Jplqt+8/iQp0s35xc\nsaqqknR7606SFwJfBN5bVT9P8vSxXudWVU8Bm5KcAXwZeMUyt7RoSd4KHK6qXUkuXO5+RuD1VXUw\nyUuAO5N8f/Bgr5/FhVruK+6DwLqB/bWt1rtHkqwGaM+HW72r+SY5jZnQ/mxVfamVV8TcAKrqceAu\nZpYQzkhy7EJmsPen59WO/y7w02e51WG8DvjTJPuBW5hZLvkn+p8XAFV1sD0fZub/bM9nBX0W52u5\ng/vbwMb2N9/PBS4Dti9zT0thO7ClbW9hZn34WP2K9rfeFwBPDPxR76SSmUvrm4D7q+pjA4e6nluS\niXalTZLnM7Nufz8zAf6ONuz4eR2b7zuAr1dbOD2ZVNV1VbW2qjYw89/R16vqL+h8XgBJXpDkRce2\ngTcDe+n8s7goy73IDlwK/ICZdca/Xe5+FtD/54BDwP8ys5Z2JTNrhTuAB4GvAWe1sWHmLpofAt8D\nJpe7/2eY1+uZWVfcA+xuj0t7nxvwR8B32rz2An/X6i8F7gGmgf8ATm/157X96Xb8pcs9hyHmeCFw\n+0qZV5vDd9tj37Gc6P2zuJiH35yUpM4s91KJJGmeDG5J6ozBLUmdMbglqTMGtyR1xuCWpM4Y3JLU\nGYNbkjrzf0Ew7Is+EjUWAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "SqZ2EwnTZdC8" + }, + "source": [ + "# Deep Q-Network implementation.\n", + "\n", + "This homework shamelessly demands you to implement DQN — an approximate Q-learning algorithm with experience replay and target networks — and see if it works any better this way.\n", + "\n", + "Original paper:\n", + "https://arxiv.org/pdf/1312.5602.pdf" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Zv7XJfXaZdC9" + }, + "source": [ + "**This notebook is given for debug.** The main task is in the other notebook (**homework_pytorch_main**). The tasks are similar and share most of the code. The main difference is in environments. In main notebook it can take some 2 hours for the agent to start improving so it seems reasonable to launch the algorithm on a simpler env first. Here it is CartPole and it will train in several minutes.\n", + "\n", + "**We suggest the following pipeline:** First implement debug notebook then implement the main one.\n", + "\n", + "**About evaluation:** All points are given for the main notebook with one exception: if agent fails to beat the threshold in main notebook you can get 1 pt (instead of 3 pts) for beating the threshold in debug notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ioIEVODJZdC9" + }, + "outputs": [], + "source": [ + "import sys, os\n", + "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", + "\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/atari_wrappers.py\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/utils.py\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/replay_buffer.py\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/framebuffer.py\n", + "\n", + " !touch .setup_complete\n", + "\n", + "# This code creates a virtual display to draw game images on.\n", + "# It will have no effect if your machine has a monitor.\n", + "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", + " !bash ../xvfb start\n", + " os.environ['DISPLAY'] = ':1'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "u8OFQOtGojc8" + }, + "outputs": [], + "source": [ + "!pip install gymnasium" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FDZqlI3kZdC9" + }, + "source": [ + "__Frameworks__ - we'll accept this homework in any deep learning framework. This particular notebook was designed for PyTorch, but you find it easy to adapt it to almost any Python-based deep learning framework." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dsYq558wZdC-" + }, + "outputs": [], + "source": [ + "import random\n", + "import numpy as np\n", + "import torch\n", + "import utils\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6ypPZ8e6ZdC-" + }, + "outputs": [], + "source": [ + "import gymnasium as gym\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9j8EGNlSZdC-" + }, + "source": [ + "### CartPole again\n", + "\n", + "Another env can be used without any modification of the code. State space should be a single vector, actions should be discrete.\n", + "\n", + "CartPole is the simplest one. It should take several minutes to solve it.\n", + "\n", + "For LunarLander it can take 1-2 hours to get 200 points (a good score) on Colab and training progress does not look informative." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "v-5u-CcQZdC-" + }, + "outputs": [], + "source": [ + "ENV_NAME = \"CartPole-v1\"\n", + "\n", + "\n", + "def make_env():\n", + " # some envs are wrapped with a time limit wrapper by default\n", + " env = gym.make(ENV_NAME, render_mode=\"rgb_array\").unwrapped\n", + " return env\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "AmFXRrkqZdC-" + }, + "outputs": [], + "source": [ + "env = make_env()\n", + "env.reset()\n", + "plt.imshow(env.render())\n", + "state_shape, n_actions = env.observation_space.shape, env.action_space.n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qOyWgOmvZdC-" + }, + "source": [ + "### Building a network" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XqpThLZXZdC-" + }, + "source": [ + "We now need to build a neural network that can map observations to state q-values.\n", + "The model does not have to be huge yet. 1-2 hidden layers with < 200 neurons and ReLU activation will probably be enough. Batch normalization and dropout can spoil everything here." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UVlpkvZOZdC-" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "# those who have a GPU but feel unfair to use it can uncomment:\n", + "# device = torch.device('cpu')\n", + "device" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RFva1cpyZdC-" + }, + "outputs": [], + "source": [ + "class DQNAgent(nn.Module):\n", + " def __init__(self, state_shape, n_actions, epsilon=0):\n", + "\n", + " super().__init__()\n", + " self.epsilon = epsilon\n", + " self.n_actions = n_actions\n", + " self.state_shape = state_shape\n", + " # Define your network body here. Please make sure agent is fully contained here\n", + " assert len(state_shape) == 1\n", + " state_dim = state_shape[0]\n", + " \n", + "\n", + "\n", + " def forward(self, state_t):\n", + " \"\"\"\n", + " takes agent's observation (tensor), returns qvalues (tensor)\n", + " :param state_t: a batch states, shape = [batch_size, *state_dim=4]\n", + " \"\"\"\n", + " # Use your network to compute qvalues for given state\n", + " qvalues = \n", + "\n", + " assert qvalues.requires_grad, \"qvalues must be a torch tensor with grad\"\n", + " assert (\n", + " len(qvalues.shape) == 2 and\n", + " qvalues.shape[0] == state_t.shape[0] and\n", + " qvalues.shape[1] == n_actions\n", + " )\n", + "\n", + " return qvalues\n", + "\n", + " def get_qvalues(self, states):\n", + " \"\"\"\n", + " like forward, but works on numpy arrays, not tensors\n", + " \"\"\"\n", + " model_device = next(self.parameters()).device\n", + " states = torch.tensor(states, device=model_device, dtype=torch.float32)\n", + " qvalues = self.forward(states)\n", + " return qvalues.data.cpu().numpy()\n", + "\n", + " def sample_actions(self, qvalues):\n", + " \"\"\"pick actions given qvalues. Uses epsilon-greedy exploration strategy. \"\"\"\n", + " epsilon = self.epsilon\n", + " batch_size, n_actions = qvalues.shape\n", + "\n", + " random_actions = np.random.choice(n_actions, size=batch_size)\n", + " best_actions = qvalues.argmax(axis=-1)\n", + "\n", + " should_explore = np.random.choice(\n", + " [0, 1], batch_size, p=[1-epsilon, epsilon])\n", + " return np.where(should_explore, random_actions, best_actions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Bv1s5JKzZdC-" + }, + "outputs": [], + "source": [ + "agent = DQNAgent(state_shape, n_actions, epsilon=0.5).to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vazC0DPQZdC_" + }, + "source": [ + "Now let's try out our agent to see if it raises any errors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e-Sg1cqPZdC_" + }, + "outputs": [], + "source": [ + "def evaluate(env, agent, n_games=1, greedy=False, t_max=10000, seed=None):\n", + " \"\"\" Plays n_games full games. If greedy, picks actions as argmax(qvalues). Returns mean reward. \"\"\"\n", + " rewards = []\n", + " for _ in range(n_games):\n", + " s, _ = env.reset(seed=seed)\n", + " reward = 0\n", + " for _ in range(t_max):\n", + " qvalues = agent.get_qvalues([s])\n", + " action = qvalues.argmax(axis=-1)[0] if greedy else agent.sample_actions(qvalues)[0]\n", + " s, r, terminated, truncated, _ = env.step(action)\n", + " reward += r\n", + " if terminated or truncated:\n", + " break\n", + "\n", + " rewards.append(reward)\n", + " return np.mean(rewards)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y_0NzjUEZdC_" + }, + "source": [ + "### Experience replay\n", + "For this assignment, we provide you with experience replay buffer. If you implemented experience replay buffer in last week's assignment, you can copy-paste it here in main notebook **to get 2 bonus points**.\n", + "\n", + "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/exp_replay.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jHyCO4TuZdC_" + }, + "source": [ + "#### The interface is fairly simple:\n", + "* `exp_replay.add(obs, act, rw, next_obs, done)` - saves (s,a,r,s',done) tuple into the buffer\n", + "* `exp_replay.sample(batch_size)` - returns observations, actions, rewards, next_observations and is_done for `batch_size` random samples.\n", + "* `len(exp_replay)` - returns number of elements stored in replay buffer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wQEHwR1AZdC_" + }, + "outputs": [], + "source": [ + "from replay_buffer import ReplayBuffer\n", + "exp_replay = ReplayBuffer(10)\n", + "\n", + "for _ in range(30):\n", + " exp_replay.add(env.reset()[0], env.action_space.sample(), 1.0, env.reset()[0], done=False)\n", + "\n", + "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(5)\n", + "\n", + "assert len(exp_replay) == 10, \"experience replay size should be 10 because that's what maximum capacity is\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0RnFX5sfZdC_" + }, + "outputs": [], + "source": [ + "def play_and_record(initial_state, agent, env, exp_replay, n_steps=1):\n", + " \"\"\"\n", + " Play the game for exactly n_steps, record every (s,a,r,s', done) to replay buffer.\n", + " Whenever game ends due to termination or truncation, add record with done=terminated and reset the game.\n", + " It is guaranteed that env has terminated=False when passed to this function.\n", + "\n", + " PLEASE DO NOT RESET ENV UNLESS IT IS \"DONE\"\n", + "\n", + " :returns: return sum of rewards over time and the state in which the env stays\n", + " \"\"\"\n", + " s = initial_state\n", + " sum_rewards = 0\n", + "\n", + " # Play the game for n_steps as per instructions above\n", + " \n", + "\n", + " return sum_rewards, s" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZXXmFEKGZdC_", + "outputId": "d1b66847-a141-4406-9697-7ebd194fdb6a" + }, + "outputs": [], + "source": [ + "# testing your code.\n", + "exp_replay = ReplayBuffer(2000)\n", + "\n", + "state, _ = env.reset()\n", + "play_and_record(state, agent, env, exp_replay, n_steps=1000)\n", + "\n", + "# if you're using your own experience replay buffer, some of those tests may need correction.\n", + "# just make sure you know what your code does\n", + "assert len(exp_replay) == 1000, \\\n", + " \"play_and_record should have added exactly 1000 steps, \" \\\n", + " \"but instead added %i\" % len(exp_replay)\n", + "is_dones = list(zip(*exp_replay._storage))[-1]\n", + "\n", + "assert 0 < np.mean(is_dones) < 0.1, \\\n", + " \"Please make sure you restart the game whenever it is 'done' and \" \\\n", + " \"record the is_done correctly into the buffer. Got %f is_done rate over \" \\\n", + " \"%i steps. [If you think it's your tough luck, just re-run the test]\" % (\n", + " np.mean(is_dones), len(exp_replay))\n", + "\n", + "for _ in range(100):\n", + " obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(10)\n", + " assert obs_batch.shape == next_obs_batch.shape == (10,) + state_shape\n", + " assert act_batch.shape == (10,), \\\n", + " \"actions batch should have shape (10,) but is instead %s\" % str(act_batch.shape)\n", + " assert reward_batch.shape == (10,), \\\n", + " \"rewards batch should have shape (10,) but is instead %s\" % str(reward_batch.shape)\n", + " assert is_done_batch.shape == (10,), \\\n", + " \"is_done batch should have shape (10,) but is instead %s\" % str(is_done_batch.shape)\n", + " assert [int(i) in (0, 1) for i in is_dones], \\\n", + " \"is_done should be strictly True or False\"\n", + " assert [0 <= a < n_actions for a in act_batch], \"actions should be within [0, n_actions)\"\n", + "\n", + "print(\"Well done!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uoVGsnHRZdC_" + }, + "source": [ + "### Target networks\n", + "\n", + "We also employ the so called \"target network\" - a copy of neural network weights to be used for reference Q-values:\n", + "\n", + "The network itself is an exact copy of agent network, but it's parameters are not trained. Instead, they are moved here from agent's actual network every so often.\n", + "\n", + "$$ Q_{reference}(s,a) = r + \\gamma \\cdot \\max _{a'} Q_{target}(s',a') $$\n", + "\n", + "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/target_net.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8BLJCNiuZdC_", + "outputId": "6181261a-60cf-4626-fbe6-930a6ccd9896" + }, + "outputs": [], + "source": [ + "target_network = DQNAgent(agent.state_shape, agent.n_actions, epsilon=0.5).to(device)\n", + "# This is how you can load weights from agent into target network\n", + "target_network.load_state_dict(agent.state_dict())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I_GGShX3ZdC_" + }, + "source": [ + "### Learning with... Q-learning\n", + "Here we write a function similar to `agent.update` from tabular q-learning." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4hbg-xANZdC_" + }, + "source": [ + "Compute Q-learning TD error:\n", + "\n", + "$$ L = { 1 \\over N} \\sum_i [ Q_{\\theta}(s,a) - Q_{reference}(s,a) ] ^2 $$\n", + "\n", + "With Q-reference defined as\n", + "\n", + "$$ Q_{reference}(s,a) = r(s,a) + \\gamma \\cdot max_{a'} Q_{target}(s', a') $$\n", + "\n", + "Where\n", + "* $Q_{target}(s',a')$ denotes Q-value of next state and next action predicted by __target_network__\n", + "* $s, a, r, s'$ are current state, action, reward and next state respectively\n", + "* $\\gamma$ is a discount factor defined two cells above.\n", + "\n", + "\n", + "__Note 1:__ there's an example input below. Feel free to experiment with it before you write the function.\n", + "\n", + "__Note 2:__ compute_td_loss is a source of 99% of bugs in this homework. If reward doesn't improve, it often helps to go through it line by line [with a rubber duck](https://rubberduckdebugging.com/)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VxrEOC7mZdC_" + }, + "outputs": [], + "source": [ + "def compute_td_loss(states, actions, rewards, next_states, is_done,\n", + " agent, target_network,\n", + " gamma=0.99,\n", + " check_shapes=False,\n", + " device=device):\n", + " \"\"\" Compute td loss using torch operations only. Use the formulae above. \"\"\"\n", + " states = torch.tensor(states, device=device, dtype=torch.float32) # shape: [batch_size, *state_shape]\n", + " actions = torch.tensor(actions, device=device, dtype=torch.int64) # shape: [batch_size]\n", + " rewards = torch.tensor(rewards, device=device, dtype=torch.float32) # shape: [batch_size]\n", + " # shape: [batch_size, *state_shape]\n", + " next_states = torch.tensor(next_states, device=device, dtype=torch.float)\n", + " is_done = torch.tensor(\n", + " is_done.astype('float32'),\n", + " device=device,\n", + " dtype=torch.float32,\n", + " ) # shape: [batch_size]\n", + " is_not_done = 1 - is_done\n", + "\n", + " # get q-values for all actions in current states\n", + " predicted_qvalues = agent(states) # shape: [batch_size, n_actions]\n", + "\n", + " # compute q-values for all actions in next states\n", + " predicted_next_qvalues = target_network(next_states) # shape: [batch_size, n_actions]\n", + "\n", + " # select q-values for chosen actions\n", + " predicted_qvalues_for_actions = predicted_qvalues[range(len(actions)), actions] # shape: [batch_size]\n", + "\n", + " # compute V*(next_states) using predicted next q-values\n", + " next_state_values = \n", + "\n", + " assert next_state_values.dim() == 1 and next_state_values.shape[0] == states.shape[0], \\\n", + " \"must predict one value per state\"\n", + "\n", + " # compute \"target q-values\" for loss - it's what's inside square parentheses in the above formula.\n", + " # at the last state use the simplified formula: Q(s,a) = r(s,a) since s' doesn't exist\n", + " # you can multiply next state values by is_not_done to achieve this.\n", + " target_qvalues_for_actions = \n", + "\n", + " # mean squared error loss to minimize\n", + " loss = torch.mean((predicted_qvalues_for_actions - target_qvalues_for_actions.detach()) ** 2)\n", + "\n", + " if check_shapes:\n", + " assert predicted_next_qvalues.data.dim() == 2, \\\n", + " \"make sure you predicted q-values for all actions in next state\"\n", + " assert next_state_values.data.dim() == 1, \\\n", + " \"make sure you computed V(s') as maximum over just the actions axis and not all axes\"\n", + " assert target_qvalues_for_actions.data.dim() == 1, \\\n", + " \"there's something wrong with target q-values, they must be a vector\"\n", + "\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pgZKcPPnZdC_" + }, + "source": [ + "Sanity checks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Yp8eREoDZdC_" + }, + "outputs": [], + "source": [ + "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(10)\n", + "\n", + "loss = compute_td_loss(obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch,\n", + " agent, target_network,\n", + " gamma=0.99, check_shapes=True)\n", + "loss.backward()\n", + "\n", + "assert loss.requires_grad and tuple(loss.data.size()) == (), \\\n", + " \"you must return scalar loss - mean over batch\"\n", + "assert np.any(next(agent.parameters()).grad.data.cpu().numpy() != 0), \\\n", + " \"loss must be differentiable w.r.t. network weights\"\n", + "assert np.all(next(target_network.parameters()).grad is None), \\\n", + " \"target network should not have grads\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8A1QtGVqZdC_" + }, + "source": [ + "### Main loop\n", + "\n", + "It's time to put everything together and see if it learns anything." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8lAUT94JZdC_" + }, + "outputs": [], + "source": [ + "from tqdm import trange\n", + "from IPython.display import clear_output\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "YOk81bdZZdC_", + "outputId": "2fd2404e-19e5-4ebe-b0e1-c6f7593db790" + }, + "outputs": [], + "source": [ + "seed = \n", + "random.seed(seed)\n", + "np.random.seed(seed)\n", + "torch.manual_seed(seed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "13K5t2CTZdDA", + "outputId": "031d4a0a-99c3-4cc3-f7c3-77e4d8a5a331" + }, + "outputs": [], + "source": [ + "state_dim = env.observation_space.shape\n", + "n_actions = env.action_space.n\n", + "state, _ = env.reset(seed=seed)\n", + "\n", + "agent = DQNAgent(state_dim, n_actions, epsilon=1).to(device)\n", + "target_network = DQNAgent(state_dim, n_actions, epsilon=1).to(device)\n", + "target_network.load_state_dict(agent.state_dict())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "iD7PAlwQZdDA", + "outputId": "aeb4bb67-4776-4b02-e558-d9d1a3306d47" + }, + "outputs": [], + "source": [ + "REPLAY_BUFFER_SIZE = 10**4\n", + "\n", + "exp_replay = ReplayBuffer(REPLAY_BUFFER_SIZE)\n", + "for i in range(100):\n", + " if not utils.is_enough_ram(min_available_gb=0.1):\n", + " print(\"\"\"\n", + " Less than 100 Mb RAM available.\n", + " Make sure the buffer size in not too huge.\n", + " Also check, maybe other processes consume RAM heavily.\n", + " \"\"\"\n", + " )\n", + " break\n", + " play_and_record(state, agent, env, exp_replay, n_steps=10**2)\n", + " if len(exp_replay) == REPLAY_BUFFER_SIZE:\n", + " break\n", + "print(len(exp_replay))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Zl2VCEYQZdDA" + }, + "outputs": [], + "source": [ + "# # for something more complicated than CartPole\n", + "\n", + "# timesteps_per_epoch = 1\n", + "# batch_size = 32\n", + "# total_steps = 3 * 10**6\n", + "# decay_steps = 1 * 10**6\n", + "\n", + "# opt = torch.optim.Adam(agent.parameters(), lr=1e-4)\n", + "\n", + "# init_epsilon = 1\n", + "# final_epsilon = 0.1\n", + "\n", + "# loss_freq = 20\n", + "# refresh_target_network_freq = 1000\n", + "# eval_freq = 5000\n", + "\n", + "# max_grad_norm = 5000" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x-sD-QyUZdDA" + }, + "outputs": [], + "source": [ + "timesteps_per_epoch = 1\n", + "batch_size = 32\n", + "total_steps = 4 * 10**4\n", + "decay_steps = 1 * 10**4\n", + "\n", + "opt = torch.optim.Adam(agent.parameters(), lr=1e-4)\n", + "\n", + "init_epsilon = 1\n", + "final_epsilon = 0.1\n", + "\n", + "loss_freq = 20\n", + "refresh_target_network_freq = 100\n", + "eval_freq = 1000\n", + "\n", + "max_grad_norm = 5000" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "piqDfKQAZdDA" + }, + "outputs": [], + "source": [ + "mean_rw_history = []\n", + "td_loss_history = []\n", + "grad_norm_history = []\n", + "initial_state_v_history = []\n", + "step = 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ks8NAV8AZdDA" + }, + "outputs": [], + "source": [ + "import time\n", + "\n", + "def wait_for_keyboard_interrupt():\n", + " try:\n", + " while True:\n", + " time.sleep(1)\n", + " except KeyboardInterrupt:\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sU3GSGZqZdDA" + }, + "outputs": [], + "source": [ + "state, _ = env.reset()\n", + "with trange(step, total_steps + 1) as progress_bar:\n", + " for step in progress_bar:\n", + " if not utils.is_enough_ram():\n", + " print('less that 100 Mb RAM available, freezing')\n", + " print('make sure everything is ok and use KeyboardInterrupt to continue')\n", + " wait_for_keyboard_interrupt()\n", + "\n", + " agent.epsilon = utils.linear_decay(init_epsilon, final_epsilon, step, decay_steps)\n", + "\n", + " # play\n", + " _, state = play_and_record(state, agent, env, exp_replay, timesteps_per_epoch)\n", + "\n", + " # train\n", + " \n", + "\n", + " loss = \n", + "\n", + " loss.backward()\n", + " grad_norm = nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)\n", + " opt.step()\n", + " opt.zero_grad()\n", + "\n", + " if step % loss_freq == 0:\n", + " td_loss_history.append(loss.data.cpu().item())\n", + " grad_norm_history.append(grad_norm)\n", + "\n", + " if step % refresh_target_network_freq == 0:\n", + " # Load agent weights into target_network\n", + " \n", + "\n", + " if step % eval_freq == 0:\n", + " mean_rw_history.append(evaluate(\n", + " make_env(), agent, n_games=3, greedy=True, t_max=1000, seed=step)\n", + " )\n", + " initial_state_q_values = agent.get_qvalues(\n", + " [make_env().reset(seed=step)[0]]\n", + " )\n", + " initial_state_v_history.append(np.max(initial_state_q_values))\n", + "\n", + " clear_output(True)\n", + " print(\"buffer size = %i, epsilon = %.5f\" %\n", + " (len(exp_replay), agent.epsilon))\n", + "\n", + " plt.figure(figsize=[16, 9])\n", + "\n", + " plt.subplot(2, 2, 1)\n", + " plt.title(\"Mean reward per episode\")\n", + " plt.plot(mean_rw_history)\n", + " plt.grid()\n", + "\n", + " assert not np.isnan(td_loss_history[-1])\n", + " plt.subplot(2, 2, 2)\n", + " plt.title(\"TD loss history (smoothened)\")\n", + " plt.plot(utils.smoothen(td_loss_history))\n", + " plt.grid()\n", + "\n", + " plt.subplot(2, 2, 3)\n", + " plt.title(\"Initial state V\")\n", + " plt.plot(initial_state_v_history)\n", + " plt.grid()\n", + "\n", + " plt.subplot(2, 2, 4)\n", + " plt.title(\"Grad norm history (smoothened)\")\n", + " plt.plot(utils.smoothen(grad_norm_history))\n", + " plt.grid()\n", + "\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qwWFT2SBZdDA" + }, + "outputs": [], + "source": [ + "final_score = evaluate(\n", + " make_env(),\n", + " agent, n_games=30, greedy=True, t_max=1000\n", + ")\n", + "print('final score:', final_score)\n", + "assert final_score > 300, 'not good enough for DQN'\n", + "print('Well done')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "G-feeX9YZdDA" + }, + "source": [ + "**Agent's predicted V-values vs their Monte-Carlo estimates**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rjVuSIrPZdDA" + }, + "outputs": [], + "source": [ + "eval_env = make_env()\n", + "record = utils.play_and_log_episode(eval_env, agent)\n", + "print('total reward for life:', np.sum(record['rewards']))\n", + "for key in record:\n", + " print(key)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FCacwLw6ZdDA" + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(5, 5))\n", + "ax = fig.add_subplot(1, 1, 1)\n", + "\n", + "ax.scatter(record['v_mc'], record['v_agent'])\n", + "ax.plot(sorted(record['v_mc']), sorted(record['v_mc']),\n", + " 'black', linestyle='--', label='x=y')\n", + "\n", + "ax.grid()\n", + "ax.legend()\n", + "ax.set_title('State Value Estimates')\n", + "ax.set_xlabel('Monte-Carlo')\n", + "ax.set_ylabel('Agent')\n", + "\n", + "plt.show()" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" } - ], - "source": [ - "env = make_env()\n", - "env.reset()\n", - "plt.imshow(env.render(\"rgb_array\"))\n", - "state_shape, n_actions = env.observation_space.shape, env.action_space.n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Building a network" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We now need to build a neural network that can map observations to state q-values.\n", - "The model does not have to be huge yet. 1-2 hidden layers with < 200 neurons and ReLU activation will probably be enough. Batch normalization and dropout can spoil everything here." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", - "# those who have a GPU but feel unfair to use it can uncomment:\n", - "# device = torch.device('cpu')\n", - "device" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class DQNAgent(nn.Module):\n", - " def __init__(self, state_shape, n_actions, epsilon=0):\n", - "\n", - " super().__init__()\n", - " self.epsilon = epsilon\n", - " self.n_actions = n_actions\n", - " self.state_shape = state_shape\n", - " # Define your network body here. Please make sure agent is fully contained here\n", - " assert len(state_shape) == 1\n", - " state_dim = state_shape[0]\n", - " \n", - "\n", - " \n", - " def forward(self, state_t):\n", - " \"\"\"\n", - " takes agent's observation (tensor), returns qvalues (tensor)\n", - " :param state_t: a batch states, shape = [batch_size, *state_dim=4]\n", - " \"\"\"\n", - " # Use your network to compute qvalues for given state\n", - " qvalues = \n", - "\n", - " assert qvalues.requires_grad, \"qvalues must be a torch tensor with grad\"\n", - " assert (\n", - " len(qvalues.shape) == 2 and \n", - " qvalues.shape[0] == state_t.shape[0] and \n", - " qvalues.shape[1] == n_actions\n", - " )\n", - "\n", - " return qvalues\n", - "\n", - " def get_qvalues(self, states):\n", - " \"\"\"\n", - " like forward, but works on numpy arrays, not tensors\n", - " \"\"\"\n", - " model_device = next(self.parameters()).device\n", - " states = torch.tensor(states, device=model_device, dtype=torch.float32)\n", - " qvalues = self.forward(states)\n", - " return qvalues.data.cpu().numpy()\n", - "\n", - " def sample_actions(self, qvalues):\n", - " \"\"\"pick actions given qvalues. Uses epsilon-greedy exploration strategy. \"\"\"\n", - " epsilon = self.epsilon\n", - " batch_size, n_actions = qvalues.shape\n", - "\n", - " random_actions = np.random.choice(n_actions, size=batch_size)\n", - " best_actions = qvalues.argmax(axis=-1)\n", - "\n", - " should_explore = np.random.choice(\n", - " [0, 1], batch_size, p=[1-epsilon, epsilon])\n", - " return np.where(should_explore, random_actions, best_actions)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "agent = DQNAgent(state_shape, n_actions, epsilon=0.5).to(device)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now let's try out our agent to see if it raises any errors." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def evaluate(env, agent, n_games=1, greedy=False, t_max=10000):\n", - " \"\"\" Plays n_games full games. If greedy, picks actions as argmax(qvalues). Returns mean reward. \"\"\"\n", - " rewards = []\n", - " for _ in range(n_games):\n", - " s = env.reset()\n", - " reward = 0\n", - " for _ in range(t_max):\n", - " qvalues = agent.get_qvalues([s])\n", - " action = qvalues.argmax(axis=-1)[0] if greedy else agent.sample_actions(qvalues)[0]\n", - " s, r, done, _ = env.step(action)\n", - " reward += r\n", - " if done:\n", - " break\n", - "\n", - " rewards.append(reward)\n", - " return np.mean(rewards)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "evaluate(env, agent, n_games=1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Experience replay\n", - "For this assignment, we provide you with experience replay buffer. If you implemented experience replay buffer in last week's assignment, you can copy-paste it here in main notebook **to get 2 bonus points**.\n", - "\n", - "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/exp_replay.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### The interface is fairly simple:\n", - "* `exp_replay.add(obs, act, rw, next_obs, done)` - saves (s,a,r,s',done) tuple into the buffer\n", - "* `exp_replay.sample(batch_size)` - returns observations, actions, rewards, next_observations and is_done for `batch_size` random samples.\n", - "* `len(exp_replay)` - returns number of elements stored in replay buffer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from replay_buffer import ReplayBuffer\n", - "exp_replay = ReplayBuffer(10)\n", - "\n", - "for _ in range(30):\n", - " exp_replay.add(env.reset(), env.action_space.sample(), 1.0, env.reset(), done=False)\n", - "\n", - "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(5)\n", - "\n", - "assert len(exp_replay) == 10, \"experience replay size should be 10 because that's what maximum capacity is\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def play_and_record(initial_state, agent, env, exp_replay, n_steps=1):\n", - " \"\"\"\n", - " Play the game for exactly n_steps, record every (s,a,r,s', done) to replay buffer. \n", - " Whenever game ends, add record with done=True and reset the game.\n", - " It is guaranteed that env has done=False when passed to this function.\n", - "\n", - " PLEASE DO NOT RESET ENV UNLESS IT IS \"DONE\"\n", - "\n", - " :returns: return sum of rewards over time and the state in which the env stays\n", - " \"\"\"\n", - " s = initial_state\n", - " sum_rewards = 0\n", - "\n", - " # Play the game for n_steps as per instructions above\n", - " \n", - "\n", - " return sum_rewards, s" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# testing your code.\n", - "exp_replay = ReplayBuffer(2000)\n", - "\n", - "state = env.reset()\n", - "play_and_record(state, agent, env, exp_replay, n_steps=1000)\n", - "\n", - "# if you're using your own experience replay buffer, some of those tests may need correction.\n", - "# just make sure you know what your code does\n", - "assert len(exp_replay) == 1000, \\\n", - " \"play_and_record should have added exactly 1000 steps, \" \\\n", - " \"but instead added %i\" % len(exp_replay)\n", - "is_dones = list(zip(*exp_replay._storage))[-1]\n", - "\n", - "assert 0 < np.mean(is_dones) < 0.1, \\\n", - " \"Please make sure you restart the game whenever it is 'done' and \" \\\n", - " \"record the is_done correctly into the buffer. Got %f is_done rate over \" \\\n", - " \"%i steps. [If you think it's your tough luck, just re-run the test]\" % (\n", - " np.mean(is_dones), len(exp_replay))\n", - "\n", - "for _ in range(100):\n", - " obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(10)\n", - " assert obs_batch.shape == next_obs_batch.shape == (10,) + state_shape\n", - " assert act_batch.shape == (10,), \\\n", - " \"actions batch should have shape (10,) but is instead %s\" % str(act_batch.shape)\n", - " assert reward_batch.shape == (10,), \\\n", - " \"rewards batch should have shape (10,) but is instead %s\" % str(reward_batch.shape)\n", - " assert is_done_batch.shape == (10,), \\\n", - " \"is_done batch should have shape (10,) but is instead %s\" % str(is_done_batch.shape)\n", - " assert [int(i) in (0, 1) for i in is_dones], \\\n", - " \"is_done should be strictly True or False\"\n", - " assert [0 <= a < n_actions for a in act_batch], \"actions should be within [0, n_actions)\"\n", - "\n", - "print(\"Well done!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Target networks\n", - "\n", - "We also employ the so called \"target network\" - a copy of neural network weights to be used for reference Q-values:\n", - "\n", - "The network itself is an exact copy of agent network, but it's parameters are not trained. Instead, they are moved here from agent's actual network every so often.\n", - "\n", - "$$ Q_{reference}(s,a) = r + \\gamma \\cdot \\max _{a'} Q_{target}(s',a') $$\n", - "\n", - "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/target_net.png)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "target_network = DQNAgent(agent.state_shape, agent.n_actions, epsilon=0.5).to(device)\n", - "# This is how you can load weights from agent into target network\n", - "target_network.load_state_dict(agent.state_dict())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Learning with... Q-learning\n", - "Here we write a function similar to `agent.update` from tabular q-learning." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Compute Q-learning TD error:\n", - "\n", - "$$ L = { 1 \\over N} \\sum_i [ Q_{\\theta}(s,a) - Q_{reference}(s,a) ] ^2 $$\n", - "\n", - "With Q-reference defined as\n", - "\n", - "$$ Q_{reference}(s,a) = r(s,a) + \\gamma \\cdot max_{a'} Q_{target}(s', a') $$\n", - "\n", - "Where\n", - "* $Q_{target}(s',a')$ denotes Q-value of next state and next action predicted by __target_network__\n", - "* $s, a, r, s'$ are current state, action, reward and next state respectively\n", - "* $\\gamma$ is a discount factor defined two cells above.\n", - "\n", - "\n", - "__Note 1:__ there's an example input below. Feel free to experiment with it before you write the function.\n", - "\n", - "__Note 2:__ compute_td_loss is a source of 99% of bugs in this homework. If reward doesn't improve, it often helps to go through it line by line [with a rubber duck](https://rubberduckdebugging.com/)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def compute_td_loss(states, actions, rewards, next_states, is_done,\n", - " agent, target_network,\n", - " gamma=0.99,\n", - " check_shapes=False,\n", - " device=device):\n", - " \"\"\" Compute td loss using torch operations only. Use the formulae above. \"\"\"\n", - " states = torch.tensor(states, device=device, dtype=torch.float32) # shape: [batch_size, *state_shape]\n", - " actions = torch.tensor(actions, device=device, dtype=torch.int64) # shape: [batch_size]\n", - " rewards = torch.tensor(rewards, device=device, dtype=torch.float32) # shape: [batch_size]\n", - " # shape: [batch_size, *state_shape]\n", - " next_states = torch.tensor(next_states, device=device, dtype=torch.float)\n", - " is_done = torch.tensor(\n", - " is_done.astype('float32'),\n", - " device=device,\n", - " dtype=torch.float32,\n", - " ) # shape: [batch_size]\n", - " is_not_done = 1 - is_done\n", - "\n", - " # get q-values for all actions in current states\n", - " predicted_qvalues = agent(states) # shape: [batch_size, n_actions]\n", - "\n", - " # compute q-values for all actions in next states\n", - " predicted_next_qvalues = target_network(next_states) # shape: [batch_size, n_actions]\n", - " \n", - " # select q-values for chosen actions\n", - " predicted_qvalues_for_actions = predicted_qvalues[range(len(actions)), actions] # shape: [batch_size]\n", - "\n", - " # compute V*(next_states) using predicted next q-values\n", - " next_state_values = \n", - "\n", - " assert next_state_values.dim() == 1 and next_state_values.shape[0] == states.shape[0], \\\n", - " \"must predict one value per state\"\n", - "\n", - " # compute \"target q-values\" for loss - it's what's inside square parentheses in the above formula.\n", - " # at the last state use the simplified formula: Q(s,a) = r(s,a) since s' doesn't exist\n", - " # you can multiply next state values by is_not_done to achieve this.\n", - " target_qvalues_for_actions = \n", - "\n", - " # mean squared error loss to minimize\n", - " loss = torch.mean((predicted_qvalues_for_actions - target_qvalues_for_actions.detach()) ** 2)\n", - "\n", - " if check_shapes:\n", - " assert predicted_next_qvalues.data.dim() == 2, \\\n", - " \"make sure you predicted q-values for all actions in next state\"\n", - " assert next_state_values.data.dim() == 1, \\\n", - " \"make sure you computed V(s') as maximum over just the actions axis and not all axes\"\n", - " assert target_qvalues_for_actions.data.dim() == 1, \\\n", - " \"there's something wrong with target q-values, they must be a vector\"\n", - "\n", - " return loss" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Sanity checks" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(10)\n", - "\n", - "loss = compute_td_loss(obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch,\n", - " agent, target_network,\n", - " gamma=0.99, check_shapes=True)\n", - "loss.backward()\n", - "\n", - "assert loss.requires_grad and tuple(loss.data.size()) == (), \\\n", - " \"you must return scalar loss - mean over batch\"\n", - "assert np.any(next(agent.parameters()).grad.data.cpu().numpy() != 0), \\\n", - " \"loss must be differentiable w.r.t. network weights\"\n", - "assert np.all(next(target_network.parameters()).grad is None), \\\n", - " \"target network should not have grads\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Main loop\n", - "\n", - "It's time to put everything together and see if it learns anything." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from tqdm import trange\n", - "from IPython.display import clear_output\n", - "import matplotlib.pyplot as plt" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "seed = \n", - "random.seed(seed)\n", - "np.random.seed(seed)\n", - "torch.manual_seed(seed)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "env = make_env(seed)\n", - "state_dim = env.observation_space.shape\n", - "n_actions = env.action_space.n\n", - "state = env.reset()\n", - "\n", - "agent = DQNAgent(state_dim, n_actions, epsilon=1).to(device)\n", - "target_network = DQNAgent(state_dim, n_actions, epsilon=1).to(device)\n", - "target_network.load_state_dict(agent.state_dict())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "REPLAY_BUFFER_SIZE = 10**4\n", - "\n", - "exp_replay = ReplayBuffer(REPLAY_BUFFER_SIZE)\n", - "for i in range(100):\n", - " if not utils.is_enough_ram(min_available_gb=0.1):\n", - " print(\"\"\"\n", - " Less than 100 Mb RAM available. \n", - " Make sure the buffer size in not too huge.\n", - " Also check, maybe other processes consume RAM heavily.\n", - " \"\"\"\n", - " )\n", - " break\n", - " play_and_record(state, agent, env, exp_replay, n_steps=10**2)\n", - " if len(exp_replay) == REPLAY_BUFFER_SIZE:\n", - " break\n", - "print(len(exp_replay))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# # for something more complicated than CartPole\n", - "\n", - "# timesteps_per_epoch = 1\n", - "# batch_size = 32\n", - "# total_steps = 3 * 10**6\n", - "# decay_steps = 1 * 10**6\n", - "\n", - "# opt = torch.optim.Adam(agent.parameters(), lr=1e-4)\n", - "\n", - "# init_epsilon = 1\n", - "# final_epsilon = 0.1\n", - "\n", - "# loss_freq = 20\n", - "# refresh_target_network_freq = 1000\n", - "# eval_freq = 5000\n", - "\n", - "# max_grad_norm = 5000" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "timesteps_per_epoch = 1\n", - "batch_size = 32\n", - "total_steps = 4 * 10**4\n", - "decay_steps = 1 * 10**4\n", - "\n", - "opt = torch.optim.Adam(agent.parameters(), lr=1e-4)\n", - "\n", - "init_epsilon = 1\n", - "final_epsilon = 0.1\n", - "\n", - "loss_freq = 20\n", - "refresh_target_network_freq = 100\n", - "eval_freq = 1000\n", - "\n", - "max_grad_norm = 5000" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mean_rw_history = []\n", - "td_loss_history = []\n", - "grad_norm_history = []\n", - "initial_state_v_history = []\n", - "step = 0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import time\n", - "\n", - "def wait_for_keyboard_interrupt():\n", - " try:\n", - " while True:\n", - " time.sleep(1)\n", - " except KeyboardInterrupt:\n", - " pass" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "state = env.reset()\n", - "with trange(step, total_steps + 1) as progress_bar:\n", - " for step in progress_bar:\n", - " if not utils.is_enough_ram():\n", - " print('less that 100 Mb RAM available, freezing')\n", - " print('make sure everything is ok and use KeyboardInterrupt to continue')\n", - " wait_for_keyboard_interrupt()\n", - "\n", - " agent.epsilon = utils.linear_decay(init_epsilon, final_epsilon, step, decay_steps)\n", - "\n", - " # play\n", - " _, state = play_and_record(state, agent, env, exp_replay, timesteps_per_epoch)\n", - "\n", - " # train\n", - " \n", - "\n", - " loss = \n", - "\n", - " loss.backward()\n", - " grad_norm = nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)\n", - " opt.step()\n", - " opt.zero_grad()\n", - "\n", - " if step % loss_freq == 0:\n", - " td_loss_history.append(loss.data.cpu().item())\n", - " grad_norm_history.append(grad_norm)\n", - "\n", - " if step % refresh_target_network_freq == 0:\n", - " # Load agent weights into target_network\n", - " \n", - "\n", - " if step % eval_freq == 0:\n", - " mean_rw_history.append(evaluate(\n", - " make_env(seed=step), agent, n_games=3, greedy=True, t_max=1000)\n", - " )\n", - " initial_state_q_values = agent.get_qvalues(\n", - " [make_env(seed=step).reset()]\n", - " )\n", - " initial_state_v_history.append(np.max(initial_state_q_values))\n", - "\n", - " clear_output(True)\n", - " print(\"buffer size = %i, epsilon = %.5f\" %\n", - " (len(exp_replay), agent.epsilon))\n", - "\n", - " plt.figure(figsize=[16, 9])\n", - "\n", - " plt.subplot(2, 2, 1)\n", - " plt.title(\"Mean reward per episode\")\n", - " plt.plot(mean_rw_history)\n", - " plt.grid()\n", - "\n", - " assert not np.isnan(td_loss_history[-1])\n", - " plt.subplot(2, 2, 2)\n", - " plt.title(\"TD loss history (smoothened)\")\n", - " plt.plot(utils.smoothen(td_loss_history))\n", - " plt.grid()\n", - "\n", - " plt.subplot(2, 2, 3)\n", - " plt.title(\"Initial state V\")\n", - " plt.plot(initial_state_v_history)\n", - " plt.grid()\n", - "\n", - " plt.subplot(2, 2, 4)\n", - " plt.title(\"Grad norm history (smoothened)\")\n", - " plt.plot(utils.smoothen(grad_norm_history))\n", - " plt.grid()\n", - "\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "final_score = evaluate(\n", - " make_env(),\n", - " agent, n_games=30, greedy=True, t_max=1000\n", - ")\n", - "print('final score:', final_score)\n", - "assert final_score > 300, 'not good enough for DQN'\n", - "print('Well done')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Agent's predicted V-values vs their Monte-Carlo estimates**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "eval_env = make_env()\n", - "record = utils.play_and_log_episode(eval_env, agent)\n", - "print('total reward for life:', np.sum(record['rewards']))\n", - "for key in record:\n", - " print(key)" - ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig = plt.figure(figsize=(5, 5))\n", - "ax = fig.add_subplot(1, 1, 1)\n", - "\n", - "ax.scatter(record['v_mc'], record['v_agent'])\n", - "ax.plot(sorted(record['v_mc']), sorted(record['v_mc']),\n", - " 'black', linestyle='--', label='x=y')\n", - "\n", - "ax.grid()\n", - "ax.legend()\n", - "ax.set_title('State Value Estimates')\n", - "ax.set_xlabel('Monte-Carlo')\n", - "ax.set_ylabel('Agent')\n", - "\n", - "plt.show()" - ] - } - ], - "metadata": { - "language_info": { - "name": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 1 + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/week04_approx_rl/homework_pytorch_main.ipynb b/week04_approx_rl/homework_pytorch_main.ipynb index 05ec667a0..3dc30fbcf 100644 --- a/week04_approx_rl/homework_pytorch_main.ipynb +++ b/week04_approx_rl/homework_pytorch_main.ipynb @@ -1,1329 +1,1531 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Deep Q-Network implementation.\n", - "\n", - "This homework shamelessly demands you to implement DQN — an approximate Q-learning algorithm with experience replay and target networks — and see if it works any better this way.\n", - "\n", - "Original paper:\n", - "https://arxiv.org/pdf/1312.5602.pdf" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**This notebook is the main notebook.** Another notebook is given for debug. (**homework_pytorch_main**). The tasks are similar and share most of the code. The main difference is in environments. In main notebook it can take some 2 hours for the agent to start improving so it seems reasonable to launch the algorithm on a simpler env first. In debug one it is CartPole and it will train in several minutes.\n", - "\n", - "**We suggest the following pipeline:** First implement debug notebook then implement the main one.\n", - "\n", - "**About evaluation:** All points are given for the main notebook with one exception: if agent fails to beat the threshold in main notebook you can get 1 pt (instead of 3 pts) for beating the threshold in debug notebook." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys, os\n", - "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", - " \n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/atari_wrappers.py\n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/utils.py\n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/replay_buffer.py\n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/framebuffer.py\n", - "\n", - " !touch .setup_complete\n", - "\n", - "# This code creates a virtual display to draw game images on.\n", - "# It will have no effect if your machine has a monitor.\n", - "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", - " !bash ../xvfb start\n", - " os.environ['DISPLAY'] = ':1'" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "__Frameworks__ - we'll accept this homework in any deep learning framework. This particular notebook was designed for PyTorch, but you find it easy to adapt it to almost any Python-based deep learning framework." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import random\n", - "import numpy as np\n", - "import torch\n", - "import utils" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import gym\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Let's play some old videogames\n", - "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/nerd.png)\n", - "\n", - "This time we're gonna apply approximate Q-learning to an Atari game called Breakout. It's not the hardest thing out there, but it's definitely way more complex than anything we tried before.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ENV_NAME = \"BreakoutNoFrameskip-v4\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preprocessing (3 pts)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's see what observations look like." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAA6UAAAH3CAYAAABD+PmTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3dbazkd3nf/8/1twMPNlQ2N7Us29SG\nOqmgah1YuVYL/GlpyGJFMfQBtVUFJ0VdkEBK5FSVCVJBlSK1aTASaupoEdaaKjHQOgSrclwcNwqq\nUhPWxDHmxtgmRni12AVXQJYIYvvbB+e3MFl295w9c/Od+Z7XSxqdOb+ZOXNZ+/bRXmfm/LZaawEA\nAIAe/r/eAwAAALB3WUoBAADoxlIKAABAN5ZSAAAAurGUAgAA0I2lFAAAgG6WtpRW1YGqeqiqHqmq\nG5f1PNCDvhmdxhmdxhmdxtkktYx/p7Sqzkny5SQ/neTxJJ9Jcl1r7QsLfzJYMX0zOo0zOo0zOo2z\naZb1SumVSR5prX2ltfb9JB9Jcs2SngtWTd+MTuOMTuOMTuNslHOX9HUvSvK1mc8fT/IPZu9QVQeT\nHJw+feWS5oBZ32itvWgBX2fbvhON04XGGd3KGtc3HSyq70TjrKfTNr6spXRbrbVDSQ4lSVUt/j3E\n8KO+uson0zgdaJzRraxxfdOB7+GM7rSNL+vtu0eTXDLz+cXTMRiBvhmdxhmdxhmdxtkoy1pKP5Pk\n8qq6rKqek+TaJHcs6blg1fTN6DTO6DTO6DTORlnK23dba09X1TuT/I8k5yS5pbX2+WU816LddNNN\nO77vDTfcsOvHnvz4eR47r57PfbKTZ1nmc+3WJvedaHzVz30yjS+fxlf73CfT+HLpe7XPfbJN6DvR\n+G4er/EtvRpf2u+UttbuTHLnsr4+9KRvRqdxRqdxRqdxNkm3Ex1tgkX+9OVsHz/vc89jXX/qx+Jp\nnNFpnJHpm9FpfO9Y1u+UAgAAwLa8UsqP2O4nQXvxpzeMReOMTuOMTN+Mbi827pVSAAAAuvFKKdv+\ntGWV76GHZdA4o9M4I9M3o9O4V0oBAADoyCulZzDvTyXmefwqfyKyF376wqlpnNFpnJHpm9FpfO+o\n1lrvGVJV/YdgL7ivtba/xxNrnBXROKPr0ri+WRHfwxndaRv39l0AAAC6WYu371588cVDntqY9dKz\nMY2zChpndL0a0zer4Hs4oztTY14pBQAAoBtLKQAAAN1YSgEAAOjGUgoAAEA3llIAAAC62fVSWlWX\nVNUfVtUXqurzVfVL0/H3VtXRqrp/uly9uHFhdTTO6DTOyPTN6DTOSOb5J2GeTvIrrbXPVtXzktxX\nVXdPt72/tfYb848HXWmc0Wmckemb0WmcYex6KW2tHUtybLr+nar6YpKLFjUY9KZxRqdxRqZvRqdx\nRrKQ3ymtqkuT/FSST0+H3llVD1TVLVV1/mkec7CqjlTVkePHjy9iDFgajTM6jTMyfTM6jbPp5l5K\nq+rHk9ye5Jdba99OcnOSlya5Ils/vXnfqR7XWjvUWtvfWtu/b9++eceApdE4o9M4I9M3o9M4I5hr\nKa2qH8vW/wS/3Vr73SRprT3RWnumtfZskg8muXL+MaEPjTM6jTMyfTM6jTOKec6+W0k+lOSLrbWb\nZo5fOHO3NyV5cPfjQT8aZ3QaZ2T6ZnQaZyTznH33HyX5+SSfq6r7p2O/muS6qroiSUvyWJK3zTUh\n9KNxRqdxRqZvRqdxhjHP2Xf/V5I6xU137n4cWB8aZ3QaZ2T6ZnQaZyTzvFK6MjfccEPvEdgAN910\n0/Z3WlMaZyc0zug2tXF9sxOb2neicXZmnsYX8k/CAAAAwG5YSgEAAOjGUgoAAEA3llIAAAC6sZQC\nAADQjaUUAACAbiylAAAAdGMpBQAAoBtLKQAAAN1YSgEAAOjGUgoAAEA3llIAAAC6sZQCAADQjaUU\nAACAbs6d9wtU1WNJvpPkmSRPt9b2V9Xzk3w0yaVJHkvy5tba/533uWDV9M3oNM7oNM7I9M0oFvVK\n6T9urV3RWts/fX5jkntaa5cnuWf6HDaVvhmdxhmdxhmZvtl4y3r77jVJbp2u35rkjUt6HuhB34xO\n44xO44xM32ycRSylLcknq+q+qjo4HbugtXZsuv71JBec/KCqOlhVR6rqyPHjxxcwBizFrvpONM7G\n0Dij8/cURuZ7OEOY+3dKk7yqtXa0qv5mkrur6kuzN7bWWlW1kx/UWjuU5FCSXHLJJT9yO6yJXfU9\n3aZxNoHGGZ2/pzAy38MZwtyvlLbWjk4fn0zy8SRXJnmiqi5Mkunjk/M+D/Sgb0ancUancUamb0Yx\n11JaVfuq6nknrid5fZIHk9yR5Prpbtcn+cQ8zwM96JvRaZzRaZyR6ZuRzPv23QuSfLyqTnyt32mt\n3VVVn0nysap6a5KvJnnznM8DPeib0Wmc0WmckembYcy1lLbWvpLk75/i+DeTvG6erw296ZvRaZzR\naZyR6ZuRLOJER0t374EDvUdgA/xx7wHmoHF2QuOMblMb1zc7sal9JxpnZ+ZpfFn/TikAAABsy1IK\nAABAN5ZSAAAAurGUAgAA0I2lFAAAgG424uy7z/7tb/ceAZZK44xO44xM34xO4yybV0oBAADoxlIK\nAABAN5ZSAAAAurGUAgAA0I2lFAAAgG424uy7T/2N7/YeAZZK44xO44xM34xO4yybV0oBAADoxlIK\nAABAN7t++25V/WSSj84cekmSf5vkvCT/Ksn/mY7/amvtzl1PCJ1onNFpnNFpnJHpm5HseiltrT2U\n5IokqapzkhxN8vEkv5jk/a2131jIhNCJxhmdxhmdxhmZvhnJok509Lokj7bWvlpVC/qSP/TU3/n+\nwr8mA/rGUr+6xulP44xuQxvXNzuyoX0nGmeH5mh8Ub9Tem2S22Y+f2dVPVBVt1TV+Qt6DuhJ44xO\n44xO44xM32y0uZfSqnpOkp9L8l+nQzcneWm23k5wLMn7TvO4g1V1pKqOHD9+fN4xYGk0zug0zuh2\n07i+2RS+hzOCRbxS+oYkn22tPZEkrbUnWmvPtNaeTfLBJFee6kGttUOttf2ttf379u1bwBiwNBpn\ndBpndGfduL7ZIL6Hs/EWsZRel5m3C1TVhTO3vSnJgwt4DuhJ44xO44xO44xM32y8uU50VFX7kvx0\nkrfNHP71qroiSUvy2Em3wUbROKPTOKPTOCPTN6OYayltrR1P8oKTjv38XBOdwu88++JFf0kG9Pol\nfE2Ns040zug2tXF9sxOb2neicXZmnsYXdfZdAAAAOGuWUgAAALqxlAIAANCNpRQAAIBuLKUAAAB0\nM9fZd1fl+x95b+8R2ASv/+PeE+yaxtkRjTO6DW1c3+zIhvadaJwdmqNxr5QCAADQjaUUAACAbiyl\nAAAAdGMpBQAAoBtLKQAAAN1sxNl3/+ddV/UegQ3ws6+/qfcIu6ZxdkLjjG5TG9c3O7GpfScaZ2fm\nadwrpQAAAHRjKQUAAKAbSykAAADd7GgprapbqurJqnpw5tjzq+ruqnp4+nj+dLyq6gNV9UhVPVBV\nr1jW8LAI+mZ0Gmd0Gmdk+mYv2OkrpYeTHDjp2I1J7mmtXZ7knunzJHlDksuny8EkN88/JizV4eib\nsR2Oxhnb4WiccR2OvhncjpbS1tqnkjx10uFrktw6Xb81yRtnjn+4bbk3yXlVdeEihoVl0Dej0zij\n0zgj0zd7wTy/U3pBa+3YdP3rSS6Yrl+U5Gsz93t8OgabRN+MTuOMTuOMTN8MZSEnOmqttSTtbB5T\nVQer6khVHTl+/PgixoCl2E3ficbZHBpndP6ewsh8D2cE8yylT5x4O8D08cnp+NEkl8zc7+Lp2F/T\nWjvUWtvfWtu/b9++OcaApZir70TjrD2NMzp/T2FkvoczlHmW0juSXD9dvz7JJ2aOv2U6+9dVSb41\n8/YC2BT6ZnQaZ3QaZ2T6Zijn7uROVXVbktcmeWFVPZ7kPUn+fZKPVdVbk3w1yZunu9+Z5OokjyT5\nbpJfXPDMsFD6ZnQaZ3QaZ2T6Zi/Y0VLaWrvuNDe97hT3bUneMc9QsEr6ZnQaZ3QaZ2T6Zi9YyImO\nAAAAYDcspQAAAHRjKQUAAKAbSykAAADdWEoBAADoxlIKAABAN5ZSAAAAurGUAgAA0I2lFAAAgG4s\npQAAAHRjKQUAAKAbSykAAADdWEoBAADoxlIKAABAN5ZSAAAAutl2Ka2qW6rqyap6cObYf6yqL1XV\nA1X18ao6bzp+aVX9ZVXdP11+a5nDwyJofL3ce+BA7j1woPcYQ9E4I9M3o9M4e8FOXik9nOTkvyHe\nneTvttb+XpIvJ3nXzG2PttaumC5vX8yYsFSHo3HGdjgaXwt+6LIUh6NvxnY4Gmdw2y6lrbVPJXnq\npGOfbK09PX16b5KLlzAbrITGGZ3GGZm+14MfuCyPxtkLFvE7pf8yye/PfH5ZVf1pVf1RVb16AV8f\netM4o9M4I9M3o9P4kvmhy/KdO8+Dq+rdSZ5O8tvToWNJXtxa+2ZVvTLJ71XVy1tr3z7FYw8mOZgk\n559//jxjwNJofPWuuuuu3iPsKRpnZPpmdBpnFLt+pbSqfiHJzyb5F621liStte+11r45Xb8vyaNJ\nfuJUj2+tHWqt7W+t7d+3b99ux4Cl0Tij0/jqXXXXXX7wsiL6ZnQaZyS7eqW0qg4k+TdJ/v/W2ndn\njr8oyVOttWeq6iVJLk/ylYVMCiukcUancUam79Xzw5bV0jij2XYprarbkrw2yQur6vEk78nWGb6e\nm+TuqkqSe6eze70myb+rqr9K8mySt7fWnjrlF4Y1oXFGp3FGpm9Gp/H+/NBl+bZdSltr153i8IdO\nc9/bk9w+71CwShpndBpnZPpmdBpnL1jE2XcBAABgVyylAAAAdGMpBQAAoBtLKQAAAN1YSgEAAOjG\nUgoAAEA3llIAAAC6sZQCAADQjaUUAACAbiylAAAAdGMpBQAAoBtLKQAAAN1YSgEAAOjGUgoAAEA3\nllIAAAC6sZQCAADQzbZLaVXdUlVPVtWDM8feW1VHq+r+6XL1zG3vqqpHquqhqvqZZQ0Oi6JxRqdx\nRqdxRqZv9oKdvFJ6OMmBUxx/f2vtiulyZ5JU1cuSXJvk5dNj/nNVnbOoYWFJDkfjjO1wNM7YDkfj\njOtw9M3gtl1KW2ufSvLUDr/eNUk+0lr7Xmvtz5M8kuTKOeaDpdM4o9M4o9M4I9M3e8E8v1P6zqp6\nYHpLwfnTsYuSfG3mPo9Px35EVR2sqiNVdeT48eNzjAFLo3FGp3FGt+vG9c0G8D2cYex2Kb05yUuT\nXJHkWJL3ne0XaK0daq3tb63t37dv3y7HgKXROKPTOKObq3F9s+Z8D2cou1pKW2tPtNaeaa09m+SD\n+eHbAo4muWTmrhdPx2CjaJzRaZzRaZyR6ZvR7GopraoLZz59U5ITZwO7I8m1VfXcqrosyeVJ/mS+\nEWH1NM7oNM7oNM7I9M1ozt3uDlV1W5LXJnlhVT2e5D1JXltVVyRpSR5L8rYkaa19vqo+luQLSZ5O\n8o7W2jPLGR0WQ+OMTuOMTuOMTN/sBdsupa21605x+ENnuP+vJfm1eYaCVdI4o9M4o9M4I9M3e8E8\nZ98FAACAuVhKAQAA6MZSCgAAQDeWUgAAALqxlAIAANCNpRQAAIBuLKUAAAB0YykFAACgG0spAAAA\n3VhKAQAA6MZSCgAAQDeWUgAAALqxlAIAANCNpRQAAIBuLKUAAAB0s+1SWlW3VNWTVfXgzLGPVtX9\n0+Wxqrp/On5pVf3lzG2/tczhYRE0zug0zsj0zeg0zl5w7g7uczjJf0ry4RMHWmv//MT1qnpfkm/N\n3P/R1toVixoQVuBwNM7YDkfjjOtw9M3YDkfjDG7bpbS19qmquvRUt1VVJXlzkn+y2LFgdTTO6DTO\nyPTN6DTOXjDv75S+OskTrbWHZ45dVlV/WlV/VFWvPt0Dq+pgVR2pqiPHjx+fcwxYGo0zOo0zMn0z\nOo0zhJ28ffdMrkty28znx5K8uLX2zap6ZZLfq6qXt9a+ffIDW2uHkhxKkksuuaTNOQcsi8YZncYZ\nmb4ZncYZwq5fKa2qc5P8syQfPXGstfa91to3p+v3JXk0yU/MOyT0oHFGp3FGpm9Gp3FGMs/bd/9p\nki+11h4/caCqXlRV50zXX5Lk8iRfmW9E6EbjjE7jjEzfjE7jDGMn/yTMbUn+d5KfrKrHq+qt003X\n5q+/XSBJXpPkgem01P8tydtba08tcmBYNI0zOo0zMn0zOo2zF+zk7LvXneb4L5zi2O1Jbp9/LFgd\njTM6jTMyfTM6jbMXzHv2XQAAANg1SykAAADdWEoBAADoxlIKAABAN5ZSAAAAurGUAgAA0I2lFAAA\ngG62/XdKV+Fb5zyb/37eX/QeY0+698CBuR5/1V13LWiS+f3DT36y9winpfF+NL4aGu9H48un7370\nvRoa70fjW7xSCgAAQDeWUgAAALqxlAIAANDNWvxOKf2s0/vQYRk0zug0zsj0zeg0vsVSyjD8T83o\nNM7oNM7I9M3o5mm8WmsLHGWXQ1T1H4K94L7W2v4eT6xxVkTjjK5L4/pmRXwPZ3SnbdzvlAIAANDN\ntktpVV1SVX9YVV+oqs9X1S9Nx59fVXdX1cPTx/On41VVH6iqR6rqgap6xbL/I2AeGmd0Gmdk+mZ0\nGmcv2MkrpU8n+ZXW2suSXJXkHVX1siQ3JrmntXZ5knumz5PkDUkuny4Hk9y88KlhsTTO6DTOyPTN\n6DTO8LZdSltrx1prn52ufyfJF5NclOSaJLdOd7s1yRun69ck+XDbcm+S86rqwoVPDguicUancUam\nb0ancfaCs/qd0qq6NMlPJfl0kgtaa8emm76e5ILp+kVJvjbzsMenYyd/rYNVdaSqjpzlzLA0Gmd0\nGmdk+mZ0GmdUO15Kq+rHk9ye5Jdba9+eva1tncL3rM7a1Vo71Frb3+ssY3AyjTM6jTMyfTM6jTOy\nHS2lVfVj2fqf4Ldba787HX7ixFsBpo9PTsePJrlk5uEXT8dgbWmc0Wmckemb0Wmc0e3k7LuV5ENJ\nvthau2nmpjuSXD9dvz7JJ2aOv2U689dVSb4189YCWDsaZ3QaZ2T6ZnQaZ09orZ3xkuRV2Xo7wANJ\n7p8uVyd5QbbO9PVwkj9I8vzp/pXkN5M8muRzSfbv4Dmai8sKLkc07jL4ReMuo19+pPHo22Wci+/h\nLqNfTtl4ay01hdhVVfUfgr3gvl6/N6FxVkTjjK5L4/pmRXwPZ3Snbfyszr4LAAAAi2QpBQAAoBtL\nKQAAAN2c23uAyTeSHJ8+bqoXxvw97WT+v7WKQU5D4/3thfl7Nv4XSR7q+Pzz2gt9rLt1btz38P72\nwvz+njKfvdDIOpur8bU40VGSVNWRTf7He83f1ybMvwkznon5+1r3+dd9vu2Yv791/29Y9/m2Y/6+\nNmH+TZjxTMzf17zze/suAAAA3VhKAQAA6GadltJDvQeYk/n72oT5N2HGMzF/X+s+/7rPtx3z97fu\n/w3rPt92zN/XJsy/CTOeifn7mmv+tfmdUgAAAPaedXqlFAAAgD3GUgoAAEA33ZfSqjpQVQ9V1SNV\ndWPveXaiqh6rqs9V1f1VdWQ69vyquruqHp4+nt97zllVdUtVPVlVD84cO+XMteUD05/JA1X1in6T\n/2DWU83/3qo6Ov053F9VV8/c9q5p/oeq6mf6TP2DWTS+ZPruS+PLp/F+NrHvROOrpvHV2rS+E41v\n+wSttW6XJOckeTTJS5I8J8mfJXlZz5l2OPdjSV540rFfT3LjdP3GJP+h95wnzfeaJK9I8uB2Mye5\nOsnvJ6kkVyX59JrO/94k//oU933Z1NJzk1w2NXZOp7k13q8Pfa9mdo33a0Tjy597I/ueZtd4//k1\nvry5N6rvMzSi8enS+5XSK5M80lr7Smvt+0k+kuSazjPt1jVJbp2u35rkjR1n+RGttU8leeqkw6eb\n+ZokH25b7k1yXlVduJpJT+0085/ONUk+0lr7Xmvtz5M8kq3WetD4Cui7W9+JxldC476HL4jGl0Tj\na2Ft+040nm0a772UXpTkazOfPz4dW3ctySer6r6qOjgdu6C1dmy6/vUkF/QZ7aycbuZN+nN55/S2\nhltm3qaxTvOv0yxnY4TG9b0a6zbPTml8Pax74+s0y9nS+HrQ+HKM0Hei8R/ovZRuqle11l6R5A1J\n3lFVr5m9sW29br1R/9bOJs6c5OYkL01yRZJjSd7Xd5yhDNX4ps070fdyabw/jS+XxvvT+PIM1Xey\nmTNngY33XkqPJrlk5vOLp2NrrbV2dPr4ZJKPZ+vl6CdOvKw+fXyy34Q7drqZN+LPpbX2RGvtmdba\ns0k+mB++LWCd5l+nWXZskMb1vRrrNs+OaLy/DWl8nWY5KxrvT+PLM0jficZ/oPdS+pkkl1fVZVX1\nnCTXJrmj80xnVFX7qup5J64neX2SB7M19/XT3a5P8ok+E56V0818R5K3TGf+uirJt2beWrA2Tnpv\n/Zuy9eeQbM1/bVU9t6ouS3J5kj9Z9XwTjfej79XQeD8aX76N6zvR+LrQ+HIM1Hei8R8601mQVnHJ\n1tmlvpytszK9u/c8O5j3Jdk6m9SfJfn8iZmTvCDJPUkeTvIHSZ7fe9aT5r4tWy+r/1W23tf91tPN\nnK0zff3m9GfyuST713T+/zLN98AU/4Uz93/3NP9DSd7QeXaN9+lD36ubX+N9GtH4ambfqL6nmTW+\nHvNrfDnzblzfZ2hE49OlpgcBAADAyvV++y4AAAB7mKUUAACAbiylAAAAdGMpBQAAoBtLKQAAAN1Y\nSgEAAOjGUgoAAEA3llIAAAC6sZQCAADQjaUUAACAbiylAAAAdGMpBQAAoBtLKQAAAN1YSgEAAOjG\nUgoAAEA3llIAAAC6sZQCAADQjaUUAACAbiylAAAAdGMpBQAAoBtLKQAAAN1YSgEAAOjGUgoAAEA3\nllIAAAC6sZQCAADQjaUUAACAbiylAAAAdGMpBQAAoBtLKQAAAN1YSgEAAOjGUgoAAEA3llIAAAC6\nsZQCAADQjaUUAACAbiylAAAAdGMpBQAAoBtLKQAAAN1YSgEAAOjGUgoAAEA3llIAAAC6sZQCAADQ\njaUUAACAbiylAAAAdGMpBQAAoBtLKQAAAN0sbSmtqgNV9VBVPVJVNy7reaAHfTM6jTM6jTM6jbNJ\nqrW2+C9adU6SLyf56SSPJ/lMkutaa19Y+JPBiumb0Wmc0Wmc0WmcTbOsV0qvTPJIa+0rrbXvJ/lI\nkmuW9FywavpmdBpndBpndBpno5y7pK97UZKvzXz+eJJ/MHuHqjqY5OD06SuXNAfM+kZr7UUL+Drb\n9p1onC40zuhW1ri+6WBRfScaZz2dtvFlLaXbaq0dSnIoSapq8e8hhh/11VU+mcbpQOOMbmWN65sO\nfA9ndKdtfFlv3z2a5JKZzy+ejsEI9M3oNM7oNM7oNM5GWdZS+pkkl1fVZVX1nCTXJrljSc8Fq6Zv\nRqdxRqdxRqdxNspS3r7bWnu6qt6Z5H8kOSfJLa21zy/juRbtpptu2vF9b7jhhl0/9uTHz/PYefV8\n7pOdPMsyn2u3NrnvROOrfu6TaXz5NL7a5z6ZxpdL36t97pNtQt+JxnfzeI1v6dX40n6ntLV2Z5I7\nl/X1oSd9MzqNMzqNMzqNs0m6nehoEyzypy9n+/h5n3se6/pTPxZP44xO44xM34xO43vHsn6nFAAA\nALbllVJ+xHY/CdqLP71hLBpndBpnZPpmdHuxca+UAgAA0I1XStn2py2rfA89LIPGGZ3GGZm+GZ3G\nvVIKAABAR14pPYN5fyoxz+NX+RORvfDTF05N44xO44xM34xO43tHtdZ6z5Cq6j8Ee8F9rbX9PZ5Y\n46yIxhldl8b1zYr4Hs7oTtu4t+8CAADQzVq8fffiiy8e8tTGrJeejWmcVdA4o+vVmL5ZBd/DGd2Z\nGvNKKQAAAN1YSgEAAOjGUgoAAEA3llIAAAC6sZQCAADQza6X0qq6pKr+sKq+UFWfr6pfmo6/t6qO\nVtX90+XqxY0Lq6NxRqdxRqZvRqdxRjLPPwnzdJJfaa19tqqel+S+qrp7uu39rbXfmH886ErjjE7j\njEzfjE7jDGPXS2lr7ViSY9P171TVF5NctKjBoDeNMzqNMzJ9MzqNM5KF/E5pVV2a5KeSfHo69M6q\neqCqbqmq80/zmINVdaSqjhw/fnwRY8DSaJzRaZyR6ZvRaZxNN/dSWlU/nuT2JL/cWvt2kpuTvDTJ\nFdn66c37TvW41tqh1tr+1tr+ffv2zTsGLI3GGZ3GGZm+GZ3GGcFcS2lV/Vi2/if47dba7yZJa+2J\n1tozrbVnk3wwyZXzjwl9aJzRaZyR6ZvRaZxRzHP23UryoSRfbK3dNHP8wpm7vSnJg7sfD/rROKPT\nOCPTN6PTOCOZ5+y7/yjJzyf5XFXdPx371STXVdUVSVqSx5K8ba4JoR+NMzqNMzJ9MzqNM4x5zr77\nv5LUKW66c/fjwPrQOKPTOCPTN6PTOCOZ55XSlbnhhht6j8AGuOmmm7a/05rSODuhcUa3qY3rm53Y\n1L4TjbMz8zS+kH8SBgAAAHbDUgoAAEA3llIAAAC6sZQCAADQjaUUAACAbiylAAAAdGMpBQAAoBtL\nKQAAAN1YSgEAAOjGUgoAAEA3llIAAAC6sZQCAADQjaUUAACAbiylAAAAdHPuvF+gqh5L8p0kzyR5\nurW2v6qen+SjSS5N8liSN7fW/u+8zwWrpm9Gp3FGp3FGpm9GsahXSv9xa+2K1tr+6fMbk9zTWrs8\nyT3T57Cp9M3oNM7oNM7I9M3GW9bbd69Jcut0/dYkb1zS80AP+mZ0Gmd0Gmdk+mbjLGIpbUk+WVX3\nVdXB6dgFrbVj0/WvJ7ng5AdV1cGqOlJVR44fP76AMWApdtV3onE2hsYZnb+nMDLfwxnC3L9TmuRV\nrbWjVfU3k9xdVV+avbG11qqqnfyg1tqhJIeS5JJLLvmR22FN7Krv6TaNswk0zuj8PYWR+R7OEOZ+\npbS1dnT6+GSSjye5MskTVXVhkkwfn5z3eaAHfTM6jTM6jTMyfTOKuZbSqtpXVc87cT3J65M8mOSO\nJNdPd7s+ySfmeR7oQd+MTmOrsbMAABLfSURBVOOMTuOMTN+MZN63716Q5ONVdeJr/U5r7a6q+kyS\nj1XVW5N8Ncmb53we6EHfjE7jjE7jjEzfDGOupbS19pUkf/8Ux7+Z5HXzfG3oTd+MTuOMTuOMTN+M\nZBEnOlq6ew8c6D0CG+CPew8wB42zExpndJvauL7ZiU3tO9E4OzNP48v6d0oBAABgW5ZSAAAAurGU\nAgAA0I2lFAAAgG4spQAAAHSzEWffffZvf7v3CLBUGmd0Gmdk+mZ0GmfZvFIKAABAN5ZSAAAAurGU\nAgAA0I2lFAAAgG4spQAAAHSzEWfffepvfLf3CLBUGmd0Gmdk+mZ0GmfZvFIKAABAN5ZSAAAAutn1\n23er6ieTfHTm0EuS/Nsk5yX5V0n+z3T8V1trd+56QuhE44xO44xO44xM34xk10tpa+2hJFckSVWd\nk+Roko8n+cUk72+t/cZCJoRONM7oNM7oNM7I9M1IFnWio9clebS19tWqWtCX/KGn/s73F/41GdA3\nlvrVNU5/Gmd0G9q4vtmRDe070Tg7NEfji/qd0muT3Dbz+Tur6oGquqWqzj/VA6rqYFUdqaojx48f\nX9AYsDQaZ3QaZ3Rn1bi+2TC+h7PR5l5Kq+o5SX4uyX+dDt2c5KXZejvBsSTvO9XjWmuHWmv7W2v7\n9+3bN+8YsDQaZ3QaZ3S7aVzfbArfwxnBIl4pfUOSz7bWnkiS1toTrbVnWmvPJvlgkisX8BzQk8YZ\nncYZncYZmb7ZeItYSq/LzNsFqurCmdvelOTBBTwH9KRxRqdxRqdxRqZvNt5cJzqqqn1JfjrJ22YO\n/3pVXZGkJXnspNtgo2ic0Wmc0WmckembUcy1lLbWjid5wUnHfn6uiU7hd5598aK/JAN6/RK+psZZ\nJxpndJvauL7ZiU3tO9E4OzNP44s6+y4AAACcNUspAAAA3VhKAQAA6MZSCgAAQDeWUgAAALqZ6+y7\nq/L9j7y39whsgtf/ce8Jdk3j7IjGGd2GNq5vdmRD+040zg7N0bhXSgEAAOjGUgoAAEA3llIAAAC6\nsZQCAADQjaUUAACAbjbi7Lv/866reo/ABvjZ19/Ue4Rd0zg7oXFGt6mN65ud2NS+E42zM/M07pVS\nAAAAurGUAgAA0I2lFAAAgG52tJRW1S1V9WRVPThz7PlVdXdVPTx9PH86XlX1gap6pKoeqKpXLGt4\nWAR9MzqNMzqNMzJ9sxfs9JXSw0kOnHTsxiT3tNYuT3LP9HmSvCHJ5dPlYJKb5x8Tlupw9M3YDkfj\njO1wNM64DkffDG5HS2lr7VNJnjrp8DVJbp2u35rkjTPHP9y23JvkvKq6cBHDwjLom9FpnNFpnJHp\nm71gnt8pvaC1dmy6/vUkF0zXL0rytZn7PT4d+2uq6mBVHamqI8ePH59jDFiKufpONM7a0zij8/cU\nRuZ7OENZyImOWmstSTvLxxxqre1vre3ft2/fIsaApdhN39PjNM5G0Dij8/cURuZ7OCOYZyl94sTb\nAaaPT07Hjya5ZOZ+F0/HYJPom9FpnNFpnJHpm6HMs5TekeT66fr1ST4xc/wt09m/rkryrZm3F8Cm\n0Dej0zij0zgj0zdDOXcnd6qq25K8NskLq+rxJO9J8u+TfKyq3prkq0nePN39ziRXJ3kkyXeT/OKC\nZ4aF0jej0zij0zgj0zd7wY6W0tbadae56XWnuG9L8o55hoJV0jej0zij0zgj0zd7wUJOdAQAAAC7\nYSkFAACgG0spAAAA3VhKAQAA6MZSCgAAQDeWUgAAALqxlAIAANCNpRQAAIBuLKUAAAB0YykFAACg\nG0spAAAA3VhKAQAA6MZSCgAAQDeWUgAAALqxlAIAANDNtktpVd1SVU9W1YMzx/5jVX2pqh6oqo9X\n1XnT8Uur6i+r6v7p8lvLHB4WQeOMTuOMTN+MTuPsBTt5pfRwkgMnHbs7yd9trf29JF9O8q6Z2x5t\nrV0xXd6+mDFhqQ5H44ztcDTOuA5H34ztcDTO4LZdSltrn0ry1EnHPtlae3r69N4kFy9hNlgJjTM6\njTMyfTM6jbMXLOJ3Sv9lkt+f+fyyqvrTqvqjqnr1Ar4+9KZxRqdxRqZvRqdxNt658zy4qt6d5Okk\nvz0dOpbkxa21b1bVK5P8XlW9vLX27VM89mCSg0ly/vnnzzPG0O49sPVujavuuqvzJHuTxhmdxhmZ\nvhmdxhnFrl8prapfSPKzSf5Fa60lSWvte621b07X70vyaJKfONXjW2uHWmv7W2v79+3bt9sxYGk0\nvnz3Hjjwgx+8sHoaZ2T6ZnQaZyS7Wkqr6kCSf5Pk51pr3505/qKqOme6/pIklyf5yiIGhVXSOKPT\n+PL5oUs/+mZ0Gmc02759t6puS/LaJC+sqseTvCdbZ/h6bpK7qypJ7p3O7vWaJP+uqv4qybNJ3t5a\ne+qUXxjWhMYZncYZmb5Xz68WrZbG2Qu2XUpba9ed4vCHTnPf25PcPu9Q/JBv+MuncUancUamb0an\n8dXyQ5c+5jrREcA8fMMHAMBSCgBL4IcuALAzllIAAHbMD1yARbOUAgAAxA9detn1v1MKAAAA87KU\nAgAA0I2lFAAAgG4spQAAAHRjKQUAAKAbSykAAADdWEoBAADoxlIKAABAN5ZSAAAAurGUAgAA0I2l\nFAAAgG62XUqr6paqerKqHpw59t6qOlpV90+Xq2due1dVPVJVD1XVzyxrcFgUjTM6jTM6jTMyfbMX\n7OSV0sNJDpzi+Ptba1dMlzuTpKpeluTaJC+fHvOfq+qcRQ0LS3I4Gmdsh6NxxnY4Gmdch6NvBrft\nUtpa+1SSp3b49a5J8pHW2vdaa3+e5JEkV84xHyydxhmdxhmdxhmZvtkL5vmd0ndW1QPTWwrOn45d\nlORrM/d5fDoGm0jjjE7jjE7jjEzfDGO3S+nNSV6a5Iokx5K872y/QFUdrKojVXXk+PHjuxwDlkbj\njE7jjG6uxvXNmvM9nKHsailtrT3RWnumtfZskg/mh28LOJrkkpm7XjwdO9XXONRa299a279v377d\njAFLo3FGp3FGN2/j+mad+R7OaHa1lFbVhTOfvinJibOB3ZHk2qp6blVdluTyJH8y34iwehpndBpn\ndBpnZPpmNOdud4equi3Ja5O8sKoeT/KeJK+tqiuStCSPJXlbkrTWPl9VH0vyhSRPJ3lHa+2Z5YwO\ni6FxRqdxRqdxRqZv9oJtl9LW2nWnOPyhM9z/15L82jxDwSppnNFpnNFpnJHpm71gnrPvAgAAwFws\npQAAAHRjKQUAAKAbSykAAADdWEoBAADoxlIKAABAN5ZSAAAAurGUAgAA0I2lFAAAgG4spQAAAHRj\nKQUAAKAbSykAAADdWEoBAADoxlIKAABAN5ZSAAAAutl2Ka2qW6rqyap6cObYR6vq/unyWFXdPx2/\ntKr+cua231rm8LAIGmd0Gmdk+mZ0GmcvOHcH9zmc5D8l+fCJA621f37ielW9L8m3Zu7/aGvtikUN\nCCtwOBpnbIejccZ1OPpmbIejcQa37VLaWvtUVV16qtuqqpK8Ock/WexYsDoaZ3QaZ2T6ZnQaZy+Y\n93dKX53kidbawzPHLquqP62qP6qqV8/59aE3jTM6jTMyfTM6jTOEnbx990yuS3LbzOfHkry4tfbN\nqnplkt+rqpe31r598gOr6mCSg0ly/vnnzzkGLI3GGZ3GGZm+GZ3GGcKuXymtqnOT/LMkHz1xrLX2\nvdbaN6fr9yV5NMlPnOrxrbVDrbX9rbX9+/bt2+0YsDQaZ3QaZ2T6ZnQaZyTzvH33nyb5Umvt8RMH\nqupFVXXOdP0lSS5P8pX5RoRuNM7oNM7I9M3oNM4wdvJPwtyW5H8n+cmqeryq3jrddG3++tsFkuQ1\nSR6YTkv935K8vbX21CIHhkXTOKPTOCPTN6PTOHvBTs6+e91pjv/CKY7dnuT2+ceC1dE4o9M4I9M3\no9M4e8G8Z98FAACAXbOUAgAA0I2lFAAAgG4spQAAAHRjKQUAAKAbSykAAADdWEoBAADoxlIKAABA\nN+f2HiBJvnXOs/nv5/1F7zH2jHsPHJjr8VfdddeCJlmsf/jJT/Ye4bQ0vjqj9p1onLH7Tta3cX2v\nhr770fhqaPz0vFIKAABAN5ZSAAAAurGUAgAA0M1a/E4pq7Xu70eHeeibkembkemb0Wn89CylDMP/\n6IxO44xO44xM34xunsartbbAUXY5RFX/IdgL7mut7e/xxBpnRTTO6Lo0rm9WxPdwRnfaxrf9ndKq\nuqSq/rCqvlBVn6+qX5qOP7+q7q6qh6eP50/Hq6o+UFWPVNUDVfWKxf63wGJpnNFpnJHpm9FpnL1g\nJyc6ejrJr7TWXpbkqiTvqKqXJbkxyT2ttcuT3DN9niRvSHL5dDmY5OaFTw2LpXFGp3FGpm9Gp3GG\nt+1S2lo71lr77HT9O0m+mOSiJNckuXW6261J3jhdvybJh9uWe5OcV1UXLnxyWBCNMzqNMzJ9MzqN\nsxec1T8JU1WXJvmpJJ9OckFr7dh009eTXDBdvyjJ12Ye9vh0DNaexhmdxhmZvhmdxhnVjs++W1U/\nnuT2JL/cWvt2Vf3gttZaO9tfkK6qg9l6SwGsBY0zOo0zMn0zOo0zsh29UlpVP5at/wl+u7X2u9Ph\nJ068FWD6+OR0/GiSS2YefvF07K9prR1qre3vdZYxmKVxRqdxRqZvRqdxRreTs+9Wkg8l+WJr7aaZ\nm+5Icv10/fokn5g5/pbpzF9XJfnWzFsLYO1onNFpnJHpm9FpnD2htXbGS5JXJWlJHkhy/3S5OskL\nsnWmr4eT/EGS50/3ryS/meTRJJ9Lsn8Hz9FcXFZwOaJxl8EvGncZ/fIjjUffLuNcfA93Gf1yysZb\na6kpxK7O9j3wsEv+UWpGp3FG16VxfbMivoczutM2flZn3wUAAIBFspQCAADQjaUUAACAbiylAAAA\ndHNu7wEm30hyfPq4qV4Y8/e0k/n/1ioGOQ2N97cX5u/Z+F8keajj889rL/Sx7ta5cd/D+9sL8/t7\nynz2QiPrbK7G1+Lsu0lSVUc2+R/vNX9fmzD/Jsx4Jubva93nX/f5tmP+/tb9v2Hd59uO+fvahPk3\nYcYzMX9f887v7bsAAAB0YykFAACgm3VaSg/1HmBO5u9rE+bfhBnPxPx9rfv86z7fdszf37r/N6z7\nfNsxf1+bMP8mzHgm5u9rrvnX5ndKAQAA2HvW6ZVSAAAA9hhLKQAAAN10X0qr6kBVPVRVj1TVjb3n\n2YmqeqyqPldV91fVkenY86vq7qp6ePp4fu85Z1XVLVX1ZFU9OHPslDPXlg9MfyYPVNUr+k3+g1lP\nNf97q+ro9Odwf1VdPXPbu6b5H6qqn+kz9Q9m0fiS6bsvjS+fxvvZxL4Tja+axldr0/pONL7tE7TW\nul2SnJPk0SQvSfKcJH+W5GU9Z9rh3I8leeFJx349yY3T9RuT/Ifec54032uSvCLJg9vNnOTqJL+f\npJJcleTTazr/e5P861Pc92VTS89NctnU2Dmd5tZ4vz70vZrZNd6vEY0vf+6N7HuaXeP959f48ube\nqL7P0IjGp0vvV0qvTPJIa+0rrbXvJ/lIkms6z7Rb1yS5dbp+a5I3dpzlR7TWPpXkqZMOn27ma5J8\nuG25N8l5VXXhaiY9tdPMfzrXJPlIa+17rbU/T/JItlrrQeMroO9ufScaXwmN+x6+IBpfEo2vhbXt\nO9F4tmm891J6UZKvzXz++HRs3bUkn6yq+6rq4HTsgtbasen615Nc0Ge0s3K6mTfpz+Wd09sabpl5\nm8Y6zb9Os5yNERrX92qs2zw7pfH1sO6Nr9MsZ0vj60HjyzFC34nGf6D3UrqpXtVae0WSNyR5R1W9\nZvbGtvW69Ub9WzubOHOSm5O8NMkVSY4leV/fcYYyVOObNu9E38ul8f40vlwa70/jyzNU38lmzpwF\nNt57KT2a5JKZzy+ejq211trR6eOTST6erZejnzjxsvr08cl+E+7Y6WbeiD+X1toTrbVnWmvPJvlg\nfvi2gHWaf51m2bFBGtf3aqzbPDui8f42pPF1muWsaLw/jS/PIH0nGv+B3kvpZ5JcXlWXVdVzklyb\n5I7OM51RVe2rqueduJ7k9UkezNbc1093uz7JJ/pMeFZON/MdSd4ynfnrqiTfmnlrwdo46b31b8rW\nn0OyNf+1VfXcqrosyeVJ/mTV80003o++V0Pj/Wh8+Tau70Tj60LjyzFQ34nGf+hMZ0FaxSVbZ5f6\ncrbOyvTu3vPsYN6XZOtsUn+W5PMnZk7ygiT3JHk4yR8keX7vWU+a+7Zsvaz+V9l6X/dbTzdzts70\n9ZvTn8nnkuxf0/n/yzTfA1P8F87c/93T/A8leUPn2TXepw99r25+jfdpROOrmX2j+p5m1vh6zK/x\n5cy7cX2foRGNT5eaHgQAAAAr1/vtuwAAAOxhllIAAAC6sZQCAADQjaUUAACAbiylAAAAdGMpBQAA\noBtLKQAAAN38P4MFx8hKCA3KAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "1NDjJiqysoT-" + }, + "source": [ + "# Deep Q-Network implementation.\n", + "\n", + "This homework shamelessly demands you to implement DQN — an approximate Q-learning algorithm with experience replay and target networks — and see if it works any better this way.\n", + "\n", + "Original paper:\n", + "https://arxiv.org/pdf/1312.5602.pdf" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BcLhaXMKsoT_" + }, + "source": [ + "**This notebook is the main notebook.** Another notebook is given for debug. (**homework_pytorch_debug**). The tasks are similar and share most of the code. The main difference is in environments. In main notebook it can take some 2 hours for the agent to start improving so it seems reasonable to launch the algorithm on a simpler env first. In debug one it is CartPole and it will train in several minutes.\n", + "\n", + "**We suggest the following pipeline:** First implement debug notebook then implement the main one.\n", + "\n", + "**About evaluation:** All points are given for the main notebook with one exception: if agent fails to beat the threshold in main notebook you can get 1 pt (instead of 3 pts) for beating the threshold in debug notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IVo0UxTWsoT_" + }, + "outputs": [], + "source": [ + "import sys, os\n", + "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", + "\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/atari_wrappers.py\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/utils.py\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/replay_buffer.py\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/framebuffer.py\n", + "\n", + " !touch .setup_complete\n", + "\n", + "# This code creates a virtual display to draw game images on.\n", + "# It will have no effect if your machine has a monitor.\n", + "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", + " !bash ../xvfb start\n", + " os.environ['DISPLAY'] = ':1'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KkrBeP7YsoUA" + }, + "source": [ + "__Frameworks__ - we'll accept this homework in any deep learning framework. This particular notebook was designed for PyTorch, but you find it easy to adapt it to almost any Python-based deep learning framework." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0ABRgp2-sq5a" + }, + "outputs": [], + "source": [ + "!pip install gymnasium[atari,accept-rom-license]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XauE94NisoUA" + }, + "outputs": [], + "source": [ + "import random\n", + "import numpy as np\n", + "import torch\n", + "import utils" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "P8WoWe9DsoUA" + }, + "outputs": [], + "source": [ + "import gymnasium as gym\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6GQBgViKsoUA" + }, + "source": [ + "### Let's play some old videogames\n", + "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/nerd.png)\n", + "\n", + "This time we're gonna apply approximate Q-learning to an Atari game called Breakout. It's not the hardest thing out there, but it's definitely way more complex than anything we tried before.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "S_zvw_31soUA" + }, + "outputs": [], + "source": [ + "ENV_NAME = \"BreakoutNoFrameskip-v4\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xT9BvasNsoUA" + }, + "source": [ + "## Preprocessing (3 pts)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iwN8jA0OsoUA" + }, + "source": [ + "Let's see what observations look like." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rUZHU2HdsoUB" + }, + "outputs": [], + "source": [ + "env = gym.make(ENV_NAME, render_mode=\"rgb_array\")\n", + "env.reset()\n", + "\n", + "n_cols = 5\n", + "n_rows = 2\n", + "fig = plt.figure(figsize=(16, 9))\n", + "\n", + "for row in range(n_rows):\n", + " for col in range(n_cols):\n", + " ax = fig.add_subplot(n_rows, n_cols, row * n_cols + col + 1)\n", + " ax.imshow(env.render())\n", + " env.step(env.action_space.sample())\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hLNt1fbnsoUB" + }, + "source": [ + "**Let's play a little.**\n", + "\n", + "Pay attention to zoom and fps args of play function. Control: A, D, space." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WOIL47azsoUB" + }, + "outputs": [], + "source": [ + "# # Does not work in Colab.\n", + "# # Use KeyboardInterrupt (Kernel → Interrupt in Jupyter) to continue.\n", + "\n", + "# from gymnasium.utils.play import play\n", + "\n", + "# play(env=gym.make(ENV_NAME, render_mode=\"rgb_array\"), zoom=4, fps=40)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5DPrxQuXsoUB" + }, + "source": [ + "### Processing game image\n", + "\n", + "Raw Atari images are large, 210x160x3 by default. However, we don't need that level of detail in order to learn from them.\n", + "\n", + "We can thus save a lot of time by preprocessing game image, including\n", + "* Resizing to a smaller shape, 64x64\n", + "* Converting to grayscale\n", + "* Cropping irrelevant image parts (top, bottom and edges)\n", + "\n", + "Also please keep one dimension for channel so that final shape would be 1x64x64.\n", + "\n", + "Tip: You can implement your own grayscale converter and assign a huge weight to the red channel. This dirty trick is not necessary but it will speed up learning." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "kHBZgKV_soUB", + "outputId": "08909179-4c7a-448f-b2ee-fccaa2c9afde" + }, + "outputs": [], + "source": [ + "from gymnasium import ObservationWrapper\n", + "from gymnasium.spaces import Box\n", + "import cv2\n", + "\n", + "\n", + "class PreprocessAtariObs(ObservationWrapper):\n", + " def __init__(self, env):\n", + " \"\"\"A gym wrapper that crops, scales image into the desired shapes and grayscales it.\"\"\"\n", + " super().__init__(env)\n", + "\n", + " self.img_size = (1, 64, 64)\n", + " self.observation_space = Box(0.0, 1.0, self.img_size)\n", + "\n", + "\n", + " def _to_gray_scale(self, rgb, channel_weights=[0.8, 0.1, 0.1]):\n", + " \n", + "\n", + "\n", + " def observation(self, img):\n", + " \"\"\"what happens to each observation\"\"\"\n", + "\n", + " # Here's what you need to do:\n", + " # * crop image, remove irrelevant parts\n", + " # * resize image to self.img_size\n", + " # (Use imresize from any library you want,\n", + " # e.g. opencv, PIL, keras. Don't use skimage.imresize\n", + " # because it is extremely slow.)\n", + " # * cast image to grayscale\n", + " # * convert image pixels to (0,1) range, float32 type\n", + " \n", + " return " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dkdoKM4ZsoUB" + }, + "outputs": [], + "source": [ + "import gymnasium as gym\n", + "# spawn game instance for tests\n", + "env = gym.make(ENV_NAME, render_mode=\"rgb_array\") # create raw env\n", + "env = PreprocessAtariObs(env)\n", + "observation_shape = env.observation_space.shape\n", + "n_actions = env.action_space.n\n", + "env.reset()\n", + "obs, _, _, _, _ = env.step(env.action_space.sample())\n", + "\n", + "# test observation\n", + "assert obs.ndim == 3, \"observation must be [channel, h, w] even if there's just one channel\"\n", + "assert obs.shape == observation_shape, obs.shape\n", + "assert obs.dtype == 'float32'\n", + "assert len(np.unique(obs)) > 2, \"your image must not be binary\"\n", + "assert 0 <= np.min(obs) and np.max(\n", + " obs) <= 1, \"convert image pixels to [0,1] range\"\n", + "\n", + "assert np.max(obs) >= 0.5, \"It would be easier to see a brighter observation\"\n", + "assert np.mean(obs) >= 0.1, \"It would be easier to see a brighter observation\"\n", + "\n", + "print(\"Formal tests seem fine. Here's an example of what you'll get.\")\n", + "\n", + "n_cols = 5\n", + "n_rows = 2\n", + "fig = plt.figure(figsize=(16, 9))\n", + "obs, _ = env.reset()\n", + "for row in range(n_rows):\n", + " for col in range(n_cols):\n", + " ax = fig.add_subplot(n_rows, n_cols, row * n_cols + col + 1)\n", + " ax.imshow(obs[0, :, :], interpolation='none', cmap='gray')\n", + " obs, _, _, _, _ = env.step(env.action_space.sample())\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WGSKOffIsoUB" + }, + "source": [ + "### Wrapping." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MRQmPw5DsoUB" + }, + "source": [ + "**About the game:** You have 5 lives and get points for breaking the wall. Higher bricks cost more than the lower ones. There are 4 actions: start game (should be called at the beginning and after each life is lost), move left, move right and do nothing. There are some common wrappers used for Atari environments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oyYRDowcsoUB" + }, + "outputs": [], + "source": [ + "import atari_wrappers\n", + "\n", + "def PrimaryAtariWrap(env, clip_rewards=True):\n", + " assert 'NoFrameskip' in env.spec.id\n", + "\n", + " # This wrapper holds the same action for frames and outputs\n", + " # the maximal pixel value of 2 last frames (to handle blinking\n", + " # in some envs)\n", + " env = atari_wrappers.MaxAndSkipEnv(env, skip=4)\n", + "\n", + " # This wrapper sends done=True when each life is lost\n", + " # (not all the 5 lives that are givern by the game rules).\n", + " # It should make easier for the agent to understand that losing is bad.\n", + " env = atari_wrappers.EpisodicLifeEnv(env)\n", + "\n", + " # This wrapper laucnhes the ball when an episode starts.\n", + " # Without it the agent has to learn this action, too.\n", + " # Actually it can but learning would take longer.\n", + " env = atari_wrappers.FireResetEnv(env)\n", + "\n", + " # This wrapper transforms rewards to {-1, 0, 1} according to their sign\n", + " if clip_rewards:\n", + " env = atari_wrappers.ClipRewardEnv(env)\n", + "\n", + " # This wrapper is yours :)\n", + " env = PreprocessAtariObs(env)\n", + " return env" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4iJM3IAwsoUB" + }, + "source": [ + "**Let's see if the game is still playable after applying the wrappers.**\n", + "At playing the EpisodicLifeEnv wrapper seems not to work but actually it does (because after when life finishes a new ball is dropped automatically - it means that FireResetEnv wrapper understands that a new episode began)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dFtHgcLJsoUC" + }, + "outputs": [], + "source": [ + "# # Does not work in Colab.\n", + "# # Use KeyboardInterrupt (Kernel → Interrupt in Jupyter) to continue.\n", + "\n", + "# from gymnasium.utils.play import play\n", + "\n", + "# def make_play_env():\n", + "# env = gym.make(ENV_NAME, render_mode=\"rgb_array\")\n", + "# env = PrimaryAtariWrap(env)\n", + "# # in torch imgs have shape [c, h, w] instead of common [h, w, c]\n", + "# env = atari_wrappers.AntiTorchWrapper(env)\n", + "# return env\n", + "\n", + "# play(make_play_env(), zoom=4, fps=3)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RJAkvALbsoUC" + }, + "source": [ + "### Frame buffer\n", + "\n", + "Our agent can only process one observation at a time, so we gotta make sure it contains enough information to find optimal actions. For instance, agent has to react to moving objects so it must be able to measure object's velocity.\n", + "\n", + "To do so, we introduce a buffer that stores 4 last images. This time everything is pre-implemented for you, not really by the staff of the course :)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1ucMNoYysoUC", + "outputId": "1cbba2cd-c3d3-4475-f369-15301db3109b" + }, + "outputs": [], + "source": [ + "from framebuffer import FrameBuffer\n", + "\n", + "def make_env(clip_rewards=True):\n", + " env = gym.make(ENV_NAME, render_mode=\"rgb_array\") # create raw env\n", + " env = PrimaryAtariWrap(env, clip_rewards)\n", + " env = FrameBuffer(env, n_frames=4, dim_order='pytorch')\n", + " return env\n", + "\n", + "env = make_env()\n", + "env.reset()\n", + "n_actions = env.action_space.n\n", + "state_shape = env.observation_space.shape\n", + "n_actions, state_shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PabpXH59soUC" + }, + "outputs": [], + "source": [ + "for _ in range(12):\n", + " obs, _, _, _, _ = env.step(env.action_space.sample())\n", + "\n", + "plt.figure(figsize=[12,10])\n", + "plt.title(\"Game image\")\n", + "plt.imshow(env.render())\n", + "plt.show()\n", + "\n", + "plt.figure(figsize=[15,15])\n", + "plt.title(\"Agent observation (4 frames top to bottom)\")\n", + "plt.imshow(utils.img_by_obs(obs, state_shape), cmap='gray')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jhiOKsQvsoUC" + }, + "source": [ + "## DQN as it is (4 pts)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aspwJFiGsoUC" + }, + "source": [ + "### Building a network\n", + "\n", + "We now need to build a neural network that can map images to state q-values. This network will be called on every agent's step so it better not be resnet-152 unless you have an array of GPUs. Instead, you can use strided convolutions with a small number of features to save time and memory.\n", + "\n", + "You can build any architecture you want, but for reference, here's something that will more or less work:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cjVMIUG7soUC" + }, + "source": [ + "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/dqn_arch.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YbZIucfksoUC" + }, + "source": [ + "**Dueling network: (+2 pts)**\n", + "$$Q_{\\theta}(s, a) = V_{\\eta}(f_{\\xi}(s)) + A_{\\psi}(f_{\\xi}(s), a) - \\frac{\\sum_{a'}A_{\\psi}(f_{\\xi}(s), a')}{N_{actions}},$$\n", + "where $\\xi$, $\\eta$, and $\\psi$ are, respectively, the parameters of the\n", + "shared encoder $f_ξ$ , of the value stream $V_\\eta$ , and of the advantage stream $A_\\psi$; and $\\theta = \\{\\xi, \\eta, \\psi\\}$ is their concatenation.\n", + "\n", + "For the architecture on the image $V$ and $A$ heads can follow the dense layer instead of $Q$. Please don't worry that the model becomes a little bigger." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "SPPmY6wIsoUC", + "outputId": "717e2355-008e-4994-b5f2-1c8cf98ac445" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "# those who have a GPU but feel unfair to use it can uncomment:\n", + "# device = torch.device('cpu')\n", + "device" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FvaNwSxhsoUC" + }, + "outputs": [], + "source": [ + "def conv2d_size_out(size, kernel_size, stride):\n", + " \"\"\"\n", + " common use case:\n", + " cur_layer_img_w = conv2d_size_out(cur_layer_img_w, kernel_size, stride)\n", + " cur_layer_img_h = conv2d_size_out(cur_layer_img_h, kernel_size, stride)\n", + " to understand the shape for dense layer's input\n", + " \"\"\"\n", + " return (size - (kernel_size - 1) - 1) // stride + 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dmLl6IkhsoUC" + }, + "outputs": [], + "source": [ + "class DQNAgent(nn.Module):\n", + " def __init__(self, state_shape, n_actions, epsilon=0):\n", + "\n", + " super().__init__()\n", + " self.epsilon = epsilon\n", + " self.n_actions = n_actions\n", + " self.state_shape = state_shape\n", + "\n", + " # Define your network body here. Please make sure agent is fully contained here\n", + " # nn.Flatten() can be useful\n", + " \n", + " \n", + "\n", + " def forward(self, state_t):\n", + " \"\"\"\n", + " takes agent's observation (tensor), returns qvalues (tensor)\n", + " :param state_t: a batch of 4-frame buffers, shape = [batch_size, 4, h, w]\n", + " \"\"\"\n", + " # Use your network to compute qvalues for given state\n", + " qvalues = \n", + "\n", + " assert qvalues.requires_grad, \"qvalues must be a torch tensor with grad\"\n", + " assert (\n", + " len(qvalues.shape) == 2 and \n", + " qvalues.shape[0] == state_t.shape[0] and \n", + " qvalues.shape[1] == n_actions\n", + " )\n", + "\n", + " return qvalues\n", + "\n", + " def get_qvalues(self, states):\n", + " \"\"\"\n", + " like forward, but works on numpy arrays, not tensors\n", + " \"\"\"\n", + " model_device = next(self.parameters()).device\n", + " states = torch.tensor(states, device=model_device, dtype=torch.float32)\n", + " qvalues = self.forward(states)\n", + " return qvalues.data.cpu().numpy()\n", + "\n", + " def sample_actions(self, qvalues):\n", + " \"\"\"pick actions given qvalues. Uses epsilon-greedy exploration strategy. \"\"\"\n", + " epsilon = self.epsilon\n", + " batch_size, n_actions = qvalues.shape\n", + "\n", + " random_actions = np.random.choice(n_actions, size=batch_size)\n", + " best_actions = qvalues.argmax(axis=-1)\n", + "\n", + " should_explore = np.random.choice(\n", + " [0, 1], batch_size, p=[1-epsilon, epsilon])\n", + " return np.where(should_explore, random_actions, best_actions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BUFMLKX1soUC" + }, + "outputs": [], + "source": [ + "agent = DQNAgent(state_shape, n_actions, epsilon=0.5).to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XbsIT2EdsoUC" + }, + "source": [ + "Now let's try out our agent to see if it raises any errors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pZR3qE2esoUC" + }, + "outputs": [], + "source": [ + "def evaluate(env, agent, n_games=1, greedy=False, t_max=10000, seed=None):\n", + " \"\"\" Plays n_games full games. If greedy, picks actions as argmax(qvalues). Returns mean reward. \"\"\"\n", + " rewards = []\n", + " for _ in range(n_games):\n", + " s, _ = env.reset(seed=seed)\n", + " reward = 0\n", + " for _ in range(t_max):\n", + " qvalues = agent.get_qvalues([s])\n", + " action = qvalues.argmax(axis=-1)[0] if greedy else agent.sample_actions(qvalues)[0]\n", + " s, r, terminated, truncated, _ = env.step(action)\n", + " reward += r\n", + " if terminated or truncated:\n", + " break\n", + "\n", + " rewards.append(reward)\n", + " return np.mean(rewards)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-1OZLPwXsoUC", + "outputId": "f615e1f2-d847-420d-8ac9-d3caab30d91b" + }, + "outputs": [], + "source": [ + "evaluate(env, agent, n_games=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2BiRixA-soUC" + }, + "source": [ + "### Experience replay\n", + "For this assignment, we provide you with experience replay buffer. If you implemented experience replay buffer in last week's assignment, you can copy-paste it here **to get 2 bonus points**.\n", + "\n", + "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/exp_replay.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jTBZo5BVsoUC" + }, + "source": [ + "#### The interface is fairly simple:\n", + "* `exp_replay.add(obs, act, rw, next_obs, done)` - saves (s,a,r,s',done) tuple into the buffer\n", + "* `exp_replay.sample(batch_size)` - returns observations, actions, rewards, next_observations and is_done for `batch_size` random samples.\n", + "* `len(exp_replay)` - returns number of elements stored in replay buffer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ydi0KK9LsoUC" + }, + "outputs": [], + "source": [ + "from replay_buffer import ReplayBuffer\n", + "exp_replay = ReplayBuffer(10)\n", + "\n", + "for _ in range(30):\n", + " exp_replay.add(env.reset()[0], env.action_space.sample(), 1.0, env.reset()[0], done=False)\n", + "\n", + "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(5)\n", + "\n", + "assert len(exp_replay) == 10, \"experience replay size should be 10 because that's what maximum capacity is\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cEXv69KWsoUC" + }, + "outputs": [], + "source": [ + "def play_and_record(initial_state, agent, env, exp_replay, n_steps=1):\n", + " \"\"\"\n", + " Play the game for exactly n_steps, record every (s,a,r,s', done) to replay buffer.\n", + " Whenever game ends due to termination or truncation, add record with done=terminated and reset the game.\n", + " It is guaranteed that env has terminated=False when passed to this function.\n", + "\n", + " PLEASE DO NOT RESET ENV UNLESS IT IS \"DONE\"\n", + "\n", + " :returns: return sum of rewards over time and the state in which the env stays\n", + " \"\"\"\n", + " s = initial_state\n", + " sum_rewards = 0\n", + "\n", + " # Play the game for n_steps as per instructions above\n", + " \n", + "\n", + " return sum_rewards, s" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GaFynKaMsoUF", + "outputId": "765187d0-4391-4532-9a42-e12b7cff194c" + }, + "outputs": [], + "source": [ + "# testing your code.\n", + "exp_replay = ReplayBuffer(2000)\n", + "\n", + "state, _ = env.reset()\n", + "play_and_record(state, agent, env, exp_replay, n_steps=1000)\n", + "\n", + "# if you're using your own experience replay buffer, some of those tests may need correction.\n", + "# just make sure you know what your code does\n", + "assert len(exp_replay) == 1000, \\\n", + " \"play_and_record should have added exactly 1000 steps, \" \\\n", + " \"but instead added %i\" % len(exp_replay)\n", + "is_dones = list(zip(*exp_replay._storage))[-1]\n", + "\n", + "assert 0 < np.mean(is_dones) < 0.1, \\\n", + " \"Please make sure you restart the game whenever it is 'done' and \" \\\n", + " \"record the is_done correctly into the buffer. Got %f is_done rate over \" \\\n", + " \"%i steps. [If you think it's your tough luck, just re-run the test]\" % (\n", + " np.mean(is_dones), len(exp_replay))\n", + "\n", + "for _ in range(100):\n", + " obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(10)\n", + " assert obs_batch.shape == next_obs_batch.shape == (10,) + state_shape\n", + " assert act_batch.shape == (10,), \\\n", + " \"actions batch should have shape (10,) but is instead %s\" % str(act_batch.shape)\n", + " assert reward_batch.shape == (10,), \\\n", + " \"rewards batch should have shape (10,) but is instead %s\" % str(reward_batch.shape)\n", + " assert is_done_batch.shape == (10,), \\\n", + " \"is_done batch should have shape (10,) but is instead %s\" % str(is_done_batch.shape)\n", + " assert [int(i) in (0, 1) for i in is_dones], \\\n", + " \"is_done should be strictly True or False\"\n", + " assert [0 <= a < n_actions for a in act_batch], \"actions should be within [0, n_actions)\"\n", + "\n", + "print(\"Well done!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "y5zyryPOsoUF" + }, + "source": [ + "### Target networks\n", + "\n", + "We also employ the so called \"target network\" - a copy of neural network weights to be used for reference Q-values:\n", + "\n", + "The network itself is an exact copy of agent network, but it's parameters are not trained. Instead, they are moved here from agent's actual network every so often.\n", + "\n", + "$$ Q_{reference}(s,a) = r + \\gamma \\cdot \\max _{a'} Q_{target}(s',a') $$\n", + "\n", + "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/target_net.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "EMtDyN9fsoUF", + "outputId": "d8ff3f8e-d508-4047-9eaa-bb35f949a58c" + }, + "outputs": [], + "source": [ + "target_network = DQNAgent(agent.state_shape, agent.n_actions, epsilon=0.5).to(device)\n", + "# This is how you can load weights from agent into target network\n", + "target_network.load_state_dict(agent.state_dict())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2idY8QX0soUF" + }, + "source": [ + "### Learning with... Q-learning\n", + "Here we write a function similar to `agent.update` from tabular q-learning." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "k46MPwwwsoUF" + }, + "source": [ + "Compute Q-learning TD error:\n", + "\n", + "$$ L = { 1 \\over N} \\sum_i [ Q_{\\theta}(s,a) - Q_{reference}(s,a) ] ^2 $$\n", + "\n", + "With Q-reference defined as\n", + "\n", + "$$ Q_{reference}(s,a) = r(s,a) + \\gamma \\cdot max_{a'} Q_{target}(s', a') $$\n", + "\n", + "Where\n", + "* $Q_{target}(s',a')$ denotes Q-value of next state and next action predicted by __target_network__\n", + "* $s, a, r, s'$ are current state, action, reward and next state respectively\n", + "* $\\gamma$ is a discount factor defined two cells above.\n", + "\n", + "\n", + "__Note 1:__ there's an example input below. Feel free to experiment with it before you write the function.\n", + "\n", + "__Note 2:__ compute_td_loss is a source of 99% of bugs in this homework. If reward doesn't improve, it often helps to go through it line by line [with a rubber duck](https://rubberduckdebugging.com/).\n", + "\n", + "**Double DQN (+2 pts)**\n", + "\n", + "$$ Q_{reference}(s,a) = r(s, a) + \\gamma \\cdot\n", + "Q_{target}(s',argmax_{a'}Q_\\theta(s', a')) $$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "V02HcUYasoUG", + "outputId": "7a11e3d3-d030-40be-8f14-59b5481749fb" + }, + "outputs": [], + "source": [ + "def compute_td_loss(states, actions, rewards, next_states, is_done,\n", + " agent, target_network,\n", + " gamma=0.99,\n", + " check_shapes=False,\n", + " device=device):\n", + " \"\"\" Compute td loss using torch operations only. Use the formulae above. \"\"\"\n", + " states = torch.tensor(states, device=device, dtype=torch.float32) # shape: [batch_size, *state_shape]\n", + " actions = torch.tensor(actions, device=device, dtype=torch.int64) # shape: [batch_size]\n", + " rewards = torch.tensor(rewards, device=device, dtype=torch.float32) # shape: [batch_size]\n", + " # shape: [batch_size, *state_shape]\n", + " next_states = torch.tensor(next_states, device=device, dtype=torch.float)\n", + " is_done = torch.tensor(\n", + " is_done.astype('float32'),\n", + " device=device,\n", + " dtype=torch.float32,\n", + " ) # shape: [batch_size]\n", + " is_not_done = 1 - is_done\n", + "\n", + " # get q-values for all actions in current states\n", + " predicted_qvalues = agent(states) # shape: [batch_size, n_actions]\n", + "\n", + " # compute q-values for all actions in next states\n", + " predicted_next_qvalues = target_network(next_states) # shape: [batch_size, n_actions]\n", + " \n", + " # select q-values for chosen actions\n", + " predicted_qvalues_for_actions = predicted_qvalues[range(len(actions)), actions] # shape: [batch_size]\n", + "\n", + " # compute V*(next_states) using predicted next q-values\n", + " next_state_values = \n", + "\n", + " assert next_state_values.dim() == 1 and next_state_values.shape[0] == states.shape[0], \\\n", + " \"must predict one value per state\"\n", + "\n", + " # compute \"target q-values\" for loss - it's what's inside square parentheses in the above formula.\n", + " # at the last state use the simplified formula: Q(s,a) = r(s,a) since s' doesn't exist\n", + " # you can multiply next state values by is_not_done to achieve this.\n", + " target_qvalues_for_actions = \n", + "\n", + " # mean squared error loss to minimize\n", + " loss = torch.mean((predicted_qvalues_for_actions - target_qvalues_for_actions.detach()) ** 2)\n", + "\n", + " if check_shapes:\n", + " assert predicted_next_qvalues.data.dim() == 2, \\\n", + " \"make sure you predicted q-values for all actions in next state\"\n", + " assert next_state_values.data.dim() == 1, \\\n", + " \"make sure you computed V(s') as maximum over just the actions axis and not all axes\"\n", + " assert target_qvalues_for_actions.data.dim() == 1, \\\n", + " \"there's something wrong with target q-values, they must be a vector\"\n", + "\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "x8AvquAtsoUG" + }, + "source": [ + "Sanity checks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5nRoOn30soUG" + }, + "outputs": [], + "source": [ + "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(10)\n", + "\n", + "loss = compute_td_loss(obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch,\n", + " agent, target_network,\n", + " gamma=0.99, check_shapes=True)\n", + "loss.backward()\n", + "\n", + "assert loss.requires_grad and tuple(loss.data.size()) == (), \\\n", + " \"you must return scalar loss - mean over batch\"\n", + "assert np.any(next(agent.parameters()).grad.data.cpu().numpy() != 0), \\\n", + " \"loss must be differentiable w.r.t. network weights\"\n", + "assert np.all(next(target_network.parameters()).grad is None), \\\n", + " \"target network should not have grads\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KIplL0hSsoUG" + }, + "source": [ + "## Main loop (3 pts)\n", + "\n", + "**If deadline is tonight and it has not converged:** It is ok. Send the notebook today and when it converges send it again.\n", + "If the code is exactly the same points will not be discounted.\n", + "\n", + "It's time to put everything together and see if it learns anything." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-JV-ulB-soUG" + }, + "outputs": [], + "source": [ + "from tqdm import trange\n", + "from IPython.display import clear_output\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HY9NluVqsoUG", + "outputId": "d2ec6d34-54f5-4b49-bf52-c6c5045d70c2" + }, + "outputs": [], + "source": [ + "seed = \n", + "random.seed(seed)\n", + "np.random.seed(seed)\n", + "torch.manual_seed(seed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-eurxA-_soUG", + "outputId": "4eb73eb0-771c-442e-a086-50bae3ebd9d8" + }, + "outputs": [], + "source": [ + "env = make_env(seed)\n", + "state_shape = env.observation_space.shape\n", + "n_actions = env.action_space.n\n", + "state, _ = env.reset()\n", + "\n", + "agent = DQNAgent(state_shape, n_actions, epsilon=1).to(device)\n", + "target_network = DQNAgent(state_shape, n_actions).to(device)\n", + "target_network.load_state_dict(agent.state_dict())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WZg25kIasoUG" + }, + "source": [ + "Buffer of size $10^4$ fits into 5 Gb RAM.\n", + "\n", + "Larger sizes ($10^5$ and $10^6$ are common) can be used. It can improve the learning, but $10^4$ is quite enough. $10^2$ will probably fail learning." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "hWyMxfN4soUG", + "outputId": "a0d4147b-56b2-4f69-802a-0da87ad82bdb" + }, + "outputs": [], + "source": [ + "REPLAY_BUFFER_SIZE = 10**4\n", + "N_STEPS = 100\n", + "\n", + "exp_replay = ReplayBuffer(REPLAY_BUFFER_SIZE)\n", + "for i in trange(REPLAY_BUFFER_SIZE // N_STEPS):\n", + " if not utils.is_enough_ram(min_available_gb=0.1):\n", + " print(\"\"\"\n", + " Less than 100 Mb RAM available.\n", + " Make sure the buffer size in not too huge.\n", + " Also check, maybe other processes consume RAM heavily.\n", + " \"\"\"\n", + " )\n", + " break\n", + " play_and_record(state, agent, env, exp_replay, n_steps=N_STEPS)\n", + " if len(exp_replay) == REPLAY_BUFFER_SIZE:\n", + " break\n", + "print(len(exp_replay))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_ca9vbW4soUG" + }, + "outputs": [], + "source": [ + "timesteps_per_epoch = 1\n", + "batch_size = 16\n", + "total_steps = 3 * 10**6\n", + "decay_steps = 10**6\n", + "\n", + "opt = torch.optim.Adam(agent.parameters(), lr=1e-4)\n", + "\n", + "init_epsilon = 1\n", + "final_epsilon = 0.1\n", + "\n", + "loss_freq = 50\n", + "refresh_target_network_freq = 5000\n", + "eval_freq = 5000\n", + "\n", + "max_grad_norm = 50\n", + "\n", + "n_lives = 5" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oJWs0q-6soUG" + }, + "outputs": [], + "source": [ + "mean_rw_history = []\n", + "td_loss_history = []\n", + "grad_norm_history = []\n", + "initial_state_v_history = []\n", + "step = 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "675-JU0hsoUG" + }, + "outputs": [], + "source": [ + "import time\n", + "\n", + "def wait_for_keyboard_interrupt():\n", + " try:\n", + " while True:\n", + " time.sleep(1)\n", + " except KeyboardInterrupt:\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FgQ1vK3CsoUG" + }, + "outputs": [], + "source": [ + "state, _ = env.reset()\n", + "with trange(step, total_steps + 1) as progress_bar:\n", + " for step in progress_bar:\n", + " if not utils.is_enough_ram():\n", + " print('less that 100 Mb RAM available, freezing')\n", + " print('make sure everything is ok and use KeyboardInterrupt to continue')\n", + " wait_for_keyboard_interrupt()\n", + "\n", + " agent.epsilon = utils.linear_decay(init_epsilon, final_epsilon, step, decay_steps)\n", + "\n", + " # play\n", + " _, state = play_and_record(state, agent, env, exp_replay, timesteps_per_epoch)\n", + "\n", + " # train\n", + " \n", + "\n", + " loss = \n", + "\n", + " loss.backward()\n", + " grad_norm = nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)\n", + " opt.step()\n", + " opt.zero_grad()\n", + "\n", + " if step % loss_freq == 0:\n", + " td_loss_history.append(loss.data.cpu().item())\n", + " grad_norm_history.append(grad_norm.cpu())\n", + "\n", + " if step % refresh_target_network_freq == 0:\n", + " # Load agent weights into target_network\n", + " \n", + "\n", + " if step % eval_freq == 0:\n", + " mean_rw_history.append(evaluate(\n", + " make_env(clip_rewards=True), agent, n_games=3 * n_lives, greedy=True, seed=step)\n", + " )\n", + " initial_state_q_values = agent.get_qvalues(\n", + " [make_env().reset(seed=step)[0]]\n", + " )\n", + " initial_state_v_history.append(np.max(initial_state_q_values))\n", + "\n", + " clear_output(True)\n", + " print(\"buffer size = %i, epsilon = %.5f\" %\n", + " (len(exp_replay), agent.epsilon))\n", + "\n", + " plt.figure(figsize=[16, 9])\n", + "\n", + " plt.subplot(2, 2, 1)\n", + " plt.title(\"Mean reward per life\")\n", + " plt.plot(mean_rw_history)\n", + " plt.grid()\n", + "\n", + " assert not np.isnan(td_loss_history[-1])\n", + " plt.subplot(2, 2, 2)\n", + " plt.title(\"TD loss history (smoothened)\")\n", + " plt.plot(utils.smoothen(td_loss_history))\n", + " plt.grid()\n", + "\n", + " plt.subplot(2, 2, 3)\n", + " plt.title(\"Initial state V\")\n", + " plt.plot(initial_state_v_history)\n", + " plt.grid()\n", + "\n", + " plt.subplot(2, 2, 4)\n", + " plt.title(\"Grad norm history (smoothened)\")\n", + " plt.plot(utils.smoothen(grad_norm_history))\n", + " plt.grid()\n", + "\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZEDQhQrdsoUG" + }, + "source": [ + "Agent is evaluated for 1 life, not for a whole episode of 5 lives. Rewards in evaluation are also truncated. Cuz this is what environment the agent is learning in and in this way mean rewards per life can be compared with initial state value\n", + "\n", + "**The goal is to get 15 points in the real env**. So 3 or better 4 points in the preprocessed one will probably be enough. You can interrupt learning then." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "s0jLjYGwsoUG" + }, + "source": [ + "Final scoring is done on a whole episode with all 5 lives." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xTGVrwwQsoUG" + }, + "outputs": [], + "source": [ + "final_score = evaluate(\n", + " make_env(clip_rewards=False),\n", + " agent, n_games=30, greedy=True, t_max=10 * 1000, seed=9\n", + ")\n", + "print('final score:', final_score)\n", + "assert final_score >= 3, 'not as cool as DQN can'\n", + "print('Cool!')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ovaG8N4lsoUH" + }, + "source": [ + "## How to interpret plots:\n", + "\n", + "This aint no supervised learning so don't expect anything to improve monotonously.\n", + "* **TD loss** is the MSE between agent's current Q-values and target Q-values. It may slowly increase or decrease, it's ok. The \"not ok\" behavior includes going NaN or stayng at exactly zero before agent has perfect performance.\n", + "* **grad norm** just shows the intensivity of training. Not ok is growing to values of about 100 (or maybe even 50) though it depends on network architecture.\n", + "* **mean reward** is the expected sum of r(s,a) agent gets over the full game session. It will oscillate, but on average it should get higher over time (after a few thousand iterations...).\n", + " * In basic q-learning implementation it takes about 40k steps to \"warm up\" agent before it starts to get better.\n", + "* **Initial state V** is the expected discounted reward for episode in the oppinion of the agent. It should behave more smoothly than **mean reward**. It should get higher over time but sometimes can experience drawdowns because of the agaent's overestimates.\n", + "* **buffer size** - this one is simple. It should go up and cap at max size.\n", + "* **epsilon** - agent's willingness to explore. If you see that agent's already at 0.01 epsilon before it's average reward is above 0 - it means you need to increase epsilon. Set it back to some 0.2 - 0.5 and decrease the pace at which it goes down.\n", + "* Smoothing of plots is done with a gaussian kernel\n", + "\n", + "At first your agent will lose quickly. Then it will learn to suck less and at least hit the ball a few times before it loses. Finally it will learn to actually score points.\n", + "\n", + "**Training will take time.** A lot of it actually. Probably you will not see any improvment during first **150k** time steps (note that by default in this notebook agent is evaluated every 5000 time steps).\n", + "\n", + "But hey, long training time isn't _that_ bad:\n", + "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/training.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kVV72AB-soUH" + }, + "source": [ + "## About hyperparameters:\n", + "\n", + "The task has something in common with supervised learning: loss is optimized through the buffer (instead of Train dataset). But the distribution of states and actions in the buffer **is not stationary** and depends on the policy that generated it. It can even happen that the mean TD error across the buffer is very low but the performance is extremely poor (imagine the agent collecting data to the buffer always manages to avoid the ball).\n", + "\n", + "* Total timesteps and training time: It seems to be so huge, but actually it is normal for RL.\n", + "\n", + "* $\\epsilon$ decay shedule was taken from the original paper and is like traditional for epsilon-greedy policies. At the beginning of the training the agent's greedy policy is poor so many random actions should be taken.\n", + "\n", + "* Optimizer: In the original paper RMSProp was used (they did not have Adam in 2013) and it can work not worse than Adam. For us Adam was default and it worked.\n", + "\n", + "* lr: $10^{-3}$ would probably be too huge\n", + "\n", + "* batch size: This one can be very important: if it is too small the agent can fail to learn. Huge batch takes more time to process. If batch of size 8 can not be processed on the hardware you use take 2 (or even 4) batches of size 4, divide the loss on them by 2 (or 4) and make optimization step after both backward() calls in torch.\n", + "\n", + "* target network update frequency: has something in common with learning rate. Too frequent updates can lead to divergence. Too rare can lead to slow leraning. For millions of total timesteps thousands of inner steps seem ok. One iteration of target network updating is an iteration of the (this time approximate) $\\gamma$-compression that stands behind Q-learning. The more inner steps it makes the more accurate is the compression.\n", + "* max_grad_norm - just huge enough. In torch clip_grad_norm also evaluates the norm before clipping and it can be convenient for logging." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Plp8WC_esoUH" + }, + "source": [ + "### Video" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DdExc_AssoUH" + }, + "outputs": [], + "source": [ + "# record sessions\n", + "from gymnasium.wrappers import RecordVideo\n", + "\n", + "with make_env() as env, RecordVideo(\n", + " env=env, video_folder=\"./videos\", episode_trigger=lambda episode_number: True\n", + ") as env_monitor:\n", + " sessions = [\n", + " evaluate(env_monitor, agent, n_games=n_lives, greedy=True) for _ in range(10)\n", + " ]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lt6xg1n_soUH" + }, + "outputs": [], + "source": [ + "# Show video. This may not work in some setups. If it doesn't\n", + "# work for you, you can download the videos and view them locally.\n", + "\n", + "from pathlib import Path\n", + "from base64 import b64encode\n", + "from IPython.display import HTML\n", + "\n", + "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n", + "video_path = video_paths[-1] # You can also try other indices\n", + "\n", + "if 'google.colab' in sys.modules:\n", + " # https://stackoverflow.com/a/57378660/1214547\n", + " with video_path.open('rb') as fp:\n", + " mp4 = fp.read()\n", + " data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", + "else:\n", + " data_url = str(video_path)\n", + "\n", + "HTML(\"\"\"\n", + "\n", + "\"\"\".format(data_url))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fLPx2aI7soUH" + }, + "source": [ + "## Let's have a closer look at this.\n", + "\n", + "If average episode score is below 200 using all 5 lives, then probably DQN has not converged fully. But anyway let's make a more complete record of an episode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Q9hjXI6WsoUH" + }, + "outputs": [], + "source": [ + "eval_env = make_env(clip_rewards=False)\n", + "record = utils.play_and_log_episode(eval_env, agent)\n", + "print('total reward for life:', np.sum(record['rewards']))\n", + "for key in record:\n", + " print(key)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HD10tYWgsoUH" + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(5, 5))\n", + "ax = fig.add_subplot(1, 1, 1)\n", + "\n", + "ax.scatter(record['v_mc'], record['v_agent'])\n", + "ax.plot(sorted(record['v_mc']), sorted(record['v_mc']),\n", + " 'black', linestyle='--', label='x=y')\n", + "\n", + "ax.grid()\n", + "ax.legend()\n", + "ax.set_title('State Value Estimates')\n", + "ax.set_xlabel('Monte-Carlo')\n", + "ax.set_ylabel('Agent')\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_j3Sopf0soUH" + }, + "source": [ + "$\\hat V_{Monte-Carlo}(s_t) = \\sum_{\\tau=0}^{episode~end} \\gamma^{\\tau-t}r_t$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "46yLblzXsoUH" + }, + "source": [ + "Is there a big bias? It's ok, anyway it works." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FjKayU39soUH" + }, + "source": [ + "## Bonus I (2 pts)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iUZPD4RUsoUH" + }, + "source": [ + "**1.** Plot several (say 3) states with high and low spreads of Q estimate by actions i.e.\n", + "$$\\max_a \\hat Q(s,a) - \\min_a \\hat Q(s,a)\\$$\n", + "Please take those states from different episodes to make sure that the states are really different.\n", + "\n", + "What should high and low spread mean at least in the world of perfect Q-fucntions?\n", + "\n", + "Comment the states you like most.\n", + "\n", + "**2.** Plot several (say 3) states with high td-error and several states with high values of\n", + "$$| \\hat V_{Monte-Carlo}(s) - \\hat V_{agent}(s)|,$$\n", + "$$\\hat V_{agent}(s)=\\max_a \\hat Q(s,a).$$ Please take those states from different episodes to make sure that the states are really different. From what part (i.e. beginning, middle, end) of an episode did these states come from?\n", + "\n", + "Comment the states you like most." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "se-cjDOasoUH" + }, + "outputs": [], + "source": [ + "from utils import play_and_log_episode, img_by_obs\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "e36bU0u8soUH" + }, + "source": [ + "## Bonus II (1-5 pts). Get High Score!\n", + "\n", + "1 point to you for each 50 points of your agent. Truncated by 5 points. Starting with 50 points, **not** 50 + threshold.\n", + "\n", + "One way is to train for several days and use heavier hardware (why not actually).\n", + "\n", + "Another way is to apply modifications (see **Bonus III**)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "78e4nRoSsoUH" + }, + "source": [ + "## Bonus III (2+ pts). Apply modifications to DQN.\n", + "\n", + "For inspiration see [Rainbow](https://arxiv.org/abs/1710.02298) - a version of q-learning that combines lots of them.\n", + "\n", + "Points for Bonus II and Bonus III fully stack. So if modified agent gets score 250+ you get 5 pts for Bonus II + points for modifications. If the final score is 40 then you get the points for modifications.\n", + "\n", + "\n", + "Some modifications:\n", + "* [Prioritized experience replay](https://arxiv.org/abs/1511.05952) (5 pts for your own implementation, 3 pts for using a ready one)\n", + "* [double q-learning](https://arxiv.org/abs/1509.06461) (2 pts)\n", + "* [dueling q-learning](https://arxiv.org/abs/1511.06581) (2 pts)\n", + "* multi-step heuristics (see [Rainbow](https://arxiv.org/abs/1710.02298)) (3 pts)\n", + "* [Noisy Nets](https://arxiv.org/abs/1706.10295) (3 pts)\n", + "* [distributional RL](https://arxiv.org/abs/1707.06887)(distributional and distributed stand for different things here) (5 pts)\n", + "* Other modifications (2+ pts depending on complexity)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "j_3RXboysoUH" + }, + "source": [ + "## Bonus IV (4+ pts). Distributed RL.\n", + "\n", + "Solve the task in a distributed way. It can strongly speed up learning. See [article](https://arxiv.org/pdf/1602.01783.pdf) or some guides." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7RdBNRyfsoUH" + }, + "source": [ + "**As usual bonus points for all the tasks fully stack.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "f9X5aB56soUI" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" } - ], - "source": [ - "env = gym.make(ENV_NAME)\n", - "env.reset()\n", - "\n", - "n_cols = 5\n", - "n_rows = 2\n", - "fig = plt.figure(figsize=(16, 9))\n", - "\n", - "for row in range(n_rows):\n", - " for col in range(n_cols):\n", - " ax = fig.add_subplot(n_rows, n_cols, row * n_cols + col + 1)\n", - " ax.imshow(env.render('rgb_array'))\n", - " env.step(env.action_space.sample())\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Let's play a little.**\n", - "\n", - "Pay attention to zoom and fps args of play function. Control: A, D, space." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# # Does not work in Colab.\n", - "# # Use KeyboardInterrupt (Kernel → Interrupt in Jupyter) to continue.\n", - "\n", - "# from gym.utils.play import play\n", - "\n", - "# play(env=gym.make(ENV_NAME), zoom=5, fps=30)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Processing game image \n", - "\n", - "Raw Atari images are large, 210x160x3 by default. However, we don't need that level of detail in order to learn from them.\n", - "\n", - "We can thus save a lot of time by preprocessing game image, including\n", - "* Resizing to a smaller shape, 64x64\n", - "* Converting to grayscale\n", - "* Cropping irrelevant image parts (top, bottom and edges)\n", - "\n", - "Also please keep one dimension for channel so that final shape would be 1x64x64.\n", - "\n", - "Tip: You can implement your own grayscale converter and assign a huge weight to the red channel. This dirty trick is not necessary but it will speed up learning." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from gym.core import ObservationWrapper\n", - "from gym.spaces import Box\n", - "\n", - "\n", - "class PreprocessAtariObs(ObservationWrapper):\n", - " def __init__(self, env):\n", - " \"\"\"A gym wrapper that crops, scales image into the desired shapes and grayscales it.\"\"\"\n", - " ObservationWrapper.__init__(self, env)\n", - "\n", - " self.img_size = (1, 64, 64)\n", - " self.observation_space = Box(0.0, 1.0, self.img_size)\n", - "\n", - "\n", - " def _to_gray_scale(self, rgb, channel_weights=[0.8, 0.1, 0.1]):\n", - " \n", - "\n", - "\n", - " def observation(self, img):\n", - " \"\"\"what happens to each observation\"\"\"\n", - "\n", - " # Here's what you need to do:\n", - " # * crop image, remove irrelevant parts\n", - " # * resize image to self.img_size\n", - " # (Use imresize from any library you want,\n", - " # e.g. opencv, PIL, keras. Don't use skimage.imresize\n", - " # because it is extremely slow.)\n", - " # * cast image to grayscale\n", - " # * convert image pixels to (0,1) range, float32 type\n", - " \n", - " return " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import gym\n", - "# spawn game instance for tests\n", - "env = gym.make(ENV_NAME) # create raw env\n", - "env = PreprocessAtariObs(env)\n", - "observation_shape = env.observation_space.shape\n", - "n_actions = env.action_space.n\n", - "env.reset()\n", - "obs, _, _, _ = env.step(env.action_space.sample())\n", - "\n", - "# test observation\n", - "assert obs.ndim == 3, \"observation must be [channel, h, w] even if there's just one channel\"\n", - "assert obs.shape == observation_shape, obs.shape\n", - "assert obs.dtype == 'float32'\n", - "assert len(np.unique(obs)) > 2, \"your image must not be binary\"\n", - "assert 0 <= np.min(obs) and np.max(\n", - " obs) <= 1, \"convert image pixels to [0,1] range\"\n", - "\n", - "assert np.max(obs) >= 0.5, \"It would be easier to see a brighter observation\"\n", - "assert np.mean(obs) >= 0.1, \"It would be easier to see a brighter observation\"\n", - "\n", - "print(\"Formal tests seem fine. Here's an example of what you'll get.\")\n", - "\n", - "n_cols = 5\n", - "n_rows = 2\n", - "fig = plt.figure(figsize=(16, 9))\n", - "obs = env.reset()\n", - "for row in range(n_rows):\n", - " for col in range(n_cols):\n", - " ax = fig.add_subplot(n_rows, n_cols, row * n_cols + col + 1)\n", - " ax.imshow(obs[0, :, :], interpolation='none', cmap='gray')\n", - " obs, _, _, _ = env.step(env.action_space.sample())\n", - "plt.show()\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Wrapping." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**About the game:** You have 5 lives and get points for breaking the wall. Higher bricks cost more than the lower ones. There are 4 actions: start game (should be called at the beginning and after each life is lost), move left, move right and do nothing. There are some common wrappers used for Atari environments." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import atari_wrappers\n", - "\n", - "def PrimaryAtariWrap(env, clip_rewards=True):\n", - " assert 'NoFrameskip' in env.spec.id\n", - "\n", - " # This wrapper holds the same action for frames and outputs\n", - " # the maximal pixel value of 2 last frames (to handle blinking\n", - " # in some envs)\n", - " env = atari_wrappers.MaxAndSkipEnv(env, skip=4)\n", - "\n", - " # This wrapper sends done=True when each life is lost\n", - " # (not all the 5 lives that are givern by the game rules).\n", - " # It should make easier for the agent to understand that losing is bad.\n", - " env = atari_wrappers.EpisodicLifeEnv(env)\n", - "\n", - " # This wrapper laucnhes the ball when an episode starts.\n", - " # Without it the agent has to learn this action, too.\n", - " # Actually it can but learning would take longer.\n", - " env = atari_wrappers.FireResetEnv(env)\n", - "\n", - " # This wrapper transforms rewards to {-1, 0, 1} according to their sign\n", - " if clip_rewards:\n", - " env = atari_wrappers.ClipRewardEnv(env)\n", - "\n", - " # This wrapper is yours :)\n", - " env = PreprocessAtariObs(env)\n", - " return env" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Let's see if the game is still playable after applying the wrappers.**\n", - "At playing the EpisodicLifeEnv wrapper seems not to work but actually it does (because after when life finishes a new ball is dropped automatically - it means that FireResetEnv wrapper understands that a new episode began)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# # Does not work in Colab.\n", - "# # Use KeyboardInterrupt (Kernel → Interrupt in Jupyter) to continue.\n", - "\n", - "# from gym.utils.play import play\n", - "\n", - "# def make_play_env():\n", - "# env = gym.make(ENV_NAME)\n", - "# env = PrimaryAtariWrap(env)\n", - "# # in PyTorch images have shape [c, h, w] instead of common [h, w, c]\n", - "# env = atari_wrappers.AntiTorchWrapper(env)\n", - "# return env\n", - "\n", - "# play(make_play_env(), zoom=10, fps=3)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Frame buffer\n", - "\n", - "Our agent can only process one observation at a time, so we gotta make sure it contains enough information to find optimal actions. For instance, agent has to react to moving objects so it must be able to measure object's velocity.\n", - "\n", - "To do so, we introduce a buffer that stores 4 last images. This time everything is pre-implemented for you, not really by the staff of the course :)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from framebuffer import FrameBuffer\n", - "\n", - "def make_env(clip_rewards=True, seed=None):\n", - " env = gym.make(ENV_NAME) # create raw env\n", - " if seed is not None:\n", - " env.seed(seed)\n", - " env = PrimaryAtariWrap(env, clip_rewards)\n", - " env = FrameBuffer(env, n_frames=4, dim_order='pytorch')\n", - " return env\n", - "\n", - "env = make_env()\n", - "env.reset()\n", - "n_actions = env.action_space.n\n", - "state_shape = env.observation_space.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for _ in range(12):\n", - " obs, _, _, _ = env.step(env.action_space.sample())\n", - "\n", - "plt.figure(figsize=[12,10])\n", - "plt.title(\"Game image\")\n", - "plt.imshow(env.render(\"rgb_array\"))\n", - "plt.show()\n", - "\n", - "plt.figure(figsize=[15,15])\n", - "plt.title(\"Agent observation (4 frames top to bottom)\")\n", - "plt.imshow(utils.img_by_obs(obs, state_shape), cmap='gray')\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## DQN as it is (4 pts)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Building a network\n", - "\n", - "We now need to build a neural network that can map images to state q-values. This network will be called on every agent's step so it better not be resnet-152 unless you have an array of GPUs. Instead, you can use strided convolutions with a small number of features to save time and memory.\n", - "\n", - "You can build any architecture you want, but for reference, here's something that will more or less work:" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/dqn_arch.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Dueling network: (+2 pts)**\n", - "$$Q_{\\theta}(s, a) = V_{\\eta}(f_{\\xi}(s)) + A_{\\psi}(f_{\\xi}(s), a) - \\frac{\\sum_{a'}A_{\\psi}(f_{\\xi}(s), a')}{N_{actions}},$$\n", - "where $\\xi$, $\\eta$, and $\\psi$ are, respectively, the parameters of the\n", - "shared encoder $f_ξ$ , of the value stream $V_\\eta$ , and of the advan\n", - "tage stream $A_\\psi$; and $\\theta = \\{\\xi, \\eta, \\psi\\}$ is their concatenation.\n", - "\n", - "For the architecture on the image $V$ and $A$ heads can follow the dense layer instead of $Q$. Please don't worry that the model becomes a little bigger." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", - "# those who have a GPU but feel unfair to use it can uncomment:\n", - "# device = torch.device('cpu')\n", - "device" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def conv2d_size_out(size, kernel_size, stride):\n", - " \"\"\"\n", - " common use case:\n", - " cur_layer_img_w = conv2d_size_out(cur_layer_img_w, kernel_size, stride)\n", - " cur_layer_img_h = conv2d_size_out(cur_layer_img_h, kernel_size, stride)\n", - " to understand the shape for dense layer's input\n", - " \"\"\"\n", - " return (size - (kernel_size - 1) - 1) // stride + 1" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class DQNAgent(nn.Module):\n", - " def __init__(self, state_shape, n_actions, epsilon=0):\n", - "\n", - " super().__init__()\n", - " self.epsilon = epsilon\n", - " self.n_actions = n_actions\n", - " self.state_shape = state_shape\n", - "\n", - " # Define your network body here. Please make sure agent is fully contained here\n", - " # nn.Flatten() can be useful\n", - " \n", - " \n", - "\n", - " def forward(self, state_t):\n", - " \"\"\"\n", - " takes agent's observation (tensor), returns qvalues (tensor)\n", - " :param state_t: a batch of 4-frame buffers, shape = [batch_size, 4, h, w]\n", - " \"\"\"\n", - " # Use your network to compute qvalues for given state\n", - " qvalues = \n", - "\n", - " assert qvalues.requires_grad, \"qvalues must be a torch tensor with grad\"\n", - " assert (\n", - " len(qvalues.shape) == 2 and \n", - " qvalues.shape[0] == state_t.shape[0] and \n", - " qvalues.shape[1] == n_actions\n", - " )\n", - "\n", - " return qvalues\n", - "\n", - " def get_qvalues(self, states):\n", - " \"\"\"\n", - " like forward, but works on numpy arrays, not tensors\n", - " \"\"\"\n", - " model_device = next(self.parameters()).device\n", - " states = torch.tensor(states, device=model_device, dtype=torch.float32)\n", - " qvalues = self.forward(states)\n", - " return qvalues.data.cpu().numpy()\n", - "\n", - " def sample_actions(self, qvalues):\n", - " \"\"\"pick actions given qvalues. Uses epsilon-greedy exploration strategy. \"\"\"\n", - " epsilon = self.epsilon\n", - " batch_size, n_actions = qvalues.shape\n", - "\n", - " random_actions = np.random.choice(n_actions, size=batch_size)\n", - " best_actions = qvalues.argmax(axis=-1)\n", - "\n", - " should_explore = np.random.choice(\n", - " [0, 1], batch_size, p=[1-epsilon, epsilon])\n", - " return np.where(should_explore, random_actions, best_actions)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "agent = DQNAgent(state_shape, n_actions, epsilon=0.5).to(device)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now let's try out our agent to see if it raises any errors." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def evaluate(env, agent, n_games=1, greedy=False, t_max=10000):\n", - " \"\"\" Plays n_games full games. If greedy, picks actions as argmax(qvalues). Returns mean reward. \"\"\"\n", - " rewards = []\n", - " for _ in range(n_games):\n", - " s = env.reset()\n", - " reward = 0\n", - " for _ in range(t_max):\n", - " qvalues = agent.get_qvalues([s])\n", - " action = qvalues.argmax(axis=-1)[0] if greedy else agent.sample_actions(qvalues)[0]\n", - " s, r, done, _ = env.step(action)\n", - " reward += r\n", - " if done:\n", - " break\n", - "\n", - " rewards.append(reward)\n", - " return np.mean(rewards)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "evaluate(env, agent, n_games=1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Experience replay\n", - "For this assignment, we provide you with experience replay buffer. If you implemented experience replay buffer in last week's assignment, you can copy-paste it here **to get 2 bonus points**.\n", - "\n", - "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/exp_replay.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### The interface is fairly simple:\n", - "* `exp_replay.add(obs, act, rw, next_obs, done)` - saves (s,a,r,s',done) tuple into the buffer\n", - "* `exp_replay.sample(batch_size)` - returns observations, actions, rewards, next_observations and is_done for `batch_size` random samples.\n", - "* `len(exp_replay)` - returns number of elements stored in replay buffer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from replay_buffer import ReplayBuffer\n", - "exp_replay = ReplayBuffer(10)\n", - "\n", - "for _ in range(30):\n", - " exp_replay.add(env.reset(), env.action_space.sample(), 1.0, env.reset(), done=False)\n", - "\n", - "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(5)\n", - "\n", - "assert len(exp_replay) == 10, \"experience replay size should be 10 because that's what maximum capacity is\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def play_and_record(initial_state, agent, env, exp_replay, n_steps=1):\n", - " \"\"\"\n", - " Play the game for exactly n_steps, record every (s,a,r,s', done) to replay buffer. \n", - " Whenever game ends, add record with done=True and reset the game.\n", - " It is guaranteed that env has done=False when passed to this function.\n", - "\n", - " PLEASE DO NOT RESET ENV UNLESS IT IS \"DONE\"\n", - "\n", - " :returns: return sum of rewards over time and the state in which the env stays\n", - " \"\"\"\n", - " s = initial_state\n", - " sum_rewards = 0\n", - "\n", - " # Play the game for n_steps as per instructions above\n", - " \n", - "\n", - " return sum_rewards, s" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# testing your code.\n", - "exp_replay = ReplayBuffer(2000)\n", - "\n", - "state = env.reset()\n", - "play_and_record(state, agent, env, exp_replay, n_steps=1000)\n", - "\n", - "# if you're using your own experience replay buffer, some of those tests may need correction.\n", - "# just make sure you know what your code does\n", - "assert len(exp_replay) == 1000, \\\n", - " \"play_and_record should have added exactly 1000 steps, \" \\\n", - " \"but instead added %i\" % len(exp_replay)\n", - "is_dones = list(zip(*exp_replay._storage))[-1]\n", - "\n", - "assert 0 < np.mean(is_dones) < 0.1, \\\n", - " \"Please make sure you restart the game whenever it is 'done' and \" \\\n", - " \"record the is_done correctly into the buffer. Got %f is_done rate over \" \\\n", - " \"%i steps. [If you think it's your tough luck, just re-run the test]\" % (\n", - " np.mean(is_dones), len(exp_replay))\n", - "\n", - "for _ in range(100):\n", - " obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(10)\n", - " assert obs_batch.shape == next_obs_batch.shape == (10,) + state_shape\n", - " assert act_batch.shape == (10,), \\\n", - " \"actions batch should have shape (10,) but is instead %s\" % str(act_batch.shape)\n", - " assert reward_batch.shape == (10,), \\\n", - " \"rewards batch should have shape (10,) but is instead %s\" % str(reward_batch.shape)\n", - " assert is_done_batch.shape == (10,), \\\n", - " \"is_done batch should have shape (10,) but is instead %s\" % str(is_done_batch.shape)\n", - " assert [int(i) in (0, 1) for i in is_dones], \\\n", - " \"is_done should be strictly True or False\"\n", - " assert [0 <= a < n_actions for a in act_batch], \"actions should be within [0, n_actions)\"\n", - "\n", - "print(\"Well done!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Target networks\n", - "\n", - "We also employ the so called \"target network\" - a copy of neural network weights to be used for reference Q-values:\n", - "\n", - "The network itself is an exact copy of agent network, but it's parameters are not trained. Instead, they are moved here from agent's actual network every so often.\n", - "\n", - "$$ Q_{reference}(s,a) = r + \\gamma \\cdot \\max _{a'} Q_{target}(s',a') $$\n", - "\n", - "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/target_net.png)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "target_network = DQNAgent(agent.state_shape, agent.n_actions, epsilon=0.5).to(device)\n", - "# This is how you can load weights from agent into target network\n", - "target_network.load_state_dict(agent.state_dict())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Learning with... Q-learning\n", - "Here we write a function similar to `agent.update` from tabular q-learning." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Compute Q-learning TD error:\n", - "\n", - "$$ L = { 1 \\over N} \\sum_i [ Q_{\\theta}(s,a) - Q_{reference}(s,a) ] ^2 $$\n", - "\n", - "With Q-reference defined as\n", - "\n", - "$$ Q_{reference}(s,a) = r(s,a) + \\gamma \\cdot max_{a'} Q_{target}(s', a') $$\n", - "\n", - "Where\n", - "* $Q_{target}(s',a')$ denotes Q-value of next state and next action predicted by __target_network__\n", - "* $s, a, r, s'$ are current state, action, reward and next state respectively\n", - "* $\\gamma$ is a discount factor defined two cells above.\n", - "\n", - "\n", - "__Note 1:__ there's an example input below. Feel free to experiment with it before you write the function.\n", - "\n", - "__Note 2:__ compute_td_loss is a source of 99% of bugs in this homework. If reward doesn't improve, it often helps to go through it line by line [with a rubber duck](https://rubberduckdebugging.com/).\n", - "\n", - "**Double DQN (+2 pts)**\n", - "\n", - "$$ Q_{reference}(s,a) = r(s, a) + \\gamma \\cdot\n", - "Q_{target}(s',argmax_{a'}Q_\\theta(s', a')) $$" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def compute_td_loss(states, actions, rewards, next_states, is_done,\n", - " agent, target_network,\n", - " gamma=0.99,\n", - " check_shapes=False,\n", - " device=device):\n", - " \"\"\" Compute td loss using torch operations only. Use the formulae above. \"\"\"\n", - " states = torch.tensor(states, device=device, dtype=torch.float32) # shape: [batch_size, *state_shape]\n", - " actions = torch.tensor(actions, device=device, dtype=torch.int64) # shape: [batch_size]\n", - " rewards = torch.tensor(rewards, device=device, dtype=torch.float32) # shape: [batch_size]\n", - " # shape: [batch_size, *state_shape]\n", - " next_states = torch.tensor(next_states, device=device, dtype=torch.float)\n", - " is_done = torch.tensor(\n", - " is_done.astype('float32'),\n", - " device=device,\n", - " dtype=torch.float32,\n", - " ) # shape: [batch_size]\n", - " is_not_done = 1 - is_done\n", - "\n", - " # get q-values for all actions in current states\n", - " predicted_qvalues = agent(states) # shape: [batch_size, n_actions]\n", - "\n", - " # compute q-values for all actions in next states\n", - " predicted_next_qvalues = target_network(next_states) # shape: [batch_size, n_actions]\n", - " \n", - " # select q-values for chosen actions\n", - " predicted_qvalues_for_actions = predicted_qvalues[range(len(actions)), actions] # shape: [batch_size]\n", - "\n", - " # compute V*(next_states) using predicted next q-values\n", - " next_state_values = \n", - "\n", - " assert next_state_values.dim() == 1 and next_state_values.shape[0] == states.shape[0], \\\n", - " \"must predict one value per state\"\n", - "\n", - " # compute \"target q-values\" for loss - it's what's inside square parentheses in the above formula.\n", - " # at the last state use the simplified formula: Q(s,a) = r(s,a) since s' doesn't exist\n", - " # you can multiply next state values by is_not_done to achieve this.\n", - " target_qvalues_for_actions = \n", - "\n", - " # mean squared error loss to minimize\n", - " loss = torch.mean((predicted_qvalues_for_actions - target_qvalues_for_actions.detach()) ** 2)\n", - "\n", - " if check_shapes:\n", - " assert predicted_next_qvalues.data.dim() == 2, \\\n", - " \"make sure you predicted q-values for all actions in next state\"\n", - " assert next_state_values.data.dim() == 1, \\\n", - " \"make sure you computed V(s') as maximum over just the actions axis and not all axes\"\n", - " assert target_qvalues_for_actions.data.dim() == 1, \\\n", - " \"there's something wrong with target q-values, they must be a vector\"\n", - "\n", - " return loss" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Sanity checks" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(10)\n", - "\n", - "loss = compute_td_loss(obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch,\n", - " agent, target_network,\n", - " gamma=0.99, check_shapes=True)\n", - "loss.backward()\n", - "\n", - "assert loss.requires_grad and tuple(loss.data.size()) == (), \\\n", - " \"you must return scalar loss - mean over batch\"\n", - "assert np.any(next(agent.parameters()).grad.data.cpu().numpy() != 0), \\\n", - " \"loss must be differentiable w.r.t. network weights\"\n", - "assert np.all(next(target_network.parameters()).grad is None), \\\n", - " \"target network should not have grads\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Main loop (3 pts)\n", - "\n", - "**If deadline is tonight and it has not converged:** It is ok. Send the notebook today and when it converges send it again.\n", - "If the code is exactly the same points will not be discounted.\n", - "\n", - "It's time to put everything together and see if it learns anything." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from tqdm import trange\n", - "from IPython.display import clear_output\n", - "import matplotlib.pyplot as plt" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "seed = \n", - "random.seed(seed)\n", - "np.random.seed(seed)\n", - "torch.manual_seed(seed)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "env = make_env(seed)\n", - "state_shape = env.observation_space.shape\n", - "n_actions = env.action_space.n\n", - "state = env.reset()\n", - "\n", - "agent = DQNAgent(state_shape, n_actions, epsilon=1).to(device)\n", - "target_network = DQNAgent(state_shape, n_actions).to(device)\n", - "target_network.load_state_dict(agent.state_dict())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Buffer of size $10^4$ fits into 5 Gb RAM.\n", - "\n", - "Larger sizes ($10^5$ and $10^6$ are common) can be used. It can improve the learning, but $10^4$ is quite enough. $10^2$ will probably fail learning." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "REPLAY_BUFFER_SIZE = 10**4\n", - "N_STEPS = 100\n", - "\n", - "exp_replay = ReplayBuffer(REPLAY_BUFFER_SIZE)\n", - "for i in trange(REPLAY_BUFFER_SIZE // N_STEPS):\n", - " if not utils.is_enough_ram(min_available_gb=0.1):\n", - " print(\"\"\"\n", - " Less than 100 Mb RAM available. \n", - " Make sure the buffer size in not too huge.\n", - " Also check, maybe other processes consume RAM heavily.\n", - " \"\"\"\n", - " )\n", - " break\n", - " play_and_record(state, agent, env, exp_replay, n_steps=N_STEPS)\n", - " if len(exp_replay) == REPLAY_BUFFER_SIZE:\n", - " break\n", - "print(len(exp_replay))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "timesteps_per_epoch = 1\n", - "batch_size = 16\n", - "total_steps = 3 * 10**6\n", - "decay_steps = 10**6\n", - "\n", - "opt = torch.optim.Adam(agent.parameters(), lr=1e-4)\n", - "\n", - "init_epsilon = 1\n", - "final_epsilon = 0.1\n", - "\n", - "loss_freq = 50\n", - "refresh_target_network_freq = 5000\n", - "eval_freq = 5000\n", - "\n", - "max_grad_norm = 50\n", - "\n", - "n_lives = 5" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mean_rw_history = []\n", - "td_loss_history = []\n", - "grad_norm_history = []\n", - "initial_state_v_history = []\n", - "step = 0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import time\n", - "\n", - "def wait_for_keyboard_interrupt():\n", - " try:\n", - " while True:\n", - " time.sleep(1)\n", - " except KeyboardInterrupt:\n", - " pass" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "state = env.reset()\n", - "with trange(step, total_steps + 1) as progress_bar:\n", - " for step in progress_bar:\n", - " if not utils.is_enough_ram():\n", - " print('less that 100 Mb RAM available, freezing')\n", - " print('make sure everything is ok and use KeyboardInterrupt to continue')\n", - " wait_for_keyboard_interrupt()\n", - "\n", - " agent.epsilon = utils.linear_decay(init_epsilon, final_epsilon, step, decay_steps)\n", - "\n", - " # play\n", - " _, state = play_and_record(state, agent, env, exp_replay, timesteps_per_epoch)\n", - "\n", - " # train\n", - " \n", - "\n", - " loss = \n", - "\n", - " loss.backward()\n", - " grad_norm = nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)\n", - " opt.step()\n", - " opt.zero_grad()\n", - "\n", - " if step % loss_freq == 0:\n", - " td_loss_history.append(loss.data.cpu().item())\n", - " grad_norm_history.append(grad_norm.cpu())\n", - "\n", - " if step % refresh_target_network_freq == 0:\n", - " # Load agent weights into target_network\n", - " \n", - "\n", - " if step % eval_freq == 0:\n", - " mean_rw_history.append(evaluate(\n", - " make_env(clip_rewards=True, seed=step), agent, n_games=3 * n_lives, greedy=True)\n", - " )\n", - " initial_state_q_values = agent.get_qvalues(\n", - " [make_env(seed=step).reset()]\n", - " )\n", - " initial_state_v_history.append(np.max(initial_state_q_values))\n", - "\n", - " clear_output(True)\n", - " print(\"buffer size = %i, epsilon = %.5f\" %\n", - " (len(exp_replay), agent.epsilon))\n", - "\n", - " plt.figure(figsize=[16, 9])\n", - "\n", - " plt.subplot(2, 2, 1)\n", - " plt.title(\"Mean reward per life\")\n", - " plt.plot(mean_rw_history)\n", - " plt.grid()\n", - "\n", - " assert not np.isnan(td_loss_history[-1])\n", - " plt.subplot(2, 2, 2)\n", - " plt.title(\"TD loss history (smoothened)\")\n", - " plt.plot(utils.smoothen(td_loss_history))\n", - " plt.grid()\n", - "\n", - " plt.subplot(2, 2, 3)\n", - " plt.title(\"Initial state V\")\n", - " plt.plot(initial_state_v_history)\n", - " plt.grid()\n", - "\n", - " plt.subplot(2, 2, 4)\n", - " plt.title(\"Grad norm history (smoothened)\")\n", - " plt.plot(utils.smoothen(grad_norm_history))\n", - " plt.grid()\n", - "\n", - " plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Agent is evaluated for 1 life, not for a whole episode of 5 lives. Rewards in evaluation are also truncated. Cuz this is what environment the agent is learning in and in this way mean rewards per life can be compared with initial state value\n", - "\n", - "**The goal is to get 15 points in the real env**. So 3 or better 4 points in the preprocessed one will probably be enough. You can interrupt learning then." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Final scoring is done on a whole episode with all 5 lives." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "final_score = evaluate(\n", - " make_env(clip_rewards=False, seed=9),\n", - " agent, n_games=30, greedy=True, t_max=10 * 1000\n", - ")\n", - "print('final score:', final_score)\n", - "assert final_score >= 3, 'not as cool as DQN can'\n", - "print('Cool!')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## How to interpret plots:\n", - "\n", - "This aint no supervised learning so don't expect anything to improve monotonously. \n", - "* **TD loss** is the MSE between agent's current Q-values and target Q-values. It may slowly increase or decrease, it's ok. The \"not ok\" behavior includes going NaN or stayng at exactly zero before agent has perfect performance.\n", - "* **grad norm** just shows the intensivity of training. Not ok is growing to values of about 100 (or maybe even 50) though it depends on network architecture.\n", - "* **mean reward** is the expected sum of r(s,a) agent gets over the full game session. It will oscillate, but on average it should get higher over time (after a few thousand iterations...). \n", - " * In basic q-learning implementation it takes about 40k steps to \"warm up\" agent before it starts to get better.\n", - "* **Initial state V** is the expected discounted reward for episode in the oppinion of the agent. It should behave more smoothly than **mean reward**. It should get higher over time but sometimes can experience drawdowns because of the agaent's overestimates.\n", - "* **buffer size** - this one is simple. It should go up and cap at max size.\n", - "* **epsilon** - agent's willingness to explore. If you see that agent's already at 0.01 epsilon before it's average reward is above 0 - it means you need to increase epsilon. Set it back to some 0.2 - 0.5 and decrease the pace at which it goes down.\n", - "* Smoothing of plots is done with a gaussian kernel\n", - "\n", - "At first your agent will lose quickly. Then it will learn to suck less and at least hit the ball a few times before it loses. Finally it will learn to actually score points.\n", - "\n", - "**Training will take time.** A lot of it actually. Probably you will not see any improvment during first **150k** time steps (note that by default in this notebook agent is evaluated every 5000 time steps).\n", - "\n", - "But hey, long training time isn't _that_ bad:\n", - "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/training.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## About hyperparameters:\n", - "\n", - "The task has something in common with supervised learning: loss is optimized through the buffer (instead of Train dataset). But the distribution of states and actions in the buffer **is not stationary** and depends on the policy that generated it. It can even happen that the mean TD error across the buffer is very low but the performance is extremely poor (imagine the agent collecting data to the buffer always manages to avoid the ball).\n", - "\n", - "* Total timesteps and training time: It seems to be so huge, but actually it is normal for RL.\n", - "\n", - "* $\\epsilon$ decay shedule was taken from the original paper and is like traditional for epsilon-greedy policies. At the beginning of the training the agent's greedy policy is poor so many random actions should be taken.\n", - "\n", - "* Optimizer: In the original paper RMSProp was used (they did not have Adam in 2013) and it can work not worse than Adam. For us Adam was default and it worked.\n", - "\n", - "* lr: $10^{-3}$ would probably be too huge\n", - "\n", - "* batch size: This one can be very important: if it is too small the agent can fail to learn. Huge batch takes more time to process. If batch of size 8 can not be processed on the hardware you use take 2 (or even 4) batches of size 4, divide the loss on them by 2 (or 4) and make optimization step after both backward() calls in torch.\n", - "\n", - "* target network update frequency: has something in common with learning rate. Too frequent updates can lead to divergence. Too rare can lead to slow leraning. For millions of total timesteps thousands of inner steps seem ok. One iteration of target network updating is an iteration of the (this time approximate) $\\gamma$-compression that stands behind Q-learning. The more inner steps it makes the more accurate is the compression.\n", - "* max_grad_norm - just huge enough. In torch clip_grad_norm also evaluates the norm before clipping and it can be convenient for logging." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Video" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Record sessions\n", - "\n", - "import gym.wrappers\n", - "\n", - "with gym.wrappers.Monitor(make_env(), directory=\"videos\", force=True) as env_monitor:\n", - " sessions = [evaluate(env_monitor, agent, n_games=n_lives, greedy=True) for _ in range(10)]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Show video. This may not work in some setups. If it doesn't\n", - "# work for you, you can download the videos and view them locally.\n", - "\n", - "from pathlib import Path\n", - "from base64 import b64encode\n", - "from IPython.display import HTML\n", - "\n", - "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n", - "video_path = video_paths[-1] # You can also try other indices\n", - "\n", - "if 'google.colab' in sys.modules:\n", - " # https://stackoverflow.com/a/57378660/1214547\n", - " with video_path.open('rb') as fp:\n", - " mp4 = fp.read()\n", - " data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", - "else:\n", - " data_url = str(video_path)\n", - "\n", - "HTML(\"\"\"\n", - "\n", - "\"\"\".format(data_url))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Let's have a closer look at this.\n", - "\n", - "If average episode score is below 200 using all 5 lives, then probably DQN has not converged fully. But anyway let's make a more complete record of an episode." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "eval_env = make_env(clip_rewards=False)\n", - "record = utils.play_and_log_episode(eval_env, agent)\n", - "print('total reward for life:', np.sum(record['rewards']))\n", - "for key in record:\n", - " print(key)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig = plt.figure(figsize=(5, 5))\n", - "ax = fig.add_subplot(1, 1, 1)\n", - "\n", - "ax.scatter(record['v_mc'], record['v_agent'])\n", - "ax.plot(sorted(record['v_mc']), sorted(record['v_mc']),\n", - " 'black', linestyle='--', label='x=y')\n", - "\n", - "ax.grid()\n", - "ax.legend()\n", - "ax.set_title('State Value Estimates')\n", - "ax.set_xlabel('Monte-Carlo')\n", - "ax.set_ylabel('Agent')\n", - "\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "$\\hat V_{Monte-Carlo}(s_t) = \\sum_{\\tau=0}^{episode~end} \\gamma^{\\tau-t}r_t$" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Is there a big bias? It's ok, anyway it works." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Bonus I (2 pts)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**1.** Plot several (say 3) states with high and low spreads of Q estimate by actions i.e.\n", - "$$\\max_a \\hat Q(s,a) - \\min_a \\hat Q(s,a)\\$$\n", - "Please take those states from different episodes to make sure that the states are really different.\n", - "\n", - "What should high and low spread mean at least in the world of perfect Q-fucntions?\n", - "\n", - "Comment the states you like most.\n", - "\n", - "**2.** Plot several (say 3) states with high td-error and several states with high values of\n", - "$$| \\hat V_{Monte-Carlo}(s) - \\hat V_{agent}(s)|,$$ \n", - "$$\\hat V_{agent}(s)=\\max_a \\hat Q(s,a).$$ Please take those states from different episodes to make sure that the states are really different. From what part (i.e. beginning, middle, end) of an episode did these states come from?\n", - "\n", - "Comment the states you like most." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from utils import play_and_log_episode, img_by_obs\n", - "\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Bonus II (1-5 pts). Get High Score!\n", - "\n", - "1 point to you for each 50 points of your agent. Truncated by 5 points. Starting with 50 points, **not** 50 + threshold.\n", - "\n", - "One way is to train for several days and use heavier hardware (why not actually).\n", - "\n", - "Another way is to apply modifications (see **Bonus III**)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Bonus III (2+ pts). Apply modifications to DQN.\n", - "\n", - "For inspiration see [Rainbow](https://arxiv.org/abs/1710.02298) - a version of q-learning that combines lots of them.\n", - "\n", - "Points for Bonus II and Bonus III fully stack. So if modified agent gets score 250+ you get 5 pts for Bonus II + points for modifications. If the final score is 40 then you get the points for modifications.\n", - "\n", - "\n", - "Some modifications:\n", - "* [Prioritized experience replay](https://arxiv.org/abs/1511.05952) (5 pts for your own implementation, 3 pts for using a ready one)\n", - "* [double q-learning](https://arxiv.org/abs/1509.06461) (2 pts)\n", - "* [dueling q-learning](https://arxiv.org/abs/1511.06581) (2 pts)\n", - "* multi-step heuristics (see [Rainbow](https://arxiv.org/abs/1710.02298)) (3 pts)\n", - "* [Noisy Nets](https://arxiv.org/abs/1706.10295) (3 pts)\n", - "* [distributional RL](https://arxiv.org/abs/1707.06887)(distributional and distributed stand for different things here) (5 pts)\n", - "* Other modifications (2+ pts depending on complexity)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Bonus IV (4+ pts). Distributed RL.\n", - "\n", - "Solve the task in a distributed way. It can strongly speed up learning. See [article](https://arxiv.org/pdf/1602.01783.pdf) or some guides." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**As usual bonus points for all the tasks fully stack.**" - ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "language_info": { - "name": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 1 + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/week04_approx_rl/seminar_pytorch.ipynb b/week04_approx_rl/seminar_pytorch.ipynb index 9161a92ec..1cab93585 100644 --- a/week04_approx_rl/seminar_pytorch.ipynb +++ b/week04_approx_rl/seminar_pytorch.ipynb @@ -1,392 +1,454 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Approximate q-learning\n", - "\n", - "In this notebook you will teach a __PyTorch__ neural network to do Q-learning." - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "kr_aKWMGEmh-" + }, + "source": [ + "# Approximate q-learning\n", + "\n", + "In this notebook you will teach a __PyTorch__ neural network to do Q-learning." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oaMu65ONEmh_" + }, + "outputs": [], + "source": [ + "import sys, os\n", + "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", + " !touch .setup_complete\n", + "\n", + "# This code creates a virtual display to draw game images on.\n", + "# It will have no effect if your machine has a monitor.\n", + "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", + " !bash ../xvfb start\n", + " os.environ['DISPLAY'] = ':1'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "avILCRKkEpaX" + }, + "outputs": [], + "source": [ + "!pip install gymnasium[classic_control]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "K_SRk2ASEmh_" + }, + "outputs": [], + "source": [ + "import gymnasium as gym\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x2YvkgprEmh_" + }, + "outputs": [], + "source": [ + "env = gym.make(\"CartPole-v0\", render_mode=\"rgb_array\").env\n", + "env.reset()\n", + "n_actions = env.action_space.n\n", + "state_dim = env.observation_space.shape\n", + "\n", + "plt.imshow(env.render())\n", + "env.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sI8W19CwEmh_" + }, + "source": [ + "# Approximate Q-learning: building the network\n", + "\n", + "To train a neural network policy one must have a neural network policy. Let's build it.\n", + "\n", + "\n", + "Since we're working with a pre-extracted features (cart positions, angles and velocities), we don't need a complicated network yet. In fact, let's build something like this for starters:\n", + "\n", + "![img](https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/yet_another_week/_resource/qlearning_scheme.png)\n", + "\n", + "For your first run, please only use linear layers (`nn.Linear`) and activations. Stuff like batch normalization or dropout may ruin everything if used haphazardly.\n", + "\n", + "Also please avoid using nonlinearities like sigmoid & tanh: since agent's observations are not normalized, sigmoids might be saturated at initialization. Instead, use non-saturating nonlinearities like ReLU.\n", + "\n", + "Ideally you should start small with maybe 1-2 hidden layers with < 200 neurons and then increase network size if agent doesn't beat the target score." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "YdWXv8WJEmiA" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "y2-PcaIQEmiA" + }, + "outputs": [], + "source": [ + "network = nn.Sequential()\n", + "\n", + "network.add_module('layer1', )\n", + "\n", + "\n", + "\n", + "# hint: use state_dim[0] as input size" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "8xuWPGriEmiA" + }, + "outputs": [], + "source": [ + "def get_action(state, epsilon=0):\n", + " \"\"\"\n", + " sample actions with epsilon-greedy policy\n", + " recap: with p = epsilon pick random action, else pick action with highest Q(s,a)\n", + " \"\"\"\n", + " state = torch.tensor(state[None], dtype=torch.float32)\n", + " q_values = network(state).detach().numpy()\n", + "\n", + " \n", + "\n", + " return int( )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wroEfSRNEmiA" + }, + "outputs": [], + "source": [ + "s, _ = env.reset()\n", + "assert tuple(network(torch.tensor([s]*3, dtype=torch.float32)).size()) == (\n", + " 3, n_actions), \"please make sure your model maps state s -> [Q(s,a0), ..., Q(s, a_last)]\"\n", + "assert isinstance(list(network.modules(\n", + "))[-1], nn.Linear), \"please make sure you predict q-values without nonlinearity (ignore if you know what you're doing)\"\n", + "assert isinstance(get_action(s), int), \"get_action(s) must return int, not %s. try int(action)\" % (type(get_action(s)))\n", + "\n", + "# test epsilon-greedy exploration\n", + "for eps in [0., 0.1, 0.5, 1.0]:\n", + " state_frequencies = np.bincount(\n", + " [get_action(s, epsilon=eps) for i in range(10000)], minlength=n_actions)\n", + " best_action = state_frequencies.argmax()\n", + " assert abs(state_frequencies[best_action] -\n", + " 10000 * (1 - eps + eps / n_actions)) < 200\n", + " for other_action in range(n_actions):\n", + " if other_action != best_action:\n", + " assert abs(state_frequencies[other_action] -\n", + " 10000 * (eps / n_actions)) < 200\n", + " print('e=%.1f tests passed' % eps)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f88ovLBQEmiA" + }, + "source": [ + "### Q-learning via gradient descent\n", + "\n", + "We shall now train our agent's Q-function by minimizing the TD loss:\n", + "$$ L = { 1 \\over N} \\sum_i (Q_{\\theta}(s,a) - [r(s,a) + \\gamma \\cdot max_{a'} Q_{-}(s', a')]) ^2 $$\n", + "\n", + "\n", + "Where\n", + "* $s, a, r, s'$ are current state, action, reward and next state respectively\n", + "* $\\gamma$ is a discount factor defined two cells above.\n", + "\n", + "The tricky part is with $Q_{-}(s',a')$. From an engineering standpoint, it's the same as $Q_{\\theta}$ - the output of your neural network policy. However, when doing gradient descent, __we won't propagate gradients through it__ to make training more stable (see lectures).\n", + "\n", + "To do so, we shall use `x.detach()` function which basically says \"consider this thing constant when doing backprop\"." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "id": "bOIpO142EmiB" + }, + "outputs": [], + "source": [ + "def compute_td_loss(states, actions, rewards, next_states, is_done, gamma=0.99, check_shapes=False):\n", + " \"\"\" Compute td loss using torch operations only. Use the formula above. \"\"\"\n", + " states = torch.tensor(\n", + " states, dtype=torch.float32) # shape: [batch_size, state_size]\n", + " actions = torch.tensor(actions, dtype=torch.long) # shape: [batch_size]\n", + " rewards = torch.tensor(rewards, dtype=torch.float32) # shape: [batch_size]\n", + " # shape: [batch_size, state_size]\n", + " next_states = torch.tensor(next_states, dtype=torch.float32)\n", + " is_done = torch.tensor(is_done, dtype=torch.uint8) # shape: [batch_size]\n", + "\n", + " # get q-values for all actions in current states\n", + " predicted_qvalues = network(states) # shape: [batch_size, n_actions]\n", + "\n", + " # select q-values for chosen actions\n", + " predicted_qvalues_for_actions = predicted_qvalues[ # shape: [batch_size]\n", + " range(states.shape[0]), actions\n", + " ]\n", + "\n", + " # compute q-values for all actions in next states\n", + " predicted_next_qvalues = \n", + "\n", + " # compute V*(next_states) using predicted next q-values\n", + " next_state_values = \n", + " assert next_state_values.dtype == torch.float32\n", + "\n", + " # compute \"target q-values\" for loss - it's what's inside square parentheses in the above formula.\n", + " target_qvalues_for_actions = \n", + "\n", + " # at the last state we shall use simplified formula: Q(s,a) = r(s,a) since s' doesn't exist\n", + " target_qvalues_for_actions = torch.where(\n", + " is_done, rewards, target_qvalues_for_actions)\n", + "\n", + " # mean squared error loss to minimize\n", + " loss = torch.mean((predicted_qvalues_for_actions -\n", + " target_qvalues_for_actions.detach()) ** 2)\n", + "\n", + " if check_shapes:\n", + " assert predicted_next_qvalues.data.dim(\n", + " ) == 2, \"make sure you predicted q-values for all actions in next state\"\n", + " assert next_state_values.data.dim(\n", + " ) == 1, \"make sure you computed V(s') as maximum over just the actions axis and not all axes\"\n", + " assert target_qvalues_for_actions.data.dim(\n", + " ) == 1, \"there's something wrong with target q-values, they must be a vector\"\n", + "\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "id": "lKi6AK3DEmiB" + }, + "outputs": [], + "source": [ + "# sanity checks\n", + "s, _ = env.reset()\n", + "a = env.action_space.sample()\n", + "next_s, r, terminated, _, _ = env.step(a)\n", + "loss = compute_td_loss([s], [a], [r], [next_s], [terminated], check_shapes=True)\n", + "loss.backward()\n", + "\n", + "assert len(loss.size()) == 0, \"you must return scalar loss - mean over batch\"\n", + "assert np.any(next(network.parameters()).grad.detach().numpy() !=\n", + " 0), \"loss must be differentiable w.r.t. network weights\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LgL6G5lFEmiB" + }, + "source": [ + "### Playing the game" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "id": "zsHb_fjjEmiB" + }, + "outputs": [], + "source": [ + "opt = torch.optim.Adam(network.parameters(), lr=1e-4)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "id": "nJ_-xtsjEmiB" + }, + "outputs": [], + "source": [ + "def generate_session(env, t_max=1000, epsilon=0, train=False):\n", + " \"\"\"play env with approximate q-learning agent and train it at the same time\"\"\"\n", + " total_reward = 0\n", + " s, _ = env.reset()\n", + "\n", + " for t in range(t_max):\n", + " a = get_action(s, epsilon=epsilon)\n", + " next_s, r, terminated, truncated, _ = env.step(a)\n", + "\n", + " if train:\n", + " opt.zero_grad()\n", + " compute_td_loss([s], [a], [r], [next_s], [terminated]).backward()\n", + " opt.step()\n", + "\n", + " total_reward += r\n", + " s = next_s\n", + " if terminated or truncated:\n", + " break\n", + "\n", + " return total_reward" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "id": "40mKYuVIEmiB" + }, + "outputs": [], + "source": [ + "epsilon = 0.5" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EXy8ij00EmiB" + }, + "outputs": [], + "source": [ + "for i in range(1000):\n", + " session_rewards = [generate_session(env, epsilon=epsilon, train=True) for _ in range(100)]\n", + " print(\"epoch #{}\\tmean reward = {:.3f}\\tepsilon = {:.3f}\".format(i, np.mean(session_rewards), epsilon))\n", + "\n", + " epsilon *= 0.99\n", + " assert epsilon >= 1e-4, \"Make sure epsilon is always nonzero during training\"\n", + "\n", + " if np.mean(session_rewards) > 300:\n", + " print(\"You Win!\")\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XJPoF9XtEmiB" + }, + "source": [ + "### How to interpret results\n", + "\n", + "\n", + "Welcome to the f.. world of deep f...n reinforcement learning. Don't expect agent's reward to smoothly go up. Hope for it to go increase eventually. If it deems you worthy.\n", + "\n", + "Seriously though,\n", + "* __ mean reward__ is the average reward per game. For a correct implementation it may stay low for some 10 epochs, then start growing while oscilating insanely and converges by ~50-100 steps depending on the network architecture.\n", + "* If it never reaches target score by the end of for loop, try increasing the number of hidden neurons or look at the epsilon.\n", + "* __ epsilon__ - agent's willingness to explore. If you see that agent's already at < 0.01 epsilon before it's is at least 200, just reset it back to 0.1 - 0.5." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lhKiN-qOEmiB" + }, + "source": [ + "### Record videos\n", + "\n", + "As usual, we now use `gymnasium.wrappers.RecordVideo` to record a video of our agent playing the game. Unlike our previous attempts with state binarization, this time we expect our agent to act ~~(or fail)~~ more smoothly since there's no more binarization error at play.\n", + "\n", + "As you already did with tabular q-learning, we set epsilon=0 for final evaluation to prevent agent from exploring himself to death." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2yqPkj6HEmiB" + }, + "outputs": [], + "source": [ + "# Record sessions\n", + "\n", + "from gymnasium.wrappers import RecordVideo\n", + "\n", + "with gym.make(\"CartPole-v0\", render_mode=\"rgb_array\") as record_env, RecordVideo(\n", + " record_env, video_folder=\"videos\"\n", + ") as env_monitor:\n", + " sessions = [\n", + " generate_session(env_monitor, epsilon=0, train=False) for _ in range(100)\n", + " ]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "afqi2qomEmiC" + }, + "outputs": [], + "source": [ + "# Show video. This may not work in some setups. If it doesn't\n", + "# work for you, you can download the videos and view them locally.\n", + "\n", + "from pathlib import Path\n", + "from base64 import b64encode\n", + "from IPython.display import HTML\n", + "\n", + "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n", + "video_path = video_paths[-1] # You can also try other indices\n", + "\n", + "if 'google.colab' in sys.modules:\n", + " # https://stackoverflow.com/a/57378660/1214547\n", + " with video_path.open('rb') as fp:\n", + " mp4 = fp.read()\n", + " data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", + "else:\n", + " data_url = str(video_path)\n", + "\n", + "HTML(\"\"\"\n", + "\n", + "\"\"\".format(data_url))" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys, os\n", - "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", - " !touch .setup_complete\n", - "\n", - "# This code creates a virtual display to draw game images on.\n", - "# It will have no effect if your machine has a monitor.\n", - "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", - " !bash ../xvfb start\n", - " os.environ['DISPLAY'] = ':1'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import gym\n", - "import numpy as np\n", - "import pandas as pd\n", - "import matplotlib.pyplot as plt\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "env = gym.make(\"CartPole-v0\").env\n", - "env.reset()\n", - "n_actions = env.action_space.n\n", - "state_dim = env.observation_space.shape\n", - "\n", - "plt.imshow(env.render(\"rgb_array\"))\n", - "env.close()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Approximate Q-learning: building the network\n", - "\n", - "To train a neural network policy one must have a neural network policy. Let's build it.\n", - "\n", - "\n", - "Since we're working with a pre-extracted features (cart positions, angles and velocities), we don't need a complicated network yet. In fact, let's build something like this for starters:\n", - "\n", - "![img](https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/yet_another_week/_resource/qlearning_scheme.png)\n", - "\n", - "For your first run, please only use linear layers (`nn.Linear`) and activations. Stuff like batch normalization or dropout may ruin everything if used haphazardly. \n", - "\n", - "Also please avoid using nonlinearities like sigmoid & tanh: since agent's observations are not normalized, sigmoids might be saturated at initialization. Instead, use non-saturating nonlinearities like ReLU.\n", - "\n", - "Ideally you should start small with maybe 1-2 hidden layers with < 200 neurons and then increase network size if agent doesn't beat the target score." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "network = nn.Sequential()\n", - "\n", - "network.add_module('layer1', )\n", - "\n", - "\n", - "\n", - "# hint: use state_dim[0] as input size" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def get_action(state, epsilon=0):\n", - " \"\"\"\n", - " sample actions with epsilon-greedy policy\n", - " recap: with p = epsilon pick random action, else pick action with highest Q(s,a)\n", - " \"\"\"\n", - " state = torch.tensor(state[None], dtype=torch.float32)\n", - " q_values = network(state).detach().numpy()\n", - "\n", - " \n", - "\n", - " return int( )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "s = env.reset()\n", - "assert tuple(network(torch.tensor([s]*3, dtype=torch.float32)).size()) == (\n", - " 3, n_actions), \"please make sure your model maps state s -> [Q(s,a0), ..., Q(s, a_last)]\"\n", - "assert isinstance(list(network.modules(\n", - "))[-1], nn.Linear), \"please make sure you predict q-values without nonlinearity (ignore if you know what you're doing)\"\n", - "assert isinstance(get_action(\n", - " s), int), \"get_action(s) must return int, not %s. try int(action)\" % (type(get_action(s)))\n", - "\n", - "# test epsilon-greedy exploration\n", - "for eps in [0., 0.1, 0.5, 1.0]:\n", - " state_frequencies = np.bincount(\n", - " [get_action(s, epsilon=eps) for i in range(10000)], minlength=n_actions)\n", - " best_action = state_frequencies.argmax()\n", - " assert abs(state_frequencies[best_action] -\n", - " 10000 * (1 - eps + eps / n_actions)) < 200\n", - " for other_action in range(n_actions):\n", - " if other_action != best_action:\n", - " assert abs(state_frequencies[other_action] -\n", - " 10000 * (eps / n_actions)) < 200\n", - " print('e=%.1f tests passed' % eps)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Q-learning via gradient descent\n", - "\n", - "We shall now train our agent's Q-function by minimizing the TD loss:\n", - "$$ L = { 1 \\over N} \\sum_i (Q_{\\theta}(s,a) - [r(s,a) + \\gamma \\cdot max_{a'} Q_{-}(s', a')]) ^2 $$\n", - "\n", - "\n", - "Where\n", - "* $s, a, r, s'$ are current state, action, reward and next state respectively\n", - "* $\\gamma$ is a discount factor defined two cells above.\n", - "\n", - "The tricky part is with $Q_{-}(s',a')$. From an engineering standpoint, it's the same as $Q_{\\theta}$ - the output of your neural network policy. However, when doing gradient descent, __we won't propagate gradients through it__ to make training more stable (see lectures).\n", - "\n", - "To do so, we shall use `x.detach()` function which basically says \"consider this thing constant when doingbackprop\"." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def compute_td_loss(states, actions, rewards, next_states, is_done, gamma=0.99, check_shapes=False):\n", - " \"\"\" Compute td loss using torch operations only. Use the formula above. \"\"\"\n", - " states = torch.tensor(\n", - " states, dtype=torch.float32) # shape: [batch_size, state_size]\n", - " actions = torch.tensor(actions, dtype=torch.long) # shape: [batch_size]\n", - " rewards = torch.tensor(rewards, dtype=torch.float32) # shape: [batch_size]\n", - " # shape: [batch_size, state_size]\n", - " next_states = torch.tensor(next_states, dtype=torch.float32)\n", - " is_done = torch.tensor(is_done, dtype=torch.uint8) # shape: [batch_size]\n", - "\n", - " # get q-values for all actions in current states\n", - " predicted_qvalues = network(states) # shape: [batch_size, n_actions]\n", - "\n", - " # select q-values for chosen actions\n", - " predicted_qvalues_for_actions = predicted_qvalues[ # shape: [batch_size]\n", - " range(states.shape[0]), actions\n", - " ]\n", - "\n", - " # compute q-values for all actions in next states\n", - " predicted_next_qvalues = \n", - "\n", - " # compute V*(next_states) using predicted next q-values\n", - " next_state_values = \n", - " assert next_state_values.dtype == torch.float32\n", - "\n", - " # compute \"target q-values\" for loss - it's what's inside square parentheses in the above formula.\n", - " target_qvalues_for_actions = \n", - "\n", - " # at the last state we shall use simplified formula: Q(s,a) = r(s,a) since s' doesn't exist\n", - " target_qvalues_for_actions = torch.where(\n", - " is_done, rewards, target_qvalues_for_actions)\n", - "\n", - " # mean squared error loss to minimize\n", - " loss = torch.mean((predicted_qvalues_for_actions -\n", - " target_qvalues_for_actions.detach()) ** 2)\n", - "\n", - " if check_shapes:\n", - " assert predicted_next_qvalues.data.dim(\n", - " ) == 2, \"make sure you predicted q-values for all actions in next state\"\n", - " assert next_state_values.data.dim(\n", - " ) == 1, \"make sure you computed V(s') as maximum over just the actions axis and not all axes\"\n", - " assert target_qvalues_for_actions.data.dim(\n", - " ) == 1, \"there's something wrong with target q-values, they must be a vector\"\n", - "\n", - " return loss" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# sanity checks\n", - "s = env.reset()\n", - "a = env.action_space.sample()\n", - "next_s, r, done, _ = env.step(a)\n", - "loss = compute_td_loss([s], [a], [r], [next_s], [done], check_shapes=True)\n", - "loss.backward()\n", - "\n", - "assert len(loss.size()) == 0, \"you must return scalar loss - mean over batch\"\n", - "assert np.any(next(network.parameters()).grad.detach().numpy() !=\n", - " 0), \"loss must be differentiable w.r.t. network weights\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Playing the game" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "opt = torch.optim.Adam(network.parameters(), lr=1e-4)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def generate_session(env, t_max=1000, epsilon=0, train=False):\n", - " \"\"\"play env with approximate q-learning agent and train it at the same time\"\"\"\n", - " total_reward = 0\n", - " s = env.reset()\n", - "\n", - " for t in range(t_max):\n", - " a = get_action(s, epsilon=epsilon)\n", - " next_s, r, done, _ = env.step(a)\n", - "\n", - " if train:\n", - " opt.zero_grad()\n", - " compute_td_loss([s], [a], [r], [next_s], [done]).backward()\n", - " opt.step()\n", - "\n", - " total_reward += r\n", - " s = next_s\n", - " if done:\n", - " break\n", - "\n", - " return total_reward" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "epsilon = 0.5" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i in range(1000):\n", - " session_rewards = [generate_session(env, epsilon=epsilon, train=True) for _ in range(100)]\n", - " print(\"epoch #{}\\tmean reward = {:.3f}\\tepsilon = {:.3f}\".format(i, np.mean(session_rewards), epsilon))\n", - "\n", - " epsilon *= 0.99\n", - " assert epsilon >= 1e-4, \"Make sure epsilon is always nonzero during training\"\n", - "\n", - " if np.mean(session_rewards) > 300:\n", - " print(\"You Win!\")\n", - " break" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### How to interpret results\n", - "\n", - "\n", - "Welcome to the f.. world of deep f...n reinforcement learning. Don't expect agent's reward to smoothly go up. Hope for it to go increase eventually. If it deems you worthy.\n", - "\n", - "Seriously though,\n", - "* __ mean reward__ is the average reward per game. For a correct implementation it may stay low for some 10 epochs, then start growing while oscilating insanely and converges by ~50-100 steps depending on the network architecture. \n", - "* If it never reaches target score by the end of for loop, try increasing the number of hidden neurons or look at the epsilon.\n", - "* __ epsilon__ - agent's willingness to explore. If you see that agent's already at < 0.01 epsilon before it's is at least 200, just reset it back to 0.1 - 0.5." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Record videos\n", - "\n", - "As usual, we now use `gym.wrappers.Monitor` to record a video of our agent playing the game. Unlike our previous attempts with state binarization, this time we expect our agent to act ~~(or fail)~~ more smoothly since there's no more binarization error at play.\n", - "\n", - "As you already did with tabular q-learning, we set epsilon=0 for final evaluation to prevent agent from exploring himself to death." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Record sessions\n", - "\n", - "import gym.wrappers\n", - "\n", - "with gym.wrappers.Monitor(gym.make(\"CartPole-v0\"), directory=\"videos\", force=True) as env_monitor:\n", - " sessions = [generate_session(env_monitor, epsilon=0, train=False) for _ in range(100)]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Show video. This may not work in some setups. If it doesn't\n", - "# work for you, you can download the videos and view them locally.\n", - "\n", - "from pathlib import Path\n", - "from base64 import b64encode\n", - "from IPython.display import HTML\n", - "\n", - "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n", - "video_path = video_paths[-1] # You can also try other indices\n", - "\n", - "if 'google.colab' in sys.modules:\n", - " # https://stackoverflow.com/a/57378660/1214547\n", - " with video_path.open('rb') as fp:\n", - " mp4 = fp.read()\n", - " data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", - "else:\n", - " data_url = str(video_path)\n", - "\n", - "HTML(\"\"\"\n", - "\n", - "\"\"\".format(data_url))" - ] - } - ], - "metadata": { - "language_info": { - "name": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 1 + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/week04_approx_rl/utils.py b/week04_approx_rl/utils.py index bc833c45c..4e398a537 100644 --- a/week04_approx_rl/utils.py +++ b/week04_approx_rl/utils.py @@ -26,7 +26,7 @@ def play_and_log_episode(env, agent, gamma=0.99, t_max=10000): td_errors = [] rewards = [] - s = env.reset() + s, _ = env.reset() for step in range(t_max): states.append(s) qvalues = agent.get_qvalues([s]) @@ -39,9 +39,9 @@ def play_and_log_episode(env, agent, gamma=0.99, t_max=10000): action = qvalues.argmax(axis=-1)[0] - s, r, done, _ = env.step(action) + s, r, terminated, truncated, _ = env.step(action) rewards.append(r) - if done: + if terminated or truncated: break td_errors.append(np.abs(rewards[-1] + gamma * v_agent[-1] - v_agent[-2])) @@ -54,7 +54,7 @@ def play_and_log_episode(env, agent, gamma=0.99, t_max=10000): 'q_spreads': np.array(q_spreads), 'td_errors': np.array(td_errors), 'rewards': np.array(rewards), - 'episode_finished': np.array(done) + 'episode_finished': np.array(terminated or truncated) } return return_pack diff --git a/week05_explore/README.md b/week05_explore/README.md index 70bba0413..5cbb50c25 100644 --- a/week05_explore/README.md +++ b/week05_explore/README.md @@ -22,6 +22,8 @@ * Same topics in russian - [video](https://www.youtube.com/watch?v=WCE9hhPbCmc) * Note: UCB-1 is not for bernoulli rewards, but for arbitrary r in [0,1], so you can just scale any reward to [0,1] to obtain a peace of mind. It's derived directly from Hoeffding's inequality. +* Very interesting blog post written by Lilian Weng that summarises this week's materials: [The Multi-Armed Bandit Problem and Its Solutions](https://lilianweng.github.io/posts/2018-01-23-multi-armed-bandit/) + ## Seminar In this seminar, you'll be solving basic and contextual bandits with uncertainty-based exploration like Bayesian UCB and Thompson Sampling. You will also get acquainted with Bayesian Neural Networks. diff --git a/week05_explore/action_rewards.npy b/week05_explore/action_rewards.npy deleted file mode 100644 index 231bcb18b..000000000 Binary files a/week05_explore/action_rewards.npy and /dev/null differ diff --git a/week05_explore/all_states.npy b/week05_explore/all_states.npy deleted file mode 100644 index 43940d9ba..000000000 Binary files a/week05_explore/all_states.npy and /dev/null differ diff --git a/week05_explore/bnn.png b/week05_explore/bnn.png deleted file mode 100644 index 6ff8059fb..000000000 Binary files a/week05_explore/bnn.png and /dev/null differ diff --git a/week05_explore/deep_see.png b/week05_explore/deep_see.png new file mode 100644 index 000000000..a1601b725 Binary files /dev/null and b/week05_explore/deep_see.png differ diff --git a/week05_explore/q_learning_agent.py b/week05_explore/q_learning_agent.py new file mode 100644 index 000000000..f7f52fca7 --- /dev/null +++ b/week05_explore/q_learning_agent.py @@ -0,0 +1,112 @@ +from collections import defaultdict +import random +import math +import numpy as np + + +class QLearningAgent: + def __init__(self, alpha, epsilon, discount, get_legal_actions): + """ + Q-Learning Agent + based on https://inst.eecs.berkeley.edu/~cs188/sp19/projects.html + Instance variables you have access to + - self.epsilon (exploration prob) + - self.alpha (learning rate) + - self.discount (discount rate aka gamma) + + Functions you should use + - self.get_legal_actions(state) {state, hashable -> list of actions, each is hashable} + which returns legal actions for a state + - self.get_qvalue(state,action) + which returns Q(state,action) + - self.set_qvalue(state,action,value) + which sets Q(state,action) := value + !!!Important!!! + Note: please avoid using self._qValues directly. + There's a special self.get_qvalue/set_qvalue for that. + """ + + self.get_legal_actions = get_legal_actions + self._qvalues = defaultdict(lambda: defaultdict(lambda: 0)) + self.alpha = alpha + self.epsilon = epsilon + self.discount = discount + + def get_qvalue(self, state, action): + """ Returns Q(state,action) """ + return self._qvalues[state][action] + + def set_qvalue(self, state, action, value): + """ Sets the Qvalue for [state,action] to the given value """ + self._qvalues[state][action] = value + + def get_value(self, state): + """ + Compute your agent's estimate of V(s) using current q-values + V(s) = max_over_action Q(state,action) over possible actions. + Note: please take into account that q-values can be negative. + """ + possible_actions = self.get_legal_actions(state) + + # If there are no legal actions, return 0.0 + if len(possible_actions) == 0: + return 0.0 + + value = max([self.get_qvalue(state, a) for a in possible_actions]) + return value + + def update(self, state, action, reward, next_state, done): + """ + You should do your Q-Value update here: + Q(s,a) := (1 - alpha) * Q(s,a) + alpha * (r + gamma * V(s')) + """ + + # agent parameters + gamma = self.discount + learning_rate = self.alpha + + q = reward + gamma * (1 - done) * self.get_value(next_state) + q = (1 - learning_rate) * self.get_qvalue(state, action) + learning_rate * q + + self.set_qvalue(state, action, q) + + def get_best_action(self, state): + """ + Compute the best action to take in a state (using current q-values). + """ + possible_actions = self.get_legal_actions(state) + + # If there are no legal actions, return None + if len(possible_actions) == 0: + return None + + idx = np.argmax([self.get_qvalue(state, a) for a in possible_actions]) + + return possible_actions[idx] + + def get_action(self, state): + """ + Compute the action to take in the current state, including exploration. + With probability self.epsilon, we should take a random action. + otherwise - the best policy action (self.get_best_action). + + Note: To pick randomly from a list, use random.choice(list). + To pick True or False with a given probablity, generate uniform number in [0, 1] + and compare it with your probability + """ + + # Pick Action + possible_actions = self.get_legal_actions(state) + action = None + + # If there are no legal actions, return None + if len(possible_actions) == 0: + return None + + # agent parameters: + epsilon = self.epsilon + + if np.random.rand() < epsilon: + return np.random.choice(possible_actions) + + return self.get_best_action(state) \ No newline at end of file diff --git a/week05_explore/replay_buffer.py b/week05_explore/replay_buffer.py new file mode 100644 index 000000000..9136dd078 --- /dev/null +++ b/week05_explore/replay_buffer.py @@ -0,0 +1,75 @@ +# This code is shamelessly stolen from +# https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py +import numpy as np +import random + + +class ReplayBuffer(object): + def __init__(self, size): + """Create Replay buffer. + Parameters + ---------- + size: int + Max number of transitions to store in the buffer. When the buffer + overflows the old memories are dropped. + """ + self._storage = [] + self._maxsize = size + self._next_idx = 0 + + def __len__(self): + return len(self._storage) + + def add(self, obs_t, action, reward, obs_tp1, done): + data = (obs_t, action, reward, obs_tp1, done) + + if self._next_idx >= len(self._storage): + self._storage.append(data) + else: + self._storage[self._next_idx] = data + self._next_idx = (self._next_idx + 1) % self._maxsize + + def _encode_sample(self, idxes): + obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] + for i in idxes: + data = self._storage[i] + obs_t, action, reward, obs_tp1, done = data + obses_t.append(np.array(obs_t, copy=False)) + actions.append(np.array(action, copy=False)) + rewards.append(reward) + obses_tp1.append(np.array(obs_tp1, copy=False)) + dones.append(done) + return ( + np.array(obses_t), + np.array(actions), + np.array(rewards), + np.array(obses_tp1), + np.array(dones) + ) + + def sample(self, batch_size): + """Sample a batch of experiences. + Parameters + ---------- + batch_size: int + How many transitions to sample. + Returns + ------- + obs_batch: np.array + batch of observations + act_batch: np.array + batch of actions executed given obs_batch + rew_batch: np.array + rewards received as results of executing act_batch + next_obs_batch: np.array + next set of observations seen after executing act_batch + done_mask: np.array + done_mask[i] = 1 if executing act_batch[i] resulted in + the end of an episode and 0 otherwise. + """ + idxes = [ + random.randint(0, len(self._storage) - 1) + for _ in range(batch_size) + ] + return self._encode_sample(idxes) + diff --git a/week05_explore/river_swim.png b/week05_explore/river_swim.png deleted file mode 100644 index 233244c6c..000000000 Binary files a/week05_explore/river_swim.png and /dev/null differ diff --git a/week05_explore/und1.mp4 b/week05_explore/und1.mp4 new file mode 100644 index 000000000..d67190f54 Binary files /dev/null and b/week05_explore/und1.mp4 differ diff --git a/week05_explore/und2.mp4 b/week05_explore/und2.mp4 new file mode 100644 index 000000000..1e41469fe Binary files /dev/null and b/week05_explore/und2.mp4 differ diff --git a/week05_explore/week5.ipynb b/week05_explore/week5.ipynb index 22e3cf7bf..b4c5f3538 100644 --- a/week05_explore/week5.ipynb +++ b/week05_explore/week5.ipynb @@ -2,16 +2,62 @@ "cells": [ { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 98, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Defaulting to user installation because normal site-packages is not writeable\n", + "Looking in indexes: https://pypi.yandex-team.ru/simple/\n", + "Requirement already satisfied: bsuite in /home/npytincev/.local/lib/python3.8/site-packages (0.3.5)\n", + "Requirement already satisfied: plotnine in /home/npytincev/.local/lib/python3.8/site-packages (from bsuite) (0.8.0)\n", + "Requirement already satisfied: matplotlib in /home/npytincev/.local/lib/python3.8/site-packages (from bsuite) (3.5.1)\n", + "Requirement already satisfied: pandas in /home/npytincev/.local/lib/python3.8/site-packages (from bsuite) (1.4.2)\n", + "Requirement already satisfied: termcolor in /home/npytincev/.local/lib/python3.8/site-packages (from bsuite) (1.1.0)\n", + "Requirement already satisfied: absl-py in /home/npytincev/.local/lib/python3.8/site-packages (from bsuite) (0.12.0)\n", + "Requirement already satisfied: scikit-image in /home/npytincev/.local/lib/python3.8/site-packages (from bsuite) (0.19.2)\n", + "Requirement already satisfied: numpy in /home/npytincev/.local/lib/python3.8/site-packages (from bsuite) (1.22.3)\n", + "Requirement already satisfied: six in /home/npytincev/.local/lib/python3.8/site-packages (from bsuite) (1.15.0)\n", + "Requirement already satisfied: immutabledict in /home/npytincev/.local/lib/python3.8/site-packages (from bsuite) (2.2.1)\n", + "Requirement already satisfied: dm-env in /home/npytincev/.local/lib/python3.8/site-packages (from bsuite) (1.5)\n", + "Requirement already satisfied: scipy in /home/npytincev/.local/lib/python3.8/site-packages (from bsuite) (1.6.0)\n", + "Requirement already satisfied: dm-tree in /home/npytincev/.local/lib/python3.8/site-packages (from dm-env->bsuite) (0.1.6)\n", + "Requirement already satisfied: packaging>=20.0 in /home/npytincev/.local/lib/python3.8/site-packages (from matplotlib->bsuite) (20.9)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /home/npytincev/.local/lib/python3.8/site-packages (from matplotlib->bsuite) (2.8.2)\n", + "Requirement already satisfied: pyparsing>=2.2.1 in /home/npytincev/.local/lib/python3.8/site-packages (from matplotlib->bsuite) (2.4.7)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /home/npytincev/.local/lib/python3.8/site-packages (from matplotlib->bsuite) (4.31.2)\n", + "Requirement already satisfied: pillow>=6.2.0 in /home/npytincev/.local/lib/python3.8/site-packages (from matplotlib->bsuite) (8.4.0)\n", + "Requirement already satisfied: cycler>=0.10 in /home/npytincev/.local/lib/python3.8/site-packages (from matplotlib->bsuite) (0.10.0)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /home/npytincev/.local/lib/python3.8/site-packages (from matplotlib->bsuite) (1.3.1)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/npytincev/.local/lib/python3.8/site-packages (from pandas->bsuite) (2022.1)\n", + "Requirement already satisfied: descartes>=1.1.0 in /home/npytincev/.local/lib/python3.8/site-packages (from plotnine->bsuite) (1.1.0)\n", + "Requirement already satisfied: patsy>=0.5.1 in /home/npytincev/.local/lib/python3.8/site-packages (from plotnine->bsuite) (0.5.1)\n", + "Requirement already satisfied: mizani>=0.7.3 in /home/npytincev/.local/lib/python3.8/site-packages (from plotnine->bsuite) (0.7.4)\n", + "Requirement already satisfied: statsmodels>=0.12.1 in /home/npytincev/.local/lib/python3.8/site-packages (from plotnine->bsuite) (0.12.2)\n", + "Requirement already satisfied: tifffile>=2019.7.26 in /home/npytincev/.local/lib/python3.8/site-packages (from scikit-image->bsuite) (2022.3.25)\n", + "Requirement already satisfied: PyWavelets>=1.1.1 in /home/npytincev/.local/lib/python3.8/site-packages (from scikit-image->bsuite) (1.3.0)\n", + "Requirement already satisfied: networkx>=2.2 in /home/npytincev/.local/lib/python3.8/site-packages (from scikit-image->bsuite) (2.5.1)\n", + "Requirement already satisfied: imageio>=2.4.1 in /home/npytincev/.local/lib/python3.8/site-packages (from scikit-image->bsuite) (2.16.1)\n", + "Requirement already satisfied: palettable in /home/npytincev/.local/lib/python3.8/site-packages (from mizani>=0.7.3->plotnine->bsuite) (3.3.0)\n", + "Requirement already satisfied: decorator<5,>=4.3 in /usr/lib/python3/dist-packages (from networkx>=2.2->scikit-image->bsuite) (4.4.2)\n" + ] + } + ], "source": [ "import sys, os\n", "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week05_explore/action_rewards.npy\n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week05_explore/all_states.npy\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week05_explore/q_learning_agent.py\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week05_explore/replay_buffer.py\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week05_explore/und1.mp4\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week05_explore/und2.mp4\n", "\n", + " !pip install -q gymnasium\n", + " !pip install -q shimmy[bsuite]\n", " !touch .setup_complete\n", "\n", "# This code creates a virtual display to draw game images on.\n", @@ -23,7 +69,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 99, "metadata": {}, "outputs": [], "source": [ @@ -34,12 +80,75 @@ "np.set_printoptions(precision=3)\n", "np.set_printoptions(suppress=True)\n", "\n", - "import pandas\n", + "import pandas as pd\n", "\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, + { + "cell_type": "code", + "execution_count": 121, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 121, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from IPython.display import HTML\n", + "\n", + "HTML(\"\"\"\n", + " \n", + "\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": 122, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 122, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "HTML(\"\"\"\n", + " \n", + "\"\"\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -47,14 +156,12 @@ "## Contents\n", "* [1. Bernoulli Bandit](#Part-1.-Bernoulli-Bandit)\n", " * [Bonus 1.1. Gittins index (5 points)](#Bonus-1.1.-Gittins-index-%285-points%29.)\n", - " * [HW 1.1. Nonstationary Bernoulli bandit](#HW-1.1.-Nonstationary-Bernoulli-bandit)\n", - "* [2. Contextual bandit](#Part-2.-Contextual-bandit)\n", - " * [2.1 Bulding a BNN agent](#2.1-Bulding-a-BNN-agent)\n", - " * [2.2 Training the agent](#2.2-Training-the-agent)\n", - " * [HW 2.1 Better exploration](#HW-2.1-Better-exploration)\n", - "* [3. Exploration in MDP](#Part-3.-Exploration-in-MDP)\n", - " * [Bonus 3.1 Posterior sampling RL (3 points)](#Bonus-3.1-Posterior-sampling-RL-%283-points%29)\n", - " * [Bonus 3.2 Bootstrapped DQN (10 points)](#Bonus-3.2-Bootstrapped-DQN-%2810-points%29)\n" + " * [HW 1.1. Nonstationary Bernoulli bandit (2 points)](#HW-1.1.-Nonstationary-Bernoulli-bandit)\n", + "* [2. Exploration in MDP](#Part-2.-Exploration-in-MDP)\n", + " * [2.1 Epsilon-greedy q-learning](#2.1-Epsilon-greedy-q-learning)\n", + " * [2.2 Reward shaping](#2.2-Reward-shaping)\n", + " * [2.3 Curiosity-driven Exploration](#2.3-Curiosity-driven-Exploration)\n", + " * [HW 2.1 Random network distillation (3 points)](#HW-2.1:-Random-network-distillation)\n" ] }, { @@ -67,9 +174,9 @@ "\n", "The bandit has $K$ actions. Action produce 1.0 reward $r$ with probability $0 \\le \\theta_k \\le 1$ which is unknown to agent, but fixed over time. Agent's objective is to minimize regret over fixed number $T$ of action selections:\n", "\n", - "$$\\rho = T\\theta^* - \\sum_{t=1}^T r_t$$\n", + "$$\\rho = T\\theta^* - \\sum_{t=1}^T \\theta_{a_t}$$\n", "\n", - "Where $\\theta^* = \\max_k\\{\\theta_k\\}$\n", + "Where $\\theta^* = \\max_k\\{\\theta_k\\}$ and $\\theta_{a_t}$ corresponds to the chosen action $a_t$ on each step.\n", "\n", "**Real-world analogy:**\n", "\n", @@ -80,7 +187,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 100, "metadata": {}, "outputs": [], "source": [ @@ -102,6 +209,11 @@ " \"\"\"\n", " return np.max(self._probs)\n", "\n", + " def action_value(self, action):\n", + " \"\"\" Used for regret calculation\n", + " \"\"\"\n", + " return self._probs[action]\n", + "\n", " def step(self):\n", " \"\"\" Used in nonstationary version\n", " \"\"\"\n", @@ -114,7 +226,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 101, "metadata": {}, "outputs": [], "source": [ @@ -181,7 +293,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 102, "metadata": {}, "outputs": [], "source": [ @@ -190,7 +302,10 @@ " self._epsilon = epsilon\n", "\n", " def get_action(self):\n", - " \n", + " if np.random.random() < self._epsilon:\n", + " return np.random.randint(len(self._successes))\n", + " else:\n", + " return np.argmax(self._successes / (self._successes + self._failures + 0.1))\n", "\n", " @property\n", " def name(self):\n", @@ -228,13 +343,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 116, "metadata": {}, "outputs": [], "source": [ "class UCBAgent(AbstractAgent):\n", " def get_action(self):\n", - " " + " pulls = self._successes + self._failures + 0.1\n", + " return np.argmax(self._successes / pulls + np.sqrt(2 * np.log(self._total_pulls + 0.1) / pulls))" ] }, { @@ -268,18 +384,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 117, "metadata": {}, "outputs": [], "source": [ "class ThompsonSamplingAgent(AbstractAgent):\n", " def get_action(self):\n", - " \n" + " return np.argmax(np.random.beta(self._successes + 1, self._failures + 1))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 118, "metadata": {}, "outputs": [], "source": [ @@ -303,7 +419,7 @@ " action = agent.get_action()\n", " reward = env.pull(action)\n", " agent.update(action, reward)\n", - " scores[agent.name][i] += optimal_reward - reward\n", + " scores[agent.name][i] += optimal_reward - env.action_value(action)\n", "\n", " env.step() # change bandit's state if it is unstationary\n", "\n", @@ -326,26 +442,36 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 120, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":4: RuntimeWarning: invalid value encountered in sqrt\n", + " return np.argmax(self._successes / pulls + np.sqrt(2 * np.log(self._total_pulls + 0.1) / pulls))\n" + ] + }, { "data": { - "image/png": "\n", + "image/png": "", "text/plain": [ - "" + "
" ] }, - "metadata": {}, + "metadata": { + "needs_background": "light" + }, "output_type": "display_data" } ], "source": [ "# Uncomment agents\n", "agents = [\n", - " # EpsilonGreedyAgent(),\n", - " # UCBAgent(),\n", - " # ThompsonSamplingAgent()\n", + " EpsilonGreedyAgent(),\n", + " UCBAgent(),\n", + " ThompsonSamplingAgent()\n", "]\n", "\n", "regret = get_regret(BernoulliBandit(), agents, n_steps=10000, n_trials=10)\n", @@ -375,7 +501,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 300, "metadata": {}, "outputs": [], "source": [ @@ -424,17 +550,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 301, "metadata": {}, "outputs": [ { "data": { - "image/png": "\n", + "image/png": "", "text/plain": [ - "" + "
" ] }, - "metadata": {}, + "metadata": { + "needs_background": "light" + }, "output_type": "display_data" } ], @@ -447,7 +575,7 @@ " drifting_probs.append(drifting_env._probs)\n", "\n", "plt.figure(figsize=(17, 8))\n", - "plt.plot(pandas.DataFrame(drifting_probs).rolling(window=20).mean())\n", + "plt.plot(pd.DataFrame(drifting_probs).rolling(window=20).mean())\n", "\n", "plt.xlabel(\"steps\")\n", "plt.ylabel(\"Success probability\")\n", @@ -465,7 +593,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 302, "metadata": {}, "outputs": [], "source": [ @@ -474,17 +602,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 303, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":5: RuntimeWarning: invalid value encountered in sqrt\n", + " np.sqrt(2 * np.log(self._total_pulls + 0.1) / (self._successes + self._failures + 0.1)))\n" + ] + }, { "data": { - "image/png": "\n", + "image/png": "", "text/plain": [ - "" + "
" ] }, - "metadata": {}, + "metadata": { + "needs_background": "light" + }, "output_type": "display_data" } ], @@ -493,7 +631,7 @@ " ThompsonSamplingAgent(),\n", " EpsilonGreedyAgent(),\n", " UCBAgent(),\n", - " YourAgent()\n", + "# YourAgent()\n", "]\n", "\n", "regret = get_regret(DriftingBandit(), drifting_agents, n_steps=20000, n_trials=10)\n", @@ -504,287 +642,279 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Part 2. Contextual bandit (Associative search)\n", - "\n", - "So far we considered nonassociative bandits - that is, we had no need to associate different actions with different environment states. In a general reinforcement learning there is more than one state, and the goal\n", - "is to learn a policy: a mapping from states to the actions that are best in those\n", - "states.\n", - "\n", - "The simplest way in which nonassociative tasks can extend to the associative setting is called contextual bandits. In such task an agent at each step is provided with a state on which reward distribution of actions may depend. For example, consider a task where you have several different k-armed bandits and on each step your agent is provided with an id of current bandit.\n", - "\n", - "**Real-word analogy:**\n", - "> Contextual advertising. We have a lot of banners and a lot of different users. Users can have different features: age, gender, search requests. We want to show banner with highest click probability.\n", - "\n", - "**Question:** What is the difference between contextual bandits and full reinforcement learning task?" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If we want use strategies from above, we need somehow store reward distributions conditioned both on actions and bandit's state. \n", - "One way to do this - use bayesian neural networks. Instead of giving pointwise estimates of target, they maintain probability distributions\n", + "## Part 2. Exploration in MDP\n", "\n", - "\n", - "Picture from https://arxiv.org/pdf/1505.05424.pdf" + "The following problem, called \"deep see\", illustrates importance of exploration in context of mdp's." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 2.1 Making bayesian neural network\n", - "\n", - "\n", - "Code and formulas come from this [post](https://www.ritchievink.com/blog/2019/09/16/variational-inference-from-scratch/). You can read more on theory in this [paper](https://arxiv.org/abs/1505.05424).\n", - "\n", - "### 2.1.1 Some theory\n", - "We have data $D$ and model some model parameters $Z$. In the process of model fitting we want to get posterior distribution of model parameters, i.e. $P(Z | D)$. We can use Bayes rule:\n", - "\n", - "$$ P(Z | D) = \\frac{P(D|Z)P(Z)}{P(D)}$$\n", - "\n", - "Where $P(D|Z)$ is likelihood, $P(Z)$ - prior distribution of model parameters, $P(D)$ - probability of obtaining data from our model family.\n", - "\n", - "Direct computation of $P(Z | D)$ is intractable, so we will replace it with variational distribution $Q(Z)$ parametrized by $\\Theta$ and our learning process will minimize $D_{KL}(Q(Z | \\Theta)||P(Z | D))$ with $\\Theta$ as parameter.\n", - "\n", - "$$D_{KL}(Q(Z)||P(Z|D)) = \\int_{Z}Q(Z) log\\frac{Q(Z)}{P(Z|D)}dZ$$\n", - "\n", - "Rewrite posterior:\n", - "\n", - "$$D_{KL} = \\int_{Z}Q(Z) log\\frac{Q(Z)P(D)}{P(Z,D)}dZ$$\n", + "\n", "\n", - "Apply logarithm modification rule:\n", - "$$D_{KL} = \\int_{Z}Q(Z) log\\frac{Q(Z)}{P(Z,D)}dZ + \\int_{Z}Q(Z) logP(D)dZ$$\n", + "The deep sea problem is implemented as an $N×N$ grid with a one-hot encoding for state.\n", + "The agent begins each episode in the top left corner of the grid and descends one row\n", + "per timestep. Each episode terminates after N steps, when the agent reaches the bottom\n", + "row. In each state there is a random but fixed mapping between actions $A = {0,1}$ and\n", + "the transitions ‘left’ and ‘right’. At each timestep there is a small cost $r = −0.01/N$ of\n", + "moving right, and $r = 0$ for moving left. However, should the agent transition right at every\n", + "timestep of the episode it will be rewarded with an additional reward of $+1$.\n", "\n", - "As $P(D)$ is not parametrized by $Z$ we can rewrite:\n", - "$$D_{KL} = \\int_{Z}Q(Z) log\\frac{Q(Z)}{P(Z,D)}dZ + logP(D)$$\n", + "**Question:** Why is the deep see a challengin exploration problem?\n", "\n", - "We do not care about $logP(D)$ because it does not depend on parameters. So we need to optimize only first part of sum which is called ELBO.\n", - "$$ELBO = E_{Z \\sim Q}[Q(Z) log\\frac{P(Z,D)}{Q(Z)}]$$\n", - "\n", - "Now, rewriting First we rewrite the joint probability $P(Z,D)$ into conditional probability $P(D|Z)P(Z)$ and applying logarithmic rule we have:\n", - "\n", - "$$ELBO = E_{Z \\sim Q}[logP(D|Z)] + E_{Z \\sim Q}[log\\frac{P(Z)}{Q(Z)}]$$\n", - "\n", - "If we rewrite second term in integral form:\n", - "\n", - "$$ELBO = E_{Z \\sim Q}[logP(D|Z)] - D_{KL}(Q(Z)||P(Z))$$\n", - "\n", - "First term is likelihood of data being received from our model and second term is called reconstruction error." + "See full paper [here](https://openreview.net/forum?id=rygf-kSYwH)" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 304, "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m\u001b[37mLoaded bsuite_id: deep_sea/0.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'size': 10, 'mapping_seed': 42}" + ] + }, + "execution_count": 304, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "### 2.1.2 The model\n", - "\n", - "We will model weights of our neural network with normal distribution. Our prior and likelihood:\n", - "\n", - "$$w \\sim N(0, 1)$$\n", - "$$y \\sim P(y|x, w)$$\n", - "\n", - "Our variational distribution is\n", + "import gymnasium as gym\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "from torch import nn\n", + "from time import sleep\n", + "from tqdm import tqdm\n", + "from IPython.display import clear_output\n", "\n", - "$$w \\sim N(\\mu_w, \\sigma_w ^ 2)$$\n", + "from q_learning_agent import QLearningAgent\n", + "from replay_buffer import ReplayBuffer\n", "\n", - "And variational parameters are $\\Theta = (\\mu_w, \\sigma_w)$" + "env = gym.make(\"bsuite/deep_sea-v0\", size=10, seed=42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "**Reparametrization trick**\n", - "\n", - "We cannot backpropagate directly throw $w \\sim N(\\mu_w, \\sigma_w ^ 2)$, so we rewrite this function as\n", - "\n", - "$$w = \\mu + \\sigma * \\epsilon$$\n", - "\n", - "$$\\epsilon \\sim N(0,1)$$\n", - "\n", - "We will also model $\\sigma_w$ as $log(1 + e^{p_w})$ to be able to optimize throw it without constraints.\n", - "\n", - "Let's code it up:" + "## 2.1 Epsilon-greedy q-learning " ] }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 150, "metadata": {}, "outputs": [], "source": [ - "import torch\n", - "from torch import nn\n", + "def get_state_number(s):\n", + " return np.argmax(s.flatten())\n", + "\n", + "def test_agent(agent, greedy=True, delay=.5):\n", + " v = get_all_states_value(agent)\n", + " s, _ = env.reset()\n", + " done = False\n", + " while not done:\n", + " fig, ax = plt.subplots(ncols=2)\n", + " ax[0].imshow(s)\n", + " ax[0].set_title('State')\n", + " im = ax[1].imshow(v)\n", + " plt.colorbar(im)\n", + " ax[1].set_title('Value function')\n", + " clear_output(True)\n", + " plt.show()\n", + " s = get_state_number(s)\n", + " if greedy:\n", + " a = agent.get_best_action(s)\n", + " else:\n", + " a = agent.get_action(s)\n", + "\n", + " s, r, terminated, truncated, _ = env.step(a)\n", + " done = terminated or truncated\n", + " sleep(delay)\n", + "\n", + "def get_all_states_value(agent):\n", + " s_shape = env.observation_space.shape\n", + " s_shape_flatten = np.prod(s_shape)\n", + " v = np.zeros(s_shape_flatten)\n", + " for i in range(s_shape_flatten):\n", + " v[i] = agent.get_value(i)\n", + " v = v.reshape(s_shape)\n", + " return v\n", "\n", - "def reparameterize(mu, p):\n", - " sigma = torch.log(1 + torch.exp(p))\n", - " eps = torch.randn_like(sigma)\n", - " return mu + (eps * sigma)" + "def to_one_hot(x, ndims):\n", + " \"\"\" helper: take an integer vector and convert it to 1-hot matrix. \"\"\"\n", + " x = x.long().view(-1, 1)\n", + " x = torch.zeros(\n", + " x.size()[0], ndims).scatter_(1, x, 1)\n", + " return x" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 151, "metadata": {}, + "outputs": [], "source": [ - "Next, implement KL divergence $D_{KL}(Q(Z)||P(Z)$ in expectation form:" + "agent = QLearningAgent(\n", + " epsilon=1, \n", + " alpha=0.5, \n", + " discount=1, \n", + " get_legal_actions=lambda s: range(env.action_space.n)\n", + ")" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 152, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ - "def kl_divergence(z, mu_theta, p_theta):\n", - " log_prior = torch.distributions.Normal(0, 1).log_prob(z) \n", - " log_p_q = torch.distributions.Normal(mu_theta, torch.log(1 + torch.exp(p_theta))).log_prob(z) \n", - " return (log_p_q - log_prior).sum()" + "test_agent(agent, greedy=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "No we can implement full bayesian layer" + "Let's try to solve this by q-learning with high epsilon!" ] }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "class LinearVariational(nn.Module):\n", - " def __init__(self, in_features, out_features, parent, bias=True):\n", - " super().__init__()\n", - " self.in_features = in_features\n", - " self.out_features = out_features\n", - " self.include_bias = bias \n", - " self.parent = parent\n", - " \n", - " if getattr(parent, 'accumulated_kl_div', None) is None:\n", - " if getattr(parent.parent, 'accumulated_kl_div', None) is None:\n", - " parent.accumulated_kl_div = 0\n", - " else:\n", - " parent.accumulated_kl_div = parent.parent.accumulated_kl_div\n", - " \n", - " # Initialize the variational parameters.\n", - " # 𝑄(𝑤)=N(𝜇_𝜃,𝜎2_𝜃)\n", - " # Do some random initialization with 𝜎=0.001\n", - " self.w_mu = nn.Parameter(\n", - " torch.FloatTensor(in_features, out_features).normal_(mean=0, std=0.001)\n", - " )\n", - " # proxy for variance\n", - " # log(1 + exp(ρ))◦ eps\n", - " self.w_p = nn.Parameter(\n", - " torch.FloatTensor(in_features, out_features).normal_(mean=-3, std=0.001)\n", - " )\n", - " if self.include_bias:\n", - " self.b_mu = nn.Parameter(\n", - " torch.zeros(out_features)\n", - " )\n", - " # proxy for variance\n", - " self.b_p = nn.Parameter(\n", - " torch.zeros(out_features)\n", - " )\n", - "\n", - " def forward(self, x):\n", - " w = reparameterize(self.w_mu, self.w_p)\n", - " \n", - " if self.include_bias:\n", - " b = reparameterize(self.b_mu, self.b_p)\n", - " else:\n", - " b = 0\n", - " \n", - " z = x @ w + b\n", - " \n", - " self.parent.accumulated_kl_div += kl_divergence(w, self.w_mu, self.w_p)\n", - " if self.include_bias:\n", - " self.parent.accumulated_kl_div += kl_divergence(b, self.b_mu, self.b_p)\n", - " return z" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And simple model" + "for i in range(5000):\n", + " s, _ = env.reset()\n", + " done = False\n", + " while not done:\n", + " i_s = get_state_number(s)\n", + " a = agent.get_action(i_s)\n", + " s_next, r, terminated, truncated, _ = env.step(a)\n", + " done = terminated or truncated\n", + " i_s_next = get_state_number(s_next)\n", + " agent.update(i_s, a, r, i_s_next, terminated)\n", + " s = s_next" ] }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ - "class KL:\n", - " accumulated_kl_div = 0\n", - "\n", - "class Model(nn.Module):\n", - " def __init__(self, in_size, hidden_size, out_size):\n", - " super().__init__()\n", - " self.kl_loss = KL\n", - " \n", - " self.layers = nn.Sequential(\n", - " LinearVariational(in_size, hidden_size, self.kl_loss),\n", - " nn.ReLU(),\n", - " LinearVariational(hidden_size, hidden_size, self.kl_loss),\n", - " nn.ReLU(),\n", - " LinearVariational(hidden_size, out_size, self.kl_loss)\n", - " )\n", - " \n", - " @property\n", - " def accumulated_kl_div(self):\n", - " return self.kl_loss.accumulated_kl_div\n", - " \n", - " def reset_kl_div(self):\n", - " self.kl_loss.accumulated_kl_div = 0\n", - " \n", - " def forward(self, x):\n", - " return self.layers(x)" + "test_agent(agent)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Finall loss" + "But if we do bigger env:" ] }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 305, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m\u001b[37mLoaded bsuite_id: deep_sea/1.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'size': 12, 'mapping_seed': 42}" + ] + }, + "execution_count": 305, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "def det_loss(y, y_pred, model):\n", - " batch_size = y.shape[0]\n", - " reconstruction_error = -torch.distributions.Normal(y_pred, .1).log_prob(y).sum()\n", - " kl = model.accumulated_kl_div\n", - " model.reset_kl_div()\n", - " return reconstruction_error + kl" + "env = gym.make(\"bsuite/deep_sea-v0\", size=12, seed=42)" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 317, "metadata": {}, + "outputs": [], "source": [ - "Toy experiment" + "agent = QLearningAgent(\n", + " epsilon=1,\n", + " alpha=0.5,\n", + " discount=1,\n", + " get_legal_actions=lambda s: range(env.action_space.n)\n", + ")\n", + "\n", + "for i in range(5000):\n", + " s, _ = env.reset()\n", + " done = False\n", + " while not done:\n", + " i_s = get_state_number(s)\n", + " a = agent.get_action(i_s)\n", + " s_next, r, terminated, truncated, _ = env.step(a)\n", + " done = terminated or truncated\n", + " i_s_next = get_state_number(s_next)\n", + " agent.update(i_s, a, r, i_s_next, terminated)\n", + " s = s_next" ] }, { "cell_type": "code", - "execution_count": 71, - "metadata": { - "scrolled": false - }, + "execution_count": 318, + "metadata": {}, "outputs": [ { "data": { - "image/png": "\n", + "image/png": "", "text/plain": [ - "
" + "
" ] }, "metadata": { @@ -794,297 +924,140 @@ } ], "source": [ - "import numpy as np\n", - "import torch\n", - "from torch import nn\n", - "from sklearn import datasets\n", - "import matplotlib.pyplot as plt\n", - "\n", - "\n", - "# Generate dataset\n", - "w0 = 0.125\n", - "b0 = 5.\n", - "x_range = [-20, 60]\n", - "\n", - "def load_dataset(n=150, n_tst=150):\n", - " np.random.seed(43)\n", - "\n", - " def s(x):\n", - " g = (x - x_range[0]) / (x_range[1] - x_range[0])\n", - " return 3 * (0.25 + g**2.)\n", - "\n", - " x = (x_range[1] - x_range[0]) * np.random.rand(n) + x_range[0]\n", - " eps = np.random.randn(n) * s(x)\n", - " y = (w0 * x * (1. + np.sin(x)) + b0) + eps\n", - " y = (y - y.mean()) / y.std()\n", - " idx = np.argsort(x)\n", - " x = x[idx]\n", - " y = y[idx]\n", - " return y[:, None], x[:, None]\n", - "\n", - "y, x = load_dataset()\n", - "\n", - "\n", - "# Fit the model\n", - "X = torch.tensor(x, dtype=torch.float)\n", - "Y = torch.tensor(y, dtype=torch.float)\n", - "\n", - "\n", - "epochs = 2000\n", - "m = Model(1, 20, 1)\n", - "optim = torch.optim.Adam(m.parameters(), lr=0.01)\n", - "\n", - "for epoch in range(epochs):\n", - " optim.zero_grad()\n", - " y_pred = m(X)\n", - " loss = det_loss(y_pred, Y, m)\n", - " loss.backward()\n", - " optim.step()\n", - "\n", - "# Sample predictions from model and draw quantiles\n", - "with torch.no_grad():\n", - " trace = np.array([m(X).flatten().numpy() for _ in range(1000)]).T\n", - "\n", - "q_25, q_95 = np.quantile(trace, [0.05, 0.95], axis=1)\n", - "plt.figure(figsize=(16, 6))\n", - "plt.plot(X, trace.mean(1))\n", - "plt.scatter(X, Y)\n", - "plt.fill_between(X.flatten(), q_25, q_95, alpha=0.2)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.1.3 The bandit" + "test_agent(agent)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Load dataset for bandit simulation:" + "## 2.2 Reward shaping" ] }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 308, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "State size: 60, actions: 10\n" - ] - } - ], + "outputs": [], "source": [ - "all_states = np.load(\"all_states.npy\")\n", - "action_rewards = np.load(\"action_rewards.npy\")\n", + "class BaseIntrinsicRewardModule(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", "\n", - "state_size = all_states.shape[1]\n", - "n_actions = action_rewards.shape[1]\n", + " def get_intrinsic_reward(self, state, action, next_state):\n", + " return 0.0\n", "\n", - "print(\"State size: %i, actions: %i\" % (state_size, n_actions))" + " def get_loss(self, state_batch, action_batch, next_state_batch):\n", + " pass" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 309, "metadata": {}, + "outputs": [], "source": [ - "Dataset consists of state vectors and reward vectors which contain reward for each state. \n", - "Reward distribution on arms depends on state vector.\n", + "def train_with_reward(env, agent, reward_module, n_episodes=100, update_reward_period=100, batch_size=100, n_iter=10):\n", + " buffer = ReplayBuffer(size=int(1e6))\n", + " \n", + " if list(reward_module.parameters()):\n", + " optimizer = torch.optim.Adam(reward_module.parameters())\n", + " else:\n", + " optimizer = None\n", "\n", - "$\\epsilon$-greedy contextual bandit:\n", + " losses = []\n", + " s, _ = env.reset()\n", "\n", - "1. Sample a new set of parameters from the model\n", - "2. With probability $\\epsilon$ pick random action, with probability $(1 - \\epsilon)$ pick action with best expected reward.\n", - "3. Update the model, go to 1\n", + " for i in range(n_episodes):\n", + " done = False\n", + " \n", + " while not done:\n", + " i_s = get_state_number(s)\n", + " a = agent.get_action(i_s)\n", + " s_next, r, terminated, truncated, _ = env.step(a)\n", + " done = terminated or truncated\n", + " i_s_next = get_state_number(s_next)\n", + " \n", + " state_t = torch.tensor(s).float().view(1, -1)\n", + " action_t = torch.tensor(a).float().view(1, -1)\n", + " next_state_t = torch.tensor(s_next).float().view(1, -1)\n", + "\n", + " r_intr = reward_module.get_intrinsic_reward(state_t, action_t, next_state_t)\n", + " r += r_intr\n", "\n", + " agent.update(i_s, a, r, i_s_next, terminated)\n", + " buffer.add(s, a, r, s_next, terminated)\n", "\n", - "Let's make an agent:" + " s = s_next\n", + "\n", + " if (i + 1) % update_reward_period == 0 and optimizer is not None:\n", + " \n", + " for _ in range(n_iter):\n", + " optimizer.zero_grad()\n", + " state_batch, action_batch, _, next_state_batch, _ = buffer.sample(batch_size)\n", + " \n", + " state_tensor = torch.tensor(state_batch).float().flatten(1, 2)\n", + " action_tensor = torch.tensor(action_batch).float().view(-1, 1)\n", + " next_state_tensor = torch.tensor(next_state_batch).float().flatten(1, 2)\n", + " \n", + " loss = reward_module.get_loss(state_tensor, action_tensor, next_state_tensor)\n", + " loss.backward()\n", + " optimizer.step()\n", + " losses.append(loss.item())\n", + " \n", + " fig, ax = plt.subplots(ncols=2)\n", + " ax[0].set_title('Value function after iter: %d' % i)\n", + " im = ax[0].imshow(get_all_states_value(agent))\n", + " ax[1].plot(losses)\n", + " ax[1].set_title('Random network distillation loss')\n", + " clear_output(True)\n", + " plt.show()" ] }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 310, "metadata": {}, "outputs": [], "source": [ - "class BNNAgent(nn.Module):\n", - " def __init__(self, in_size, hidden_size, out_size):\n", + "class GoRightReward(BaseIntrinsicRewardModule):\n", + " def __init__(self):\n", " super().__init__()\n", "\n", - " self.out_size = out_size\n", - " self.kl_loss = KL\n", - " \n", - " self.layers = nn.Sequential(\n", - " LinearVariational(in_size, hidden_size, self.kl_loss),\n", - " nn.Tanh(),\n", - " LinearVariational(hidden_size, hidden_size, self.kl_loss),\n", - " nn.Tanh(),\n", - " LinearVariational(hidden_size, out_size, self.kl_loss)\n", - " )\n", - " \n", - " self.epsilon = .25\n", - " \n", - " def forward(self, x):\n", - " return self.layers(x)\n", - " \n", - " def sample_prediction(self, states, n_samples=1):\n", - " return np.stack([\n", - " self(torch.Tensor(states)).detach().numpy() \n", - " for _ in range(n_samples)])\n", - " \n", - " def get_action(self, states):\n", - " n_samples = 100\n", - " \n", - " reward_samples = self.sample_prediction(states, n_samples=n_samples)\n", - " best_actions = reward_samples.mean(axis=0).argmax(axis=-1)\n", - " random_actions = np.random.randint(0, self.out_size, len(states))\n", - " \n", - " chosen_actions = np.array(\n", - " [\n", - " random_actions[i] if np.random.random() < self.epsilon \n", - " else a for i, a in enumerate(best_actions)\n", - " ]\n", - " )\n", - " \n", - " return chosen_actions\n", - " \n", - " @property\n", - " def accumulated_kl_div(self):\n", - " return self.kl_loss.accumulated_kl_div\n", - " \n", - " def reset_kl_div(self):\n", - " self.kl_loss.accumulated_kl_div = 0" + " def get_intrinsic_reward(self, state, action, next_state):\n", + " # " ] }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 311, "metadata": {}, "outputs": [], "source": [ - "from IPython.display import clear_output\n", - "\n", - "from pandas import DataFrame\n", - "moving_average = lambda x, **kw: DataFrame(\n", - " {'x': np.asarray(x)}).x.ewm(**kw).mean().values\n", - "\n", - "\n", - "def get_new_samples(states, action_rewards, batch_size=8):\n", - " \"\"\"samples random minibatch, emulating new users\"\"\"\n", - " batch_ix = np.random.randint(0, len(states), batch_size)\n", - " return states[batch_ix], action_rewards[batch_ix]\n", - "\n", - "\n", - "def train_contextual_agent(agent, batch_size=32, n_iters=100):\n", - " total_samples = 0\n", - " optim = torch.optim.Adam(agent.parameters(), lr=0.01)\n", - " \n", - " rewards_history = []\n", - " \n", - " for i in range(n_iters):\n", - " # Sample batch of bandit states\n", - " b_states, b_action_rewards = get_new_samples(all_states, action_rewards, batch_size)\n", - " \n", - " # Get actions from bandit\n", - " b_actions = agent.get_action(b_states)\n", - " \n", - " # Get rewards\n", - " b_rewards = b_action_rewards[np.arange(batch_size), b_actions]\n", - " \n", - " mse, kl = 0, 0\n", - " \n", - " # Update model\n", - " for _ in range(25):\n", - " optim.zero_grad()\n", - "\n", - " action_preds = agent.forward(torch.Tensor(b_states)) \n", - " y = action_preds.gather(1, torch.LongTensor([b_actions]).T).T[0]\n", - " \n", - " # loglikelihood can be replaced with mse loss for normal distributions\n", - " loss = ((torch.Tensor(b_rewards) - y)**2).sum()\n", - " loss += agent.accumulated_kl_div / (total_samples + batch_size)\n", - " \n", - " kl += agent.accumulated_kl_div / (total_samples + batch_size)\n", - " mse += loss\n", + "agent = QLearningAgent(\n", + " epsilon=.1, \n", + " alpha=0.5, \n", + " discount=.9, \n", + " get_legal_actions=lambda s: range(env.action_space.n)\n", + ")\n", "\n", - " loss.backward()\n", - " agent.reset_kl_div()\n", - " optim.step() \n", + "go_right = GoRightReward()\n", "\n", - " rewards_history.append(b_rewards.mean())\n", - " total_samples += batch_size\n", - " \n", - " # Plot some graphs\n", - " if i % 10 == 0:\n", - " clear_output(True)\n", - " print(\"iteration #%i\\tmean reward=%.3f\\tmse=%.3f\\tkl=%.3f\" %\n", - " (i, np.mean(rewards_history[-10:]), mse, kl))\n", - " plt.plot(rewards_history)\n", - " plt.plot(moving_average(np.array(rewards_history), alpha=0.1))\n", - " plt.title(\"Reward per epesode\")\n", - " plt.xlabel(\"Episode\")\n", - " plt.ylabel(\"Reward\")\n", - " plt.show()\n", - "\n", - " samples = agent.sample_prediction(\n", - " b_states[:1], n_samples=100).T[:, 0, :]\n", - " for i in range(len(samples)):\n", - " plt.hist(samples[i], alpha=0.25, label=str(i))\n", - " plt.legend(loc='best')\n", - " print('Q(s,a) std:', ';'.join(\n", - " list(map('{:.3f}'.format, np.std(samples, axis=1)))))\n", - " print('correct', b_action_rewards[0].argmax())\n", - " print('rewards', b_action_rewards[0])\n", - " plt.title(\"p(Q(s, a))\")\n", - " plt.show()\n", - "\n", - " return moving_average(np.array(rewards_history), alpha=0.1)" + "train_with_reward(env, agent, go_right, n_episodes=500)" ] }, { "cell_type": "code", - "execution_count": 75, - "metadata": {}, + "execution_count": 312, + "metadata": { + "scrolled": true + }, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "iteration #490\tmean reward=0.703\tmse=60.592\tkl=28.053\n" - ] - }, - { - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Q(s,a) std: 0.069;0.115;0.075;0.087;0.073;0.149;0.156;0.194;0.114;0.076\n", - "correct 1\n", - "rewards [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n" - ] - }, { "data": { - "image/png": "\n", + "image/png": "", "text/plain": [ - "
" + "
" ] }, "metadata": { @@ -1094,624 +1067,422 @@ } ], "source": [ - "N_ITERS=500\n", - "agent = BNNAgent(in_size=state_size, hidden_size=32, out_size=n_actions)\n", - "greedy_agent_rewards = train_contextual_agent(agent, batch_size=32, n_iters=N_ITERS)" + "test_agent(agent, greedy=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## HW 2.1 Better exploration\n", - "\n", - "Use strategies from first part to gain more reward in contextual setting" + "## 2.3 Curiosity-driven Exploration" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 313, "metadata": {}, "outputs": [], "source": [ - "class ThompsonBNNAgent(BNNAgent):\n", - " def get_action(self, states):\n", - " \"\"\"\n", - " picks action based by taking _one_ sample from BNN and taking action with highest sampled reward (yes, that simple)\n", - " This is exactly thompson sampling.\n", - " \"\"\"\n", - "\n", - " " + "class MLP(nn.Module):\n", + " def __init__(self, input_size, hidden_size, output_size):\n", + " super().__init__()\n", + " self.layers = nn.Sequential(\n", + " nn.Linear(input_size, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_size, output_size)\n", + " )\n", + " \n", + " def init_weights(tensor):\n", + " if isinstance(tensor, nn.Linear):\n", + " nn.init.xavier_uniform_(tensor.weight)\n", + " \n", + " self.layers.apply(init_weights)\n", + " \n", + " \n", + " def forward(self, x):\n", + " return self.layers(x)" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "iteration #90\tmean reward=0.360\tmse=0.590\tkl=0.038\n" - ] - }, - { - "data": { - "image/png": "\n", - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Q(s,a) std: 0.000;0.028;0.277;0.000;0.044;0.059;0.063;0.093;0.000;0.018\n", - "correct 2\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAFo9JREFUeJzt3XuQ1eWd5/H3NzRIRCIIzcVutBs1Cl4GEYKWLiE6GmRSeIEYjBlRIeymMjuw2dTq7FZtCmsqManNGk2y2aDxMnFKkmGthSWGDUEtjRnFBjQhYS1RVJogtC0gV/vis3/00bQE6KbP7/Tp/vX7VdV1frfzPN9zWj/8+umnnxMpJSRJ+fWxchcgSSotg16Scs6gl6ScM+glKecMeknKOYNeknLOoJeknDPoJSnnDHpJyrmKchcAMHz48FRTU1PuMiSpV1m3bt3bKaXKjq7rEUFfU1NDXV1ducuQpF4lIt7ozHUO3UhSzhn0kpRzBr0k5VyPGKOXpHJpbm6mvr6eQ4cOlbuUoxo4cCDV1dX079+/S8836CX1afX19QwePJiamhoiotzl/IWUEo2NjdTX11NbW9ulNjocuomIByJiZ0RsbHfslIhYHRGvFB6HFo5HRNwbEZsj4ncRMbFLVUlSNzl06BDDhg3rkSEPEBEMGzasqJ84OjNG/xAw/bBjdwBrUkpnAWsK+wBXA2cVvhYAP+pyZZLUTXpqyH+g2Po6DPqU0tPAO4cdvgZ4uLD9MHBtu+P/lNo8BwyJiNFFVShJKkpXx+hHppS2F7bfAkYWtquAre2uqy8c244k9QL/+mpjpu1dcsawDq9ZtWoVCxcupLW1lfnz53PHHXd0+JzjUfQvY1NKKSKO+xPGI2IBbcM7nHbaacWW0af964ofHPdzLpn5dyWoRNLxam1t5atf/SqrV6+murqayZMnM3PmTMaPH59ZH12dR7/jgyGZwuPOwvFtwJh211UXjv2FlNKSlNKklNKkysoOl2qQpFxau3YtZ555JmPHjmXAgAHMmTOH5cuXZ9pHV4N+BTC3sD0XWN7u+M2F2TcXA3vaDfFIkg6zbds2xoz58/1xdXU127Yd8f64yzocuomIR4FpwPCIqAe+AdwF/Dwi5gFvADcULn8cmAFsBg4At2ZarSTpuHUY9CmlG49y6oojXJuArxZblI7Pxn4nA3Be654yVyLpeFVVVbF165/nsNTX11NVVZVpH651I0llNHnyZF555RW2bNlCU1MTS5cuZebMmZn24RIIktROZ6ZDZqmiooIf/OAHfPazn6W1tZXbbruNc889N9s+Mm1NknTcZsyYwYwZM0rWvkM3kpRzBr0k5ZxBL0k5Z9BLUs4Z9JKUcwa9JOWc0ytzoOWdto8LaD74Nv1HjezgaknHtOWZbNur/TcdXnLbbbexcuVKRowYwcaNGzu8/nh5Ry9JZXbLLbewatWqkrVv0EtSmU2dOpVTTjmlZO0b9JKUcwa9JOWcQS9JOWfQ50BF6wAqWgcwoOUEKg74LZX0UU6vlKT2OjEdMms33ngjTz31FG+//TbV1dUsXryYefPmZda+QS9JZfboo4+WtH1/zpeknDPoJSnnDHpJyjmDXpJyzqCXpJwz6CUp55xemQP7328CYM/7h+j33j5aI33k/Cc+PqgcZUm90gtvvZBpe5NHTT7m+a1bt3LzzTezY8cOIoIFCxawcOHCTGsw6CWpjCoqKvjud7/LxIkT2bt3LxdddBFXXnkl48ePz6wPh25yoKW5iZbmJpreO1TuUiQdp9GjRzNx4kQABg8ezLhx49i2bVumfRj0ktRDvP7662zYsIEpU6Zk2q5BL0k9wL59+5g1axbf+973+MQnPpFp2wa9JJVZc3Mzs2bN4qabbuL666/PvH2DXpLKKKXEvHnzGDduHF/72tdK0oezbiSpnY6mQ2bt2Wef5ac//Snnn38+EyZMAOCb3/wmM2bMyKyPooI+Iv4DMB9IwO+BW4HRwFJgGLAO+NuUUlORdUpSLl122WWklDq+sAhdHrqJiCrg74FJKaXzgH7AHODbwN0ppTOBXUB2q+dLko5bsWP0FcDHI6ICOBHYDlwOLCucfxi4tsg+JElF6HLQp5S2Af8NeJO2gN9D21DN7pRSS+GyeqCq2CIlSV1XzNDNUOAaoBY4FRgETD+O5y+IiLqIqGtoaOhqGZKkDhQzdPPXwJaUUkNKqRl4DLgUGFIYygGoBo74t7wppSUppUkppUmVlZVFlCFJOpZigv5N4OKIODEiArgC+CPwJDC7cM1cYHlxJUqSitHl6ZUppecjYhmwHmgBNgBLgF8ASyPiHwvHfpJFoZLUHfY/vzbT9gZN+dQxzx86dIipU6fy3nvv0dLSwuzZs1m8eHGmNRQ1jz6l9A3gG4cdfg049iuTJAFwwgkn8MQTT3DSSSfR3NzMZZddxtVXX83FF1+cWR8ugSBJZRQRnHTSSUDbmjfNzc20jYZnx6CXpDJrbW1lwoQJjBgxgiuvvNJliiUpb/r168eLL75IfX09a9euZePGjZm2b9BLUg8xZMgQPvOZz7Bq1apM2zXoJamMGhoa2L17NwAHDx5k9erVnHPOOZn24TLFktROR9Mhs7Z9+3bmzp1La2sr77//PjfccAOf+9znMu3DoJekMrrgggvYsGFDSftw6EaScs6gl6ScM+glKecMeknKOYNeknLOoJeknHN6pSS1s+3lXZm2V3X20E5d19rayqRJk6iqqmLlypWZ1uAdvST1APfccw/jxo0rSdsGvSSVWX19Pb/4xS+YP39+Sdo36CWpzBYtWsR3vvMdPvax0kSyQS9JZbRy5UpGjBjBRRddVLI+DHpJKqNnn32WFStWUFNTw5w5c3jiiSf40pe+lGkfBr0kldG3vvUt6uvref3111m6dCmXX345jzzySKZ9OL1Sktrp7HTI3sSgl6QeYtq0aUybNi3zdh26kaScM+glKecMeknKOYNeknLOoJeknDPoJSnnnF4pSe1s/cPvMm1vzLkXdHhNTU0NgwcPpl+/flRUVFBXV5dpDQa9JPUATz75JMOHDy9J2w7dSFLOGfSSVGYRwVVXXcVFF13EkiVLMm/foRtJKrPf/OY3VFVVsXPnTq688krOOeccpk6dmln7Rd3RR8SQiFgWEf8vIjZFxCURcUpErI6IVwqP+VshSJIyVFVVBcCIESO47rrrWLt2babtFzt0cw+wKqV0DvBXwCbgDmBNSuksYE1hX5J0BPv372fv3r0fbv/qV7/ivPPOy7SPLg/dRMTJwFTgFoCUUhPQFBHXANMKlz0MPAXcXkyRktRdOjMdMks7duzguuuuA6ClpYUvfvGLTJ8+PdM+ihmjrwUagAcj4q+AdcBCYGRKaXvhmreAkcWVKEn5NXbsWF566aWS9lHM0E0FMBH4UUrpQmA/hw3TpJQSkI705IhYEBF1EVHX0NBQRBmSpGMpJujrgfqU0vOF/WW0Bf+OiBgNUHjceaQnp5SWpJQmpZQmVVZWFlGGJOlYuhz0KaW3gK0RcXbh0BXAH4EVwNzCsbnA8qIqlCQVpdh59P8e+OeIGAC8BtxK2z8eP4+IecAbwA1F9iFJKkJRQZ9SehGYdIRTVxTTriQpOy6BIEk55xIIktTOoVd3Z9rewDOGdHjN7t27mT9/Phs3biQieOCBB7jkkksyq8Ggl6QyW7hwIdOnT2fZsmU0NTVx4MCBTNs36CWpjPbs2cPTTz/NQw89BMCAAQMYMGBApn04Ri9JZbRlyxYqKyu59dZbufDCC5k/fz779+/PtA+DXpLKqKWlhfXr1/OVr3yFDRs2MGjQIO66665M+zDoJamMqqurqa6uZsqUKQDMnj2b9evXZ9qHQS9JZTRq1CjGjBnDyy+/DMCaNWsYP358pn34y1hJaqcz0yGz9v3vf5+bbrqJpqYmxo4dy4MPPphp+wa9JJXZhAkTqKurK1n7Dt1IUs4Z9JKUcwa9JOWcQS9JOWfQS1LOGfSSlHNOr5SkdrZs2ZJpe7W1tcc8//LLL/OFL3zhw/3XXnuNO++8k0WLFmVWg0EvSWV09tln8+KLLwLQ2tpKVVUV1113XaZ9OHQjST3EmjVrOOOMMzj99NMzbdegl6QeYunSpdx4442Zt2vQS1IP0NTUxIoVK/j85z+fedsGvST1AL/85S+ZOHEiI0eOzLxtg16SeoBHH320JMM24KwbSfqIjqZDlsL+/ftZvXo1P/7xj0vSvkEvSWU2aNAgGhsbS9a+QzeSlHMGvSTlnEEvSTln0EtSzhn0kpRzBr0k5ZzTKyWpnV27nsu0vaFDL+7wmrvvvpv777+fiOD888/nwQcfZODAgZnVUPQdfUT0i4gNEbGysF8bEc9HxOaI+FlEDCi+TEnKp23btnHvvfdSV1fHxo0baW1tZenSpZn2kcXQzUJgU7v9bwN3p5TOBHYB8zLoQ5Jyq6WlhYMHD9LS0sKBAwc49dRTM22/qKCPiGrgb4D7C/sBXA4sK1zyMHBtMX1IUp5VVVXx9a9/ndNOO43Ro0dz8sknc9VVV2XaR7F39N8D/hPwfmF/GLA7pdRS2K8HqorsQ5Jya9euXSxfvpwtW7bwpz/9if379/PII49k2keXgz4iPgfsTCmt6+LzF0REXUTUNTQ0dLUMSerVfv3rX1NbW0tlZSX9+/fn+uuv57e//W2mfRRzR38pMDMiXgeW0jZkcw8wJCI+mM1TDWw70pNTSktSSpNSSpMqKyuLKEOSeq/TTjuN5557jgMHDpBSYs2aNYwbNy7TPro8vTKl9A/APwBExDTg6ymlmyLiX4DZtIX/XGB5BnVKUrfozHTILE2ZMoXZs2czceJEKioquPDCC1mwYEGmfZRiHv3twNKI+EdgA/CTEvQhSbmxePFiFi9eXLL2Mwn6lNJTwFOF7deAT2XRriSpeC6BIEk5Z9BLUs4Z9JKUcwa9JOWcQS9JOecyxZLUzrO79mba3qVDB3d4zT333MN9991HSokvf/nLLFq0KNMavKOXpDLauHEj9913H2vXruWll15i5cqVbN68OdM+DHpJKqNNmzYxZcoUTjzxRCoqKvj0pz/NY489lmkfBr0kldF5553HM888Q2NjIwcOHODxxx9n69atmfbhGL0kldG4ceO4/fbbueqqqxg0aBATJkygX79+mfbhHb0kldm8efNYt24dTz/9NEOHDuWTn/xkpu17Ry9JZbZz505GjBjBm2++yWOPPcZzz2X7AeUGvSS105npkFmbNWsWjY2N9O/fnx/+8IcMGTIk0/YNekkqs2eeeaak7TtGL0k5Z9BLUs4Z9JL6vJRSuUs4pmLrM+gl9WkDBw6ksbGxx4Z9SonGxkYGDhzY5Tb8ZaykPq26upr6+noaGhrKXcpRDRw4kOrq6i4/36CX1Kf179+f2tracpdRUg7dSFLOGfSSlHMGvSTlnEEvSTln0EtSzhn0kpRzBr0k5ZxBL0k5Z9BLUs4Z9JKUcy6BIPVRL7z1QqbtTR41OdP2lB3v6CUp5wx6Scq5Lgd9RIyJiCcj4o8R8YeIWFg4fkpErI6IVwqPQ7MrV5J0vIoZo28B/mNKaX1EDAbWRcRq4BZgTUrproi4A7gDuL34UiVlPa6uvqHLd/Qppe0ppfWF7b3AJqAKuAZ4uHDZw8C1xRYpSeq6TMboI6IGuBB4HhiZUtpeOPUWMPIoz1kQEXURUdeTP9lFknq7ooM+Ik4C/hewKKX0bvtzqe1DGI/4QYwppSUppUkppUmVlZXFliFJOoqi5tFHRH/aQv6fU0qPFQ7viIjRKaXtETEa2FlskVJv5Zi6eoJiZt0E8BNgU0rpv7c7tQKYW9ieCyzvenmSpGIVc0d/KfC3wO8j4sXCsf8M3AX8PCLmAW8ANxRXoiSpGF0O+pTSb4A4yukrutquJClb/mWsJOWcQS9JOWfQS1LOGfSSlHMGvSTlnEEvSTln0EtSzhn0kpRzBr0k5ZxBL0k5Z9BLUs4Z9JKUcwa9JOWcQS9JOWfQS1LOFfVRgpL0gZ78sYmTR00udwll5R29JOWcQS9JOWfQS1LOGfSSlHMGvSTlnLNu+oB3D+7/i2Nbtmz5cLu2trY7y+nxevLsEakrvKOXpJwz6CUp5wx6Sco5g16Scs6gl6ScM+glKeecXtnHHNy7G4CG11+jsmYs8NGplodz6qXyIMsps71xgTTv6CUp57yjz5n3Duyjubn5w/2PDx7SLf1u/cPvPtwec+4F3dKnpM7xjl6Scq4kd/QRMR24B+gH3J9SuqsU/ejI+jf3p1/zuwCkAwcBaOnX9sgJp/zF9W9v3ffh9vAxJx2z7UOv7ubtrXs/cmz4mMHwVsufD5zbcY3dtcxA47bGY54fVjWsW+pQfmT93253jPlnfkcfEf2AHwJXA+OBGyNifNb9SJI6pxR39J8CNqeUXgOIiKXANcAfS9CXDtOv+V02DRoFA06ktaUfZ+/b0XaieXDb+UOF8fvXt9LwTisA+3a/B8Dgs2oAWH+w7e78vF2HOPR+2yydV9as/nMng0/9cLPhwDv8qfGEj9TQctiCaYde3f2R8wPP6J7fG0hqU4ox+ipga7v9+sIxSVIZlG3WTUQsABYUdvdFxMtdbGo48HY2VfVaff096OuvH3wP+urrP70zF5Ui6LcBY9rtVxeOfURKaQmwpNjOIqIupTSp2HZ6s77+HvT11w++B3399XekFEM3LwBnRURtRAwA5gArStCPJKkTMr+jTym1RMTfAf+XtumVD6SU/pB1P5KkzinJGH1K6XHg8VK0fQRFD//kQF9/D/r66wffg77++o8pUkrlrkGSVEIugSBJOdfrgj4iTomI1RHxSuFx6FGua42IFwtfvf6XwRExPSJejojNEXHHEc6fEBE/K5x/PiJqur/K0urEe3BLRDS0+77PL0edpRIRD0TEzojYeJTzERH3Ft6f30XExO6usZQ68fqnRcSedt///9rdNfZUvS7ogTuANSmls4A1hf0jOZhSmlD4mtl95WWvk8tKzAN2pZTOBO4Gvt29VZbWcSyt8bN23/f7u7XI0nsImH6M81cDZxW+FgA/6oaautNDHPv1AzzT7vt/ZzfU1Cv0xqC/Bni4sP0wcG0Za+kuHy4rkVJqAj5YVqK99u/LMuCKiIhurLHUOvMe5FpK6WngnWNccg3wT6nNc8CQiBjdPdWVXidev46iNwb9yJTS9sL2W8DIo1w3MCLqIuK5iOjt/xh0ZlmJD69JKbUAe4A8Lc3Y2aU1ZhWGLZZFxJgjnM8zlx+BSyLipYj4ZUR0Yh3VvqFHfvBIRPwaGHWEU/+l/U5KKUXE0aYNnZ5S2hYRY4EnIuL3KaVXs65VPcr/AR5NKb0XEf+Wtp9wLi9zTeo+62n7/35fRMwA/jdtw1h9Xo8M+pTSXx/tXETsiIjRKaXthR9Ldx6ljW2Fx9ci4ingQqC3Bn1nlpX44Jr6iKgATgaOvRh779Lhe5BSav967we+0w119SSdWn4kr1JK77bbfjwi/kdEDE8p9cU1cD6iNw7drADmFrbnAssPvyAihkbECYXt4cCl9O5lkjuzrET792U28ETK1x9JdPgeHDYePRPY1I319QQrgJsLs28uBva0G+bMvYgY9cHvpSLiU7TlW55udrqsR97Rd+Au4OcRMQ94A7gBICImAf8upTQfGAf8OCLep+2bfVdKqdcG/dGWlYiIO4G6lNIK4CfATyNiM22/sJpTvoqz18n34O8jYibQQtt7cEvZCi6BiHgUmAYMj4h64BtAf4CU0v+k7a/RZwCbgQPAreWptDQ68fpnA1+JiBbgIDAnZzc7XeZfxkpSzvXGoRtJ0nEw6CUp5wx6Sco5g16Scs6gl6ScM+glKecMeknKOYNeknLu/wNbUSqUIPBwvQAAAABJRU5ErkJggg==\n", - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.5/dist-packages/ipykernel_launcher.py:32: FutureWarning: pd.ewm_mean is deprecated for ndarrays and will be removed in a future version\n" - ] - } - ], "source": [ - "thompson_agent_rewards = train_contextual_agent(ThompsonBNNAgent(in_size=state_size, hidden_size=32, out_size=n_actions),\n", - " batch_size=10, n_iters=N_ITERS)" + "### 2.3.1 Uncertainty with forward dynamics" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 319, "metadata": {}, "outputs": [], "source": [ - "class BayesUCBBNNAgent(BNNAgent):\n", - " q = 90\n", - "\n", - " def get_action(self, states):\n", - " \"\"\"\n", - " Compute q-th percentile of rewards P(r|s,a) for all actions\n", - " Take actions that have highest percentiles.\n", + "class ForwardDynamics(BaseIntrinsicRewardModule):\n", + " def __init__(self, states_size, actions_size, hidden_size, alpha=.1):\n", + " super().__init__()\n", + " self.module = MLP(\n", + " actions_size + states_size,\n", + " hidden_size,\n", + " states_size\n", + " )\n", + " self.alpha = alpha\n", + " self.mean_reward = 0\n", + " \n", + " def forward(self, s, a):\n", + " sa = torch.cat([s, a], dim=-1)\n", + " return s + self.module(sa)\n", "\n", - " This implements bayesian UCB strategy\n", - " \"\"\"\n", + " def get_intrinsic_reward(self, state, action, next_state):\n", + " with torch.no_grad():\n", + " r = # \n", + " r_centered = r - self.mean_reward\n", + " self.mean_reward = self.alpha * (r) + (1 - self.alpha) * self.mean_reward\n", + " return r_centered\n", "\n", - " " + " def get_loss(self, state_batch, action_batch, next_state_batch):\n", + " # " ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 320, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "iteration #90\tmean reward=0.630\tmse=0.354\tkl=0.047\n" - ] - }, { "data": { - "image/png": "\n", + "image/png": "", "text/plain": [ - "" + "
" ] }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Q(s,a) std: 0.067;0.027;0.093;0.069;0.014;0.148;0.173;0.026;0.043;0.101\n", - "correct 5\n" - ] - }, - { - "data": { - "image/png": "\n", - "text/plain": [ - "" - ] + "metadata": { + "needs_background": "light" }, - "metadata": {}, "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.5/dist-packages/ipykernel_launcher.py:32: FutureWarning: pd.ewm_mean is deprecated for ndarrays and will be removed in a future version\n" - ] } ], "source": [ - "ucb_agent_rewards = train_contextual_agent(BayesUCBBNNAgent(in_size=state_size, hidden_size=32, out_size=n_actions),\n", - " batch_size=10, n_iters=N_ITERS)" + "agent = QLearningAgent(\n", + " epsilon=.1, \n", + " alpha=0.5, \n", + " discount=.9, \n", + " get_legal_actions=lambda s: range(env.action_space.n)\n", + ")\n", + "\n", + "forward_dynamics = ForwardDynamics(\n", + " np.prod(env.observation_space.shape), \n", + " 1, \n", + " 16\n", + ")\n", + "\n", + "train_with_reward(env, agent, forward_dynamics, n_episodes=2000, update_reward_period=100, batch_size=100, n_iter=25)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 321, "metadata": {}, "outputs": [ { "data": { - "image/png": "\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAADnCAYAAAAeqiGTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAdwklEQVR4nO3de7RdZXnv8e8vIRcuuWEUkERAiCK2FGgKUqxiQQ3IINZDPeEMQThYhFNa8dYDbYeeYUuP9oLKwcJJgQKKXIqXpo4cQrwgIpKSQLhGNEYwiYEQAkmAQJK9n/PHnDvO7Ky917v3mmutudb6fcaYI3vN+e4537lhrWfN9/K8igjMzMzGtLsCZmZWDQ4IZmYGOCCYmVnOAcHMzAAHBDMzy+3R7gqYmXWD975r73huY19S2WUPv7ooIuY0uUoj5oBgZlaCDRv7WLJoRlLZcQf8YnqTqzMqDghmZqUI+qK/3ZVoiAOCmVkJAuinsyf6OiCYmZWkHz8hmJn1vCDY7iYjMzMLoM9NRmZmBu5DMDMz8ieEDs8e7YBgZlaSzu5BcOoKM7NSBEFf4jYcSTMl/UDS45Iek/SxGmUk6QpJKyU9LOmYMu7BTwhmZiWIgO3ltBjtAD4ZEQ9ImgQsk7Q4Ih4vlDkFmJVvxwFX5f82xE8IZmalEH2J23AiYl1EPJD/vAVYARw4qNhc4MbI3AdMlXRAo3fgJwQzsxIE0J/+hDBd0tLC6/kRMX9wIUkHA0cDSwYdOhBYXXi9Jt+3LrkGNTggmJmVpN63/4INETF7uAKS9gG+AVwcEZsbrVsKBwQzsxJkE9OSA8KwJI0jCwY3RcQ3axRZC8wsvJ6R72uI+xDMzEoQwPYYk7QNR5KAa4EVEXH5EMUWAGfno43eBmyKiIaai8BPCGZmpQhEXznfsU8AzgIekbQ83/eXwBsAIuJqYCFwKrASeBk4t4wLOyCYmZWkPxpvMoqIe2D4tqeICOBPG77YIG4yagFJb5d0r6RNkjZK+rGk35N0jqR7RnCegyWFJAdys4oZ6ENodNhpO/mDpckkTQa+A1wI3AaMB/4AeLWd9TKzsom+Ov0DVdfZte8MbwKIiJsjoi8itkbEncB24GrgeEkvSnoBQNL7JD0oabOk1ZL+V+Fcd+f/vpD/zvH57/x3SSskPS9pkaSDWnZ3ZgYMrJg2JmmrqurWrHv8DOiTdIOkUyRNA4iIFcAFwE8iYp+ImJqXfwk4G5gKvA+4UNL782PvyP+dmv/OTyTNJetw+gDwWuBHwM3Nvy0zK4oQ22Js0lZVDghNlk8oeTvZF4h/AZ6VtEDSfkOUvysiHomI/oh4mOzD/Z3DXOIC4H9HxIqI2AH8HXCUnxLMWq8fJW1V5YDQAvmH9TkRMQP4LeD1wJdqlZV0XJ7p8FlJm8g+8KcPc/qDgC9LeiFvdtpINkJhcO4TM2uirFN5TNJWVdWtWZeKiJ8C15MFhlqZT75ONulkZkRMIetnGPhKUav8auCjETG1sO0ZEfeWX3szG1rWqZyyVVV1a9YlJB0u6ZOSZuSvZwJnAvcBzwAzJI0v/MokYGNEvCLpWOC/FY49S7YGxxsL+64GLpX01vz8UyT9cfPuyMxq6YZOZQ87bb4tZHnKPyFpKvAC2TDUTwOvAI8BT0vqj4jpwP8A/knSlcAPyYaqTgWIiJclXQb8OM91MicivpUnwbol7zfYBCwG/q11t2hmAH0lTExrJ0WHrwFqZlYFh/723vF33zoiqey8WUuX1ct22g5+QjAzK8FAp3Inc0AwMytBoI5vMnJAMDMrSZU7jFM4IJiZlSCCSg8pTdHSgDBeE2Iie7fyktZDXuEltsWrnf3M3gBJAcyKiJUln3c/slFrR5Ot/fvJMs9f59ovAkdGxKpWXXO0ArG9wmkpUrQ0IExkb47TSa28pPWQJfG9dlehIZLuAP4zIj4zaP9c4P8CM/L0JK12PrABmBxNHJYo6S7gaxFxzcC+iNinWddrhk7vVO7s2pt1lxuAD+VLKBadRba2bjuCAWTpUR5vZjDoBoHoj7StqhoKCJLmSHpC0kpJl5RVKbMe9W3gNWTrZQCQZ8c9DbhR0rGSfpLnrVon6cpBs9wp/N5dkj5SeL3LYkz5DPrF+YJNT0j64BDnuR74MPAXecr1kyVdL+lvC2VOlLSm8PpJSZ+S9HC+KNStkiYWjs+VtDxP8f6L/HPksvy+r8yvc2VeNiQdlv88RdKNeZ6vpyT9taQxxfuT9I95GvhfSjol/U9fjp7NZSRpLPAV4BTgCOBMSWmzMsxsNxGxlWxm+tmF3R8EfhoRDwF9wMfJkh0eD5xENrN9RCTtTTab/evA64B5wD/Xev9GxDnATcDf5ynXv5t4mQ8Cc4BDgCOBc/JrHwvcSDZTfypZSvcnI+KvyFK3X5Rf56Ia5/w/wBSy1C3vJPs7FdcSPg54guzv8/fAtTWetpomgP4Yk7RVVSM1OxZYGRGrImIbcAswt5xqmfWsG4AzCt+oz873ERHLIuK+iNgREU+S9SsMlxp9KKeRfQj/a36uB4FvAGXmwLoiIn4dERuB/wCOyvefB1wXEYvzFO9r84SPw8q/gM4DLo2ILfn9/xNZc9qApyLiXyKij+xvdgBQM818c6Qtn9mtS2geSJZpc8Aasgi9C0nnk3VKMZG9GricWfeLiHskbQDeL+l+si9eHwCQ9CbgcmA2sBfZ+3fZKC5zEHDcwCp9uT2ArzZQ9cGeLvz8MlnKd4CZwMJRnG86MA54qrDvKXZN877zmnneL4CWdUoHeJRRPRExH5gPMFn7ulPKrL4byZ4M3gwsiohn8v1XAQ8CZ0bEFkkXA2cMcY6XYJdvYPsXfl4N/DAi3j3K+g137npWA4cOcWy4z4cNZMvOHgQ8nu97A7B2BNduqghVujkoRSO1X0sW7QfMoEL/ccw62I3AycCfkDcX5SYBm4EXJR0OXDjMOZYDH5C0V94pe17h2HeAN0k6S9K4fPs9SW9JrN9y4FRJ+0raH7g48fcArgXOlXSSpDGSDszvBbJ08G+s9Ut5M9BtwGWSJuWZfT8BfG0E1266Xl4P4X5glqRD8pEO88gWdjGzBuTt4/cCe7Pre+pTZOtjbCFbjvXWYU7zRWAb2YfsDWQdwwPn3wK8h+w9+2uyppYvABMSq/hV4CHgSeDOOvXYRUT8J1lH8BfJUrX/kOxbP8CXyfpPnpd0RY1f/zOyp5NVwD1kneLXpV672bL1EDp7Cc2G0l9LOpVsKcixZB1Flw1XfrL2DU9Ms2ZZEt9jc2ys7rvNutrr3zotzrvlxKSyf3vkt7sv/XVELGR0HURmZl0lG3ba2d9HqtuYZWbWQQZyGaVs9Ui6TtJ6SY8OcfzEfNLf8nz7TK1yI+Vsp2ZmJSkx/fX1wJVkAwyG8qOIOK2sC4IDgplZKbL01+U0GUXE3ZIOLuVkI+CAYNagceP3jokTp9Ut1z+us9uXe93LG9dsiIjXDldmBH0I0yUtLbyen8/ZGonjJT1ENlLsUxHx2Ah/fzcOCGYNmjhxGrOPq5V6Z1cv7T+uBbWxZrn/a596arjjWbbT5CajDQ2OMnoAOCgiXsxHe34bmNXA+QB3KpvV5Ey+NlJZ6ooxSVvD14rYHBEv5j8vBMZJmt7oeR0QzAZxJl8bHbUs26mk/QcyueYZZMcAzzV6XjcZme1uZyZfAEkDmXwfH/a3rOeVNQtZ0s3AiWR9DWuAz5Il9yMiribLYXWhpB3AVmBeGQsYOSCY7a5uJt9iFt8JE6e2rGJWXSWPMjqzzvEryYallsoBwWwUill8J02e4Sy+BtDx2U4dEMx250y+NmIDayp3MgcEs93tzORLFgjmkWUZNRtSADv8hGDWXSJih6SLgEX8JpNvw5N+rPu5ycisC40kk++YA7Yz4a/X1S23Zn3aMPFpC7zUbEcKNxmZmRm/WSCnkzkgmJmVxE8IZmbWFQvkjDogSJpJlqt7P7K/xfyI+HJZFTMz6ySB2NHfu53KO4BPRsQDkiYByyQtjghP7zezntSzfQgRsQ5Yl/+8RdIKsin/Dghm1nuih5uMivKVfY4GltQ4tjPny0Q8nM7MulNP9yEMkLQP8A3g4ojYPPh4MefLZO3rnC9m1rV6OiBIGkcWDG6KiG+WUyUzs84TiL5e7VTOF2e4FlgREZeXVyWzziIFE8dur1tu8fH/nHS+yw87MancfZc3sgKjNUOndyo3Es5OAM4C/lDS8nw7taR6mZl1lMg7lVO2qmpklNE90OHh0KwGz7Gx0YoKf9in8Exls915jo2NQrW//adwQDAbxHNsbLT8hGDWxYabY2NWFAF9/Q4IZl1puDk2xQmXe+63TxtqZ1XUy6OMzLpWvTk2ETE/ImZHxOzxU/dsfQWtcoKsyShlqyo/IZgN4jk2Njqd36nsJwSz3XmOjY1KRNpWVX5CMBtkpHNs9tnjVY7fd1Xdcp/61dyk871pn/VJ5e77h6uTygG87dMXJJe10atyc1AKBwQzsxJko4w6u9Gls2tvZlYhZTUZSbpO0npJjw5xXJKukLRS0sOSjimj/g4IZmYlKXGU0fXAnGGOnwLMyrfzgasarjwOCGZmpQjSgkFKQIiIu4GNwxSZC9wYmfuAqZIOaPQeHBDMzEoSiVsJDgRWF16vyfc1xJ3KZmZlCIj01BXTJS0tvJ6fry7ZVg4IZmYlGcGw0w0R0cgKR2uBmYXXM/J9DXGTkZlZSVo4MW0BcHY+2uhtwKY8S29DGn5CkDQWWAqsjYjTGj2fmVknGshlVAZJNwMnkjUtrQE+C4wDiIirgYXAqcBK4GXg3DKuW0aT0ceAFcDkEs5lZtaZAigpIETEmXWOB/CnpVysoKGAIGkG8D7gMuATpdTIrMNs3j6RO585om65P5l5d9L5po55Oancsle3JZUDeO7ItA+q1zxc4UQ7HaDKeYpSNPqE8CXgL4BJQxUo5o2fyF4NXs7MrKo0klFGlTTqTmVJpwHrI2LZcOWKeePHMWG0lzNrOUljJT0o6Tvtrot1iBZORGiGRp4QTgBOz9MCTwQmS/paRHyonKqZtZ37xyxddH6201E/IUTEpRExIyIOBuYB33cwsG5R6B+7pt11sQ7S4U8InodgVtuXyPrH+msdlHS+pKWSlm7btLWlFbMqU+JWTaUEhIi4y3MQrFuk9I/tsqbyFK+pbLn+xK2inLrCbHfuH7ORK3EeQru4ychsEPeP2Wh5TWUzM8tU+MM+hQOC2TAi4i7gruHK9PWP4YWt9fsRVmxNS1f/lj3TklZe/qv3JJUDuOyPvp5U7pL9zkgq97rF45Kv3VM6vMnIAcHMrCTyE4KZmRGCDk9d4YBgZlYWPyGYmRnggGBmZjkHBDMz64aJaQ4IZmYl8SgjMzPLOCCYmRn4CcGs5/WFeOmV8XXL/fvV70w6340nvJJU7tPH3JlUDuCSO+cllXv7765IKvfj5+uvIQ3w2qVJxbqH+xDMzKzqi9+kcLZTsxokTZV0u6SfSloh6fh218k6QC+vmOY3jXWxLwN3RMThwO+Qra1sNiz1p21V1WiT0cCb5gxJ44G9SqiTWVtJmgK8AzgHICK2AdvaWSfrEBX+9p9i1E8IhTfNtZC9aSLihZLqZdZOhwDPAv8q6UFJ10jau1iguKZy36aX21NLqxRF+lZVjTQZ1X3TwK5vnO282sDlzFpmD+AY4KqIOBp4CbikWKC4pvLYKX4wtlwobauoRgJC3TcN7PrGGceEBi5n1jJrgDURsSR/fTvZ/+tmwyupU1nSHElPSFopabfPVUnnSHpW0vJ8+0gZ1W8kIPhNY10pIp4GVkt6c77rJODxNlbJOkQZTUaSxgJfAU4BjgDOlFRr4setEXFUvl1TRv1H3akcEU9LWi3pzRHxBH7TWHf5M+CmfLDEKuDcNtfHqi5KG0F0LLAyIlYBSLoFmEsLPl8bHWXkN411pYhYDsxOKvzKWPqemFS32PNH9iWdbs+J25PKXTV/blI5gCP/6BdJ5X704OFJ5bRHWs/oplljk8pN+XmFx2KORHqH8XRJxXnc8yNifv7zgcDqwrE1wHE1zvFfJL0D+Bnw8YhYXaPMiDQUEEb0pjEz63bpAWFDRDTy2fkfwM0R8aqkjwI3AH/YwPkAz1Q2MytNScNO1wIzC69n5Pt2iojnImJg2OY1wO+WUX8HBDOzarkfmCXpkLw5fh6woFhA0gGFl6dT0kx6J7czMytLCZPOImKHpIuARcBY4LqIeEzS54ClEbEA+HNJpwM7gI3ks+ob5YBgZlaG8kYZERELgYWD9n2m8POlwKXlXO03HBDMzMpS4bQUKRwQzMxKIKqdpyiFA4KZWVkcEMzMjIpnMk3hgGDWIO2A8S/Uz2A5cX3a223buslJ5bYcmjbzGeCJ7x2aVG6Pw9NSee9zb1qG162vS/uE3PjWtAyg+z5W8U/cDp9w7YBgZlaSTn9C8MQ0sxokfVzSY5IelXSzpIntrpN1gF5eU9msG0k6EPhzYHZE/BbZ5KB57a2VVV5qMKhwQHCTkVltewB7StpOtlb4r9tcH+sAbjIy6zIRsRb4R+BXwDpgU0TcWSyzy5rKW19qRzWtijr8CcEBwWwQSdPIFiQ5BHg9sLekDxXL7LKm8p67LSVuPUr9aVtVOSCY7e5k4JcR8WxEbAe+Cfx+m+tkVdcFfQgOCGa7+xXwNkl7SRLZ8rClpBe27qURbFXVUEDw0DzrRhGxBLgdeAB4hOx9Mn/YXzKDjn9CGPUoo8LQvCMiYquk28iG5l1fUt3M2iYiPgt8Nqlw4te+aSvT1krec82LSeXWvXNaUjmAbVPSyh1w2/ikci/un/ap9oY70mY+P3942sznzYekf4ed/MvWN9Z3+iijRoedemiemdmADg8Io24yShmaB7sOz9vOq4MPm5l1h+jhUUYpQ/Ng1+F545gw+pqamVVdh/chNNKp7KF5ZmYFirStqhoJCB6aZ2ZW1OFPCKPuVI6IJZIGhubtAB7EQ/PMrIdV+dt/ioZGGY1oaJ6ZWTcLvECOmZnl01F6+QnBzMwKHBDMzAxA0dkRwQHBrEHjN/UxY9HzdcttOWxy0vn0ctoEzr3WpzdYT/tZX1K5SMy8tv8PNySV65uUlt5s+r89mlRu0vFvTioHsPYPxiWVm/bTkj7EKz6CKIWznZqZlaSseQiS5kh6QtJKSZfUOD5B0q358SWSDi6j/g4I1rMkXSdpvaRHC/v2lbRY0s/zf9MzyFnPKyN1haSxwFeAU4AjgDMlHTGo2HnA8xFxGPBF4Atl1N8BwXrZ9cCcQfsuAb4XEbOA7+WvzdKUMzHtWGBlRKyKiG3ALWRpgormAjfkP98OnJRPEG6IA4L1rIi4G9g4aHfxjXYD8P5W1sk6WGJzUd5kNH0g6We+nV8404HA6sLrNfk+apWJiB3AJuA1jd6CO5XNdrVfRKzLf34a2K9WofwNfD7AxHGJiw1Y90vvVN4QEbObWJNR8ROC2RAiYsgH/GIW3/F7pC3uYt1tYGJaCZ3Ka4GZhdcz8n01y0jaA5gCPNfoPTggmO3qGUkHAOT/rm9zfayDqD+StjruB2ZJOkTSeLKVKBcMKrMA+HD+8xnA9/MvMA1xQDDbVfGN9mHg39tYF+skqR3KdT628z6Bi4BFZBmkb4uIxyR9TtLpebFrgddIWgl8gpIGP7gPwXqWpJuBE8k6+NaQJWr8PHCbpPOAp4APtq+G1mnKWg0tIhYCCwft+0zh51eAPy7nar/hgGA9KyLOHOLQSSM60fYd6Nf1Z+5OWrUm7XxveH1SsWl3/TLtfAAT01YrjAnj0873wuakYmM3bkoqp8mTksqNu3NpUjmAQx96XVK51WcdlnzOujp8prIDgplZSZzt1MzM8v6Bzo4IdTuVPb3fzCxNGakr2illlNH1eHq/mdmwSpyH0DZ1A4Kn95uZJYhI3ypqtH0ISdP7YdAUfzyj08y6V5W//adoeGLacNP78+M7p/iPI23om5lZRyon22nbjDYgeHq/mdkgXd+HMARP7zczKwqgL9K2iqrbh+Dp/WbD27zj2Q2L1l/1VI1D04G0xYeL0pYXbqXR3Ue7PV1z7+738g/JZzyoXoEqf/tPUTcglDa936xLRcRra+2XtLSKOe9HqlvuA1pwLxUeQZTCM5VbaNGvl4+o/Htff1RT6mFmzdH1TwhmZpag4iOIUjggmDXP/HZXoCTdch/QxHsRoAp3GKdwQDBrkojoig/SbrkPaP69yH0IZmbmJiMzM8tVO09RCq+pbFYySXMkPSFppaSOzgQs6UlJj0haLil9ubIKaEfq/l6dqWxmNUgaC3wFOAU4AjhT0hHtrVXD3hURR3XgXITraXXq/g7PduqAYFauY4GVEbEqIrYBt5Cli7cWa3nq/shGGaVsVeWAYFauA4HVhddr8n2dKoA7JS3LU9l3uuTU/aPS4dlO3alsZsN5e0SslfQ6YLGkn+bfvDteRIRUbot+pw879ROCWbnWAjMLr2fk+zpSRKzN/10PfIusSayTNTd1v/sQzKzgfmCWpEMkjQfmkaWL7ziS9pY0aeBn4D1UMRfryDQvdX8A/YlbRbW0yWgLz2/4btxeXprg8rTk+mMPGOn1VzavMnWv3TJlXr9ueuJmi4gdki4CFgFjgesi4rE2V2u09gO+JQmyz4qvR8Qd7a1Sulan7hfR8U1GLQ0IVU0T3MvX7+V7b5aIWAgsbHc9GhURq4DfaXc9Rqstqfv7m//1X9K+wK3AwcCTwAcj4vka5fqAR/KXv4qI0+ud201GZmZlaF2TUepciq35/JGjUoIBOCCYmZVGEUlbg5o2l6IqAaHd2RR7+fq9fO9m5UofZTRd0tLCNpI5HqlzKSbm575P0vtTTlyJeQjtTq/by9fv5Xs3K9eIhpRuGK7vTNJ3gf1rHPqrXa44/FyKg/I5JG8Evi/pkYj4xXCVqkRAMDPreAGUlJYiIk4e6pikZyQdEBHrhptLUZhDskrSXcDRwLABoSpNRmZmHa9FfQh151JImiZpQv7zdOAE4PF6J25pQKiXFljSBEm35seXSDq4xGvPlPQDSY9LekzSx2qUOVHSpjzV73JJnynx+sOmEVbmivzeH5Z0TInXfnPhnpZL2izp4kFlSr33RlIPS/pwXubnkj5cq4xZJbVmpvLngXdL+jlwcv4aSbMlXZOXeQuwVNJDwA+Az0dE3YDQsiajQlrgd5Ml/Lpf0oJBlTwPeD4iDpM0D/gC8F9LqsIO4JMR8UA++3KZpMU1/kg/iojTSrrmYO+KiKEmYZ0CzMq344Cr8n8bFhFPAEfBzv8Oa8nSEAxW5r1fD1wJ3FjYNzBc7vP5F4JLgP9Z/KV8jPVngdlkD+HL8v9PdhtnbVYpAfQ3f2JaRDxHjbkUEbEU+Ej+873Ab4/03K18QkhJC1wcTnU7cJLyaZKNioh1EfFA/vMWYAXVykI5F7gxMvcBUwdyrpTsJOAXEVFrxnhpGkg9/F5gcURszIPAYnbPaW9WQYlPBxWezdzKgJCSFnhnmYjYAWwCXlN2RfKmqKOBJTUOHy/pIUn/T9JbS7xsvTTCrUqbPA+4eYhjzbr3ASnD5botfbT1kg4PCD03ykjSPsA3gIsjYvOgww+QDdV6UdKpwLfJmnDK0PY0wnmytdOBS2scbua976YZqYfN2iqAvgpnrkvQyieElLTAO8tI2gOYAjxXVgUkjSMLBjdFxDcHH4+IzRHxYv7zQmBc3kPfsIQ0wq1Im3wK8EBEPFOjfk2794KU1MNdlT7aeklA9KdtFdXKgJCSFrg4nOoM4PsR5Txf5X0R1wIrIuLyIcrsP9BnIelYsr9PwwFJaWmEFwBn56ON3gZsKjSvlOVMhmguata9D5KSengR8J582Nw0sr/VopLrYdYcbjJKM1RaYEmfA5ZGxAKyD+yvSlpJ1iE5r8QqnACcBTwiaXm+7y+BN+T1u5osCF0oaQewFZhXUkCqmUZY0gWFay8ETiXLef0ycG4J190pD0TvBj5a2Fe8fqn3rhGkHpY0G7ggIj4SERsl/Q3ZFwiAz0XE4M5ps+pp0SijZlJJX8DNzHralPH7xe/vl/Yd9o41VyyrYtr3nutUNjNrmg7/gu2AYGZWhgjo62t3LRrigGBmVhY/IZiZGeCAYGZmANHxo4wcEMzMyhAQFZ50lsIBwcysLB2eusIBwcysDBHQ74BgZmbgTmUzM8uEnxDMzGznAjkdzAHBzKwMXZDczgHBzKwEAYRTV5iZWbbWgfsQzMwMCDcZmZkZ0PFPCF4gx8ysBJLuAFLXId8QEXOaWZ/RcEAwMzMgW0jdzMzMAcHMzDIOCGZmBjggmJlZzgHBzMwA+P8lHOtpgeAAdQAAAABJRU5ErkJggg==", "text/plain": [ - "" + "
" ] }, - "metadata": {}, + "metadata": { + "needs_background": "light" + }, "output_type": "display_data" } ], "source": [ - "plt.figure(figsize=(17, 8))\n", - "\n", - "plt.plot(greedy_agent_rewards)\n", - "plt.plot(thompson_agent_rewards)\n", - "plt.plot(ucb_agent_rewards)\n", - "\n", - "plt.legend([\n", - " \"Greedy BNN\",\n", - " \"Thompson sampling BNN\",\n", - " \"UCB BNN\"\n", - "])\n", - "\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Part 3. Exploration in MDP\n", - "\n", - "The following problem, called \"river swim\", illustrates importance of exploration in context of mdp's." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "\n", - "Picture from https://arxiv.org/abs/1306.0940" + "test_agent(agent)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Rewards and transition probabilities are unknown to an agent. Optimal policy is to swim against current, while easiest way to gain reward is to go left." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class RiverSwimEnv:\n", - " LEFT_REWARD = 5.0 / 1000\n", - " RIGHT_REWARD = 1.0\n", - "\n", - " def __init__(self, intermediate_states_count=4, max_steps=16):\n", - " self._max_steps = max_steps\n", - " self._current_state = None\n", - " self._steps = None\n", - " self._interm_states = intermediate_states_count\n", - " self.reset()\n", - "\n", - " def reset(self):\n", - " self._steps = 0\n", - " self._current_state = 1\n", - " return self._current_state, 0.0, False\n", - "\n", - " @property\n", - " def n_actions(self):\n", - " return 2\n", - "\n", - " @property\n", - " def n_states(self):\n", - " return 2 + self._interm_states\n", - "\n", - " def _get_transition_probs(self, action):\n", - " if action == 0:\n", - " if self._current_state == 0:\n", - " return [0, 1.0, 0]\n", - " else:\n", - " return [1.0, 0, 0]\n", - "\n", - " elif action == 1:\n", - " if self._current_state == 0:\n", - " return [0, .4, .6]\n", - " if self._current_state == self.n_states - 1:\n", - " return [.4, .6, 0]\n", - " else:\n", - " return [.05, .6, .35]\n", - " else:\n", - " raise RuntumeError(\n", - " \"Unknown action {}. Max action is {}\".format(action, self.n_actions))\n", - "\n", - " def step(self, action):\n", - " \"\"\"\n", - " :param action:\n", - " :type action: int\n", - " :return: observation, reward, is_done\n", - " :rtype: (int, float, bool)\n", - " \"\"\"\n", - " reward = 0.0\n", - "\n", - " if self._steps >= self._max_steps:\n", - " return self._current_state, reward, True\n", - "\n", - " transition = np.random.choice(\n", - " range(3), p=self._get_transition_probs(action))\n", - " if transition == 0:\n", - " self._current_state -= 1\n", - " elif transition == 1:\n", - " pass\n", - " else:\n", - " self._current_state += 1\n", - "\n", - " if self._current_state == 0:\n", - " reward = self.LEFT_REWARD\n", - " elif self._current_state == self.n_states - 1:\n", - " reward = self.RIGHT_REWARD\n", - "\n", - " self._steps += 1\n", - " return self._current_state, reward, False" + "### 2.3.2 Uncertainty with reverse dynamics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Let's implement q-learning agent with epsilon-greedy exploration strategy and see how it performs." + "[The paper](https://arxiv.org/pdf/1705.05363.pdf)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 322, "metadata": {}, "outputs": [], "source": [ - "class QLearningAgent:\n", - " def __init__(self, n_states, n_actions, lr=0.2, gamma=0.95, epsilon=0.1):\n", - " self._gamma = gamma\n", - " self._epsilon = epsilon\n", - " self._q_matrix = np.zeros((n_states, n_actions))\n", - " self._lr = lr\n", - "\n", - " def get_action(self, state):\n", - " if np.random.random() < self._epsilon:\n", - " return np.random.randint(0, self._q_matrix.shape[1])\n", - " else:\n", - " return np.argmax(self._q_matrix[state])\n", - "\n", - " def get_q_matrix(self):\n", - " \"\"\" Used for policy visualization\n", - " \"\"\"\n", - "\n", - " return self._q_matrix\n", - "\n", - " def start_episode(self):\n", - " \"\"\" Used in PSRL agent\n", - " \"\"\"\n", - " pass\n", + "class InverseDynamics(BaseIntrinsicRewardModule):\n", + " def __init__(self, states_size, n_actions, hidden_size, alpha=0.1):\n", + " super().__init__()\n", + " self.module = MLP(\n", + " 2 * states_size,\n", + " hidden_size,\n", + " n_actions\n", + " )\n", + " self.alpha = alpha\n", + " self.mean_reward = 0\n", + " self.n_actions = n_actions\n", + " \n", + " def forward(self, s, s_next):\n", + " # \n", + " \n", + " \n", + " def get_intrinsic_reward(self, state, action, next_state):\n", + " with torch.no_grad():\n", + " r = # \n", + " \n", + " r_centered = r - self.mean_reward\n", + " self.mean_reward = self.alpha * (r) + (1 - self.alpha) * self.mean_reward\n", + " return r_centered\n", "\n", - " def update(self, state, action, reward, next_state):\n", - " \n", - " # Finish implementation of q-learnig agent" + " def get_loss(self, state_batch, action_batch, next_state_batch): \n", + " a_pred_proba = self.forward(state_batch, next_state_batch)\n", + " a_one_hot = to_one_hot(action_batch, self.n_actions)\n", + " return -(torch.log(a_pred_proba) * a_one_hot).sum(dim=-1).mean()" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 323, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ - "def train_mdp_agent(agent, env, n_episodes):\n", - " episode_rewards = []\n", - "\n", - " for ep in range(n_episodes):\n", - " state, ep_reward, is_done = env.reset()\n", - " agent.start_episode()\n", - " while not is_done:\n", - " action = agent.get_action(state)\n", - "\n", - " next_state, reward, is_done = env.step(action)\n", - " agent.update(state, action, reward, next_state)\n", + "agent = QLearningAgent(\n", + " epsilon=.1, \n", + " alpha=0.5, \n", + " discount=.9, \n", + " get_legal_actions=lambda s: range(env.action_space.n)\n", + ")\n", "\n", - " state = next_state\n", - " ep_reward += reward\n", + "inverse_dynamics = InverseDynamics(\n", + " np.prod(env.observation_space.shape), \n", + " env.action_space.n, \n", + " 16\n", + ")\n", "\n", - " episode_rewards.append(ep_reward)\n", - " return episode_rewards" + "train_with_reward(env, agent, inverse_dynamics, n_episodes=3000, \n", + " update_reward_period=100, batch_size=100, n_iter=25)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 324, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.5/dist-packages/ipykernel_launcher.py:6: FutureWarning: pd.ewm_mean is deprecated for ndarrays and will be removed in a future version\n", - " \n" - ] - }, { "data": { - "image/png": "\n", + "image/png": "", "text/plain": [ - "" + "
" ] }, - "metadata": {}, + "metadata": { + "needs_background": "light" + }, "output_type": "display_data" } ], "source": [ - "env = RiverSwimEnv()\n", - "agent = QLearningAgent(env.n_states, env.n_actions)\n", - "rews = train_mdp_agent(agent, env, 1000)\n", - "plt.figure(figsize=(15, 8))\n", - "\n", - "plt.plot(moving_average(np.array(rews), alpha=.1))\n", - "plt.xlabel(\"Episode count\")\n", - "plt.ylabel(\"Reward\")\n", - "plt.show()" + "test_agent(agent)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Let's visualize our policy:" + "## 2.3.3 Intrinsic Curiosity Module algorithm" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 336, "metadata": {}, "outputs": [], "source": [ - "def plot_policy(agent):\n", - " fig = plt.figure(figsize=(15, 8))\n", - " ax = fig.add_subplot(111)\n", - " ax.matshow(agent.get_q_matrix().T)\n", - " ax.set_yticklabels(['', 'left', 'right'])\n", - " plt.xlabel(\"State\")\n", - " plt.ylabel(\"Action\")\n", - " plt.title(\"Values of state-action pairs\")\n", - " plt.show()" + "class Embedder(nn.Module):\n", + " def __init__(self, states_size, embedding_size, hidden_size):\n", + " super().__init__()\n", + " self.module = MLP(\n", + " states_size,\n", + " hidden_size,\n", + " embedding_size\n", + " )\n", + " \n", + " def forward(self, s):\n", + " return self.module(s)\n", + " \n", + "class ICMModule(BaseIntrinsicRewardModule):\n", + " def __init__(self, states_size, n_actions, hidden_size, embedding_size):\n", + " super().__init__()\n", + " # \n", + " \n", + " def get_intrinsic_reward(self, state, action, next_state):\n", + " with torch.no_grad(): \n", + " # \n", + "\n", + " def get_loss(self, state_batch, action_batch, next_state_batch):\n", + " # " ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 337, "metadata": {}, "outputs": [ { "data": { - "image/png": "\n", + "image/png": "", "text/plain": [ - "" + "
" ] }, - "metadata": {}, + "metadata": { + "needs_background": "light" + }, "output_type": "display_data" } ], "source": [ - "plot_policy(agent)" + "agent = QLearningAgent(\n", + " epsilon=.1, \n", + " alpha=0.5, \n", + " discount=1, \n", + " get_legal_actions=lambda s: range(env.action_space.n)\n", + ")\n", + "\n", + "icm = ICMModule(\n", + " states_size=np.prod(env.observation_space.shape), \n", + " n_actions=env.action_space.n, \n", + " hidden_size=16, embedding_size=10\n", + ")\n", + "\n", + "train_with_reward(env, agent, icm, n_episodes=3000, update_reward_period=100, batch_size=100, n_iter=200)" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 338, "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ - "As your see, agent uses suboptimal policy of going left and does not explore the right state." + "test_agent(agent)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Bonus 3.1 Posterior sampling RL (3 points)" + "## HW 2.1: Random network distillation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now we will implement Thompson Sampling for MDP!\n", - "\n", - "General algorithm:\n", - "\n", - ">**for** episode $k = 1,2,...$ **do**\n", - ">> sample $M_k \\sim f(\\bullet\\ |\\ H_k)$\n", - "\n", - ">> compute policy $\\mu_k$ for $M_k$\n", - "\n", - ">> **for** time $t = 1, 2,...$ **do**\n", - "\n", - ">>> take action $a_t$ from $\\mu_k$ \n", - "\n", - ">>> observe $r_t$ and $s_{t+1}$\n", - ">>> update $H_k$\n", - "\n", - ">> **end for**\n", - "\n", - ">**end for**\n", - "\n", - "In our case we will model $M_k$ with two matrices: transition and reward. Transition matrix is sampled from dirichlet distribution. Reward matrix is sampled from normal-gamma distribution.\n", - "\n", - "Distributions are updated with bayes rule - see continuous distribution section at https://en.wikipedia.org/wiki/Conjugate_prior\n", - "\n", - "Article on PSRL - https://arxiv.org/abs/1306.0940" + "Implement algorithm from [this](https://arxiv.org/abs/1810.12894) paper" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 328, "metadata": {}, "outputs": [], "source": [ - "def sample_normal_gamma(mu, lmbd, alpha, beta):\n", - " \"\"\" https://en.wikipedia.org/wiki/Normal-gamma_distribution\n", - " \"\"\"\n", - " tau = np.random.gamma(shape=alpha, scale=1 / beta)\n", - " mu = np.random.normal(mu, 1.0 / np.sqrt(lmbd * tau))\n", - " return mu, tau\n", - "\n", - "\n", - "class PsrlAgent:\n", - " def __init__(self, n_states, n_actions, horizon=10):\n", - " self._n_states = n_states\n", - " self._n_actions = n_actions\n", - " self._horizon = horizon\n", - "\n", - " # params for transition sampling - Dirichlet distribution\n", - " self._transition_counts = np.zeros(\n", - " (n_states, n_states, n_actions)) + 1.0\n", - "\n", - " # params for reward sampling - Normal-gamma distribution\n", - " self._mu_matrix = np.zeros((n_states, n_actions)) + 1.0\n", - " self._state_action_counts = np.zeros(\n", - " (n_states, n_actions)) + 1.0 # lambda\n", - "\n", - " self._alpha_matrix = np.zeros((n_states, n_actions)) + 1.0\n", - " self._beta_matrix = np.zeros((n_states, n_actions)) + 1.0\n", - "\n", - " def _value_iteration(self, transitions, rewards):\n", - " # YOU CODE HERE\n", - " state_values = \n", - " return state_values\n", - "\n", - " def start_episode(self):\n", - " # sample new mdp\n", - " self._sampled_transitions = np.apply_along_axis(\n", - " np.random.dirichlet, 1, self._transition_counts)\n", - "\n", - " sampled_reward_mus, sampled_reward_stds = sample_normal_gamma(\n", - " self._mu_matrix,\n", - " self._state_action_counts,\n", - " self._alpha_matrix,\n", - " self._beta_matrix\n", - " )\n", - "\n", - " self._sampled_rewards = sampled_reward_mus\n", - " self._current_value_function = self._value_iteration(\n", - " self._sampled_transitions, self._sampled_rewards)\n", - "\n", - " def get_action(self, state):\n", - " return np.argmax(self._sampled_rewards[state] +\n", - " self._current_value_function.dot(self._sampled_transitions[state]))\n", - "\n", - " def update(self, state, action, reward, next_state):\n", - " \n", - " # update rules - https://en.wikipedia.org/wiki/Conjugate_prior\n", - "\n", - " def get_q_matrix(self):\n", - " return self._sampled_rewards + self._current_value_function.dot(self._sampled_transitions)" + "class RandomNetworkDistilationModule(BaseIntrinsicRewardModule):\n", + " # " ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 332, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.5/dist-packages/ipykernel_launcher.py:7: FutureWarning: pd.ewm_mean is deprecated for ndarrays and will be removed in a future version\n", - " import sys\n" - ] - }, { "data": { - "image/png": "\n", + "image/png": "", "text/plain": [ - "" + "
" ] }, - "metadata": {}, + "metadata": { + "needs_background": "light" + }, "output_type": "display_data" } ], "source": [ - "from pandas import DataFrame\n", - "moving_average = lambda x, **kw: DataFrame(\n", - " {'x': np.asarray(x)}).x.ewm(**kw).mean().values\n", - "\n", - "horizon = 20\n", - "env = RiverSwimEnv(max_steps=horizon)\n", - "agent = PsrlAgent(env.n_states, env.n_actions, horizon=horizon)\n", - "rews = train_mdp_agent(agent, env, 1000)\n", + "agent = QLearningAgent(\n", + " epsilon=.1, \n", + " alpha=0.5, \n", + " discount=.9, \n", + " get_legal_actions=lambda s: range(env.action_space.n)\n", + ")\n", "\n", - "plt.figure(figsize=(15, 8))\n", - "plt.plot(moving_average(np.array(rews), alpha=0.1))\n", + "rnd = RandomNetworkDistilationModule(\n", + " np.prod(env.observation_space.shape), \n", + " np.prod(env.observation_space.shape), \n", + " 16\n", + ")\n", "\n", - "plt.xlabel(\"Episode count\")\n", - "plt.ylabel(\"Reward\")\n", - "plt.show()" + "train_with_reward(env, agent, rnd, n_episodes=2000, update_reward_period=100, batch_size=100, n_iter=25)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 333, "metadata": {}, "outputs": [ { "data": { - "image/png": "\n", + "image/png": "", "text/plain": [ - "" + "
" ] }, - "metadata": {}, + "metadata": { + "needs_background": "light" + }, "output_type": "display_data" } ], "source": [ - "plot_policy(agent)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Bonus 3.2 Bootstrapped DQN (10 points)\n", - "\n", - "Implement Bootstrapped DQN algorithm and compare it's performance with ordinary DQN on BeamRider Atari game. Links:\n", - "- https://arxiv.org/abs/1602.04621" + "test_agent(agent)" ] } ], @@ -1731,9 +1502,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/week06_policy_based/a2c-optional.ipynb b/week06_policy_based/a2c-optional.ipynb index c336b4951..2ffe4deaf 100644 --- a/week06_policy_based/a2c-optional.ipynb +++ b/week06_policy_based/a2c-optional.ipynb @@ -1,315 +1,409 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "if 'google.colab' in sys.modules:\n", - " import os\n", - "\n", - " os.system('apt-get install -y xvfb')\n", - " os.system('wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/xvfb -O ../xvfb')\n", - " os.system('apt-get install -y python-opengl ffmpeg')\n", - " os.system('pip install pyglet==1.2.4')\n", - "\n", - " os.system('python -m pip install -U pygame --user')\n", - "\n", - " print('setup complete')\n", - "\n", - "# XVFB will be launched if you run on a server\n", - "import os\n", - "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", - " !bash ../xvfb start\n", - " os.environ['DISPLAY'] = ':1'" - ] + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "o4vBVdNx2EPr" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting virtual X frame buffer: Xvfb../xvfb: line 24: start-stop-daemon: command not found\n", + ".\n" + ] + } + ], + "source": [ + "import sys, os\n", + "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", + " # Install xvfb and our launcher script for it\n", + " !apt-get install -y xvfb\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/xvfb -O ../xvfb\n", + "\n", + " # Download dependencies from Github\n", + " !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week06_policy_based/atari_wrappers.py\n", + " !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week06_policy_based/env_batch.py\n", + " !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week06_policy_based/runners.py\n", + "\n", + " # Update the gym environment to be compatible with the Atari environment\n", + " !pip install -q gymnasium[atari,accept-rom-license]\n", + " !pip install -q tensorboardX\n", + "\n", + " !touch .setup_complete\n", + "\n", + "# This code creates a virtual display to draw game images on.\n", + "# It will have no effect if your machine has a monitor.\n", + "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", + " !bash ../xvfb start\n", + " os.environ['DISPLAY'] = ':1'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O_iJbFWQ2EPs" + }, + "source": [ + "# Implementing Advantage-Actor Critic (A2C)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "16ownLDJ2EPs" + }, + "source": [ + "In this notebook you will implement Advantage Actor Critic algorithm that trains on a batch of Atari 2600 environments running in parallel.\n", + "\n", + "Firstly, we will use environment wrappers implemented in file `atari_wrappers.py`. These wrappers preprocess observations (resize, grayscale, take max between frames, skip frames and stack them together) and rewards. Some of the wrappers help to reset the environment and pass `done` flag equal to `True` when agent dies.\n", + "File `env_batch.py` includes implementation of `ParallelEnvBatch` class that allows to run multiple environments in parallel. To create an environment we can use `nature_dqn_env` function. Note that if you are using\n", + "PyTorch and not using `tensorboardX` you will need to implement a wrapper that will log **raw** total rewards that the *unwrapped* environment returns and redefine the implemention of `nature_dqn_env` function here.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uScP-zu12EPt" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import gymnasium as gym\n", + "from atari_wrappers import nature_dqn_env\n", + "\n", + "\n", + "env_name = \"SpaceInvadersNoFrameskip-v4\"\n", + "nenvs = 8 # change this if you have more than 8 CPU ;)\n", + "summaries = \"Tensorboard\"\n", + "\n", + "env = nature_dqn_env(env_name, nenvs=nenvs, summaries=summaries)\n", + "obs, _ = env.reset()\n", + "assert obs.shape == (nenvs, 4, 84, 84)\n", + "assert obs.dtype == np.float32\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jiWeYgmd2EPt" + }, + "source": [ + "Next, we will need to implement a model that predicts logits and values. It is suggested that you use the same model as in [Nature DQN paper](https://www.nature.com/articles/nature14236) with a modification that instead of having a single output layer, it will have two output layers taking as input the output of the last hidden layer. **Note** that this model is different from the model you used in homework where you implemented DQN. You can use your favorite deep learning framework here. We suggest that you use orthogonal initialization with parameter $\\sqrt{2}$ for kernels and initialize biases with zeros." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "id": "FIkJ7z7TiWS4" + }, + "outputs": [], + "source": [ + "# import tensorflow as torch\n", + "# import torch as tf\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pA2VlyZ32EPt" + }, + "source": [ + "You will also need to define and use a policy that wraps the model. While the model computes logits for all actions, the policy will sample actions and also compute their log probabilities. `policy.act` should return a dictionary of all the arrays that are needed to interact with an environment and train the model.\n", + " Note that actions must be an `np.ndarray` while the other\n", + "tensors need to have the type determined by your deep learning framework." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "id": "dtHP-Fo72EPt" + }, + "outputs": [], + "source": [ + "class Policy:\n", + " def __init__(self, model):\n", + " self.model = model\n", + "\n", + " def act(self, inputs):\n", + " # Implement a policy by calling the model, sampling actions and computing their log probs.\n", + " # Should return a dict containing keys ['actions', 'logits', 'log_probs', 'values'].\n", + "\n", + " \n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2oPCQwsd2EPt" + }, + "source": [ + "Next will pass the environment and policy to a runner that collects partial trajectories from the environment.\n", + "The class that does is is already implemented for you." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "id": "fj-fKr_A2EPt" + }, + "outputs": [], + "source": [ + "from runners import EnvRunner" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_9JehIbH2EPt" + }, + "source": [ + "This runner interacts with the environment for a given number of steps and returns a dictionary containing\n", + "keys\n", + "\n", + "* 'observations'\n", + "* 'rewards'\n", + "* 'resets'\n", + "* 'actions'\n", + "* all other keys that you defined in `Policy`\n", + "\n", + "under each of these keys there is a python `list` of interactions with the environment. This list has length $T$ that is size of partial trajectory. Partial trajectory for given moment `t` is part of `ComputeValueTargets.__call__` input argument `trajectory` from moment `t` to the end (i.e. it's different at each iteration in the algorithm)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iY7FB6s72EPu" + }, + "source": [ + "To train the part of the model that predicts state values you will need to compute the value targets.\n", + "Any callable could be passed to `EnvRunner` to be applied to each partial trajectory after it is collected.\n", + "Thus, we can implement and use `ComputeValueTargets` callable.\n", + "The formula for the value targets is simple:\n", + "\n", + "$$\n", + "\\hat v(s_t) = \\left( \\sum_{t'=0}^{T - 1} \\gamma^{t'}r_{t+t'} \\right) + \\gamma^T \\hat{v}(s_{t+T}),\n", + "$$\n", + "\n", + "In implementation, however, do not forget to use\n", + "`trajectory['resets']` flags to check if you need to add the value targets at the next step when\n", + "computing value targets for the current step. You can access `trajectory['state']['latest_observation']`\n", + "to get last observations in partial trajectory — $s_{t+T}$." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "id": "4CbDi3GZ2EPu" + }, + "outputs": [], + "source": [ + "class ComputeValueTargets:\n", + " def __init__(self, policy, gamma=0.99):\n", + " self.policy = policy\n", + " self.gamma = gamma\n", + "\n", + " def __call__(self, trajectory):\n", + " \"\"\"Compute value targets for a given partial trajectory.\"\"\"\n", + "\n", + " # This method should modify trajectory inplace by adding\n", + " # an item with key 'value_targets' to it.\n", + "\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9_d9OYyz2EPu" + }, + "source": [ + "After computing value targets we will transform lists of interactions into tensors\n", + "with the first dimension `batch_size` which is equal to `env_steps * num_envs`, i.e. you essentially need\n", + "to flatten the first two dimensions." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "id": "IEnqWlHh2EPu" + }, + "outputs": [], + "source": [ + "class MergeTimeBatch:\n", + " \"\"\" Merges first two axes typically representing time and env batch. \"\"\"\n", + " def __call__(self, trajectory):\n", + " # Modify trajectory inplace.\n", + "\n", + " \n" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "id": "-2CwwzLl2EPu" + }, + "outputs": [], + "source": [ + "model = \n", + "policy = Policy(model)\n", + "runner = EnvRunner(\n", + " env=env,\n", + " policy=policy,\n", + " nsteps=5,\n", + " transforms=[\n", + " ComputeValueTargets(policy),\n", + " MergeTimeBatch(),\n", + " ],\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IuYy-8Ri2EPu" + }, + "source": [ + "Now is the time to implement the advantage actor critic algorithm itself. You can look into your lecture,\n", + "[Mnih et al. 2016](https://arxiv.org/abs/1602.01783) paper, and [lecture](https://www.youtube.com/watch?v=Tol_jw5hWnI&list=PLkFD6_40KJIxJMR-j5A1mkxK26gh_qg37&index=20) by Sergey Levine." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "id": "hxFLzyRX2EPu" + }, + "outputs": [], + "source": [ + "class A2C:\n", + " def __init__(self,\n", + " policy,\n", + " optimizer,\n", + " value_loss_coef=0.25,\n", + " entropy_coef=0.01,\n", + " max_grad_norm=0.5):\n", + " self.policy = policy\n", + " self.optimizer = optimizer\n", + " self.value_loss_coef = value_loss_coef\n", + " self.entropy_coef = entropy_coef\n", + " self.max_grad_norm = max_grad_norm\n", + "\n", + " def policy_loss(self, trajectory):\n", + " # You will need to compute advantages here.\n", + " \n", + "\n", + " def value_loss(self, trajectory):\n", + " \n", + "\n", + " def loss(self, trajectory):\n", + " \n", + "\n", + " def step(self, trajectory):\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JIMtFZuG2EPu" + }, + "source": [ + "Now you can train your model. With reasonable hyperparameters training on a single GTX1080 for 10 million steps across all batched environments (which translates to about 5 hours of wall clock time)\n", + "it should be possible to achieve *average raw reward over last 100 episodes* (the average is taken over 100 last\n", + "episodes in each environment in the batch) of about 600. You should plot this quantity with respect to\n", + "`runner.step_var` — the number of interactions with all environments. It is highly\n", + "encouraged to also provide plots of the following quantities (these are useful for debugging as well):\n", + "\n", + "* [Coefficient of Determination](https://en.wikipedia.org/wiki/Coefficient_of_determination) between\n", + "value targets and value predictions\n", + "* Entropy of the policy $\\pi$\n", + "* Value loss\n", + "* Policy loss\n", + "* Value targets\n", + "* Value predictions\n", + "* Gradient norm\n", + "* Advantages\n", + "* A2C loss\n", + "\n", + "For optimization we suggest you use RMSProp with learning rate starting from 7e-4 and linearly decayed to 0, smoothing constant (alpha in PyTorch and decay in TensorFlow) equal to 0.99 and epsilon equal to 1e-5." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#if you use TensorboardSummaries\n", + "%load_ext tensorboard\n", + "%tensorboard --logdir logs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a2c = \n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZDbgUdMq2EPu" + }, + "source": [ + "### Target networks?\n", + "\n", + "You may recall a technique called \"target networks\" we used a few weeks ago when we trained a DQN agent to play Atari Breakout and wonder why we have not suggested using them here. The answer is that this is more historical than practical.\n", + "\n", + "While the \"chasing the target\" problem is still present in actor-critic value estimation and target networks do show up in follow-up papers, the original A3C/A2C papers do not mention them and do not explain this omission.\n", + "\n", + "The hypothesis why this may not be a big deal (compared to Q-learning) goes like this. An A3C/A2C agent selects actions based on policy, not an epsilon greedy exploration function, for which the argmax can change drastically due to tiny errors in function approximation. Therefore, errors in the value target caused by target chasing will cause less damage.\n", + "\n", + "Also, the actor-critic gradient relies on the advantage function $A(s_t, a_t) = Q(s_t, a_t) - V(s_t)$. Compare this to the $Q$-function $Q(s_t, a_t) = r(s_t, a_t) + \\gamma \\cdot \\mathbb{E}_{s_{t+1} \\mid s_t, a_t} V(s_{t+1})$ used in Q-learning and SARSA: we would expect that any bias in $V$-function approximation will be carried over from $V(s_{t+1})$ to $V(s_t)$ by gradient updates. However, in the formula for the advantage function the two approximations ($Q$-function and $V$-function) come with opposite signs, and thus the errors will cancel out.\n", + "\n", + "The last reason may be computational. Authors were concerned to beat existent algorithms in the wall-clock learning time, and any overhead of parameter copying (target network update) counted against this goal." + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Implementing Advantage-Actor Critic (A2C)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this notebook you will implement Advantage Actor Critic algorithm that trains on a batch of Atari 2600 environments running in parallel.\n", - "\n", - "Firstly, we will use environment wrappers implemented in file `atari_wrappers.py`. These wrappers preprocess observations (resize, grayscale, take max between frames, skip frames and stack them together) and rewards. Some of the wrappers help to reset the environment and pass `done` flag equal to `True` when agent dies.\n", - "File `env_batch.py` includes implementation of `ParallelEnvBatch` class that allows to run multiple environments in parallel. To create an environment we can use `nature_dqn_env` function. Note that if you are using\n", - "PyTorch and not using `tensorboardX` you will need to implement a wrapper that will log **raw** total rewards that the *unwrapped* environment returns and redefine the implemention of `nature_dqn_env` function here.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "from atari_wrappers import nature_dqn_env, NumpySummaries\n", - "\n", - "\n", - "env = nature_dqn_env(\"SpaceInvadersNoFrameskip-v4\", nenvs=8, summaries='Numpy')\n", - "obs = env.reset()\n", - "assert obs.shape == (8, 84, 84, 4)\n", - "assert obs.dtype == np.uint8" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, we will need to implement a model that predicts logits and values. It is suggested that you use the same model as in [Nature DQN paper](https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf) with a modification that instead of having a single output layer, it will have two output layers taking as input the output of the last hidden layer. **Note** that this model is different from the model you used in homework where you implemented DQN. You can use your favorite deep learning framework here. We suggest that you use orthogonal initialization with parameter $\\sqrt{2}$ for kernels and initialize biases with zeros." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# import tensorflow as torch\n", - "# import torch as tf\n", - "\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You will also need to define and use a policy that wraps the model. While the model computes logits for all actions, the policy will sample actions and also compute their log probabilities. `policy.act` should return a dictionary of all the arrays that are needed to interact with an environment and train the model.\n", - " Note that actions must be an `np.ndarray` while the other\n", - "tensors need to have the type determined by your deep learning framework." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class Policy:\n", - " def __init__(self, model):\n", - " self.model = model\n", - "\n", - " def act(self, inputs):\n", - " \n", - " # Should return a dict containing keys ['actions', 'logits', 'log_probs', 'values']." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next will pass the environment and policy to a runner that collects partial trajectories from the environment.\n", - "The class that does is is already implemented for you." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from runners import EnvRunner" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This runner interacts with the environment for a given number of steps and returns a dictionary containing\n", - "keys\n", - "\n", - "* 'observations'\n", - "* 'rewards'\n", - "* 'resets'\n", - "* 'actions'\n", - "* all other keys that you defined in `Policy`\n", - "\n", - "under each of these keys there is a python `list` of interactions with the environment. This list has length $T$ that is size of partial trajectory. Partial trajectory for given moment `t` is part of `ComputeValueTargets.__call__` input argument `trajectory` from moment `t` to the end (i.e. it's different at each iteration in the algorithm)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To train the part of the model that predicts state values you will need to compute the value targets.\n", - "Any callable could be passed to `EnvRunner` to be applied to each partial trajectory after it is collected.\n", - "Thus, we can implement and use `ComputeValueTargets` callable.\n", - "The formula for the value targets is simple:\n", - "\n", - "$$\n", - "\\hat v(s_t) = \\left( \\sum_{t'=0}^{T - 1} \\gamma^{t'}r_{t+t'} \\right) + \\gamma^T \\hat{v}(s_{t+T}),\n", - "$$\n", - "\n", - "In implementation, however, do not forget to use\n", - "`trajectory['resets']` flags to check if you need to add the value targets at the next step when\n", - "computing value targets for the current step. You can access `trajectory['state']['latest_observation']`\n", - "to get last observations in partial trajectory — $s_{t+T}$." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class ComputeValueTargets:\n", - " def __init__(self, policy, gamma=0.99):\n", - " self.policy = policy\n", - "\n", - " def __call__(self, trajectory):\n", - " # This method should modify trajectory inplace by adding\n", - " # an item with key 'value_targets' to it.\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "After computing value targets we will transform lists of interactions into tensors\n", - "with the first dimension `batch_size` which is equal to `env_steps * nenvs`, i.e. you essentially need\n", - "to flatten the first two dimensions." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class MergeTimeBatch:\n", - " \"\"\" Merges first two axes typically representing time and env batch. \"\"\"\n", - " def __call__(self, trajectory):\n", - " # Modify trajectory inplace.\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model = \n", - "policy = Policy(model)\n", - "runner = EnvRunner(\n", - " env, policy, nsteps=5,\n", - " transforms=[\n", - " ComputeValueTargets(),\n", - " MergeTimeBatch(),\n", - " ])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now is the time to implement the advantage actor critic algorithm itself. You can look into your lecture,\n", - "[Mnih et al. 2016](https://arxiv.org/abs/1602.01783) paper, and [lecture](https://www.youtube.com/watch?v=Tol_jw5hWnI&list=PLkFD6_40KJIxJMR-j5A1mkxK26gh_qg37&index=20) by Sergey Levine." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class A2C:\n", - " def __init__(self,\n", - " policy,\n", - " optimizer,\n", - " value_loss_coef=0.25,\n", - " entropy_coef=0.01,\n", - " max_grad_norm=0.5):\n", - " self.policy = policy\n", - " self.optimizer = optimizer\n", - " self.value_loss_coef = value_loss_coef\n", - " self.entropy_coef = entropy_coef\n", - " self.max_grad_norm = max_grad_norm\n", - "\n", - " def policy_loss(self, trajectory):\n", - " # You will need to compute advantages here.\n", - " \n", - "\n", - " def value_loss(self, trajectory):\n", - " \n", - "\n", - " def loss(self, trajectory):\n", - " \n", - "\n", - " def step(self, trajectory):\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now you can train your model. With reasonable hyperparameters training on a single GTX1080 for 10 million steps across all batched environments (which translates to about 5 hours of wall clock time)\n", - "it should be possible to achieve *average raw reward over last 100 episodes* (the average is taken over 100 last\n", - "episodes in each environment in the batch) of about 600. You should plot this quantity with respect to\n", - "`runner.step_var` — the number of interactions with all environments. It is highly\n", - "encouraged to also provide plots of the following quantities (these are useful for debugging as well):\n", - "\n", - "* [Coefficient of Determination](https://en.wikipedia.org/wiki/Coefficient_of_determination) between\n", - "value targets and value predictions\n", - "* Entropy of the policy $\\pi$\n", - "* Value loss\n", - "* Policy loss\n", - "* Value targets\n", - "* Value predictions\n", - "* Gradient norm\n", - "* Advantages\n", - "* A2C loss\n", - "\n", - "For optimization we suggest you use RMSProp with learning rate starting from 7e-4 and linearly decayed to 0, smoothing constant (alpha in PyTorch and decay in TensorFlow) equal to 0.99 and epsilon equal to 1e-5." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a2c = \n", - "\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Target networks?\n", - "\n", - "You may recall a technique called \"target networks\" we used a few weeks ago when we trained a DQN agent to play Atari Breakout and wonder why we have not suggested using them here. The answer is that this is more historical than practical.\n", - "\n", - "While the \"chasing the target\" problem is still present in actor-critic value estimation and target networks do show up in follow-up papers, the original A3C/A2C papers do not mention them and do not explain this omission.\n", - "\n", - "The hypothesis why this may not be a big deal (compared to Q-learning) goes like this. An A3C/A2C agent selects actions based on policy, not an epsilon greedy exploration function, for which the argmax can change drastically due to tiny errors in function approximation. Therefore, errors in the value target caused by target chasing will cause less damage.\n", - "\n", - "Also, the actor-critic gradient relies on the advantage function $A(s_t, a_t) = Q(s_t, a_t) - V(s_t)$. Compare this to the $Q$-function $Q(s_t, a_t) = r(s_t, a_t) + \\gamma \\cdot \\mathbb{E}_{s_{t+1} \\mid s_t, a_t} V(s_{t+1})$ used in Q-learning and SARSA: we would expect that any bias in $V$-function approximation will be carried over from $V(s_{t+1})$ to $V(s_t)$ by gradient updates. However, in the formula for the advantage function the two approximations ($Q$-function and $V$-function) come with opposite signs, and thus the errors will cancel out.\n", - "\n", - "The last reason may be computational. Authors were concerned to beat existent algorithms in the wall-clock learning time, and any overhead of parameter copying (target network update) counted against this goal." - ] - } - ], - "metadata": { - "language_info": { - "name": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/week06_policy_based/atari_wrappers.py b/week06_policy_based/atari_wrappers.py index b1a9234de..ffd0ce25f 100644 --- a/week06_policy_based/atari_wrappers.py +++ b/week06_policy_based/atari_wrappers.py @@ -2,139 +2,149 @@ from collections import defaultdict, deque import cv2 -import gym -import gym.spaces as spaces -from gym.envs import atari +import gymnasium as gym import numpy as np +from gymnasium import ObservationWrapper, RewardWrapper, Wrapper +from gymnasium.spaces import Box +from gymnasium.wrappers import RecordVideo +from shimmy.atari_env import AtariEnv +from tensorboardX import SummaryWriter from env_batch import ParallelEnvBatch + cv2.ocl.setUseOpenCL(False) -class EpisodicLife(gym.Wrapper): - """ Sets done flag to true when agent dies. """ +class EpisodicLife(Wrapper): + """Sets done flag to true when agent dies.""" def __init__(self, env): - super(EpisodicLife, self).__init__(env) + super().__init__(env) self.lives = 0 self.real_done = True def step(self, action): - obs, rew, done, info = self.env.step(action) - self.real_done = done - info["real_done"] = done + obs, reward, terminated, truncated, info = self.env.step(action) + self.real_done = terminated or truncated + info["real_done"] = self.real_done lives = self.env.unwrapped.ale.lives() if 0 < lives < self.lives: - done = True + terminated = True self.lives = lives - return obs, rew, done, info + return obs, reward, terminated, truncated, info def reset(self, **kwargs): if self.real_done: - obs = self.env.reset(**kwargs) + obs, info = self.env.reset(**kwargs) else: - obs, _, _, _ = self.env.step(0) + obs, _, terminated, truncated, info = self.env.step(0) + if terminated or truncated: + obs, info = self.env.reset(**kwargs) self.lives = self.env.unwrapped.ale.lives() - return obs + return obs, info -class FireReset(gym.Wrapper): - """ Makes fire action when reseting environment. +class FireReset(Wrapper): + """Makes fire action when reseting environment. Some environments are fixed until the agent makes the fire action, this wrapper makes this action so that the epsiode starts automatically. """ def __init__(self, env): - super(FireReset, self).__init__(env) + super().__init__(env) action_meanings = env.unwrapped.get_action_meanings() if len(action_meanings) < 3: raise ValueError( "env.unwrapped.get_action_meanings() must be of length >= 3" - f"but is of length {len(action_meanings)}") + f"but is of length {len(action_meanings)}" + ) if env.unwrapped.get_action_meanings()[1] != "FIRE": raise ValueError( "env.unwrapped.get_action_meanings() must have 'FIRE' " - f"under index 1, but is {action_meanings}") + f"under index 1, but is {action_meanings}" + ) def step(self, action): return self.env.step(action) def reset(self, **kwargs): self.env.reset(**kwargs) - obs, _, done, _ = self.env.step(1) - if done: + obs, _, terminated, truncated, _ = self.env.step(1) + if terminated or truncated: self.env.reset(**kwargs) - obs, _, done, _ = self.env.step(2) - if done: + obs, _, terminated, truncated, _ = self.env.step(2) + if terminated or truncated: self.env.reset(**kwargs) - return obs + return obs, {} -class StartWithRandomActions(gym.Wrapper): - """ Makes random number of random actions at the beginning of each - episode. """ +class StartWithRandomActions(Wrapper): + """Makes random number of random actions at the beginning of each + episode.""" def __init__(self, env, max_random_actions=30): - super(StartWithRandomActions, self).__init__(env) + super().__init__(env) self.max_random_actions = max_random_actions self.real_done = True def step(self, action): - obs, rew, done, info = self.env.step(action) + obs, reward, terminated, truncated, info = self.env.step(action) self.real_done = info.get("real_done", True) - return obs, rew, done, info + return obs, reward, terminated, truncated, info def reset(self, **kwargs): - obs = self.env.reset() + obs, info = self.env.reset(**kwargs) if self.real_done: - num_random_actions = np.random.randint(self.max_random_actions + 1) + num_random_actions = self.unwrapped.np_random.integers( + low=1, high=self.max_random_actions + 1 + ) for _ in range(num_random_actions): - obs, _, _, _ = self.env.step(self.env.action_space.sample()) + obs, _, _, _, info = self.env.step(self.env.action_space.sample()) self.real_done = False - return obs + return obs, info -class ImagePreprocessing(gym.ObservationWrapper): - """ Preprocesses image-observations by possibly grayscaling and resizing. """ +class ImagePreprocessing(ObservationWrapper): + """Preprocesses image-observations by possibly grayscaling and resizing.""" - def __init__(self, env, width=84, height=84, grayscale=True): - super(ImagePreprocessing, self).__init__(env) - self.width = width + def __init__(self, env, height=84, width=84, grayscale=True): + super().__init__(env) self.height = height + self.width = width self.grayscale = grayscale ospace = self.env.observation_space low, high, dtype = ospace.low.min(), ospace.high.max(), ospace.dtype if self.grayscale: - self.observation_space = spaces.Box( + self.observation_space = Box( low=low, high=high, - shape=(width, height), + shape=(height, width), dtype=dtype, ) else: - obs_shape = (width, height) + self.observation_space.shape[2:] - self.observation_space = spaces.Box(low=low, high=high, - shape=obs_shape, dtype=dtype) + self.observation_space = Box( + low=low, + high=high, + shape=(height, width, *self.observation_space.shape[2:]), + dtype=dtype, + ) def observation(self, observation): - """ Performs image preprocessing. """ + """Performs image preprocessing.""" if self.grayscale: observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY) - observation = cv2.resize(observation, (self.width, self.height), - cv2.INTER_AREA) + observation = cv2.resize(observation, (self.width, self.height), cv2.INTER_AREA) return observation -class MaxBetweenFrames(gym.ObservationWrapper): - """ Takes maximum between two subsequent frames. """ +class MaxBetweenFrames(ObservationWrapper): + """Takes maximum between two subsequent frames.""" def __init__(self, env): - if (isinstance(env.unwrapped, atari.AtariEnv) and - "NoFrameskip" not in env.spec.id): - raise ValueError( - "MaxBetweenFrames requires NoFrameskip in Atari env id") - super(MaxBetweenFrames, self).__init__(env) + if isinstance(env.unwrapped, AtariEnv) and "NoFrameskip" not in env.spec.id: + raise ValueError("MaxBetweenFrames requires NoFrameskip in atari env id") + super().__init__(env) self.last_obs = None def observation(self, observation): @@ -143,15 +153,15 @@ def observation(self, observation): return obs def reset(self, **kwargs): - self.last_obs = self.env.reset() - return self.last_obs + self.last_obs, info = self.env.reset(**kwargs) + return self.last_obs, info -class QueueFrames(gym.ObservationWrapper): - """ Queues specified number of frames together along new dimension. """ +class QueueFrames(ObservationWrapper): + """Queues specified number of frames together along new dimension.""" def __init__(self, env, nframes, concat=False): - super(QueueFrames, self).__init__(env) + super().__init__(env) self.obs_queue = deque([], maxlen=nframes) self.concat = concat ospace = self.observation_space @@ -159,110 +169,143 @@ def __init__(self, env, nframes, concat=False): oshape = ospace.shape[:-1] + (ospace.shape[-1] * nframes,) else: oshape = ospace.shape + (nframes,) - self.observation_space = spaces.Box( - ospace.low.min(), ospace.high.max(), oshape, ospace.dtype) + self.observation_space = Box( + ospace.low.min(), ospace.high.max(), oshape, ospace.dtype + ) def observation(self, observation): self.obs_queue.append(observation) - return (np.concatenate(self.obs_queue, -1) if self.concat - else np.dstack(self.obs_queue)) + return ( + np.concatenate(self.obs_queue, -1) + if self.concat + else np.dstack(self.obs_queue) + ) def reset(self, **kwargs): - obs = self.env.reset() + obs, info = self.env.reset(**kwargs) for _ in range(self.obs_queue.maxlen - 1): self.obs_queue.append(obs) - return self.observation(obs) + return self.observation(obs), info -class SkipFrames(gym.Wrapper): - """ Performs the same action for several steps and returns the final result. - """ +class SkipFrames(Wrapper): + """Performs the same action for several steps and returns the final result.""" def __init__(self, env, nskip=4): - super(SkipFrames, self).__init__(env) - if (isinstance(env.unwrapped, atari.AtariEnv) and - "NoFrameskip" not in env.spec.id): - raise ValueError("SkipFrames requires NoFrameskip in Atari env id") + super().__init__(env) + if isinstance(env.unwrapped, AtariEnv) and "NoFrameskip" not in env.spec.id: + raise ValueError("SkipFrames requires NoFrameskip in atari env id") self.nskip = nskip def step(self, action): total_reward = 0.0 for _ in range(self.nskip): - obs, rew, done, info = self.env.step(action) - total_reward += rew - if done: + obs, reward, terminated, truncated, info = self.env.step(action) + total_reward += reward + if terminated or truncated: break - return obs, total_reward, done, info + return obs, total_reward, terminated, truncated, info def reset(self, **kwargs): return self.env.reset(**kwargs) -class ClipReward(gym.RewardWrapper): - """ Modifes reward to be in {-1, 0, 1} by taking sign of it. """ +class ClipReward(RewardWrapper): + """Modifes reward to be in {-1, 0, 1} by taking sign of it.""" def reward(self, reward): return np.sign(reward) -class SummariesBase(gym.Wrapper): - """ Env summaries writer base.""" +class SwapImageAxes(ObservationWrapper): + """ + Image shape to num_channels x height x width and normalization + """ - def __init__(self, env, prefix=None, running_mean_size=100): + def __init__(self, env): + super().__init__(env) + old_shape = self.observation_space.shape + self.observation_space = Box( + low=0.0, + high=1.0, + shape=(old_shape[-1], old_shape[0], old_shape[1]), + dtype=np.float32, + ) + + def observation(self, observation): + return np.transpose(observation, (2, 0, 1)).astype(np.float32) / 255.0 + + +class SummariesBase(Wrapper): + """Env summaries writer base.""" + + def __init__(self, env, prefix=None, running_mean_size=100, step_var=None): super().__init__(env) self.episode_counter = 0 self.prefix = prefix or self.env.spec.id + self.step_var = step_var or 0 - nenvs = getattr(self.env.unwrapped, "nenvs", 1) - self.rewards = np.zeros(nenvs) - self.had_ended_episodes = np.zeros(nenvs, dtype=np.bool) - self.episode_lengths = np.zeros(nenvs) - self.reward_queues = [deque([], maxlen=running_mean_size) - for _ in range(nenvs)] + self.nenvs = getattr(self.env.unwrapped, "nenvs", 1) + self.rewards = np.zeros(self.nenvs) + self.had_ended_episodes = np.zeros(self.nenvs, dtype=bool) + self.episode_lengths = np.zeros(self.nenvs) + self.reward_queues = [ + deque([], maxlen=running_mean_size) for _ in range(self.nenvs) + ] def should_write_summaries(self): - """ Returns true if it's time to write summaries. """ + """Returns true if it's time to write summaries.""" return np.all(self.had_ended_episodes) def add_summaries(self): - """ Writes summaries. """ - self.add_summary_scalar( - f"{self.prefix}/total_reward", - np.mean([q[-1] for q in self.reward_queues])) - self.add_summary_scalar( - f"{self.prefix}/reward_mean_{self.reward_queues[0].maxlen}", - np.mean([np.mean(q) for q in self.reward_queues])) - self.add_summary_scalar( - f"{self.prefix}/episode_length", - np.mean(self.episode_lengths)) + """Writes summaries.""" + self.add_summary( + f"Episodes/total_reward", np.mean([q[-1] for q in self.reward_queues]) + ) + self.add_summary( + f"Episodes/reward_mean_{self.reward_queues[0].maxlen}", + np.mean([np.mean(q) for q in self.reward_queues]), + ) + self.add_summary(f"Episodes/episode_length", np.mean(self.episode_lengths)) if self.had_ended_episodes.size > 1: - self.add_summary_scalar( - f"{self.prefix}/min_reward", - min(q[-1] for q in self.reward_queues)) - self.add_summary_scalar( - f"{self.prefix}/max_reward", - max(q[-1] for q in self.reward_queues)) + self.add_summary( + f"Episodes/min_reward", + min(q[-1] for q in self.reward_queues), + ) + self.add_summary( + f"Episodes/max_reward", + max(q[-1] for q in self.reward_queues), + ) self.episode_lengths.fill(0) self.had_ended_episodes.fill(False) def step(self, action): - obs, rew, done, info = self.env.step(action) + obs, rew, terminated, truncated, info = self.env.step(action) self.rewards += rew self.episode_lengths[~self.had_ended_episodes] += 1 info_collection = [info] if isinstance(info, dict) else info - done_collection = [done] if isinstance(done, bool) else done - done_indices = [i for i, info in enumerate(info_collection) - if info.get("real_done", done_collection[i])] + terminated_collection = ( + [terminated] if isinstance(terminated, bool) else terminated + ) + truncated_collection = [truncated] if isinstance(truncated, bool) else truncated + done_indices = [ + i + for i, info in enumerate(info_collection) + if info.get( + "real_done", terminated_collection[i] or truncated_collection[i] + ) + ] for i in done_indices: if not self.had_ended_episodes[i]: self.had_ended_episodes[i] = True self.reward_queues[i].append(self.rewards[i]) self.rewards[i] = 0 + self.step_var += self.nenvs if self.should_write_summaries(): self.add_summaries() - return obs, rew, done, info + return obs, rew, terminated, truncated, info def reset(self, **kwargs): self.rewards.fill(0) @@ -271,30 +314,23 @@ def reset(self, **kwargs): return self.env.reset(**kwargs) -class TFSummaries(SummariesBase): - """ Writes env summaries using TensorFlow.""" +class TensorboardSummaries(SummariesBase): + """Writes env summaries using Tensorboard.""" def __init__(self, env, prefix=None, running_mean_size=100, step_var=None): + super().__init__(env, prefix, running_mean_size, step_var) + self.writer = SummaryWriter(f"logs/{self.prefix}") - super().__init__(env, prefix, running_mean_size) - - import tensorflow as tf - self.step_var = (step_var if step_var is not None - else tf.train.get_global_step()) - - def add_summary_scalar(self, name, value): - import tensorflow as tf - tf.contrib.summary.scalar(name, value, step = self.step_var) + def add_summary(self, name, value): + if isinstance(value, dict): + self.writer.add_scalars(name, value, self.step_var) + else: + self.writer.add_scalar(name, value, self.step_var) class NumpySummaries(SummariesBase): _summaries = defaultdict(list) - _summary_step = None - - @classmethod - def set_step(cls, step): - cls._summary_step = step @classmethod def get_values(cls, name): @@ -304,16 +340,44 @@ def get_values(cls, name): def clear(cls): cls._summaries = defaultdict(list) - def __init__(self, env, prefix = None, running_mean_size = 100): - super().__init__(env, prefix, running_mean_size) + def __init__(self, env, prefix=None, running_mean_size=100, step_var=None): + super().__init__(env, prefix, running_mean_size, step_var) + + def add_summary(self, name, value): + self._summaries[name].append((self.step_var, value)) + + +def get_summaries_class(summaries): + summaries_class_map = { + "Numpy": NumpySummaries, + "Tensorboard": TensorboardSummaries, + } + if summaries in summaries_class_map: + return summaries_class_map[summaries] + + raise NotImplementedError( + f"Unknown summaries: {summaries}. Supported summaries: {summaries_class_map.keys()}" + ) - def add_summary_scalar(self, name, value): - self._summaries[name].append((self._summary_step, value)) +# magic for parallel launching of environments +class _thunk: + def __init__(self, i, env_id, **kwargs): + self.env_id = env_id + self.i = i + self.kwargs = kwargs -def nature_dqn_env(env_id, nenvs=None, seed=None, - summaries='TensorFlow', clip_reward=True): - """ Wraps env as in Nature DQN paper. """ + def __call__(self): + return nature_dqn_env( + self.env_id, + summaries=False, + clip_reward=False, + **self.kwargs, + ) + + +def nature_dqn_env(env_id, nenvs=None, seed=None, summaries="Numpy", clip_reward=True): + """Wraps env as in Nature DQN paper.""" if "NoFrameskip" not in env_id: raise ValueError(f"env_id must have 'NoFrameskip' but is {env_id}") if nenvs is not None: @@ -322,25 +386,24 @@ def nature_dqn_env(env_id, nenvs=None, seed=None, if isinstance(seed, int): seed = [seed] * nenvs if len(seed) != nenvs: - raise ValueError(f"seed has length {len(seed)} but must have " - f"length equal to nenvs which is {nenvs}") - - env = ParallelEnvBatch([ - lambda i=i, env_seed=env_seed: nature_dqn_env( - env_id, seed=env_seed, summaries=False, clip_reward=False) - for i, env_seed in enumerate(seed) - ]) + raise ValueError( + f"seed has length {len(seed)} but must have " + f"length equal to nenvs which is {nenvs}" + ) + + thunks = [_thunk(i, env_id) for i in range(nenvs)] + env = ParallelEnvBatch(make_env=thunks, seeds=seed) + if summaries: - summaries_class = NumpySummaries if summaries == 'Numpy' else TFSummaries + summaries_class = get_summaries_class(summaries) env = summaries_class(env, prefix=env_id) if clip_reward: env = ClipReward(env) return env - env = gym.make(env_id) - env.seed(seed) + env = gym.make(env_id, render_mode="rgb_array") if summaries: - env = TFSummaries(env) + env = TensorboardSummaries(env) env = EpisodicLife(env) if "FIRE" in env.unwrapped.get_action_meanings(): env = FireReset(env) @@ -349,6 +412,7 @@ def nature_dqn_env(env_id, nenvs=None, seed=None, env = SkipFrames(env, 4) env = ImagePreprocessing(env, width=84, height=84, grayscale=True) env = QueueFrames(env, 4) + env = SwapImageAxes(env) if clip_reward: env = ClipReward(env) return env diff --git a/week06_policy_based/env_batch.py b/week06_policy_based/env_batch.py index 1e23913e9..b2bd163ac 100644 --- a/week06_policy_based/env_batch.py +++ b/week06_policy_based/env_batch.py @@ -1,8 +1,9 @@ # pylint: skip-file -from multiprocessing import Process, Pipe +from multiprocessing import Pipe, Process -from gym import Env, Wrapper, Space import numpy as np +from gymnasium import Env, Wrapper +from gymnasium.spaces import Space class SpaceBatch(Space): @@ -12,18 +13,26 @@ def __init__(self, spaces): first_dtype = spaces[0].dtype for space in spaces: if not isinstance(space, first_type): - raise TypeError("spaces have different types: {}, {}" - .format(first_type, type(space))) + raise TypeError( + "spaces have different types: {}, {}".format( + first_type, type(space) + ) + ) if first_shape != space.shape: - raise ValueError("spaces have different shapes: {}, {}" - .format(first_shape, space.shape)) + raise ValueError( + "spaces have different shapes: {}, {}".format( + first_shape, space.shape + ) + ) if first_dtype != space.dtype: - raise ValueError("spaces have different data types: {}, {}" - .format(first_dtype, space.dtype)) + raise ValueError( + "spaces have different data types: {}, {}".format( + first_dtype, space.dtype + ) + ) self.spaces = spaces - super(SpaceBatch, self).__init__(shape=self.spaces[0].shape, - dtype=self.spaces[0].dtype) + super().__init__(shape=self.spaces[0].shape, dtype=self.spaces[0].dtype) def sample(self): return np.stack([space.sample() for space in self.spaces]) @@ -39,16 +48,15 @@ def __init__(self, make_env, nenvs=None): self._nenvs = len(self.envs) # self.observation_space = SpaceBatch([env.observation_space # for env in self._envs]) - self.action_space = SpaceBatch([env.action_space - for env in self._envs]) + self.action_space = SpaceBatch([env.action_space for env in self._envs]) def _get_make_env_functions(self, make_env, nenvs): if nenvs is None and not isinstance(make_env, list): - raise ValueError("When nenvs is None make_env" - " must be a list of callables") - if nenvs is not None and not callable(make_env): raise ValueError( - "When nenvs is not None make_env must be callable") + "When nenvs is None make_env" " must be a list of callables" + ) + if nenvs is not None and not callable(make_env): + raise ValueError("When nenvs is not None make_env must be callable") if nenvs is not None: make_env = [make_env for _ in range(nenvs)] @@ -66,29 +74,41 @@ def _check_actions(self, actions): if not len(actions) == self.nenvs: raise ValueError( "number of actions is not equal to number of envs: " - "len(actions) = {}, nenvs = {}" - .format(len(actions), self.nenvs)) + "len(actions) = {}, nenvs = {}".format(len(actions), self.nenvs) + ) def step(self, actions): self._check_actions(actions) - obs, rews, resets, infos = [], [], [], [] + observations, rewards, terminated_list, truncated_list, infos = [], [], [], [], [] for env, action in zip(self._envs, actions): - ob, rew, done, info = env.step(action) - if done: - ob = env.reset() - obs.append(ob) - rews.append(rew) - resets.append(done) + obs, rew, terminated, truncated, info = env.step(action) + if terminated or truncated: + obs, info = env.reset() + observations.append(obs) + rewards.append(rew) + terminated_list.append(terminated) + truncated_list.append(truncated) infos.append(info) - return np.stack(obs), np.stack(rews), np.stack(resets), infos + return ( + np.stack(observations), + np.stack(rewards), + np.stack(terminated_list), + np.stack(truncated_list), + infos, + ) - def reset(self): - return np.stack([env.reset() for env in self.envs]) + def reset(self, **kwargs): + observations, infos = [], [] + for env in self.envs: + obs, info = env.reset(**kwargs) + observations.append(obs) + infos.append(info) + return np.stack(observations), infos class SingleEnvBatch(Wrapper, EnvBatch): def __init__(self, env): - super(SingleEnvBatch, self).__init__(env) + super().__init__(env) self.observation_space = SpaceBatch([self.env.observation_space]) self.action_space = SpaceBatch([self.env.action_space]) @@ -102,37 +122,38 @@ def envs(self): def step(self, actions): self._check_actions(actions) - ob, rew, done, info = self.env.step(actions[0]) - if done: - ob = self.env.reset() + obs, rew, terminated, truncated, info = self.env.step(actions[0]) + if terminated or truncated: + obs, info = self.env.reset() return ( - ob[None], + obs[None], np.expand_dims(rew, 0), - np.expand_dims(done, 0), + np.expand_dims(terminated, 0), + np.expand_dims(truncated, 0), [info], ) - def reset(self): - return self.env.reset()[None] + def reset(self, **kwargs): + obs, info = self.env.reset(**kwargs) + return obs[None], info -def worker(parent_connection, worker_connection, make_env_function, - send_spaces=True): +def worker(parent_connection, worker_connection, make_env_function, send_spaces=True): # Adapted from SubprocVecEnv github.com/openai/baselines parent_connection.close() env = make_env_function() if send_spaces: worker_connection.send((env.observation_space, env.action_space)) while True: - cmd, action = worker_connection.recv() + cmd, data = worker_connection.recv() if cmd == "step": - ob, rew, done, info = env.step(action) - if done: - ob = env.reset() - worker_connection.send((ob, rew, done, info)) + obs, rew, terminated, truncated, info = env.step(data) + if terminated or truncated: + obs, info = env.reset() + worker_connection.send((obs, rew, terminated, truncated, info)) elif cmd == "reset": - ob = env.reset() - worker_connection.send(ob) + obs, info = env.reset(seed=data) + worker_connection.send((obs, info)) elif cmd == "close": env.close() worker_connection.close() @@ -146,22 +167,26 @@ class ParallelEnvBatch(EnvBatch): An abstract batch of environments. """ - def __init__(self, make_env, nenvs=None): + def __init__(self, make_env, nenvs=None, seeds=None): make_env_functions = self._get_make_env_functions(make_env, nenvs) self._nenvs = len(make_env_functions) - self._parent_connections, self._worker_connections = zip(*[ - Pipe() for _ in range(self._nenvs) - ]) + self._parent_connections, self._worker_connections = zip( + *[Pipe() for _ in range(self._nenvs)] + ) + self._seeds = seeds or list(range(self._envs)) self._processes = [ Process( target=worker, args=(parent_connection, worker_connection, make_env), - daemon=True + daemon=True, + ) + for i, (parent_connection, worker_connection, make_env) in enumerate( + zip( + self._parent_connections, + self._worker_connections, + make_env_functions, + ) ) - for i, (parent_connection, worker_connection, make_env) - in enumerate(zip(self._parent_connections, - self._worker_connections, - make_env_functions)) ] for p in self._processes: p.start() @@ -187,13 +212,23 @@ def step(self, actions): for conn, a in zip(self._parent_connections, actions): conn.send(("step", a)) results = [conn.recv() for conn in self._parent_connections] - obs, rews, dones, infos = zip(*results) - return np.stack(obs), np.stack(rews), np.stack(dones), infos + obs, rews, terminated, truncated, infos = zip(*results) + return ( + np.stack(obs), + np.stack(rews), + np.stack(terminated), + np.stack(truncated), + infos, + ) - def reset(self): - for conn in self._parent_connections: - conn.send(("reset", None)) - return np.stack([conn.recv() for conn in self._parent_connections]) + def reset(self, **kwargs): + for env_idx, conn in enumerate(self._parent_connections): + conn.send(("reset", self._seeds[env_idx])) + + results = [remote.recv() for remote in self._parent_connections] + observations, infos = zip(*results) + + return np.stack(observations), infos def close(self): if self._closed: diff --git a/week06_policy_based/reinforce_pytorch.ipynb b/week06_policy_based/reinforce_pytorch.ipynb index d056ab659..0130118d5 100644 --- a/week06_policy_based/reinforce_pytorch.ipynb +++ b/week06_policy_based/reinforce_pytorch.ipynb @@ -20,6 +20,10 @@ "import sys, os\n", "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", + " !pip install -q gymnasium\n", + " !pip install moviepy\n", + " !apt install ffmpeg\n", + " !pip install imageio-ffmpeg\n", " !touch .setup_complete\n", "\n", "# This code creates a virtual display to draw game images on.\n", @@ -35,12 +39,22 @@ "metadata": {}, "outputs": [], "source": [ - "import gym\n", + "import gymnasium as gym\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# also you need to install ffmpeg if not installed\n", + "# for MacOS: ! brew install ffmpeg" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -54,7 +68,7 @@ "metadata": {}, "outputs": [], "source": [ - "env = gym.make(\"CartPole-v0\")\n", + "env = gym.make(\"CartPole-v1\", render_mode=\"rgb_array\")\n", "\n", "# gym compatibility: unwrap TimeLimit\n", "if hasattr(env, '_max_episode_steps'):\n", @@ -64,7 +78,7 @@ "n_actions = env.action_space.n\n", "state_dim = env.observation_space.shape\n", "\n", - "plt.imshow(env.render(\"rgb_array\"))" + "plt.imshow(env.render())" ] }, { @@ -91,7 +105,8 @@ "outputs": [], "source": [ "import torch\n", - "import torch.nn as nn" + "import torch.nn as nn\n", + "import torch.nn.functional as F" ] }, { @@ -155,7 +170,7 @@ "metadata": {}, "outputs": [], "source": [ - "test_states = np.array([env.reset() for _ in range(5)])\n", + "test_states = np.array([env.reset()[0] for _ in range(5)])\n", "test_probas = predict_probs(test_states)\n", "assert isinstance(test_probas, np.ndarray), \\\n", " \"you must return np array and not %s\" % type(test_probas)\n", @@ -180,13 +195,14 @@ "outputs": [], "source": [ "def generate_session(env, t_max=1000):\n", - " \"\"\" \n", + " \"\"\"\n", " Play a full session with REINFORCE agent.\n", " Returns sequences of states, actions, and rewards.\n", " \"\"\"\n", " # arrays to record session\n", " states, actions, rewards = [], [], []\n", - " s = env.reset()\n", + "\n", + " s = env.reset()[0]\n", "\n", " for t in range(t_max):\n", " # action probabilities array aka pi(a|s)\n", @@ -194,7 +210,8 @@ "\n", " # Sample action with given probabilities.\n", " a = \n", - " new_s, r, done, info = env.step(a)\n", + "\n", + " new_s, r, terminated, truncated, info = env.step(a)\n", "\n", " # record session history to train later\n", " states.append(s)\n", @@ -202,7 +219,7 @@ " rewards.append(r)\n", "\n", " s = new_s\n", - " if done:\n", + " if terminated or truncated:\n", " break\n", "\n", " return states, actions, rewards" @@ -300,20 +317,6 @@ "When you compute the gradient of that function with respect to network weights $\\theta$, it will become exactly the policy gradient." ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def to_one_hot(y_tensor, ndims):\n", - " \"\"\" helper: take an integer vector and convert it to 1-hot matrix. \"\"\"\n", - " y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)\n", - " y_one_hot = torch.zeros(\n", - " y_tensor.size()[0], ndims).scatter_(1, y_tensor, 1)\n", - " return y_one_hot" - ] - }, { "cell_type": "code", "execution_count": null, @@ -333,7 +336,7 @@ "\n", " # cast everything into torch tensors\n", " states = torch.tensor(states, dtype=torch.float32)\n", - " actions = torch.tensor(actions, dtype=torch.int32)\n", + " actions = torch.tensor(actions, dtype=torch.int64)\n", " cumulative_returns = np.array(get_cumulative_rewards(rewards, gamma))\n", " cumulative_returns = torch.tensor(cumulative_returns, dtype=torch.float32)\n", "\n", @@ -347,7 +350,7 @@ "\n", " # select log-probabilities for chosen actions, log pi(a_i|s_i)\n", " log_probs_for_actions = torch.sum(\n", - " log_probs * to_one_hot(actions, env.action_space.n), dim=1)\n", + " log_probs * F.one_hot(actions, env.action_space.n), dim=1)\n", " \n", " # Compute loss here. Don't forgen entropy regularization with `entropy_coef` \n", " entropy = \n", @@ -376,7 +379,7 @@ "for i in range(100):\n", " rewards = [train_on_session(*generate_session(env)) for _ in range(100)] # generate new sessions\n", " \n", - " print(\"mean reward:%.3f\" % (np.mean(rewards)))\n", + " print(\"mean reward: %.3f\" % (np.mean(rewards)))\n", " \n", " if np.mean(rewards) > 500:\n", " print(\"You Win!\") # but you can train even further\n", @@ -398,10 +401,12 @@ "source": [ "# Record sessions\n", "\n", - "import gym.wrappers\n", + "from gymnasium.wrappers import RecordVideo\n", "\n", - "with gym.wrappers.Monitor(gym.make(\"CartPole-v0\"), directory=\"videos\", force=True) as env_monitor:\n", - " sessions = [generate_session(env_monitor) for _ in range(100)]" + "with gym.make(\"CartPole-v1\", render_mode=\"rgb_array\") as env, RecordVideo(\n", + " env=env, video_folder=\"./videos\"\n", + ") as env_monitor:\n", + " sessions = [generate_session(env_monitor) for _ in range(10)]\n" ] }, { @@ -437,9 +442,22 @@ } ], "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", - "pygments_lexer": "ipython3" + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" } }, "nbformat": 4, diff --git a/week06_policy_based/reinforce_tensorflow.ipynb b/week06_policy_based/reinforce_tensorflow.ipynb index 063f18b18..644340cb4 100644 --- a/week06_policy_based/reinforce_tensorflow.ipynb +++ b/week06_policy_based/reinforce_tensorflow.ipynb @@ -23,6 +23,7 @@ " \n", " if not os.path.exists('.setup_complete'):\n", " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", + " !pip install -q gym[classic_control]==0.18.0\n", " !touch .setup_complete\n", "\n", "# This code creates a virtual display to draw game images on.\n", diff --git a/week06_policy_based/runners.py b/week06_policy_based/runners.py index 120c42484..2d608aa68 100644 --- a/week06_policy_based/runners.py +++ b/week06_policy_based/runners.py @@ -1,11 +1,10 @@ -""" RL env runner """ from collections import defaultdict import numpy as np class EnvRunner: - """ Reinforcement learning runner in an environment with given policy """ + """Reinforcement learning runner in an environment with given policy""" def __init__(self, env, policy, nsteps, transforms=None, step_var=None): self.env = env @@ -13,20 +12,25 @@ def __init__(self, env, policy, nsteps, transforms=None, step_var=None): self.nsteps = nsteps self.transforms = transforms or [] self.step_var = step_var if step_var is not None else 0 - self.state = {"latest_observation": self.env.reset()} + self.state = {"latest_observation": self.env.reset()[0]} @property def nenvs(self): - """ Returns number of batched envs or `None` if env is not batched """ + """Returns number of batched envs or `None` if env is not batched""" return getattr(self.env.unwrapped, "nenvs", None) - def reset(self): - """ Resets env and runner states. """ - self.state["latest_observation"] = self.env.reset() + def reset(self, **kwargs): + """Resets env and runner states.""" + self.state["latest_observation"] = self.env.reset(**kwargs)[0] self.policy.reset() + def add_summary(self, name, val): + """Writes logs""" + add_summary = self.env.get_wrapper_attr("add_summary") + add_summary(name, val) + def get_next(self): - """ Runs the agent in the environment. """ + """Runs the agent in the environment.""" trajectory = defaultdict(list, {"actions": []}) observations = [] rewards = [] @@ -37,27 +41,29 @@ def get_next(self): observations.append(self.state["latest_observation"]) act = self.policy.act(self.state["latest_observation"]) if "actions" not in act: - raise ValueError("result of policy.act must contain 'actions' " - f"but has keys {list(act.keys())}") + raise ValueError( + "result of policy.act must contain 'actions' " + f"but has keys {list(act.keys())}" + ) for key, val in act.items(): trajectory[key].append(val) - obs, rew, done, _ = self.env.step(trajectory["actions"][-1]) + obs, rew, terminated, truncated, _ = self.env.step( + trajectory["actions"][-1] + ) self.state["latest_observation"] = obs rewards.append(rew) - resets.append(done) + reset = np.logical_or(terminated, truncated) + resets.append(reset) self.step_var += self.nenvs or 1 # Only reset if the env is not batched. Batched envs should # auto-reset. - if not self.nenvs and np.all(done): + if not self.nenvs and np.all(reset): self.state["env_steps"] = i + 1 - self.state["latest_observation"] = self.env.reset() + self.state["latest_observation"] = self.env.reset()[0] - trajectory.update( - observations=observations, - rewards=rewards, - resets=resets) + trajectory.update(observations=observations, rewards=rewards, resets=resets) trajectory["state"] = self.state for transform in self.transforms: diff --git a/week08_pomdp/README.md b/week08_pomdp/README.md index 8154b5b49..6bb589a2f 100644 --- a/week08_pomdp/README.md +++ b/week08_pomdp/README.md @@ -5,6 +5,7 @@ _Links on all articles mentioned during the lecture could be found in "Reference ## Basics * Our [lecture](https://yadi.sk/i/AHzpTjiT3U8L8e) and [seminar](https://yadi.sk/i/Ka-I7nBp3U8LAG) (russian) +* A Lecture on Basics by Pavel Shvechikov (english) [Video](https://www.youtube.com/watch?v=aV4wz7FAXmo) * A lecture on basics by Andrew NG (english, LQ) - [video](https://www.youtube.com/watch?v=yCqPMD6coO8) * A lecture on basics by 5vision (russian) - [video](https://www.youtube.com/watch?v=_dkaynuKUFE) * _[alternative]_ Chalkboard-style 2-part lecture by B. Ravindran. - [part1](https://www.youtube.com/watch?v=9G_KevA8DFY), [part2](https://www.youtube.com/watch?v=dMOUp7YzUpQ) diff --git a/week08_pomdp/atari_util.py b/week08_pomdp/atari_util.py index baa77ec0d..11065cde6 100644 --- a/week08_pomdp/atari_util.py +++ b/week08_pomdp/atari_util.py @@ -1,14 +1,14 @@ import cv2 import numpy as np -from gym.core import Wrapper -from gym.spaces.box import Box +from gymnasium import Wrapper +from gymnasium.spaces import Box class PreprocessAtari(Wrapper): def __init__(self, env, height=42, width=42, color=False, crop=lambda img: img, n_frames=4, dim_order='pytorch', reward_scale=1): """A gym wrapper that reshapes, crops and scales image into the desired shapes""" - super(PreprocessAtari, self).__init__(env) + super().__init__(env) self.img_size = (height, width) self.crop = crop self.color = color @@ -25,18 +25,19 @@ def __init__(self, env, height=42, width=42, color=False, self.observation_space = Box(0.0, 1.0, obs_shape) self.framebuffer = np.zeros(obs_shape, 'float32') - def reset(self): + def reset(self, **kwargs): """Resets the game, returns initial frames""" self.framebuffer = np.zeros_like(self.framebuffer) - self.update_buffer(self.env.reset()) - return self.framebuffer + state, info = self.env.reset(**kwargs) + self.update_buffer(state) + return self.framebuffer, info def step(self, action): """Plays the game for 1 step, returns frame buffer""" - new_img, r, done, info = self.env.step(action) + new_img, r, terminated, truncated, info = self.env.step(action) self.update_buffer(new_img) - return self.framebuffer, r * self.reward_scale, done, info + return self.framebuffer, r * self.reward_scale, terminated, truncated, info ### image processing ### diff --git a/week08_pomdp/env_pool.py b/week08_pomdp/env_pool.py index 2eda898b7..709dca407 100644 --- a/week08_pomdp/env_pool.py +++ b/week08_pomdp/env_pool.py @@ -1,5 +1,5 @@ """ -A thin wrapper for OpenAI gym environments that maintains a set of parallel games and has a method to generate +A thin wrapper for Farama gymnasium environments that maintains a set of parallel games and has a method to generate interaction sessions given agent one-step applier function. """ @@ -15,7 +15,7 @@ def __init__(self, agent, make_env, n_parallel_games=1): and is capable of some auxilary actions like evaluating agent on one game session (See .evaluate()). :param agent: Agent which interacts with the environment. - :param make_env: Factory that produces environments OR a name of the gym environment. + :param make_env: Factory that produces environments OR a name of the gymnasium environment. :param n_games: Number of parallel games. One game by default. :param max_size: Max pool size by default (if appending sessions). By default, pool is not constrained in size. """ @@ -25,17 +25,17 @@ def __init__(self, agent, make_env, n_parallel_games=1): self.envs = [self.make_env() for _ in range(n_parallel_games)] # Initial observations. - self.prev_observations = [env.reset() for env in self.envs] + self.prev_observations = [env.reset()[0] for env in self.envs] # Agent memory variables (if you use recurrent networks). self.prev_memory_states = agent.get_initial_state(n_parallel_games) - # Whether particular session has just been terminated and needs + # Whether particular session has just been terminated or truncated and needs # restarting. self.just_ended = [False] * len(self.envs) def interact(self, n_steps=100, verbose=False): - """Generate interaction sessions with ataries (OpenAI gym Atari environments) + """Generate interaction sessions with ataries (Farama gymnasium Atari environments) Sessions will have length n_steps. Each time one of games is finished, it is immediately getting reset and this time is recorded in is_alive_log (See returned values). @@ -46,9 +46,9 @@ def interact(self, n_steps=100, verbose=False): def env_step(i, action): if not self.just_ended[i]: - new_observation, cur_reward, is_done, info = \ + new_observation, cur_reward, terminated, truncated, info = \ self.envs[i].step(action) - if is_done: + if terminated or truncated: # Game ends now, will finalize on next tick. self.just_ended[i] = True @@ -58,7 +58,7 @@ def env_step(i, action): else: # Reset environment, get new observation to be used on next # tick. - new_observation = self.envs[i].reset() + new_observation = self.envs[i].reset()[0] # Reset memory for new episode. initial_memory_state = self.agent.get_initial_state( diff --git a/week08_pomdp/practice_pytorch.ipynb b/week08_pomdp/practice_pytorch.ipynb index dca333333..2d3d43b46 100644 --- a/week08_pomdp/practice_pytorch.ipynb +++ b/week08_pomdp/practice_pytorch.ipynb @@ -1,698 +1,681 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "if 'google.colab' in sys.modules:\n", - " !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/0ccb0673965dd650d9b284e1ec90c2bfd82c8a94/week08_pomdp/atari_util.py\n", - " !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/0ccb0673965dd650d9b284e1ec90c2bfd82c8a94/week08_pomdp/env_pool.py\n", - "\n", - "# If you are running on a server, launch xvfb to record game videos\n", - "# Please make sure you have xvfb installed\n", - "import os\n", - "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", - " !bash ../xvfb start\n", - " os.environ['DISPLAY'] = ':1'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "from IPython.core import display\n", - "import matplotlib.pyplot as plt\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Kung-Fu, recurrent style\n", - "\n", - "In this notebook we'll once again train RL agent for for Atari [KungFuMaster](https://gym.openai.com/envs/KungFuMaster-v0/), this time using recurrent neural networks.\n", - "\n", - "![img](https://upload.wikimedia.org/wikipedia/en/6/66/Kung_fu_master_mame.png)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", - "Observation shape: (1, 42, 42)\n", - "Num actions: 14\n", - "Action names: ['NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'DOWNRIGHT', 'DOWNLEFT', 'RIGHTFIRE', 'LEFTFIRE', 'DOWNFIRE', 'UPRIGHTFIRE', 'UPLEFTFIRE', 'DOWNRIGHTFIRE', 'DOWNLEFTFIRE']\n" - ] - } - ], - "source": [ - "import gym\n", - "from atari_util import PreprocessAtari\n", - "\n", - "\n", - "def make_env():\n", - " env = gym.make(\"KungFuMasterDeterministic-v0\")\n", - " env = PreprocessAtari(env, height=42, width=42,\n", - " crop=lambda img: img[60:-30, 15:],\n", - " color=False, n_frames=1)\n", - " return env\n", - "\n", - "\n", - "env = make_env()\n", - "\n", - "obs_shape = env.observation_space.shape\n", - "n_actions = env.action_space.n\n", - "\n", - "print(\"Observation shape:\", obs_shape)\n", - "print(\"Num actions:\", n_actions)\n", - "print(\"Action names:\", env.env.env.get_action_meanings())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "if 'google.colab' in sys.modules:\n", + " !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/atari_util.py\n", + " !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/env_pool.py\n", + "\n", + " !pip install -q gymnasium[atari,accept-rom-license]\n", + "\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", + " !touch .setup_complete\n", + "# If you are running on a server, launch xvfb to record game videos\n", + "# Please make sure you have xvfb installed\n", + "import os\n", + "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", + " !bash ../xvfb start\n", + " os.environ['DISPLAY'] = ':1'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from IPython.core import display\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Kung-Fu, recurrent style\n", + "\n", + "In this notebook we'll once again train RL agent for for Atari [KungFuMaster](https://gymnasium.farama.org/environments/atari/kung_fu_master/), this time using recurrent neural networks.\n", + "\n", + "![img](https://upload.wikimedia.org/wikipedia/en/6/66/Kung_fu_master_mame.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", + "Observation shape: (1, 42, 42)\n", + "Num actions: 14\n", + "Action names: ['NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'DOWNRIGHT', 'DOWNLEFT', 'RIGHTFIRE', 'LEFTFIRE', 'DOWNFIRE', 'UPRIGHTFIRE', 'UPLEFTFIRE', 'DOWNRIGHTFIRE', 'DOWNLEFTFIRE']\n" + ] + } + ], + "source": [ + "import gymnasium as gym\n", + "from atari_util import PreprocessAtari\n", + "\n", + "\n", + "def make_env():\n", + " env = gym.make(\"KungFuMasterDeterministic-v0\", render_mode=\"rgb_array\")\n", + " env = PreprocessAtari(env, height=42, width=42,\n", + " crop=lambda img: img[60:-30, 15:],\n", + " color=False, n_frames=1)\n", + " return env\n", + "\n", + "\n", + "env = make_env()\n", + "\n", + "obs_shape = env.observation_space.shape\n", + "n_actions = env.action_space.n\n", + "\n", + "print(\"Observation shape:\", obs_shape)\n", + "print(\"Num actions:\", n_actions)\n", + "print(\"Action names:\", env.unwrapped.get_action_meanings())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jheuristic/anaconda3/lib/python3.6/site-packages/scipy/misc/pilutil.py:482: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.\n", + " if issubdtype(ts, int):\n", + "/home/jheuristic/anaconda3/lib/python3.6/site-packages/scipy/misc/pilutil.py:485: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", + " elif issubdtype(type(size), float):\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAANEAAAEICAYAAADBfBG8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAFmVJREFUeJzt3XvUHHV9x/H3hyBoASHcEgi3wAGO4CVGxFTKRbyFVAXaqsFWUWkJlVA80FMIKFLUAirQKBUImnIRQSqi1BNQCnhpEeRiCJcIJIAQckMIBAVpE7/9Y2Zhstl9nnl2dp+Z2f28ztmzszOzu99J5ru/3/xmnu8oIjCzzm1QdgBmdeckMivISWRWkJPIrCAnkVlBTiKzgpxEfUjSTpJ+J2lM2bEMAidRAZKmS7pd0u8lrUynPyVJZcYVEY9HxKYRsbbMOAaFk6hDkk4EZgNfBsYD44BjgP2AjUoMzUZbRPgxwgewOfB74C+HWe/PgV8Bq4EngNMzy3YBAvhEumwVSRK+FVgAPAuc3/R5nwQWpuv+CNi5zfc2PnvD9PVPgC8AtwK/A/4T2Aq4Io3tDmCXzPtnpzGtBu4C9s8sew1waRrDQuCfgCWZ5dsD1wBPAY8C/1D2/1fP94eyA6jjA5gKrGnspEOsdxDwBpIW/43ACuCwdFljR78QeDXwHuAPwPeBbYEJwErgwHT9w4BFwOuADYHPALe2+d5WSbQI2C39AXgAeAh4V/pZlwH/nnn/36RJtiFwIrAceHW67Czgp8BYYIc04ZekyzZIk+40ktZ4V+AR4L1l/5/1dH8oO4A6PtKdbHnTvFvT1uNF4IA27/tX4Lx0urGjT8gsfxr4cOb1NcCn0+nrgaMyyzYAXqBFa9QmiU7NLD8HuD7z+v3A/CG2dxXwpnR6naQA/jaTRG8DHm9676xsgvbjw8dEnXka2FrSho0ZEfH2iNgiXbYBgKS3SbpF0lOSniPprm3d9FkrMtMvtni9aTq9MzBb0rOSngWeAUTSYuWR93uQdKKkhZKeS79r80zc25N09Rqy0zsD2zdiTN97CsnxYt9yEnXmF8BLwKHDrPdt4Dpgx4jYnKTr1unI3RPAjIjYIvN4TUTc2uHntSRpf+Ak4EPA2PSH4TleiXsZSTeuYcemGB9tinGziJjWzRirxknUgYh4Fvhn4OuS/krSppI2kDQJ2CSz6mbAMxHxB0n7Ah8p8LUXArMk7Q0gaXNJHyzwee1sRnK89xSwoaTTgNdmll+dxjFW0gRgZmbZL4HVkk6S9BpJYyS9XtJbexBnZTiJOhQRXwJOIBmdWknSPbqI5Fe80Tp8CjhD0vMkB9tXF/i+a4GzgaskrQbuAw7peAPa+xHJ8ddDwG9IBjuyXbYzgCUkI2//BXyXpFUmkvNS7wcmpct/C3yDpDvYt5Qe/Jl1RNLfA9Mj4sCyYymLWyIbEUnbSdov7b7uSTIEfm3ZcZVpw+FXMVvHRiTd1okkQ/pXAV8vNaKS9aw7J2kqyZnvMcA3IuKsnnyRWcl6kkTp1cMPAe8mOQi9AzgiIh7o+peZlaxX3bl9gUUR8QiApKtIzqm0TCJJHt2wKvptRGwz3Eq9GliYwLrDoktoOrMu6WhJd0q6s0cxmBX1mzwr9aolanVWfp3WJiLmAHPALZHVW69aoiWseznIDsDSHn2XWal6lUR3ALtLmihpI2A6yTVkZn2nJ925iFgjaSbJJSRjgLkRcX8vvsusbJW47MfHRFZRd0XEPsOt5Mt+zAqqxWU/xx9/fNkh2ACaPXt2rvXcEpkVVIuWaLTMmDEDgIsuuqjtsqzm9ZrXGelyqye3RKlWSdJq2UUXXfTyzp+dn03ATpZbfTmJUm4VrFNOohyyCTZjxowhu3btllv/chKZFeSBhZyGGyRoXset0eBwS5RDnoRw0gyuWlz2MxonW0c6PJ1nHQ9x19vs2bNzXfbjJDJrI28SuTtnVpCTyKwgj85VyNhZY9ebt+rMVSVEYiPhlqgiGgm06sxVLz+y8626nERmBXWcRJJ2TG9gtVDS/ZKOT+efLulJSfPTR1/fm8asyDHRGuDEiLhb0mbAXZJuTJedFxFfKR6eWfV1nEQRsYzkrmlExPOSFpL/1odmfaMrx0SSdgHeDNyezpopaYGkuZJaHhm7Auq6sgMJjUd2vlVX4SFuSZvyyl2uV0u6APg8ScXTz5PcqfqTze9zBdT1OWHqqVBLJOlVJAl0RUR8DyAiVkTE2oj4I3AxSXF7s75VZHROwDeBhRFxbmb+dpnVDie5t6hZ3yrSndsP+Chwr6T56bxTgCPSu2gH8BjgvxGwvlZkdO6/aX33h3mdh2NV5D/hGNrAXjt374NHrPP6DXteOaLl3fiMPN9RthkzZrSsMeFEeoUv+7EhOVmG5ySy3IYqbjnInESWm4tOtuYksiE5YYbnGgs2rEEdnctbY2FgR+csv0FJmk65O2dWkJPIrCAnkVlBA3NM1HyPoVZn4lstzz5nNc9rfNasWQ/3ahO64swzdy87hL4zUC3RcAfIeQ6gszfpyvse628DlUTDnfNoXt5q/Tzr2GAZqCRqbkVaLW+ebl6/1fvdGg22gUqiZp3c1a75Pa2Ol2yw+IoFszZG7YoFSY8BzwNrgTURsY+kLYHvALuQ/HXrhyLCVTisL3WrO/eOiJiUydqTgZsiYnfgpvS1WV/q1XmiQ4GD0ulLgZ8AJ/Xou0ZkJOeDWs1v9Z6sQ37+89HZkA5dv//+ZYfQd7qRRAH8OD2uuSitJzcurZBKRCyTtG0Xvqdrit4m0iyrG925/SJiMnAIcKykA/K8qcwKqCM9X9TpOjYYCidRRCxNn1cC15IUa1zRqD+XPq9s8b45EbFPntGPbhvplQvtXvv8kEHxCqibpHeEQNImwHtIijVeBxyZrnYk8IMi39Ntrc71DLXcbCiFzhNJ2pWk9YHk+OrbEfFFSVsBVwM7AY8DH4yIZ4b4HJ8nssoZlfNEEfEI8KYW858G3lnks83qohZXLJiVpH9qLEz+wuSyQ7ABdPdn7s61Xi2SaNsdKnWayWwdtUiiDa4e6IvNreJqkUTzd5g//EpmJalFEo3faXzZIdgAWsrSXOu5n2RWUC1aIg8sWJX5PJFZe7nOE7k7Z1aQk8isoFocE90w2Vcs2Oibene+KxbcEpkV5CQyK8hJZFZQLY6JJs3zFQtWgpy7nVsis4I6bokk7UlS5bRhV+A0YAvg74Cn0vmnRMS8jiMEPvLx04ZdZ9aJxwFw5jlfK/JVhTiGfosh327bcRJFxIPAJABJY4AnSeotfAI4LyK+0ulnd2LtSWuTiRKvEHIMgxlDt46J3gksjojfSOrSR47MmLPHJBPnlPL1jmGAY+hWEk0Hrsy8ninpY8CdwImjUcx+0H79HEN1Yig8sCBpI+ADwH+ksy4AdiPp6i2jzW9Btyugjjl7zCu/PiVxDIMZQzdaokOAuyNiBUDjGUDSxcAPW70prdk9J12v8FXcg/br5xiqE0M3kugIMl05Sds1itkDh5NURO25QeuHO4bqxFAoiST9CfBuIFtz90uSJpHcLeKxpmU9M2i/fo6hOjEUrYD6ArBV07yPFoqoQ4P26+cYqhNDLS77yWPQfv0cQ3Vi6JskGrRfP8dQnRj6JokG7dfPMVQnhr5JokH79XMM1Ymhb5Jo0H79HEN1YuibJBq0Xz/HUJ0Y+iaJBu3XzzFUJ4ZaFG9cvnzaaIVi9rLx4+e5eKPZaKhFd+6Wyb61ilWXWyKzgpxEZgU5icwKqsUx0TvunlR2CDaIxvtOeWajohYtUZ66c2bdl6/unFsis4JyJZGkuZJWSrovM29LSTdKejh9HpvOl6SvSlokaYEk31zI+lrelugSYGrTvJOBmyJid+Cm9DUk1X92Tx9Hk5TQMutbuZIoIn4GPNM0+1Dg0nT6UuCwzPzLInEbsIWk7boRrFkVFTkmGtcojZU+N66XnQA8kVlvSTpvHd0u3mhWll6MzrUqxr3eVdrdLt5oVpYiLdGKRjctfV6Zzl8C7JhZbwcg31krsxoqkkTXAUem00cCP8jM/1g6SjcFeC5TEdWs7+Tqzkm6EjgI2FrSEuBzwFnA1ZKOAh4HPpiuPg+YBiwCXiC5X5FZ38qVRBFxRJtF72yxbgDHFgnKrE58xYJZQU4is4KcRGYFOYnMCnISmRXkJDIryElkVpCTyKwgJ5FZQU4is4KcRGYFOYnMCnISmRXkJDIryElkVpCTyKwgJ5FZQcMmUZvqp1+W9Ou0wum1krZI5+8i6UVJ89PHhb0M3qwK8rREl7B+9dMbgddHxBuBh4BZmWWLI2JS+jimO2GaVdewSdSq+mlE/Dgi1qQvbyMpi2U2kLpxTPRJ4PrM64mSfiXpp5L2b/cmV0C1flGoAqqkU4E1wBXprGXAThHxtKS3AN+XtHdErG5+bzcroN58w5SXpw+eeluRj6p1DEOpenx11nFLJOlI4H3AX6dlsoiIlyLi6XT6LmAxsEc3Am0nu3OUpQoxjETd4q26jpJI0lTgJOADEfFCZv42ksak07uS3F7lkW4EmlcVdpAqxJBVtXj6zbDduTbVT2cBGwM3SgK4LR2JOwA4Q9IaYC1wTEQ035KlJxpdlDJ3mCrE0E6VY6u7YZOoTfXTb7ZZ9xrgmqJBdaKxc5TZ369CDK0cPPU2J08P1eLGx0M5eOptfO3tZ7z8+rhbBzOG4Sz41rSXpz/9Ld9Iupt82Y9ZQX2RRMfdeto6z4Maw1AarY9boe6rfXcOYI97FnAc5e4cZcVw/rmvBWDmCeudimux3lc4P72X+3DrW361b4n2uGfBOs+DFEMjgZqnh1ovz/o2MrVPoqwyE6lKMTScf+5rnSyjoLbduarsrGXG0eiSNRJluIRpXt+6oy9aoofe9MayQyg1huzxzcwTVrd83ZxAPibqntq2RNZacyvjVqf3+qIlstYtS3OrNNS61rnaJ9Ggd+WympOjMbCQTSYnUPfVPomyB/Zl7cxViGEo2WSy7qt9Etm6nCijr/YDC1X45a9CDFl77bXXeleS33zDlMpdXd4v3BKZFVTbJFo790DWzj1wnddlxVF2DMNxK9Rbte/OAex2/NiyQ6hEDA0HT71t3fND5z7gY6Ue6rQC6umSnsxUOp2WWTZL0iJJD0p6b68Cb6UKO3IVYmjmBOqtTiugApyXqXQ6D0DSXsB0YO/0PV9vFC7ptsWzV7F49ip2O34si2ev6sVX5I6j7BisXHlqLPxM0i45P+9Q4KqIeAl4VNIiYF/gFx1HmEMVduIqxGDlKDKwMDMtaD9XUqMPMwF4IrPOknTeerpVAbWx45bZjapCDFaeTpPoAmA3YBJJ1dNz0vlqsW7L6qYRMSci9omIfTqMYT1V2ImrEIMvOh1dHSVRRKyIiLUR8UfgYpIuGyQtz46ZVXcAlhYL0YrwoELvdVoBdbvMy8OBxsjddcB0SRtLmkhSAfWXxUIcWhV++asQg5Wn0wqoB0maRNJVewyYARAR90u6GniApND9sRGxtjehWyvuyo2+rlZATdf/IvDFIkHlUZVf/6rEYeWp7WU/rVRhiLkKMdjoUnpXlHKDGOb+RENd97Xf8icB+J/xLUfSR0UVYsiqak3wurn5hil35Rk9rsW1cydMbn/r19vnfRZIduS3Tfv8aIVUuRiybr4heR7q382G1/h3HE7tu3NV2GmrEEMr7/uX+WWHMBBq0Z0zK0n/dOd+eMqkskOwAZS3Ja99d86sbE4is4KcRGYFeWDBrD0PLJgV4YEFs1FSi+7c8uXThlps1hPjx8/rn+7cLZN95t2qy905s4KcRGYFOYnMCuq0Aup3MtVPH5M0P52/i6QXM8su7GXwZlWQZ2DhEuB84LLGjIj4cGNa0jnAc5n1F0dEV0/svONunyeyEozPV6iqUAVUSQI+BBw8gtBGbPz4eb38eLNCig5x7w+siIiHM/MmSvoVsBr4TET8vNUbJR0NHJ3nS67cfvuCYZqN3BFLu9QSDfc9wJWZ18uAnSLiaUlvAb4vae+IWK+CYETMAeaAr52zeus4iSRtCPwF8JbGvLSQ/Uvp9F2SFgN7AIXqbeeVPXZqnKBtNc8xlB/DaMTR7vu6/W9RZIj7XcCvI2JJY4akbRq3UpG0K0kF1EeKhTgyrf5RRvuKB8dQrRh6HUeeIe4rSW6NsqekJZKOShdNZ92uHMABwAJJ9wDfBY6JiGe6Fq1ZBXVaAZWI+HiLedcA1xQPy6w+fMWCWUF9mUTZ/m5ZV4A7hurE0Os4avGnECNRhasbHMNgxVCLP8rzyVYrwxFLl+b6o7xaJJFZSfrnL1uT619H5vI//WcAPvqLz3U7GMdQwxg6i2NmrrX6cmDBbDQ5icwKchKZFVSLY6Lx229Vynu7xTFUJwbIH8fyfH8J4ZbIrKhatETbjB/ZHbrPPfuznHDS5QBcfulnOeGk0b+TnWOoTgydxjGwLdEVl5zFuHGbvPx63LhNuOKSsxzDAMfQ6zjq0RJtu8WI39P8j9TJZxTlGKoTQy/jqMUVCyO9lfy3Lzljndcf+fhpIw+qIMdQnRg6jePmG6b0z2U/I00is27Im0R9d0xkNtry/Hn4jpJukbRQ0v2Sjk/nbynpRkkPp89j0/mS9FVJiyQtkDS51xthVqY8LdEa4MSIeB0wBThW0l7AycBNEbE7cFP6GuAQkgIlu5PUlbug61GbVciwSRQRyyLi7nT6eWAhMAE4FLg0Xe1S4LB0+lDgskjcBmwhabuuR25WESMa4k7LCb8ZuB0YFxHLIEk0Sdumq00Ansi8bUk6b1nTZ+WugHrzDVNGEqbZqMqdRJI2Jank8+mIWJ2U4W69aot5642+uQKq9Ytco3OSXkWSQFdExPfS2Ssa3bT0eWU6fwmwY+btOwA5L6Awq588o3MCvgksjIhzM4uuA45Mp48EfpCZ/7F0lG4K8Fyj22fWlyJiyAfwZyTdsQXA/PQxDdiKZFTu4fR5y3R9Af8GLAbuBfbJ8R3hhx8VfNw53L4bEfW4YsGsJL5iwWw0OInMCnISmRXkJDIrqCp/lPdb4Pfpc7/Ymv7Znn7aFsi/PTvn+bBKjM4BSLozz0hIXfTT9vTTtkD3t8fdObOCnERmBVUpieaUHUCX9dP29NO2QJe3pzLHRGZ1VaWWyKyWnERmBZWeRJKmSnowLWxy8vDvqB5Jj0m6V9J8SXem81oWcqkiSXMlrZR0X2ZebQvRtNme0yU9mf4fzZc0LbNsVro9D0p674i/MM+l3r16AGNI/mRiV2Aj4B5grzJj6nA7HgO2bpr3JeDkdPpk4Oyy4xwi/gOAycB9w8VP8mcw15P8ycsU4Pay48+5PacD/9hi3b3S/W5jYGK6P44ZyfeV3RLtCyyKiEci4n+Bq0gKnfSDdoVcKicifgY80zS7toVo2mxPO4cCV0XESxHxKLCIZL/MrewkalfUpG4C+LGku9ICLNBUyAXYtu27q6ld/HX+P5uZdkHnZrrXhben7CTKVdSkBvaLiMkkNfeOlXRA2QH1UF3/zy4AdgMmkVSeOiedX3h7yk6ivihqEhFL0+eVwLUk3YF2hVzqoq8K0UTEiohYGxF/BC7mlS5b4e0pO4nuAHaXNFHSRsB0kkIntSFpE0mbNaaB9wD30b6QS130VSGapuO2w0n+jyDZnumSNpY0kaRy7y9H9OEVGEmZBjxEMipyatnxdBD/riSjO/cA9ze2gTaFXKr4AK4k6eL8H8kv81Ht4qeDQjQV2Z7L03gXpImzXWb9U9PteRA4ZKTf58t+zAoquztnVntOIrOCnERmBTmJzApyEpkV5CQyK8hJZFbQ/wPTMFRqoBLrRQAAAABJRU5ErkJggg==", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAFz5JREFUeJzt3X20HHV9x/H35z5xQxIeEkIMJAUfooItpi1G6sORolhEFGzViihpy7HtsfRYH9qqfcJWrZ6K2HP06EFFUquAjzVVasmJIIVaHsSIQagBBBMTEhACuXm6T9/+MXPL3jtzc/fe3Z3dze/zOmfP3f3N7M539u53Z+a3M7+vIgIzS09PuwMws/Zw8pslyslvlignv1minPxmiXLymyXKyZ8wSSdKCkl97Y5lNiRdIOm6dsfR7Zz8TSTpBkmPSTqswmWGpGdUtbyqlX1BRcQXIuLl7YzrUODkbxJJJwIvBgJ4dVuD6SDK+HPWgfxPaZ4Lgf8BrgTW1E6QtFjSv0t6QtJtkt4v6aaa6c+WtF7So5L+V9Lra6ZdKekTkr4labekWyQ9PZ92Yz7bDyUNSfrdqUFJ6pH015IelLRT0r9IOnLKbH8gaZuk7ZLeWfPc1ZJuz+PeIemjNdNOk/TfknZJ+qGk02um3SDpA5JuBvYC75V0+5S43i5pXX7/lZJ+kC9ni6RLamadWMdd+Tr+hqTfm/L+vSB/Xx/P/75gSiz/IOnm/P27TtIxU9+nJEWEb024AfcCbwV+HRgBltZMuzq/HQ6cDGwBbsqnzc8f/z7QB/wa8AjwnHz6lcCjwOp8+heAq2teO4BnHCSuP8hjexqwAPga8Pl82on586/K4/gV4GHgZfn07wFvzu8vAE7L7x8P/AI4m2wDcmb+eEk+/QbgZ8Bz8piPBHYDK2viug14Q37/9HzZPcApwA7gvCkx9tU89/dq3r9FwGPAm/NlnZ8/XlwTy33AM4F5+eMPtfvz0gk3b/mbQNKLgBOAL0XE98k+bG/Mp/UCvwP8XUTsjYgfA2trnn4O8EBEfC4iRiPiDuCrwGtr5vlaRNwaEaNkyb9qFuFdAHw0Iu6PiCHgPcAbpnTyvS8i9kTEj4DPkSUQZF9iz5B0TEQMRcT/5O1vAq6NiGsjYjwi1gO3k30ZTLgyIu7K1+lx4BsTrytpJfBsYB1ARNwQET/KX+tOsi+jl9S5fq8ENkfE5/NlXQXcA7yqZp7PRcRPImIf8CVm9/4dspz8zbEGuC4iHskff5End/2XkG2RttTMX3v/BOD5+e7zLkm7yBL2KTXzPFRzfy/ZVrhexwEP1jx+MI9n6TTxPJg/B+Aisi3mPfnu9Dk1Mb9uSswvApZN85qQvScTXypvBP4tIvYCSHq+pOslPSzpceCPgXp3zaeu38Q6HF/zuJH375DVVT/xdCJJ84DXA72SJj5khwFHSXousAkYBZYDP8mnr6h5iS3AdyPizBaFuI0sWSf8Uh7PjjymiXjuqZm+DSAiNgPn5x12vw18RdLiPObPR8RbDrLcqZeLXgccI2kV2ZfA22umfRH4OPCKiNgv6WM8mfwzXXY6df0m1uHbMzwved7yN+48YIzsWH5VfjsJ+C/gwogYIzvOvkTS4ZKeTdY5OOGbwDMlvVlSf357nqST6lz+DrLj+elcBbxd0lMlLQA+CFyTH0JM+Js8tueQ9T1cAyDpTZKWRMQ4sCufdwz4V+BVkn5LUq+kQUmnS1rONPLlfQX4J7Lj9PU1kxcCj+aJv5r8kCn3MDB+kHW8luz9e6OkvrzT82Sy99UOwsnfuDVkx5Q/i4iHJm5kW7IL8mPri8k6vR4CPk+WkAcAImI38HLgDWRbsYeAD5PtPdTjEmBtvvv9+pLpV+TLvBH4KbAf+NMp83yXrFNwA/CRiJg4geYs4C5JQ8A/k3XQ7Y+ILcC5wHvJknML8OfM/Hn6IvAy4MtTvnzeCvy9pN3A35IdlwOQHxp8ALg5X8fTal8wIn5B1m/yTrJOx78Azqk5BLNpKO8RtQpJ+jDwlIhYM+PMZi3iLX8F8t/xT8nOd9Fqso60r7c7LkubO/yqsZBsV/84YCdwKdlPX2Zt491+s0R5t98sUQ3t9ks6i6wXuBf4TER86GDz9w/Mj8HBoxtZpJkdxP79jzEyvEf1zDvn5M9PW/0E2XndW4HbJK3LT18tNTh4NKeuvniuizSzGdx+68frnreR3f7VwL35OePDZBeunNvA65lZhRpJ/uOZfP72ViafTw2ApD/MLwu9fWRkTwOLM7NmaiT5y44rCj8dRMTlEXFqRJza3z+/gcWZWTM10uG3lckXqCwnvyBkOhraR//NmxpYpJkdjA7sq3veRrb8twEr8wtGBsjOTV/XwOuZWYXmvOWPiFFJFwP/SfZT3xURcVfTIjOzlmrod/6IuJbskkoz6zI+w88sUZVe2BNHzGP/i0+pcpFmSYn/uqHueb3lN0uUk98sUU5+s0Q5+c0S5eQ3S1Slvf3Di4It549Oaovx4iUCkkcXMoio/7Mxm3k7SbPjHr6r/ud6y2+WKCe/WaKc/GaJcvKbJaracftDxNjkDo4YKX7/lHUCtpsGxgttU9cFgLK2duor7wBST7E9hjtsW1ASI/3F/wNADPcW2zqsv6/0MzRa/Lw09PmfxXM77L9tZlVx8pslyslvlignv1miGq3Y8wCwGxgDRiPi1IPOPyL6fj657LzGGomgQmVfk2UdSh3WyTStsvUp70vrLNNtrro19ibHrZH6O/ya0dv/mxHxSBNex8wq5N1+s0Q1mvwBXCfp+5L+sGyG2oo9Y3tcscesUzS62//CiNgm6VhgvaR7IuLG2hki4nLgcoDB5Su65YjY7JDX6NDd2/K/OyV9nax4543Tzg/ElBOxVNbh0YFfEVPjhmk6Kzss9rK4gdJ9vtL/RTuV9F2NT7M+PZ0We4my2MtOYqzqMzTn3X5J8yUtnLgPvBxwLS6zLtHIln8p8HVJE6/zxYj4dlOiMrOWa6Rc1/3Ac5sYi5lVyD/1mSWq0kt6RUknWYd1kE2nGzr3ykx7BmUXdJCVvb893XJGaInS2Nv4GfKW3yxRTn6zRDn5zRLl5DdLlJPfLFGV9vYHEFO+bnx6b2tNe3pvyamzGi22tdUhdnpv6WeoG0/vNbPu5uQ3S5ST3yxRTn6zRFVbsacvGF1cR69Sh3WaAaWdTx0Z51SzKf7SaevTaPGjblifJscY/S7RbWYzcPKbJcrJb5YoJ79Zombs8JN0BXAOsDMifjlvWwRcA5wIPAC8PiIem3Fp40J7p47g2Wm9MmZdrMkluq8EzprS9m5gQ0SsBDbkj82si8yY/Pk4/I9OaT4XWJvfXwuc1+S4zKzF5nrMvzQitgPkf4+dbsZJFXuGXLHHrFO0vMMvIi6PiFMj4tTeBfNbvTgzq9Ncz/DbIWlZRGyXtAzYWdezAnpGpjY2ehqXmf2/WfSfz3XLvw5Yk99fA3xjjq9jZm0yY/JLugr4HvAsSVslXQR8CDhT0mbgzPyxmXWRGXf7I+L8aSa9tMmxmFmFfIafWaIqv6R3bHGhx8/MmqXPl/Sa2Qyc/GaJcvKbJcrJb5aoakt0D4uBrQNVLtIsKRpu7iW9ZnYIcvKbJcrJb5YoJ79Zopz8Zoly8pslyslvlignv1minPxmiapnJJ8rJO2UtKmm7RJJP5e0Mb+d3dowzazZ5lq0A+CyiFiV365tblhm1mpzLdphZl2ukWP+iyXdmR8WHN20iMysEnNN/k8CTwdWAduBS6ebcVLFnj2u2GPWKeaU/BGxIyLGImIc+DSw+iDzPlmxZ74r9ph1ijklf16lZ8JrgE3TzWtmnWnGwTzyoh2nA8dI2gr8HXC6pFVkxYEeAP6ohTGaWQvMtWjHZ1sQi5lVyGf4mSXKyW+WKCe/WaKc/GaJcvKbJcrJb5YoJ79Zopz8Zoly8pslyslvlignv1minPxmiXLymyXKyW+WKCe/WaKc/GaJcvKbJaqeij0rJF0v6W5Jd0l6W96+SNJ6SZvzvx6+26yL1LPlHwXeGREnAacBfyLpZODdwIaIWAlsyB+bWZeop2LP9oi4I7+/G7gbOB44F1ibz7YWOK9VQZpZ883qmF/SicCvArcASyNiO2RfEMCx0zzHRTvMOlDdyS9pAfBV4M8i4ol6n+eiHWadqa7kl9RPlvhfiIiv5c07Jop35H93tiZEM2uFenr7RTZO/90R8dGaSeuANfn9NcA3mh+embXKjEU7gBcCbwZ+JGlj3vZe4EPAlyRdBPwMeF1rQjSzVqinYs9NgKaZ/NLmhmNmVfEZfmaJcvKbJaqeY/6m0Tj07Z18BDE2LwrzxXQHGW3Ut68Y1PhAcb7x3uL6dIve4fre+LGB7l1He5K3/GaJcvKbJcrJb5YoJ79Zoirt8Ot/aA/L//G/J7Vte9cLCvMNH9neDqWBJ4odX8ddekuhbdebVhfbVrYkpKZTyVu8Yv1Qoa1vZ/EyjvsvPK7Q1s0dnanylt8sUU5+s0Q5+c0S5eQ3S5ST3yxR1Z7ee9hh9J749Mlt41VGUJ+ekWJb39Ilhbbx3gqCaRGNF3/RGD7qsEJb7xNl5zCXvGAXvxep8pbfLFFOfrNEOfnNEtVIxZ5LJP1c0sb8dnbrwzWzZqmnw2+iYs8dkhYC35e0Pp92WUR8pN6F7V/Sx0/eMnl4/94DJaeFtvlM0QOLigFsfttTC20aK3lyGzswy8ZGmHZwhP3FeR84r2TeviMKTQMPdd7/zGavnjH8tgMTxTl2S5qo2GNmXayRij0AF0u6U9IV0xXqdMUes87USMWeTwJPB1aR7RlcWvY8V+wx60xzrtgTETsiYiwixoFPA8XrW82sY814zD9dxR5JyyYKdQKvATbN9Fo9IzBv5+ROpZGFnTeAp8aKAQw+XJxvZGGxrZ3XtZ+y+r5C29ED+0rn/e59xYEHLlv95ULbkt7i9fwXrntroa1vdweOumoH1UjFnvMlrSLr530A+KOWRGhmLdFIxZ5rmx+OmVXFZ/iZJcrJb5aoSi/pBegZnfy47NLSaPNgkBqts63sDL82Xtq6cePTCm1/dWZ55fS3/MZ3C233DC8rtG3at7zQ1lNnZR/rbN7ymyXKyW+WKCe/WaKc/GaJqrbDTzDeP7mp3Wfz1Wtq3NB5sfc/Xvwu/8Tml5TOe8ep1xTaHhrbW2j74I/PKrT1HJhDcNZxvOU3S5ST3yxRTn6zRDn5zRLl5DdLVKW9/dELw0dOOXW3Ayv2jA0WTy8em1cyYxcMWrnnzkWl7c/ce2GhbXS4+HHof7BYxccODd7ymyXKyW+WKCe/WaLqqdgzKOlWST/MK/a8L29/qqRbJG2WdI2kknKuZtap6unwOwCcERFD+Si+N0n6D+AdZBV7rpb0KeAisuG8p6XBMfpPmjwg5MjdxYow7e4EHDmi2JO38ITHC2177j2y0Na7r7PO+e0ZLY/nWct2FNru2XFsoS1wh9+hasYtf2SG8of9+S2AM4Cv5O1rgfNaEqGZtUS94/b35iP37gTWA/cBuyJiYnybrUxTwqu2Ys/oE8ULR8ysPepK/rw4xypgOVlxjpPKZpvmuf9fsafviMPnHqmZNdWsevsjYhdwA3AacJSkiT6D5cC25oZmZq1UT8WeJcBIROySNA94GfBh4HrgtcDVwBqgfKTIGhFibKzzf10s67QbGy+Ju9Mu6C8x3l9+GuJLj7mn0HbvL44ptI00PSLrFPX09i8D1krqJdtT+FJEfFPSj4GrJb0f+AFZSS8z6xL1VOy5k6ws99T2+3FxTrOu1fn74GbWEk5+s0RVeknvQN8oJyx+dFLb/b0LCvOpAy/zXXrE7kLbTw8rxt67v7O+T0fnl3f4/crglkLbnkeKP8X6nO1DV2d9Us2sMk5+s0Q5+c0S5eQ3S1SlHX6j4z3sHJrSSdaBnXtjC4tB9ajYcdYz0gVn+C0oqS0O3HWgeB2W9rWxvrhVzlt+s0Q5+c0S5eQ3S5ST3yxR1RbtGOpj+ObFk9oGy/uj2mrgiWLH17atKwpthw9XEU1j5u0sP0fvW5eeVmg74szitqCsNLl1Lo3VP6+3/GaJcvKbJcrJb5YoJ79Zohqp2HOlpJ9K2pjfVrU+XDNrlkYq9gD8eUR85SDPnaRnBOZv74K61ocQjZe/30PPOrrQNvho8bTm6On8U5jtST2z+PWsnjH8Aiir2GNmXWxOFXsi4pZ80gck3SnpMkmlRd0mVezZv6dJYZtZo+ZUsUfSLwPvAZ4NPA9YBPzlNM99smLP4PwmhW1mjZprxZ6zImJ7XsTzAPA5PIy3WVeZc8UeScsiYrskkVXo3TTTa4VgzCNCVqy8w2500NfuH4pmU0SqkYo938m/GARsBP54DrGaWZs0UrHnjJZEZGaV8Bl+Zoly8pslqtLr+cf7Ye9SnzFm1iqzGX/BW36zRDn5zRLl5DdLlJPfLFGVdvj1jMDhO3xBoFmr9IzMYt7WhWFmnczJb5YoJ79Zopz8ZomqtMNPMbuKImY2OyWV5KflLb9Zopz8Zoly8pslyslvlqi6kz8fvvsHkr6ZP36qpFskbZZ0jSSPzmfWRWbT2/824G7giPzxh4HLIuJqSZ8CLgI+ebAX6BkOFm6ZUtS+pCLMeG/5Nf8Du4YLbRorVplphfGB4oCXIwuLF0+X9bb2DRXPuewZruZnj+gt/34fPqr4Xd0zVgxeo8W2vqHi/6Eqw0eXlodAJbFHX3HdB3YdKD55mqpGzTY2v/h5GRsoxjhdlaR6Pv+9++vPh3qLdiwHXgl8Jn8s4AxgolTXWrIRfM2sS9S72/8x4C+Aia+VxcCuiJioDLYVOL7sibUVe0ZGXLHHrFPUU6X3HGBnRHy/trlk1tJ9p9qKPf39rthj1inqOeZ/IfBqSWcDg2TH/B8DjpLUl2/9lwPbWhemmTVbPeP2v4esLh+STgfeFREXSPoy8FrgamAN8I2ZXmt0vtixenKHzciC4g7D+EB5B8xxNxU7e/qfqKbjbOj4YgfZL1aVxFnStPjO4tu8YGs1nWYjC8sr82x7cXGnr/dAcYdu4PFi27G3Nx7XXP38JeUjVPbtK8ZZ9tlasaG+Ts1WePTk4ud3aEVJR2VveTxP+d5goe2wxybX5J5NSfVGfuf/S+Adku4l6wP4bAOvZWYVm9WFPRFxA1mhTiLiflyc06xr+Qw/s0Q5+c0SpYjqBtSU9DDwYP7wGOCRyhbeWofSuoDXp9MdbH1OiIgl9bxIpck/acHS7RFxalsW3mSH0rqA16fTNWt9vNtvlignv1mi2pn8l7dx2c12KK0LeH06XVPWp23H/GbWXt7tN0uUk98sUZUnv6SzJP2vpHslvbvq5TdK0hWSdkraVNO2SNL6fEiz9ZKObmeMsyFphaTrJd0t6S5Jb8vbu26dJA1KulXSD/N1eV/e3tVDzrVqCL1Kk19SL/AJ4BXAycD5kk6uMoYmuBI4a0rbu4ENEbES2JA/7hajwDsj4iTgNOBP8v9JN67TAeCMiHgusAo4S9JpPDnk3ErgMbIh57rJxBB6E5qyPlVv+VcD90bE/RExTHY58LkVx9CQiLgReHRK87lkQ5lBlw1pFhHbI+KO/P5usg/Z8XThOkVmKH/Yn9+CLh5yrpVD6FWd/McDW2oeTzv8V5dZGhHbIUsm4Ng2xzMnkk4EfhW4hS5dp3wXeSOwE1gP3EedQ851qDkPoTeTqpO/7uG/rFqSFgBfBf4sIp5odzxzFRFjEbGKbHSp1cBJZbNVG9XcNDqE3kwqLdRJ9i21oubxoTL81w5JyyJiu6RlZFudriGpnyzxvxARX8ubu3qdImKXpBvI+jG6dci5lg6hV/WW/zZgZd5bOQC8AVhXcQytsI5sKDOoc0izTpEfQ34WuDsiPlozqevWSdISSUfl9+cBLyPrw7iebMg56JJ1gWwIvYhYHhEnkuXKdyLiApq1PhFR6Q04G/gJ2bHYX1W9/CbEfxWwHRgh25O5iOw4bAOwOf+7qN1xzmJ9XkS223gnsDG/nd2N6wScAvwgX5dNwN/m7U8DbgXuBb4MHNbuWOewbqcD32zm+vj0XrNE+Qw/s0Q5+c0S5eQ3S5ST3yxRTn6zRDn5zRLl5DdL1P8B8FPBd33wU/8AAAAASUVORK5CYII=", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "s, _ = env.reset()\n", + "for _ in range(100):\n", + " s, _, _, _, _ = env.step(env.action_space.sample())\n", + "\n", + "plt.title('Game image')\n", + "plt.imshow(env.render())\n", + "plt.show()\n", + "\n", + "plt.title('Agent observation')\n", + "plt.imshow(s.reshape([42, 42]))\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### POMDP setting\n", + "\n", + "The Atari game we're working with is actually a POMDP: your agent needs to know timing at which enemies spawn and move, but cannot do so unless it has some memory. \n", + "\n", + "Let's design another agent that has a recurrent neural net memory to solve this. Here's a sketch.\n", + "\n", + "![img](img1.jpg)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class SimpleRecurrentAgent(nn.Module):\n", + " def __init__(self, obs_shape, n_actions, reuse=False):\n", + " \"\"\"A simple actor-critic agent\"\"\"\n", + " super(self.__class__, self).__init__()\n", + "\n", + " self.conv0 = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2))\n", + " self.conv1 = nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2))\n", + " self.conv2 = nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2))\n", + " self.flatten = nn.Flatten()\n", + "\n", + " self.hid = nn.Linear(512, 128)\n", + " self.rnn = nn.LSTMCell(128, 128)\n", + "\n", + " self.logits = nn.Linear(128, n_actions)\n", + " self.state_value = nn.Linear(128, 1)\n", + "\n", + " def forward(self, prev_state, obs_t):\n", + " \"\"\"\n", + " Takes agent's previous hidden state and a new observation,\n", + " returns a new hidden state and whatever the agent needs to learn\n", + " \"\"\"\n", + "\n", + " # Apply the whole neural net for one step here.\n", + " # See docs on self.rnn(...).\n", + " # The recurrent cell should take the last feedforward dense layer as input.\n", + " \n", + "\n", + " new_state = \n", + " logits = \n", + " state_value = \n", + "\n", + " return new_state, (logits, state_value)\n", + "\n", + " def get_initial_state(self, batch_size):\n", + " \"\"\"Return a list of agent memory states at game start. Each state is a np array of shape [batch_size, ...]\"\"\"\n", + " return torch.zeros((batch_size, 128)), torch.zeros((batch_size, 128))\n", + "\n", + " def sample_actions(self, agent_outputs):\n", + " \"\"\"pick actions given numeric agent outputs (np arrays)\"\"\"\n", + " logits, state_values = agent_outputs\n", + " probs = F.softmax(logits, dim=-1)\n", + " return torch.multinomial(probs, 1)[:, 0].data.numpy()\n", + "\n", + " def step(self, prev_state, obs_t):\n", + " \"\"\" like forward, but obs_t is a numpy array \"\"\"\n", + " obs_t = torch.tensor(np.asarray(obs_t), dtype=torch.float32)\n", + " (h, c), (l, s) = self.forward(prev_state, obs_t)\n", + " return (h.detach(), c.detach()), (l.detach(), s.detach())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n_parallel_games = 5\n", + "gamma = 0.99\n", + "\n", + "agent = SimpleRecurrentAgent(obs_shape, n_actions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "state = [env.reset()[0]]\n", + "_, (logits, value) = agent.step(agent.get_initial_state(1), state)\n", + "print(\"action logits:\\n\", logits)\n", + "print(\"state values:\\n\", value)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Let's play!\n", + "Let's build a function that measures agent's average reward." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate(agent, env, n_games=1):\n", + " \"\"\"Plays an entire game start to end, returns session rewards.\"\"\"\n", + "\n", + " game_rewards = []\n", + " for _ in range(n_games):\n", + " # initial observation and memory\n", + " observation, _ = env.reset()\n", + " prev_memories = agent.get_initial_state(1)\n", + "\n", + " total_reward = 0\n", + " while True:\n", + " new_memories, readouts = agent.step(\n", + " prev_memories, observation[None, ...])\n", + " action = agent.sample_actions(readouts)\n", + "\n", + " observation, reward, terminated, truncated, info = env.step(action[0])\n", + "\n", + " total_reward += reward\n", + " prev_memories = new_memories\n", + " if terminated or truncated:\n", + " break\n", + "\n", + " game_rewards.append(total_reward)\n", + " return game_rewards" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from gymnasium.wrappers import RecordVideo\n", + "\n", + "with make_env() as record_env, RecordVideo(record_env, video_folder=\"videos\") as env_monitor:\n", + " rewards = evaluate(agent, env_monitor, n_games=3)\n", + "\n", + "print(rewards)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show video. This may not work in some setups. If it doesn't\n", + "# work for you, you can download the videos and view them locally.\n", + "\n", + "from pathlib import Path\n", + "from base64 import b64encode\n", + "from IPython.display import HTML\n", + "\n", + "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n", + "video_path = video_paths[-1] # You can also try other indices\n", + "\n", + "if 'google.colab' in sys.modules:\n", + " # https://stackoverflow.com/a/57378660/1214547\n", + " with video_path.open('rb') as fp:\n", + " mp4 = fp.read()\n", + " data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", + "else:\n", + " data_url = str(video_path)\n", + "\n", + "HTML(\"\"\"\n", + "\n", + "\"\"\".format(data_url))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training on parallel games\n", + "\n", + "We introduce a class called EnvPool - it's a tool that handles multiple environments for you. Here's how it works:\n", + "![img](img2.jpg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from env_pool import EnvPool\n", + "pool = EnvPool(agent, make_env, n_parallel_games)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We gonna train our agent on a thing called __rollouts:__\n", + "![img](img3.jpg)\n", + "\n", + "A rollout is just a sequence of T observations, actions and rewards that agent took consequently.\n", + "* First __s0__ is not necessarily initial state for the environment\n", + "* Final state is not necessarily terminal\n", + "* We sample several parallel rollouts for efficiency" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# for each of n_parallel_games, take 10 steps\n", + "rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Actions shape:\", rollout_actions.shape)\n", + "print(\"Rewards shape:\", rollout_rewards.shape)\n", + "print(\"Mask shape:\", rollout_mask.shape)\n", + "print(\"Observations shape: \", rollout_obs.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Actor-critic objective\n", + "\n", + "Here we define a loss function that uses rollout above to train advantage actor-critic agent.\n", + "\n", + "\n", + "Our loss consists of three components:\n", + "\n", + "* __The policy \"loss\"__\n", + " $$ \\hat J = {1 \\over T} \\cdot \\sum_t { \\log \\pi(a_t | s_t) } \\cdot A_{const}(s,a) $$\n", + " * This function has no meaning in and of itself, but it was built such that\n", + " * $ \\nabla \\hat J = {1 \\over N} \\cdot \\sum_t { \\nabla \\log \\pi(a_t | s_t) } \\cdot A(s,a) \\approx \\nabla E_{s, a \\sim \\pi} R(s,a) $\n", + " * Therefore if we __maximize__ J_hat with gradient descent we will maximize expected reward\n", + " \n", + " \n", + "* __The value \"loss\"__\n", + " $$ L_{td} = {1 \\over T} \\cdot \\sum_t { [r + \\gamma \\cdot V_{const}(s_{t+1}) - V(s_t)] ^ 2 }$$\n", + " * Ye Olde TD_loss from q-learning and alike\n", + " * If we minimize this loss, V(s) will converge to $V_\\pi(s) = E_{a \\sim \\pi(a | s)} R(s,a) $\n", + "\n", + "\n", + "* __Entropy Regularizer__\n", + " $$ H = - {1 \\over T} \\sum_t \\sum_a {\\pi(a|s_t) \\cdot \\log \\pi (a|s_t)}$$\n", + " * If we __maximize__ entropy we discourage agent from predicting zero probability to actions\n", + " prematurely (a.k.a. exploration)\n", + " \n", + " \n", + "So we optimize a linear combination of $L_{td}$ $- \\hat J$, $-H$\n", + " \n", + "```\n", + "\n", + "```\n", + "\n", + "```\n", + "\n", + "```\n", + "\n", + "```\n", + "\n", + "```\n", + "\n", + "\n", + "__One more thing:__ since we train on T-step rollouts, we can use N-step formula for advantage for free:\n", + " * At the last step, $A(s_t,a_t) = r(s_t, a_t) + \\gamma \\cdot V(s_{t+1}) - V(s) $\n", + " * One step earlier, $A(s_t,a_t) = r(s_t, a_t) + \\gamma \\cdot r(s_{t+1}, a_{t+1}) + \\gamma ^ 2 \\cdot V(s_{t+2}) - V(s) $\n", + " * Et cetera, et cetera. This way agent starts training much faster since it's estimate of A(s,a) depends less on his (imperfect) value function and more on actual rewards. There's also a [nice generalization](https://arxiv.org/abs/1506.02438) of this.\n", + "\n", + "\n", + "__Note:__ it's also a good idea to scale rollout_len up to learn longer sequences. You may wish set it to >=20 or to start at 10 and then scale up as time passes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "opt = torch.optim.Adam(agent.parameters(), lr=1e-5)\n", + "\n", + "\n", + "def train_on_rollout(states, actions, rewards, is_not_done, prev_memory_states, gamma=0.99):\n", + " \"\"\"\n", + " Takes a sequence of states, actions and rewards produced by generate_session.\n", + " Updates agent's weights by following the policy gradient above.\n", + " Please use Adam optimizer with default parameters.\n", + " \"\"\"\n", + "\n", + " # shape: [batch_size, time, c, h, w]\n", + " states = torch.tensor(np.asarray(states), dtype=torch.float32)\n", + " actions = torch.tensor(np.array(actions), dtype=torch.int64) # shape: [batch_size, time]\n", + " rewards = torch.tensor(np.array(rewards), dtype=torch.float32) # shape: [batch_size, time]\n", + " is_not_done = torch.tensor(np.array(is_not_done), dtype=torch.float32) # shape: [batch_size, time]\n", + " rollout_length = rewards.shape[1] - 1\n", + "\n", + " # predict logits, probas and log-probas using an agent.\n", + " memory = [m.detach() for m in prev_memory_states]\n", + "\n", + " logits = [] # append logit sequence here\n", + " state_values = [] # append state values here\n", + " for t in range(rewards.shape[1]):\n", + " obs_t = states[:, t]\n", + "\n", + " # use agent to comute logits_t and state values_t.\n", + " # append them to logits and state_values array\n", + "\n", + " memory, (logits_t, values_t) = \n", + "\n", + " logits.append(logits_t)\n", + " state_values.append(values_t)\n", + "\n", + " logits = torch.stack(logits, dim=1)\n", + " state_values = torch.stack(state_values, dim=1)\n", + " probas = F.softmax(logits, dim=2)\n", + " logprobas = F.log_softmax(logits, dim=2)\n", + "\n", + " # select log-probabilities for chosen actions, log pi(a_i|s_i)\n", + " actions_one_hot = F.one_hot(actions, n_actions).view(\n", + " actions.shape[0], actions.shape[1], n_actions)\n", + " logprobas_for_actions = torch.sum(logprobas * actions_one_hot, dim=-1)\n", + "\n", + " # Now let's compute two loss components:\n", + " # 1) Policy gradient objective.\n", + " # Notes: Please don't forget to call .detach() on advantage term. Also please use mean, not sum.\n", + " # it's okay to use loops if you want\n", + " J_hat = 0 # policy objective as in the formula for J_hat\n", + "\n", + " # 2) Temporal difference MSE for state values\n", + " # Notes: Please don't forget to call .detach() on V(s') term. Also please use mean, not sum.\n", + " # it's okay to use loops if you want\n", + " value_loss = 0\n", + "\n", + " cumulative_returns = state_values[:, -1].detach()\n", + "\n", + " for t in reversed(range(rollout_length)):\n", + " r_t = rewards[:, t] # current rewards\n", + " # current state values\n", + " V_t = state_values[:, t]\n", + " V_next = state_values[:, t + 1].detach() # next state values\n", + " # log-probability of a_t in s_t\n", + " logpi_a_s_t = logprobas_for_actions[:, t]\n", + "\n", + " # update G_t = r_t + gamma * G_{t+1} as we did in week6 reinforce\n", + " cumulative_returns = r_t + gamma * cumulative_returns\n", + "\n", + " # Compute temporal difference error (MSE for V(s))\n", + " value_loss += \n", + "\n", + " # compute advantage A(s_t, a_t) using cumulative returns and V(s_t) as baseline\n", + " advantage = \n", + " advantage = advantage.detach()\n", + "\n", + " # compute policy pseudo-loss aka -J_hat.\n", + " J_hat += \n", + "\n", + " # regularize with entropy\n", + " entropy_reg = \n", + "\n", + " # add-up three loss components and average over time\n", + " loss = -J_hat / rollout_length +\\\n", + " value_loss / rollout_length +\\\n", + " -0.01 * entropy_reg\n", + "\n", + " # Gradient descent step\n", + " \n", + "\n", + " return loss.data.numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# let's test it\n", + "memory = list(pool.prev_memory_states)\n", + "rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(10)\n", + "\n", + "train_on_rollout(rollout_obs, rollout_actions,\n", + " rollout_rewards, rollout_mask, memory)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train \n", + "\n", + "just run train step and see if agent learns any better" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "from tqdm import trange\n", + "from pandas import DataFrame\n", + "moving_average = lambda x, **kw: DataFrame(\n", + " {'x': np.asarray(x)}).x.ewm(**kw).mean().values\n", + "\n", + "rewards_history = []" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for i in trange(15000):\n", + "\n", + " memory = list(pool.prev_memory_states)\n", + " rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(\n", + " 10)\n", + " train_on_rollout(rollout_obs, rollout_actions,\n", + " rollout_rewards, rollout_mask, memory)\n", + "\n", + " if i % 100 == 0:\n", + " rewards_history.append(np.mean(evaluate(agent, env, n_games=1)))\n", + " clear_output(True)\n", + " plt.plot(rewards_history, label='rewards')\n", + " plt.plot(moving_average(np.array(rewards_history),\n", + " span=10), label='rewards ewma@10')\n", + " plt.legend()\n", + " plt.show()\n", + " if rewards_history[-1] >= 10000:\n", + " print(\"Your agent has just passed the minimum homework threshold\")\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Relax and grab some refreshments while your agent is locked in an infinite loop of violence and death.\n", + "\n", + "__How to interpret plots:__\n", + "\n", + "The session reward is the easy thing: it should in general go up over time, but it's okay if it fluctuates ~~like crazy~~. It's also OK if it reward doesn't increase substantially before some 10k initial steps. However, if reward reaches zero and doesn't seem to get up over 2-3 evaluations, there's something wrong happening.\n", + "\n", + "\n", + "Since we use a policy-based method, we also keep track of __policy entropy__ - the same one you used as a regularizer. The only important thing about it is that your entropy shouldn't drop too low (`< 0.1`) before your agent gets the yellow belt. Or at least it can drop there, but _it shouldn't stay there for long_.\n", + "\n", + "If it does, the culprit is likely:\n", + "* Some bug in entropy computation. Remember that it is $ - \\sum p(a_i) \\cdot log p(a_i) $\n", + "* Your agent architecture converges too fast. Increase entropy coefficient in actor loss. \n", + "* Gradient explosion - just [clip gradients](https://stackoverflow.com/a/56069467) and maybe use a smaller network\n", + "* Us. Or PyTorch developers. Or aliens. Or lizardfolk. Contact us on forums before it's too late!\n", + "\n", + "If you're debugging, just run `logits, values = agent.step(batch_states)` and manually look into logits and values. This will reveal the problem 9 times out of 10: you'll likely see some NaNs or insanely large numbers or zeros. Try to catch the moment when this happens for the first time and investigate from there." + ] + }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/jheuristic/anaconda3/lib/python3.6/site-packages/scipy/misc/pilutil.py:482: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.\n", - " if issubdtype(ts, int):\n", - "/home/jheuristic/anaconda3/lib/python3.6/site-packages/scipy/misc/pilutil.py:485: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", - " elif issubdtype(type(size), float):\n" - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### \"Final\" evaluation" + ] }, { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAANEAAAEICAYAAADBfBG8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFmVJREFUeJzt3XvUHHV9x/H3hyBoASHcEgi3wAGO4CVGxFTKRbyFVAXa\nqsFWUWkJlVA80FMIKFLUAirQKBUImnIRQSqi1BNQCnhpEeRiCJcIJIAQckMIBAVpE7/9Y2Zhstl9\nnnl2dp+Z2f28ztmzszOzu99J5ru/3/xmnu8oIjCzzm1QdgBmdeckMivISWRWkJPIrCAnkVlBTiKz\ngpxEfUjSTpJ+J2lM2bEMAidRAZKmS7pd0u8lrUynPyVJZcYVEY9HxKYRsbbMOAaFk6hDkk4EZgNf\nBsYD44BjgP2AjUoMzUZbRPgxwgewOfB74C+HWe/PgV8Bq4EngNMzy3YBAvhEumwVSRK+FVgAPAuc\n3/R5nwQWpuv+CNi5zfc2PnvD9PVPgC8AtwK/A/4T2Aq4Io3tDmCXzPtnpzGtBu4C9s8sew1waRrD\nQuCfgCWZ5dsD1wBPAY8C/1D2/1fP94eyA6jjA5gKrGnspEOsdxDwBpIW/43ACuCwdFljR78QeDXw\nHuAPwPeBbYEJwErgwHT9w4BFwOuADYHPALe2+d5WSbQI2C39AXgAeAh4V/pZlwH/nnn/36RJtiFw\nIrAceHW67Czgp8BYYIc04ZekyzZIk+40ktZ4V+AR4L1l/5/1dH8oO4A6PtKdbHnTvFvT1uNF4IA2\n7/tX4Lx0urGjT8gsfxr4cOb1NcCn0+nrgaMyyzYAXqBFa9QmiU7NLD8HuD7z+v3A/CG2dxXwpnR6\nnaQA/jaTRG8DHm9676xsgvbjw8dEnXka2FrSho0ZEfH2iNgiXbYBgKS3SbpF0lOSniPprm3d9Fkr\nMtMvtni9aTq9MzBb0rOSngWeAUTSYuWR93uQdKKkhZKeS79r80zc25N09Rqy0zsD2zdiTN97Csnx\nYt9yEnXmF8BLwKHDrPdt4Dpgx4jYnKTr1unI3RPAjIjYIvN4TUTc2uHntSRpf+Ak4EPA2PSH4Tle\niXsZSTeuYcemGB9tinGziJjWzRirxknUgYh4Fvhn4OuS/krSppI2kDQJ2CSz6mbAMxHxB0n7Ah8p\n8LUXArMk7Q0gaXNJHyzwee1sRnK89xSwoaTTgNdmll+dxjFW0gRgZmbZL4HVkk6S9BpJYyS9XtJb\nexBnZTiJOhQRXwJOIBmdWknSPbqI5Fe80Tp8CjhD0vMkB9tXF/i+a4GzgaskrQbuAw7peAPa+xHJ\n8ddDwG9IBjuyXbYzgCUkI2//BXyXpFUmkvNS7wcmpct/C3yDpDvYt5Qe/Jl1RNLfA9Mj4sCyYymL\nWyIbEUnbSdov7b7uSTIEfm3ZcZVpw+FXMVvHRiTd1okkQ/pXAV8vNaKS9aw7J2kqyZnvMcA3IuKs\nnnyRWcl6kkTp1cMPAe8mOQi9AzgiIh7o+peZlaxX3bl9gUUR8QiApKtIzqm0TCJJHt2wKvptRGwz\n3Eq9GliYwLrDoktoOrMu6WhJd0q6s0cxmBX1mzwr9aolanVWfp3WJiLmAHPALZHVW69aoiWseznI\nDsDSHn2XWal6lUR3ALtLmihpI2A6yTVkZn2nJ925iFgjaSbJJSRjgLkRcX8vvsusbJW47MfHRFZR\nd0XEPsOt5Mt+zAqqxWU/xx9/fNkh2ACaPXt2rvXcEpkVVIuWaLTMmDEDgIsuuqjtsqzm9ZrXGely\nqye3RKlWSdJq2UUXXfTyzp+dn03ATpZbfTmJUm4VrFNOohyyCTZjxowhu3btllv/chKZFeSBhZyG\nGyRoXset0eBwS5RDnoRw0gyuWlz2MxonW0c6PJ1nHQ9x19vs2bNzXfbjJDJrI28SuTtnVpCTyKwg\nj85VyNhZY9ebt+rMVSVEYiPhlqgiGgm06sxVLz+y8626nERmBXWcRJJ2TG9gtVDS/ZKOT+efLulJ\nSfPTR1/fm8asyDHRGuDEiLhb0mbAXZJuTJedFxFfKR6eWfV1nEQRsYzkrmlExPOSFpL/1odmfaMr\nx0SSdgHeDNyezpopaYGkuZJaHhm7Auq6sgMJjUd2vlVX4SFuSZvyyl2uV0u6APg8ScXTz5PcqfqT\nze9zBdT1OWHqqVBLJOlVJAl0RUR8DyAiVkTE2oj4I3AxSXF7s75VZHROwDeBhRFxbmb+dpnVDie5\nt6hZ3yrSndsP+Chwr6T56bxTgCPSu2gH8BjgvxGwvlZkdO6/aX33h3mdh2NV5D/hGNrAXjt374NH\nrPP6DXteOaLl3fiMPN9RthkzZrSsMeFEeoUv+7EhOVmG5ySy3IYqbjnInESWm4tOtuYksiE5YYbn\nGgs2rEEdnctbY2FgR+csv0FJmk65O2dWkJPIrCAnkVlBA3NM1HyPoVZn4lstzz5nNc9rfNasWQ/3\nahO64swzdy87hL4zUC3RcAfIeQ6gszfpyvse628DlUTDnfNoXt5q/Tzr2GAZqCRqbkVaLW+ebl6/\n1fvdGg22gUqiZp3c1a75Pa2Ol2yw+IoFszZG7YoFSY8BzwNrgTURsY+kLYHvALuQ/HXrhyLCVTis\nL3WrO/eOiJiUydqTgZsiYnfgpvS1WV/q1XmiQ4GD0ulLgZ8AJ/Xou0ZkJOeDWs1v9Z6sQ37+89HZ\nkA5dv//+ZYfQd7qRRAH8OD2uuSitJzcurZBKRCyTtG0Xvqdrit4m0iyrG925/SJiMnAIcKykA/K8\nqcwKqCM9X9TpOjYYCidRRCxNn1cC15IUa1zRqD+XPq9s8b45EbFPntGPbhvplQvtXvv8kEHxCqib\npHeEQNImwHtIijVeBxyZrnYk8IMi39Ntrc71DLXcbCiFzhNJ2pWk9YHk+OrbEfFFSVsBVwM7AY8D\nH4yIZ4b4HJ8nssoZlfNEEfEI8KYW858G3lnks83qohZXLJiVpH9qLEz+wuSyQ7ABdPdn7s61Xi2S\naNsdKnWayWwdtUiiDa4e6IvNreJqkUTzd5g//EpmJalFEo3faXzZIdgAWsrSXOu5n2RWUC1aIg8s\nWJX5PJFZe7nOE7k7Z1aQk8isoFocE90w2Vcs2Oibene+KxbcEpkV5CQyK8hJZFZQLY6JJs3zFQtW\ngpy7nVsis4I6bokk7UlS5bRhV+A0YAvg74Cn0vmnRMS8jiMEPvLx04ZdZ9aJxwFw5jlfK/JVhTiG\nfosh327bcRJFxIPAJABJY4AnSeotfAI4LyK+0ulnd2LtSWuTiRKvEHIMgxlDt46J3gksjojfSOrS\nR47MmLPHJBPnlPL1jmGAY+hWEk0Hrsy8ninpY8CdwImjUcx+0H79HEN1Yig8sCBpI+ADwH+ksy4A\ndiPp6i2jzW9Btyugjjl7zCu/PiVxDIMZQzdaokOAuyNiBUDjGUDSxcAPW70prdk9J12v8FXcg/br\n5xiqE0M3kugIMl05Sds1itkDh5NURO25QeuHO4bqxFAoiST9CfBuIFtz90uSJpHcLeKxpmU9M2i/\nfo6hOjEUrYD6ArBV07yPFoqoQ4P26+cYqhNDLS77yWPQfv0cQ3Vi6JskGrRfP8dQnRj6JokG7dfP\nMVQnhr5JokH79XMM1Ymhb5Jo0H79HEN1YuibJBq0Xz/HUJ0Y+iaJBu3XzzFUJ4ZaFG9cvnzaaIVi\n9rLx4+e5eKPZaKhFd+6Wyb61ilWXWyKzgpxEZgU5icwKqsUx0TvunlR2CDaIxvtOeWajohYtUZ66\nc2bdl6/unFsis4JyJZGkuZJWSrovM29LSTdKejh9HpvOl6SvSlokaYEk31zI+lrelugSYGrTvJOB\nmyJid+Cm9DUk1X92Tx9Hk5TQMutbuZIoIn4GPNM0+1Dg0nT6UuCwzPzLInEbsIWk7boRrFkVFTkm\nGtcojZU+N66XnQA8kVlvSTpvHd0u3mhWll6MzrUqxr3eVdrdLt5oVpYiLdGKRjctfV6Zzl8C7JhZ\nbwcg31krsxoqkkTXAUem00cCP8jM/1g6SjcFeC5TEdWs7+Tqzkm6EjgI2FrSEuBzwFnA1ZKOAh4H\nPpiuPg+YBiwCXiC5X5FZ38qVRBFxRJtF72yxbgDHFgnKrE58xYJZQU4is4KcRGYFOYnMCnISmRXk\nJDIryElkVpCTyKwgJ5FZQU4is4KcRGYFOYnMCnISmRXkJDIryElkVpCTyKwgJ5FZQcMmUZvqp1+W\n9Ou0wum1krZI5+8i6UVJ89PHhb0M3qwK8rREl7B+9dMbgddHxBuBh4BZmWWLI2JS+jimO2GaVdew\nSdSq+mlE/Dgi1qQvbyMpi2U2kLpxTPRJ4PrM64mSfiXpp5L2b/cmV0C1flGoAqqkU4E1wBXprGXA\nThHxtKS3AN+XtHdErG5+bzcroN58w5SXpw+eeluRj6p1DEOpenx11nFLJOlI4H3AX6dlsoiIlyLi\n6XT6LmAxsEc3Am0nu3OUpQoxjETd4q26jpJI0lTgJOADEfFCZv42ksak07uS3F7lkW4EmlcVdpAq\nxJBVtXj6zbDduTbVT2cBGwM3SgK4LR2JOwA4Q9IaYC1wTEQ035KlJxpdlDJ3mCrE0E6VY6u7YZOo\nTfXTb7ZZ9xrgmqJBdaKxc5TZ369CDK0cPPU2J08P1eLGx0M5eOptfO3tZ7z8+rhbBzOG4Sz41rSX\npz/9Ld9Iupt82Y9ZQX2RRMfdeto6z4Maw1AarY9boe6rfXcOYI97FnAc5e4cZcVw/rmvBWDmCeud\nimux3lc4P72X+3DrW361b4n2uGfBOs+DFEMjgZqnh1ovz/o2MrVPoqwyE6lKMTScf+5rnSyjoLbd\nuarsrGXG0eiSNRJluIRpXt+6oy9aoofe9MayQyg1huzxzcwTVrd83ZxAPibqntq2RNZacyvjVqf3\n+qIlstYtS3OrNNS61rnaJ9Ggd+WympOjMbCQTSYnUPfVPomyB/Zl7cxViGEo2WSy7qt9Etm6nCij\nr/YDC1X45a9CDFl77bXXeleS33zDlMpdXd4v3BKZFVTbJFo790DWzj1wnddlxVF2DMNxK9Rbte/O\nAex2/NiyQ6hEDA0HT71t3fND5z7gY6Ue6rQC6umSnsxUOp2WWTZL0iJJD0p6b68Cb6UKO3IVYmjm\nBOqtTiugApyXqXQ6D0DSXsB0YO/0PV9vFC7ptsWzV7F49ip2O34si2ev6sVX5I6j7BisXHlqLPxM\n0i45P+9Q4KqIeAl4VNIiYF/gFx1HmEMVduIqxGDlKDKwMDMtaD9XUqMPMwF4IrPOknTeerpVAbWx\n45bZjapCDFaeTpPoAmA3YBJJ1dNz0vlqsW7L6qYRMSci9omIfTqMYT1V2ImrEIMvOh1dHSVRRKyI\niLUR8UfgYpIuGyQtz46ZVXcAlhYL0YrwoELvdVoBdbvMy8OBxsjddcB0SRtLmkhSAfWXxUIcWhV+\n+asQg5Wn0wqoB0maRNJVewyYARAR90u6GniApND9sRGxtjehWyvuyo2+rlZATdf/IvDFIkHlUZVf\n/6rEYeWp7WU/rVRhiLkKMdjoUnpXlHKDGOb+RENd97Xf8icB+J/xLUfSR0UVYsiqak3wurn5hil3\n5Rk9rsW1cydMbn/r19vnfRZIduS3Tfv8aIVUuRiybr4heR7q382G1/h3HE7tu3NV2GmrEEMr7/uX\n+WWHMBBq0Z0zK0n/dOd+eMqkskOwAZS3Ja99d86sbE4is4KcRGYFeWDBrD0PLJgV4YEFs1FSi+7c\n8uXThlps1hPjx8/rn+7cLZN95t2qy905s4KcRGYFOYnMCuq0Aup3MtVPH5M0P52/i6QXM8su7GXw\nZlWQZ2DhEuB84LLGjIj4cGNa0jnAc5n1F0dEV0/svONunyeyEozPV6iqUAVUSQI+BBw8gtBGbPz4\neb38eLNCig5x7w+siIiHM/MmSvoVsBr4TET8vNUbJR0NHJ3nS67cfvuCYZqN3BFLu9QSDfc9wJWZ\n18uAnSLiaUlvAb4vae+IWK+CYETMAeaAr52zeus4iSRtCPwF8JbGvLSQ/Uvp9F2SFgN7AIXqbeeV\nPXZqnKBtNc8xlB/DaMTR7vu6/W9RZIj7XcCvI2JJY4akbRq3UpG0K0kF1EeKhTgyrf5RRvuKB8dQ\nrRh6HUeeIe4rSW6NsqekJZKOShdNZ92uHMABwAJJ9wDfBY6JiGe6Fq1ZBXVaAZWI+HiLedcA1xQP\ny6w+fMWCWUF9mUTZ/m5ZV4A7hurE0Os4avGnECNRhasbHMNgxVCLP8rzyVYrwxFLl+b6o7xaJJFZ\nSfrnL1uT619H5vI//WcAPvqLz3U7GMdQwxg6i2NmrrX6cmDBbDQ5icwKchKZFVSLY6Lx229Vynu7\nxTFUJwbIH8fyfH8J4ZbIrKhatETbjB/ZHbrPPfuznHDS5QBcfulnOeGk0b+TnWOoTgydxjGwLdEV\nl5zFuHGbvPx63LhNuOKSsxzDAMfQ6zjq0RJtu8WI39P8j9TJZxTlGKoTQy/jqMUVCyO9lfy3Lzlj\nndcf+fhpIw+qIMdQnRg6jePmG6b0z2U/I00is27Im0R9d0xkNtry/Hn4jpJukbRQ0v2Sjk/nbynp\nRkkPp89j0/mS9FVJiyQtkDS51xthVqY8LdEa4MSIeB0wBThW0l7AycBNEbE7cFP6GuAQkgIlu5PU\nlbug61GbVciwSRQRyyLi7nT6eWAhMAE4FLg0Xe1S4LB0+lDgskjcBmwhabuuR25WESMa4k7LCb8Z\nuB0YFxHLIEk0Sdumq00Ansi8bUk6b1nTZ+WugHrzDVNGEqbZqMqdRJI2Jank8+mIWJ2U4W69aot5\n642+uQKq9Ytco3OSXkWSQFdExPfS2Ssa3bT0eWU6fwmwY+btOwA5L6Awq588o3MCvgksjIhzM4uu\nA45Mp48EfpCZ/7F0lG4K8Fyj22fWlyJiyAfwZyTdsQXA/PQxDdiKZFTu4fR5y3R9Af8GLAbuBfbJ\n8R3hhx8VfNw53L4bEfW4YsGsJL5iwWw0OInMCnISmRXkJDIrqCp/lPdb4Pfpc7/Ymv7Znn7aFsi/\nPTvn+bBKjM4BSLozz0hIXfTT9vTTtkD3t8fdObOCnERmBVUpieaUHUCX9dP29NO2QJe3pzLHRGZ1\nVaWWyKyWnERmBZWeRJKmSnowLWxy8vDvqB5Jj0m6V9J8SXem81oWcqkiSXMlrZR0X2ZebQvRtNme\n0yU9mf4fzZc0LbNsVro9D0p674i/MM+l3r16AGNI/mRiV2Aj4B5grzJj6nA7HgO2bpr3JeDkdPpk\n4Oyy4xwi/gOAycB9w8VP8mcw15P8ycsU4Pay48+5PacD/9hi3b3S/W5jYGK6P44ZyfeV3RLtCyyK\niEci4n+Bq0gKnfSDdoVcKicifgY80zS7toVo2mxPO4cCV0XESxHxKLCIZL/MrewkalfUpG4C+LGk\nu9ICLNBUyAXYtu27q6ld/HX+P5uZdkHnZrrXhben7CTKVdSkBvaLiMkkNfeOlXRA2QH1UF3/zy4A\ndgMmkVSeOiedX3h7yk6ivihqEhFL0+eVwLUk3YF2hVzqoq8K0UTEiohYGxF/BC7mlS5b4e0pO4nu\nAHaXNFHSRsB0kkIntSFpE0mbNaaB9wD30b6QS130VSGapuO2w0n+jyDZnumSNpY0kaRy7y9H9OEV\nGEmZBjxEMipyatnxdBD/riSjO/cA9ze2gTaFXKr4AK4k6eL8H8kv81Ht4qeDQjQV2Z7L03gXpImz\nXWb9U9PteRA4ZKTf58t+zAoquztnVntOIrOCnERmBTmJzApyEpkV5CQyK8hJZFbQ/wPTMFRqoBLr\nRQAAAABJRU5ErkJggg==\n", - "text/plain": [ - "" + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from gymnasium.wrappers import RecordVideo\n", + "\n", + "with make_env() as record_env, RecordVideo(record_env, video_folder=\"videos\") as env_monitor:\n", + " final_rewards = evaluate(agent, env_monitor, n_games=20)\n", + "\n", + "print(\"Final mean reward\", np.mean(final_rewards))" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFz5JREFUeJzt3X20HHV9x/H35z5xQxIeEkIMJAUfooItpi1G6sORolhE\nFGzViihpy7HtsfRYH9qqfcJWrZ6K2HP06EFFUquAjzVVasmJIIVaHsSIQagBBBMTEhACuXm6T9/+\nMXPL3jtzc/fe3Z3dze/zOmfP3f3N7M539u53Z+a3M7+vIgIzS09PuwMws/Zw8pslyslvlignv1mi\nnPxmiXLymyXKyZ8wSSdKCkl97Y5lNiRdIOm6dsfR7Zz8TSTpBkmPSTqswmWGpGdUtbyqlX1BRcQX\nIuLl7YzrUODkbxJJJwIvBgJ4dVuD6SDK+HPWgfxPaZ4Lgf8BrgTW1E6QtFjSv0t6QtJtkt4v6aaa\n6c+WtF7So5L+V9Lra6ZdKekTkr4labekWyQ9PZ92Yz7bDyUNSfrdqUFJ6pH015IelLRT0r9IOnLK\nbH8gaZuk7ZLeWfPc1ZJuz+PeIemjNdNOk/TfknZJ+qGk02um3SDpA5JuBvYC75V0+5S43i5pXX7/\nlZJ+kC9ni6RLamadWMdd+Tr+hqTfm/L+vSB/Xx/P/75gSiz/IOnm/P27TtIxU9+nJEWEb024AfcC\nbwV+HRgBltZMuzq/HQ6cDGwBbsqnzc8f/z7QB/wa8AjwnHz6lcCjwOp8+heAq2teO4BnHCSuP8hj\nexqwAPga8Pl82on586/K4/gV4GHgZfn07wFvzu8vAE7L7x8P/AI4m2wDcmb+eEk+/QbgZ8Bz8piP\nBHYDK2viug14Q37/9HzZPcApwA7gvCkx9tU89/dq3r9FwGPAm/NlnZ8/XlwTy33AM4F5+eMPtfvz\n0gk3b/mbQNKLgBOAL0XE98k+bG/Mp/UCvwP8XUTsjYgfA2trnn4O8EBEfC4iRiPiDuCrwGtr5vla\nRNwaEaNkyb9qFuFdAHw0Iu6PiCHgPcAbpnTyvS8i9kTEj4DPkSUQZF9iz5B0TEQMRcT/5O1vAq6N\niGsjYjwi1gO3k30ZTLgyIu7K1+lx4BsTrytpJfBsYB1ARNwQET/KX+tOsi+jl9S5fq8ENkfE5/Nl\nXQXcA7yqZp7PRcRPImIf8CVm9/4dspz8zbEGuC4iHskff5End/2XkG2RttTMX3v/BOD5+e7zLkm7\nyBL2KTXzPFRzfy/ZVrhexwEP1jx+MI9n6TTxPJg/B+Aisi3mPfnu9Dk1Mb9uSswvApZN85qQvScT\nXypvBP4tIvYCSHq+pOslPSzpceCPgXp3zaeu38Q6HF/zuJH375DVVT/xdCJJ84DXA72SJj5khwFH\nSXousAkYBZYDP8mnr6h5iS3AdyPizBaFuI0sWSf8Uh7PjjymiXjuqZm+DSAiNgPn5x12vw18RdLi\nPObPR8RbDrLcqZeLXgccI2kV2ZfA22umfRH4OPCKiNgv6WM8mfwzXXY6df0m1uHbMzwved7yN+48\nYIzsWH5VfjsJ+C/gwogYIzvOvkTS4ZKeTdY5OOGbwDMlvVlSf357nqST6lz+DrLj+elcBbxd0lMl\nLQA+CFyTH0JM+Js8tueQ9T1cAyDpTZKWRMQ4sCufdwz4V+BVkn5LUq+kQUmnS1rONPLlfQX4J7Lj\n9PU1kxcCj+aJv5r8kCn3MDB+kHW8luz9e6OkvrzT82Sy99UOwsnfuDVkx5Q/i4iHJm5kW7IL8mPr\ni8k6vR4CPk+WkAcAImI38HLgDWRbsYeAD5PtPdTjEmBtvvv9+pLpV+TLvBH4KbAf+NMp83yXrFNw\nA/CRiJg4geYs4C5JQ8A/k3XQ7Y+ILcC5wHvJknML8OfM/Hn6IvAy4MtTvnzeCvy9pN3A35IdlwOQ\nHxp8ALg5X8fTal8wIn5B1m/yTrJOx78Azqk5BLNpKO8RtQpJ+jDwlIhYM+PMZi3iLX8F8t/xT8nO\nd9Fqso60r7c7LkubO/yqsZBsV/84YCdwKdlPX2Zt491+s0R5t98sUQ3t9ks6i6wXuBf4TER86GDz\n9w/Mj8HBoxtZpJkdxP79jzEyvEf1zDvn5M9PW/0E2XndW4HbJK3LT18tNTh4NKeuvniuizSzGdx+\n68frnreR3f7VwL35OePDZBeunNvA65lZhRpJ/uOZfP72ViafTw2ApD/MLwu9fWRkTwOLM7NmaiT5\ny44rCj8dRMTlEXFqRJza3z+/gcWZWTM10uG3lckXqCwnvyBkOhraR//NmxpYpJkdjA7sq3veRrb8\ntwEr8wtGBsjOTV/XwOuZWYXmvOWPiFFJFwP/SfZT3xURcVfTIjOzlmrod/6IuJbskkoz6zI+w88s\nUZVe2BNHzGP/i0+pcpFmSYn/uqHueb3lN0uUk98sUU5+s0Q5+c0S5eQ3S1Slvf3Di4It549Oaovx\n4iUCkkcXMoio/7Mxm3k7SbPjHr6r/ud6y2+WKCe/WaKc/GaJcvKbJaracftDxNjkDo4YKX7/lHUC\ntpsGxgttU9cFgLK2duor7wBST7E9hjtsW1ASI/3F/wNADPcW2zqsv6/0MzRa/Lw09PmfxXM77L9t\nZlVx8pslyslvlignv1miGq3Y8wCwGxgDRiPi1IPOPyL6fj657LzGGomgQmVfk2UdSh3WyTStsvUp\n70vrLNNtrro19ibHrZH6O/ya0dv/mxHxSBNex8wq5N1+s0Q1mvwBXCfp+5L+sGyG2oo9Y3tcsces\nUzS62//CiNgm6VhgvaR7IuLG2hki4nLgcoDB5Su65YjY7JDX6NDd2/K/OyV9nax4543Tzg/ElBOx\nVNbh0YFfEVPjhmk6Kzss9rK4gdJ9vtL/RTuV9F2NT7M+PZ0We4my2MtOYqzqMzTn3X5J8yUtnLgP\nvBxwLS6zLtHIln8p8HVJE6/zxYj4dlOiMrOWa6Rc1/3Ac5sYi5lVyD/1mSWq0kt6RUknWYd1kE2n\nGzr3ykx7BmUXdJCVvb893XJGaInS2Nv4GfKW3yxRTn6zRDn5zRLl5DdLlJPfLFGV9vYHEFO+bnx6\nb2tNe3pvyamzGi22tdUhdnpv6WeoG0/vNbPu5uQ3S5ST3yxRTn6zRFVbsacvGF1cR69Sh3WaAaWd\nTx0Z51SzKf7SaevTaPGjblifJscY/S7RbWYzcPKbJcrJb5YoJ79Zombs8JN0BXAOsDMifjlvWwRc\nA5wIPAC8PiIem3Fp40J7p47g2Wm9MmZdrMkluq8EzprS9m5gQ0SsBDbkj82si8yY/Pk4/I9OaT4X\nWJvfXwuc1+S4zKzF5nrMvzQitgPkf4+dbsZJFXuGXLHHrFO0vMMvIi6PiFMj4tTeBfNbvTgzq9Nc\nz/DbIWlZRGyXtAzYWdezAnpGpjY2ehqXmf2/WfSfz3XLvw5Yk99fA3xjjq9jZm0yY/JLugr4HvAs\nSVslXQR8CDhT0mbgzPyxmXWRGXf7I+L8aSa9tMmxmFmFfIafWaIqv6R3bHGhx8/MmqXPl/Sa2Qyc\n/GaJcvKbJcrJb5aoakt0D4uBrQNVLtIsKRpu7iW9ZnYIcvKbJcrJb5YoJ79Zopz8Zoly8pslyslv\nlignv1minPxmiapnJJ8rJO2UtKmm7RJJP5e0Mb+d3dowzazZ5lq0A+CyiFiV365tblhm1mpzLdph\nZl2ukWP+iyXdmR8WHN20iMysEnNN/k8CTwdWAduBS6ebcVLFnj2u2GPWKeaU/BGxIyLGImIc+DSw\n+iDzPlmxZ74r9ph1ijklf16lZ8JrgE3TzWtmnWnGwTzyoh2nA8dI2gr8HXC6pFVkxYEeAP6ohTGa\nWQvMtWjHZ1sQi5lVyGf4mSXKyW+WKCe/WaKc/GaJcvKbJcrJb5YoJ79Zopz8Zoly8pslyslvlign\nv1minPxmiXLymyXKyW+WKCe/WaKc/GaJcvKbJaqeij0rJF0v6W5Jd0l6W96+SNJ6SZvzvx6+26yL\n1LPlHwXeGREnAacBfyLpZODdwIaIWAlsyB+bWZeop2LP9oi4I7+/G7gbOB44F1ibz7YWOK9VQZpZ\n883qmF/SicCvArcASyNiO2RfEMCx0zzHRTvMOlDdyS9pAfBV4M8i4ol6n+eiHWadqa7kl9RPlvhf\niIiv5c07Jop35H93tiZEM2uFenr7RTZO/90R8dGaSeuANfn9NcA3mh+embXKjEU7gBcCbwZ+JGlj\n3vZe4EPAlyRdBPwMeF1rQjSzVqinYs9NgKaZ/NLmhmNmVfEZfmaJcvKbJaqeY/6m0Tj07Z18BDE2\nLwrzxXQHGW3Ut68Y1PhAcb7x3uL6dIve4fre+LGB7l1He5K3/GaJcvKbJcrJb5YoJ79Zoirt8Ot/\naA/L//G/J7Vte9cLCvMNH9neDqWBJ4odX8ddekuhbdebVhfbVrYkpKZTyVu8Yv1Qoa1vZ/Eyjvsv\nPK7Q1s0dnanylt8sUU5+s0Q5+c0S5eQ3S5ST3yxR1Z7ee9hh9J749Mlt41VGUJ+ekWJb39Ilhbbx\n3gqCaRGNF3/RGD7qsEJb7xNl5zCXvGAXvxep8pbfLFFOfrNEOfnNEtVIxZ5LJP1c0sb8dnbrwzWz\nZqmnw2+iYs8dkhYC35e0Pp92WUR8pN6F7V/Sx0/eMnl4/94DJaeFtvlM0QOLigFsfttTC20aK3ly\nGzswy8ZGmHZwhP3FeR84r2TeviMKTQMPdd7/zGavnjH8tgMTxTl2S5qo2GNmXayRij0AF0u6U9IV\n0xXqdMUes87USMWeTwJPB1aR7RlcWvY8V+wx60xzrtgTETsiYiwixoFPA8XrW82sY814zD9dxR5J\nyyYKdQKvATbN9Fo9IzBv5+ROpZGFnTeAp8aKAQw+XJxvZGGxrZ3XtZ+y+r5C29ED+0rn/e59xYEH\nLlv95ULbkt7i9fwXrntroa1vdweOumoH1UjFnvMlrSLr530A+KOWRGhmLdFIxZ5rmx+OmVXFZ/iZ\nJcrJb5aoSi/pBegZnfy47NLSaPNgkBqts63sDL82Xtq6cePTCm1/dWZ55fS3/MZ3C233DC8rtG3a\nt7zQ1lNnZR/rbN7ymyXKyW+WKCe/WaKc/GaJqrbDTzDeP7mp3Wfz1Wtq3NB5sfc/Xvwu/8Tml5TO\ne8ep1xTaHhrbW2j74I/PKrT1HJhDcNZxvOU3S5ST3yxRTn6zRDn5zRLl5DdLVKW9/dELw0dOOXW3\nAyv2jA0WTy8em1cyYxcMWrnnzkWl7c/ce2GhbXS4+HHof7BYxccODd7ymyXKyW+WKCe/WaLqqdgz\nKOlWST/MK/a8L29/qqRbJG2WdI2kknKuZtap6unwOwCcERFD+Si+N0n6D+AdZBV7rpb0KeAisuG8\np6XBMfpPmjwg5MjdxYow7e4EHDmi2JO38ITHC2177j2y0Na7r7PO+e0ZLY/nWct2FNru2XFsoS1w\nh9+hasYtf2SG8of9+S2AM4Cv5O1rgfNaEqGZtUS94/b35iP37gTWA/cBuyJiYnybrUxTwqu2Ys/o\nE8ULR8ysPepK/rw4xypgOVlxjpPKZpvmuf9fsafviMPnHqmZNdWsevsjYhdwA3AacJSkiT6D5cC2\n5oZmZq1UT8WeJcBIROySNA94GfBh4HrgtcDVwBqgfKTIGhFibKzzf10s67QbGy+Ju9Mu6C8x3l9+\nGuJLj7mn0HbvL44ptI00PSLrFPX09i8D1krqJdtT+FJEfFPSj4GrJb0f+AFZSS8z6xL1VOy5k6ws\n99T2+3FxTrOu1fn74GbWEk5+s0RVeknvQN8oJyx+dFLb/b0LCvOpAy/zXXrE7kLbTw8rxt67v7O+\nT0fnl3f4/crglkLbnkeKP8X6nO1DV2d9Us2sMk5+s0Q5+c0S5eQ3S1SlHX6j4z3sHJrSSdaBnXtj\nC4tB9ajYcdYz0gVn+C0oqS0O3HWgeB2W9rWxvrhVzlt+s0Q5+c0S5eQ3S5ST3yxR1RbtGOpj+ObF\nk9oGy/uj2mrgiWLH17atKwpthw9XEU1j5u0sP0fvW5eeVmg74szitqCsNLl1Lo3VP6+3/GaJcvKb\nJcrJb5YoJ79Zohqp2HOlpJ9K2pjfVrU+XDNrlkYq9gD8eUR85SDPnaRnBOZv74K61ocQjZe/30PP\nOrrQNvho8bTm6On8U5jtST2z+PWsnjH8Aiir2GNmXWxOFXsi4pZ80gck3SnpMkmlRd0mVezZv6dJ\nYZtZo+ZUsUfSLwPvAZ4NPA9YBPzlNM99smLP4PwmhW1mjZprxZ6zImJ7XsTzAPA5PIy3WVeZc8Ue\nScsiYrskkVXo3TTTa4VgzCNCVqy8w2500NfuH4pmU0SqkYo938m/GARsBP54DrGaWZs0UrHnjJZE\nZGaV8Bl+Zoly8pslqtLr+cf7Ye9SnzFm1iqzGX/BW36zRDn5zRLl5DdLlJPfLFGVdvj1jMDhO3xB\noFmr9IzMYt7WhWFmnczJb5YoJ79Zopz8ZomqtMNPMbuKImY2OyWV5KflLb9Zopz8Zoly8pslyslv\nlqi6kz8fvvsHkr6ZP36qpFskbZZ0jSSPzmfWRWbT2/824G7giPzxh4HLIuJqSZ8CLgI+ebAX6BkO\nFm6ZUtS+pCLMeG/5Nf8Du4YLbRorVplphfGB4oCXIwuLF0+X9bb2DRXPuewZruZnj+gt/34fPqr4\nXd0zVgxeo8W2vqHi/6Eqw0eXlodAJbFHX3HdB3YdKD55mqpGzTY2v/h5GRsoxjhdlaR6Pv+9++vP\nh3qLdiwHXgl8Jn8s4AxgolTXWrIRfM2sS9S72/8x4C+Aia+VxcCuiJioDLYVOL7sibUVe0ZGXLHH\nrFPUU6X3HGBnRHy/trlk1tJ9p9qKPf39rthj1inqOeZ/IfBqSWcDg2TH/B8DjpLUl2/9lwPbWhem\nmTVbPeP2v4esLh+STgfeFREXSPoy8FrgamAN8I2ZXmt0vtixenKHzciC4g7D+EB5B8xxNxU7e/qf\nqKbjbOj4YgfZL1aVxFnStPjO4tu8YGs1nWYjC8sr82x7cXGnr/dAcYdu4PFi27G3Nx7XXP38JeUj\nVPbtK8ZZ9tlasaG+Ts1WePTk4ud3aEVJR2VveTxP+d5goe2wxybX5J5NSfVGfuf/S+Adku4l6wP4\nbAOvZWYVm9WFPRFxA1mhTiLiflyc06xr+Qw/s0Q5+c0SpYjqBtSU9DDwYP7wGOCRyhbeWofSuoDX\np9MdbH1OiIgl9bxIpck/acHS7RFxalsW3mSH0rqA16fTNWt9vNtvlignv1mi2pn8l7dx2c12KK0L\neH06XVPWp23H/GbWXt7tN0uUk98sUZUnv6SzJP2vpHslvbvq5TdK0hWSdkraVNO2SNL6fEiz9ZKO\nbmeMsyFphaTrJd0t6S5Jb8vbu26dJA1KulXSD/N1eV/e3tVDzrVqCL1Kk19SL/AJ4BXAycD5kk6u\nMoYmuBI4a0rbu4ENEbES2JA/7hajwDsj4iTgNOBP8v9JN67TAeCMiHgusAo4S9JpPDnk3ErgMbIh\n57rJxBB6E5qyPlVv+VcD90bE/RExTHY58LkVx9CQiLgReHRK87lkQ5lBlw1pFhHbI+KO/P5usg/Z\n8XThOkVmKH/Yn9+CLh5yrpVD6FWd/McDW2oeTzv8V5dZGhHbIUsm4Ng2xzMnkk4EfhW4hS5dp3wX\neSOwE1gP3EedQ851qDkPoTeTqpO/7uG/rFqSFgBfBf4sIp5odzxzFRFjEbGKbHSp1cBJZbNVG9Xc\nNDqE3kwqLdRJ9i21oubxoTL81w5JyyJiu6RlZFudriGpnyzxvxARX8ubu3qdImKXpBvI+jG6dci5\nlg6hV/WW/zZgZd5bOQC8AVhXcQytsI5sKDOoc0izTpEfQ34WuDsiPlozqevWSdISSUfl9+cBLyPr\nw7iebMg56JJ1gWwIvYhYHhEnkuXKdyLiApq1PhFR6Q04G/gJ2bHYX1W9/CbEfxWwHRgh25O5iOw4\nbAOwOf+7qN1xzmJ9XkS223gnsDG/nd2N6wScAvwgX5dNwN/m7U8DbgXuBb4MHNbuWOewbqcD32zm\n+vj0XrNE+Qw/s0Q5+c0S5eQ3S5ST3yxRTn6zRDn5zRLl5DdL1P8B8FPBd33wU/8AAAAASUVORK5C\nYII=\n", - "text/plain": [ - "" + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show video. This may not work in some setups. If it doesn't\n", + "# work for you, you can download the videos and view them locally.\n", + "\n", + "from pathlib import Path\n", + "from base64 import b64encode\n", + "from IPython.display import HTML\n", + "\n", + "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n", + "video_path = video_paths[-1] # You can also try other indices\n", + "\n", + "if 'google.colab' in sys.modules:\n", + " # https://stackoverflow.com/a/57378660/1214547\n", + " with video_path.open('rb') as fp:\n", + " mp4 = fp.read()\n", + " data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", + "else:\n", + " data_url = str(video_path)\n", + "\n", + "HTML(\"\"\"\n", + "\n", + "\"\"\".format(data_url))" ] - }, - "metadata": {}, - "output_type": "display_data" } - ], - "source": [ - "s = env.reset()\n", - "for _ in range(100):\n", - " s, _, _, _ = env.step(env.action_space.sample())\n", - "\n", - "plt.title('Game image')\n", - "plt.imshow(env.render('rgb_array'))\n", - "plt.show()\n", - "\n", - "plt.title('Agent observation')\n", - "plt.imshow(s.reshape([42, 42]))\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### POMDP setting\n", - "\n", - "The Atari game we're working with is actually a POMDP: your agent needs to know timing at which enemies spawn and move, but cannot do so unless it has some memory. \n", - "\n", - "Let's design another agent that has a recurrent neural net memory to solve this. Here's a sketch.\n", - "\n", - "![img](img1.jpg)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "# a special module that converts [batch, channel, w, h] to [batch, units]\n", - "\n", - "\n", - "class Flatten(nn.Module):\n", - " def forward(self, input):\n", - " return input.view(input.size(0), -1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class SimpleRecurrentAgent(nn.Module):\n", - " def __init__(self, obs_shape, n_actions, reuse=False):\n", - " \"\"\"A simple actor-critic agent\"\"\"\n", - " super(self.__class__, self).__init__()\n", - "\n", - " self.conv0 = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2))\n", - " self.conv1 = nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2))\n", - " self.conv2 = nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2))\n", - " self.flatten = Flatten()\n", - "\n", - " self.hid = nn.Linear(512, 128)\n", - " self.rnn = nn.LSTMCell(128, 128)\n", - "\n", - " self.logits = nn.Linear(128, n_actions)\n", - " self.state_value = nn.Linear(128, 1)\n", - "\n", - " def forward(self, prev_state, obs_t):\n", - " \"\"\"\n", - " Takes agent's previous hidden state and a new observation,\n", - " returns a new hidden state and whatever the agent needs to learn\n", - " \"\"\"\n", - "\n", - " # Apply the whole neural net for one step here.\n", - " # See docs on self.rnn(...).\n", - " # The recurrent cell should take the last feedforward dense layer as input.\n", - " \n", - "\n", - " new_state = \n", - " logits = \n", - " state_value = \n", - "\n", - " return new_state, (logits, state_value)\n", - "\n", - " def get_initial_state(self, batch_size):\n", - " \"\"\"Return a list of agent memory states at game start. Each state is a np array of shape [batch_size, ...]\"\"\"\n", - " return torch.zeros((batch_size, 128)), torch.zeros((batch_size, 128))\n", - "\n", - " def sample_actions(self, agent_outputs):\n", - " \"\"\"pick actions given numeric agent outputs (np arrays)\"\"\"\n", - " logits, state_values = agent_outputs\n", - " probs = F.softmax(logits)\n", - " return torch.multinomial(probs, 1)[:, 0].data.numpy()\n", - "\n", - " def step(self, prev_state, obs_t):\n", - " \"\"\" like forward, but obs_t is a numpy array \"\"\"\n", - " obs_t = torch.tensor(np.asarray(obs_t), dtype=torch.float32)\n", - " (h, c), (l, s) = self.forward(prev_state, obs_t)\n", - " return (h.detach(), c.detach()), (l.detach(), s.detach())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "n_parallel_games = 5\n", - "gamma = 0.99\n", - "\n", - "agent = SimpleRecurrentAgent(obs_shape, n_actions)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "state = [env.reset()]\n", - "_, (logits, value) = agent.step(agent.get_initial_state(1), state)\n", - "print(\"action logits:\\n\", logits)\n", - "print(\"state values:\\n\", value)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Let's play!\n", - "Let's build a function that measures agent's average reward." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def evaluate(agent, env, n_games=1):\n", - " \"\"\"Plays an entire game start to end, returns session rewards.\"\"\"\n", - "\n", - " game_rewards = []\n", - " for _ in range(n_games):\n", - " # initial observation and memory\n", - " observation = env.reset()\n", - " prev_memories = agent.get_initial_state(1)\n", - "\n", - " total_reward = 0\n", - " while True:\n", - " new_memories, readouts = agent.step(\n", - " prev_memories, observation[None, ...])\n", - " action = agent.sample_actions(readouts)\n", - "\n", - " observation, reward, done, info = env.step(action[0])\n", - "\n", - " total_reward += reward\n", - " prev_memories = new_memories\n", - " if done:\n", - " break\n", - "\n", - " game_rewards.append(total_reward)\n", - " return game_rewards" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import gym.wrappers\n", - "\n", - "with gym.wrappers.Monitor(make_env(), directory=\"videos\", force=True) as env_monitor:\n", - " rewards = evaluate(agent, env_monitor, n_games=3)\n", - "\n", - "print(rewards)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Show video. This may not work in some setups. If it doesn't\n", - "# work for you, you can download the videos and view them locally.\n", - "\n", - "from pathlib import Path\n", - "from base64 import b64encode\n", - "from IPython.display import HTML\n", - "\n", - "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n", - "video_path = video_paths[-1] # You can also try other indices\n", - "\n", - "if 'google.colab' in sys.modules:\n", - " # https://stackoverflow.com/a/57378660/1214547\n", - " with video_path.open('rb') as fp:\n", - " mp4 = fp.read()\n", - " data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", - "else:\n", - " data_url = str(video_path)\n", - "\n", - "HTML(\"\"\"\n", - "\n", - "\"\"\".format(data_url))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Training on parallel games\n", - "\n", - "We introduce a class called EnvPool - it's a tool that handles multiple environments for you. Here's how it works:\n", - "![img](img2.jpg)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from env_pool import EnvPool\n", - "pool = EnvPool(agent, make_env, n_parallel_games)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We gonna train our agent on a thing called __rollouts:__\n", - "![img](img3.jpg)\n", - "\n", - "A rollout is just a sequence of T observations, actions and rewards that agent took consequently.\n", - "* First __s0__ is not necessarily initial state for the environment\n", - "* Final state is not necessarily terminal\n", - "* We sample several parallel rollouts for efficiency" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# for each of n_parallel_games, take 10 steps\n", - "rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(10)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"Actions shape:\", rollout_actions.shape)\n", - "print(\"Rewards shape:\", rollout_rewards.shape)\n", - "print(\"Mask shape:\", rollout_mask.shape)\n", - "print(\"Observations shape: \", rollout_obs.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Actor-critic objective\n", - "\n", - "Here we define a loss function that uses rollout above to train advantage actor-critic agent.\n", - "\n", - "\n", - "Our loss consists of three components:\n", - "\n", - "* __The policy \"loss\"__\n", - " $$ \\hat J = {1 \\over T} \\cdot \\sum_t { \\log \\pi(a_t | s_t) } \\cdot A_{const}(s,a) $$\n", - " * This function has no meaning in and of itself, but it was built such that\n", - " * $ \\nabla \\hat J = {1 \\over N} \\cdot \\sum_t { \\nabla \\log \\pi(a_t | s_t) } \\cdot A(s,a) \\approx \\nabla E_{s, a \\sim \\pi} R(s,a) $\n", - " * Therefore if we __maximize__ J_hat with gradient descent we will maximize expected reward\n", - " \n", - " \n", - "* __The value \"loss\"__\n", - " $$ L_{td} = {1 \\over T} \\cdot \\sum_t { [r + \\gamma \\cdot V_{const}(s_{t+1}) - V(s_t)] ^ 2 }$$\n", - " * Ye Olde TD_loss from q-learning and alike\n", - " * If we minimize this loss, V(s) will converge to $V_\\pi(s) = E_{a \\sim \\pi(a | s)} R(s,a) $\n", - "\n", - "\n", - "* __Entropy Regularizer__\n", - " $$ H = - {1 \\over T} \\sum_t \\sum_a {\\pi(a|s_t) \\cdot \\log \\pi (a|s_t)}$$\n", - " * If we __maximize__ entropy we discourage agent from predicting zero probability to actions\n", - " prematurely (a.k.a. exploration)\n", - " \n", - " \n", - "So we optimize a linear combination of $L_{td}$ $- \\hat J$, $-H$\n", - " \n", - "```\n", - "\n", - "```\n", - "\n", - "```\n", - "\n", - "```\n", - "\n", - "```\n", - "\n", - "```\n", - "\n", - "\n", - "__One more thing:__ since we train on T-step rollouts, we can use N-step formula for advantage for free:\n", - " * At the last step, $A(s_t,a_t) = r(s_t, a_t) + \\gamma \\cdot V(s_{t+1}) - V(s) $\n", - " * One step earlier, $A(s_t,a_t) = r(s_t, a_t) + \\gamma \\cdot r(s_{t+1}, a_{t+1}) + \\gamma ^ 2 \\cdot V(s_{t+2}) - V(s) $\n", - " * Et cetera, et cetera. This way agent starts training much faster since it's estimate of A(s,a) depends less on his (imperfect) value function and more on actual rewards. There's also a [nice generalization](https://arxiv.org/abs/1506.02438) of this.\n", - "\n", - "\n", - "__Note:__ it's also a good idea to scale rollout_len up to learn longer sequences. You may wish set it to >=20 or to start at 10 and then scale up as time passes." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def to_one_hot(y, n_dims=None):\n", - " \"\"\" Take an integer tensor and convert it to 1-hot matrix. \"\"\"\n", - " y_tensor = y.to(dtype=torch.int64).view(-1, 1)\n", - " n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1\n", - " y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)\n", - " return y_one_hot" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "opt = torch.optim.Adam(agent.parameters(), lr=1e-5)\n", - "\n", - "\n", - "def train_on_rollout(states, actions, rewards, is_not_done, prev_memory_states, gamma=0.99):\n", - " \"\"\"\n", - " Takes a sequence of states, actions and rewards produced by generate_session.\n", - " Updates agent's weights by following the policy gradient above.\n", - " Please use Adam optimizer with default parameters.\n", - " \"\"\"\n", - "\n", - " # shape: [batch_size, time, c, h, w]\n", - " states = torch.tensor(np.asarray(states), dtype=torch.float32)\n", - " actions = torch.tensor(np.array(actions), dtype=torch.int64) # shape: [batch_size, time]\n", - " rewards = torch.tensor(np.array(rewards), dtype=torch.float32) # shape: [batch_size, time]\n", - " is_not_done = torch.tensor(np.array(is_not_done), dtype=torch.float32) # shape: [batch_size, time]\n", - " rollout_length = rewards.shape[1] - 1\n", - "\n", - " # predict logits, probas and log-probas using an agent.\n", - " memory = [m.detach() for m in prev_memory_states]\n", - "\n", - " logits = [] # append logit sequence here\n", - " state_values = [] # append state values here\n", - " for t in range(rewards.shape[1]):\n", - " obs_t = states[:, t]\n", - "\n", - " # use agent to comute logits_t and state values_t.\n", - " # append them to logits and state_values array\n", - "\n", - " memory, (logits_t, values_t) = \n", - "\n", - " logits.append(logits_t)\n", - " state_values.append(values_t)\n", - "\n", - " logits = torch.stack(logits, dim=1)\n", - " state_values = torch.stack(state_values, dim=1)\n", - " probas = F.softmax(logits, dim=2)\n", - " logprobas = F.log_softmax(logits, dim=2)\n", - "\n", - " # select log-probabilities for chosen actions, log pi(a_i|s_i)\n", - " actions_one_hot = to_one_hot(actions, n_actions).view(\n", - " actions.shape[0], actions.shape[1], n_actions)\n", - " logprobas_for_actions = torch.sum(logprobas * actions_one_hot, dim=-1)\n", - "\n", - " # Now let's compute two loss components:\n", - " # 1) Policy gradient objective.\n", - " # Notes: Please don't forget to call .detach() on advantage term. Also please use mean, not sum.\n", - " # it's okay to use loops if you want\n", - " J_hat = 0 # policy objective as in the formula for J_hat\n", - "\n", - " # 2) Temporal difference MSE for state values\n", - " # Notes: Please don't forget to call on V(s') term. Also please use mean, not sum.\n", - " # it's okay to use loops if you want\n", - " value_loss = 0\n", - "\n", - " cumulative_returns = state_values[:, -1].detach()\n", - "\n", - " for t in reversed(range(rollout_length)):\n", - " r_t = rewards[:, t] # current rewards\n", - " # current state values\n", - " V_t = state_values[:, t]\n", - " V_next = state_values[:, t + 1].detach() # next state values\n", - " # log-probability of a_t in s_t\n", - " logpi_a_s_t = logprobas_for_actions[:, t]\n", - "\n", - " # update G_t = r_t + gamma * G_{t+1} as we did in week6 reinforce\n", - " cumulative_returns = G_t = r_t + gamma * cumulative_returns\n", - "\n", - " # Compute temporal difference error (MSE for V(s))\n", - " value_loss += \n", - "\n", - " # compute advantage A(s_t, a_t) using cumulative returns and V(s_t) as baseline\n", - " advantage = \n", - " advantage = advantage.detach()\n", - "\n", - " # compute policy pseudo-loss aka -J_hat.\n", - " J_hat += \n", - "\n", - " # regularize with entropy\n", - " entropy_reg = \n", - "\n", - " # add-up three loss components and average over time\n", - " loss = -J_hat / rollout_length +\\\n", - " value_loss / rollout_length +\\\n", - " -0.01 * entropy_reg\n", - "\n", - " # Gradient descent step\n", - " \n", - "\n", - " return loss.data.numpy()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# let's test it\n", - "memory = list(pool.prev_memory_states)\n", - "rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(10)\n", - "\n", - "train_on_rollout(rollout_obs, rollout_actions,\n", - " rollout_rewards, rollout_mask, memory)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Train \n", - "\n", - "just run train step and see if agent learns any better" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from IPython.display import clear_output\n", - "from tqdm import trange\n", - "from pandas import DataFrame\n", - "moving_average = lambda x, **kw: DataFrame(\n", - " {'x': np.asarray(x)}).x.ewm(**kw).mean().values\n", - "\n", - "rewards_history = []" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i in trange(15000):\n", - "\n", - " memory = list(pool.prev_memory_states)\n", - " rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(\n", - " 10)\n", - " train_on_rollout(rollout_obs, rollout_actions,\n", - " rollout_rewards, rollout_mask, memory)\n", - "\n", - " if i % 100 == 0:\n", - " rewards_history.append(np.mean(evaluate(agent, env, n_games=1)))\n", - " clear_output(True)\n", - " plt.plot(rewards_history, label='rewards')\n", - " plt.plot(moving_average(np.array(rewards_history),\n", - " span=10), label='rewards ewma@10')\n", - " plt.legend()\n", - " plt.show()\n", - " if rewards_history[-1] >= 10000:\n", - " print(\"Your agent has just passed the minimum homework threshold\")\n", - " break" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Relax and grab some refreshments while your agent is locked in an infinite loop of violence and death.\n", - "\n", - "__How to interpret plots:__\n", - "\n", - "The session reward is the easy thing: it should in general go up over time, but it's okay if it fluctuates ~~like crazy~~. It's also OK if it reward doesn't increase substantially before some 10k initial steps. However, if reward reaches zero and doesn't seem to get up over 2-3 evaluations, there's something wrong happening.\n", - "\n", - "\n", - "Since we use a policy-based method, we also keep track of __policy entropy__ - the same one you used as a regularizer. The only important thing about it is that your entropy shouldn't drop too low (`< 0.1`) before your agent gets the yellow belt. Or at least it can drop there, but _it shouldn't stay there for long_.\n", - "\n", - "If it does, the culprit is likely:\n", - "* Some bug in entropy computation. Remember that it is $ - \\sum p(a_i) \\cdot log p(a_i) $\n", - "* Your agent architecture converges too fast. Increase entropy coefficient in actor loss. \n", - "* Gradient explosion - just [clip gradients](https://stackoverflow.com/a/56069467) and maybe use a smaller network\n", - "* Us. Or PyTorch developers. Or aliens. Or lizardfolk. Contact us on forums before it's too late!\n", - "\n", - "If you're debugging, just run `logits, values = agent.step(batch_states)` and manually look into logits and values. This will reveal the problem 9 times out of 10: you'll likely see some NaNs or insanely large numbers or zeros. Try to catch the moment when this happens for the first time and investigate from there." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### \"Final\" evaluation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import gym.wrappers\n", - "\n", - "with gym.wrappers.Monitor(make_env(), directory=\"videos\", force=True) as env_monitor:\n", - " final_rewards = evaluate(agent, env_monitor, n_games=20)\n", - "\n", - "print(\"Final mean reward\", np.mean(final_rewards))" - ] + ], + "metadata": { + "language_info": { + "name": "python", + "pygments_lexer": "ipython3" + } }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Show video. This may not work in some setups. If it doesn't\n", - "# work for you, you can download the videos and view them locally.\n", - "\n", - "from pathlib import Path\n", - "from base64 import b64encode\n", - "from IPython.display import HTML\n", - "\n", - "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n", - "video_path = video_paths[-1] # You can also try other indices\n", - "\n", - "if 'google.colab' in sys.modules:\n", - " # https://stackoverflow.com/a/57378660/1214547\n", - " with video_path.open('rb') as fp:\n", - " mp4 = fp.read()\n", - " data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", - "else:\n", - " data_url = str(video_path)\n", - "\n", - "HTML(\"\"\"\n", - "\n", - "\"\"\".format(data_url))" - ] - } - ], - "metadata": { - "language_info": { - "name": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 1 + "nbformat": 4, + "nbformat_minor": 1 } diff --git a/week09_policy_II/mujoco_wrappers.py b/week09_policy_II/mujoco_wrappers.py index 72bc1bd9c..9ca1b9dae 100644 --- a/week09_policy_II/mujoco_wrappers.py +++ b/week09_policy_II/mujoco_wrappers.py @@ -1,6 +1,6 @@ """ MuJoCo env wrappers. """ # Adapted from https://github.com/openai/baselines -import gym +import gymnasium as gym import numpy as np @@ -83,17 +83,17 @@ def observation(self, obs): return obs def step(self, action): - obs, rews, resets, info = self.env.step(action) + obs, rews, terminated, truncated, info = self.env.step(action) self.ret = self.ret * self.gamma + rews obs = self.observation(obs) if self.ret_rmv: self.ret_rmv.update(self.ret) rews = np.clip(rews / np.sqrt(self.ret_rmv.var + self.eps), -self.cliprew, self.cliprew) - self.ret[resets] = 0. - return obs, rews, resets, info + self.ret[terminated] = 0. + return obs, rews, terminated, truncated, info def reset(self, **kwargs): self.ret = np.zeros(getattr(self.env.unwrapped, "nenvs", 1)) - obs = self.env.reset(**kwargs) - return self.observation(obs) + obs, info = self.env.reset(**kwargs) + return self.observation(obs), info diff --git a/week09_policy_II/ppo.ipynb b/week09_policy_II/ppo.ipynb index e176a27d5..9640490c8 100644 --- a/week09_policy_II/ppo.ipynb +++ b/week09_policy_II/ppo.ipynb @@ -17,6 +17,8 @@ " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", "\n", " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week09_policy_II/mujoco_wrappers.py\n", + "\n", + " !pip -q install gymnasium[mujoco]\n", " \n", " !touch .setup_complete\n", "\n", @@ -43,27 +45,7 @@ "You will be solving a continuous control environment on which it may be easier and faster \n", "to train an agent, however note that PPO here may not be the best algorithm as, for example,\n", "Deep Deterministic Policy Gradient and Soft Actor Critic may be more suited \n", - "for continuous control environments. To run the environment you will need to install \n", - "[pybullet-gym](https://github.com/benelot/pybullet-gym) which unlike MuJoCo \n", - "does not require you to have a license.\n", - "\n", - "To install the library:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "31MBortONCVv", - "outputId": "13ea5dac-6194-497a-8ca3-d7bde217798c" - }, - "outputs": [], - "source": [ - "!git clone https://github.com/benelot/pybullet-gym lib/pybullet-gym\n", - "!pip install -e lib/pybullet-gym" + "for continuous control environments." ] }, { @@ -87,7 +69,7 @@ "The overall structure of the code is similar to the one in the A2C optional homework, but don't worry if you haven't done it, it should be relatively easy to figure it out. \n", "First, we will create an instance of the environment. \n", "We will normalize the observations and rewards, but before that you will need a wrapper that will \n", - "write summaries, mainly, the total reward during an episode. You can either use one for `TensorFlow` \n", + "write summaries, mainly, the total reward during an episode. You can either use one for `TensorBoard` \n", "implemented in `atari_wrappers.py` file from the optional A2C homework, or implement your own. " ] }, @@ -103,16 +85,25 @@ }, "outputs": [], "source": [ - "import gym \n", - "import pybulletgym\n", + "import gymnasium as gym\n", "\n", - "env = gym.make(\"HalfCheetahMuJoCoEnv-v0\")\n", + "env = gym.make(\"HalfCheetah-v4\", render_mode=\"rgb_array\")\n", "print(\"observation space: \", env.observation_space,\n", - " \"\\nobservations:\", env.reset())\n", - "print(\"action space: \", env.action_space, \n", + " \"\\nobservations:\", env.reset()[0])\n", + "print(\"action space: \", env.action_space,\n", " \"\\naction_sample: \", env.action_space.sample())" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.imshow(env.render())" + ] + }, { "cell_type": "code", "execution_count": null, @@ -137,17 +128,17 @@ " self.current_len = 0\n", "\n", " def step(self, action):\n", - " obs, rew, done, info = self.env.step(action)\n", + " obs, rew, terminated, truncated, info = self.env.step(action)\n", "\n", " self.current_reward += rew\n", " self.current_len += 1\n", " self.current_step_var += 1\n", "\n", - " if done:\n", + " if terminated or truncated:\n", " self.episode_rewards.append((self.current_step_var, self.current_reward))\n", " self.episode_lens.append((self.current_step_var, self.current_len))\n", "\n", - " return obs, rew, done, info\n", + " return obs, rew, terminated, truncated, info\n", "\n", " def reset(self, **kwargs):\n", " self.episode_counter += 1\n", @@ -182,8 +173,8 @@ "source": [ "from mujoco_wrappers import Normalize\n", "\n", - "env = Normalize(Summaries(gym.make(\"HalfCheetahMuJoCoEnv-v0\")));\n", - "env.unwrapped.seed(0);" + "env = Normalize(Summaries(gym.make(\"HalfCheetah-v4\", render_mode=\"rgb_array\")));\n", + "env.reset(seed=0)" ] }, { @@ -221,7 +212,7 @@ "\n", "class PolicyModel(nn. Module):\n", " def __init__(self):\n", - " super(PolicyModel, self).__init__()\n", + " super().__init__()\n", " self.h = 64\n", "\n", " self.policy_model = < Create your model >\n", @@ -371,16 +362,16 @@ " self.nsteps = nsteps\n", " self.transforms = transforms or []\n", " self.step_var = step_var if step_var is not None else 0\n", - " self.state = {\"latest_observation\": self.env.reset()}\n", + " self.state = {\"latest_observation\": self.env.reset()[0]}\n", "\n", " @property\n", " def nenvs(self):\n", " \"\"\" Returns number of batched envs or `None` if env is not batched \"\"\"\n", " return getattr(self.env.unwrapped, \"nenvs\", None)\n", "\n", - " def reset(self):\n", + " def reset(self, **kwargs):\n", " \"\"\" Resets env and runner states. \"\"\"\n", - " self.state[\"latest_observation\"] = self.env.reset()\n", + " self.state[\"latest_observation\"], info = self.env.reset(**kwargs)\n", " self.policy.reset()\n", "\n", " def get_next(self):\n", @@ -400,7 +391,8 @@ " for key, val in act.items():\n", " trajectory[key].append(val)\n", "\n", - " obs, rew, done, _ = self.env.step(trajectory[\"actions\"][-1])\n", + " obs, rew, terminated, truncated, _ = self.env.step(trajectory[\"actions\"][-1])\n", + " done = np.logical_or(terminated, truncated)\n", " self.state[\"latest_observation\"] = obs\n", " rewards.append(rew)\n", " resets.append(done)\n", @@ -410,7 +402,7 @@ " # auto-reset.\n", " if not self.nenvs and np.all(done):\n", " self.state[\"env_steps\"] = i + 1\n", - " self.state[\"latest_observation\"] = self.env.reset()\n", + " self.state[\"latest_observation\"] = self.env.reset()[0]\n", "\n", " trajectory.update(\n", " observations=observations,\n", diff --git a/week09_policy_II/seminar_TRPO_pytorch.ipynb b/week09_policy_II/seminar_TRPO_pytorch.ipynb index 984f6782b..f2d8dcc8f 100644 --- a/week09_policy_II/seminar_TRPO_pytorch.ipynb +++ b/week09_policy_II/seminar_TRPO_pytorch.ipynb @@ -9,6 +9,9 @@ "import sys, os\n", "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", + "\n", + " !pip install -q gymnasium\n", + "\n", " !touch .setup_complete\n", "\n", "# This code creates a virtual display to draw game images on.\n", @@ -52,9 +55,9 @@ "metadata": {}, "outputs": [], "source": [ - "import gym\n", + "import gymnasium as gym\n", "\n", - "env = gym.make(\"Acrobot-v1\")\n", + "env = gym.make(\"Acrobot-v1\", render_mode=\"rgb_array\")\n", "env.reset()\n", "observation_shape = env.observation_space.shape\n", "n_actions = env.action_space.n\n", @@ -69,9 +72,7 @@ "metadata": {}, "outputs": [], "source": [ - "from PIL import Image\n", - "\n", - "Image.fromarray(env.render('rgb_array'))" + "plt.imshow(env.render())" ] }, { @@ -101,7 +102,7 @@ " We recommend that you start simple:\n", " use 1-2 hidden layers with 100-500 units and relu for the first try\n", " '''\n", - " nn.Module.__init__(self)\n", + " super().__init__()\n", "\n", " assert isinstance(state_shape, tuple)\n", " assert len(state_shape) == 1\n", @@ -162,7 +163,7 @@ "outputs": [], "source": [ "# Check if log-probabilities satisfies all the requirements\n", - "log_probs = agent.get_log_probs(torch.tensor(env.reset()[np.newaxis], dtype=torch.float32))\n", + "log_probs = agent.get_log_probs(torch.tensor(env.reset()[0][np.newaxis], dtype=torch.float32))\n", "assert (\n", " isinstance(log_probs, torch.Tensor) and\n", " log_probs.requires_grad\n", @@ -172,8 +173,8 @@ "assert torch.allclose(sums, torch.ones_like(sums))\n", "\n", "# Demo use\n", - "print(\"sampled:\", [agent.act(env.reset()) for _ in range(5)])\n", - "print(\"greedy:\", [agent.act(env.reset(), sample=False) for _ in range(5)])" + "print(\"sampled:\", [agent.act(env.reset()[0]) for _ in range(5)])\n", + "print(\"greedy:\", [agent.act(env.reset()[0], sample=False) for _ in range(5)])" ] }, { @@ -270,16 +271,16 @@ " total_timesteps = 0\n", " while total_timesteps < n_timesteps:\n", " obervations, actions, rewards, action_probs = [], [], [], []\n", - " obervation = env.reset()\n", + " obervation, _ = env.reset()\n", " for _ in range(max_pathlength):\n", " action, policy = agent.act(obervation)\n", " obervations.append(obervation)\n", " actions.append(action)\n", " action_probs.append(policy)\n", - " obervation, reward, done, _ = env.step(action)\n", + " obervation, reward, terminated, truncated, _ = env.step(action)\n", " rewards.append(reward)\n", " total_timesteps += 1\n", - " if done or total_timesteps >= n_timesteps:\n", + " if terminated or truncated or total_timesteps >= n_timesteps:\n", " path = {\n", " \"observations\": np.array(obervations),\n", " \"policy\": np.array(action_probs),\n", @@ -697,8 +698,8 @@ "# Homework option II (10+pts)\n", "\n", "Let's use TRPO to train evil robots! (pick any of two)\n", - "* [MuJoCo robots](https://gym.openai.com/envs#mujoco)\n", - "* [Box2d robot](https://gym.openai.com/envs/BipedalWalker-v2)\n", + "* [MuJoCo robots](https://gymnasium.farama.org/environments/mujoco/#mujoco)\n", + "* [Box2d robot](https://gymnasium.farama.org/environments/box2d/bipedal_walker/)\n", "\n", "The catch here is that those environments have continuous action spaces.\n", "\n", diff --git a/week09_policy_II/td3_and_sac/hw-continuous-control_pytorch.ipynb b/week09_policy_II/td3_and_sac/hw-continuous-control_pytorch.ipynb new file mode 100644 index 000000000..714df58e0 --- /dev/null +++ b/week09_policy_II/td3_and_sac/hw-continuous-control_pytorch.ipynb @@ -0,0 +1,1066 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys, os\n", + "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", + "\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week09_policy_II/td3_and_sac/logger.py\n", + "\n", + " !pip -q install gymnasium[mujoco]\n", + " !pip -q install tensorboardX\n", + "\n", + " !touch .setup_complete\n", + "\n", + "# This code creates a virtual display to draw game images on.\n", + "# It will have no effect if your machine has a monitor.\n", + "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", + " !bash ../xvfb start\n", + " os.environ['DISPLAY'] = ':1'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Continuous Control\n", + "\n", + "\n", + "In this notebook you will solve continuous control environment using either [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf) or [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1801.01290.pdf). Both are off-policy algorithms that are current go-to algorithms for continuous control tasks.\n", + "\n", + "**Select one** of these two algorithms (TD3 or SAC) to implement. Both algorithms are extensions of basic [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971) algorithm, and DDPG is kind of \"DQN with another neural net approximating greedy policy\", and all that differs is a set of stabilization tricks:\n", + "* TD3 trains deterministic policy, while SAC uses *stochastic policy*. This means that for SAC you can solve exploration-exploitation trade-off by simple sampling from policy, while in TD3 you will have to add noise to your actions.\n", + "* TD3 proposes to stabilize targets by adding a *clipped noise* to actions, which slightly prevents overestimation. In SAC, we formally switch to formalism of Maximum Entropy RL and add *entropy bonus* into our value function.\n", + "\n", + "Also both algorithms utilize a *twin trick*: train two critics and use pessimistic targets by taking minimum from two proposals. Standard trick with target networks is also necessary. We will go through all these tricks step-by-step.\n", + "\n", + "SAC is probably less clumsy scheme than TD3, but requires a bit more code to implement. More detailed description of algorithms can be found in Spinning Up documentation:\n", + "* on [DDPG](https://spinningup.openai.com/en/latest/algorithms/ddpg.html)\n", + "* on [TD3](https://spinningup.openai.com/en/latest/algorithms/td3.html)\n", + "* on [SAC](https://spinningup.openai.com/en/latest/algorithms/sac.html)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import gymnasium as gym\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we will create an instance of the environment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-09-16T18:41:00.003174Z", + "start_time": "2020-09-16T18:40:59.921640Z" + } + }, + "outputs": [], + "source": [ + "env = gym.make(\"Ant-v4\", render_mode=\"rgb_array\")\n", + "\n", + "# we want to look inside\n", + "env.reset()\n", + "\n", + "# examples of states and actions\n", + "print(\"observation space: \", env.observation_space,\n", + " \"\\nobservations:\", env.reset()[0])\n", + "print(\"action space: \", env.action_space,\n", + " \"\\naction_sample: \", env.action_space.sample())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.imshow(env.render())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's run random policy and see how it looks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class RandomActor():\n", + " def get_action(self, states):\n", + " assert len(states.shape) == 1, \"can't work with batches\"\n", + " return env.action_space.sample()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "s, _ = env.reset()\n", + "rewards_per_step = []\n", + "actor = RandomActor()\n", + "\n", + "for i in range(10000):\n", + " a = actor.get_action(s)\n", + " s, r, terminated, truncated, _ = env.step(a)\n", + "\n", + " rewards_per_step.append(r)\n", + "\n", + " if terminated or truncated:\n", + " s, _ = env.reset()\n", + " print(\"done: \", i)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So, basically most episodes are 1000 steps long (then happens termination by time), though sometimes we are terminated earlier if simulation discovers some obvious reasons to think that we crashed our ant. Important thing about continuous control tasks like this is that we receive non-trivial signal at each step: " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rewards_per_step[100:110]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This dense signal will guide our optimizations. It also partially explains why off-policy algorithms are more effective and sample-efficient than on-policy algorithms like PPO: 1-step targets are already quite informative." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will add only one wrapper to our environment to simply write summaries, mainly, the total reward during an episode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from logger import TensorboardSummaries as Summaries\n", + "\n", + "env = gym.make(\"Ant-v4\", render_mode=\"rgb_array\")\n", + "env = Summaries(env, \"MyFirstWalkingAnt\");\n", + "\n", + "state_dim = env.observation_space.shape[0] # dimension of state space (27 numbers)\n", + "action_dim = env.action_space.shape[0] # dimension of action space (8 numbers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's start with *critic* model. On the one hand, it will function as an approximation of $Q^*(s, a)$, on the other hand it evaluates current actor $\\pi$ and can be viewed as $Q^{\\pi}(s, a)$. This critic will take both state $s$ and action $a$ as input and output a scalar value. Recommended architecture is 3-layered MLP.\n", + "\n", + "**Danger:** when models have a scalar output it is a good rule to squeeze it to avoid unexpected broadcasting, since [batch_size, 1] broadcasts with many tensor sizes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "class Critic(nn.Module):\n", + " def __init__(self, state_dim, action_dim):\n", + " super().__init__() \n", + "\n", + " \n", + "\n", + " def get_qvalues(self, states, actions):\n", + " '''\n", + " input:\n", + " states - tensor, (batch_size x features)\n", + " actions - tensor, (batch_size x actions_dim)\n", + " output:\n", + " qvalues - tensor, critic estimation, (batch_size)\n", + " '''\n", + " qvalues = \n", + "\n", + " assert len(qvalues.shape) == 1 and qvalues.shape[0] == states.shape[0]\n", + " \n", + " return qvalues" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, let's define a policy, or an actor $\\pi$. Use architecture, similar to critic (3-layered MLP). The output depends on algorithm:\n", + "\n", + "For **TD3**, model *deterministic policy*. You should output `action_dim` numbers in range $[-1, 1]$. Unfortunately, deterministic policies lead to problems with stability and exploration, so we will need three \"modes\" of how this policy can be operating:\n", + "* First one - greedy - is a simple feedforward pass through network that will be used to train the actor.\n", + "* Second one - exploration mode - is when we need to add noise (e.g. Gaussian) to our actions to collect more diverse data. \n", + "* Third mode - \"clipped noised\" - will be used when we will require a target for critic, where we need to somehow \"noise\" our actor output, but not too much, so we add *clipped noise* to our output:\n", + "$$\\pi_{\\theta}(s) + \\varepsilon, \\quad \\varepsilon = \\operatorname{clip}(\\epsilon, -0.5, 0.5), \\epsilon \\sim \\mathcal{N}(0, \\sigma^2 I)$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-09-16T18:41:06.246418Z", + "start_time": "2020-09-16T18:41:05.841255Z" + } + }, + "outputs": [], + "source": [ + "# template for TD3; template for SAC is below\n", + "class TD3_Actor(nn.Module):\n", + " def __init__(self, state_dim, action_dim):\n", + " super().__init__() \n", + "\n", + " \n", + "\n", + " def get_action(self, states, std_noise=0.1):\n", + " '''\n", + " Used to collect data by interacting with environment,\n", + " so your have to add some noise to actions.\n", + " input:\n", + " states - numpy, (batch_size x features)\n", + " output:\n", + " actions - numpy, (batch_size x actions_dim)\n", + " '''\n", + " # no gradient computation is required here since we will use this only for interaction\n", + " with torch.no_grad():\n", + " actions = \n", + " \n", + " assert isinstance(actions, (list,np.ndarray)), \"convert actions to numpy to send into env\"\n", + " assert actions.max() <= 1. and actions.min() >= -1, \"actions must be in the range [-1, 1]\"\n", + " return actions\n", + " \n", + " def get_best_action(self, states):\n", + " '''\n", + " Will be used to optimize actor. Requires differentiable w.r.t. parameters actions.\n", + " input:\n", + " states - PyTorch tensor, (batch_size x features)\n", + " output:\n", + " actions - PyTorch tensor, (batch_size x actions_dim)\n", + " '''\n", + " actions = \n", + " \n", + " assert actions.requires_grad, \"you must be able to compute gradients through actions\"\n", + " return actions\n", + " \n", + " def get_target_action(self, states, std_noise=0.2, clip_eta=0.5):\n", + " '''\n", + " Will be used to create target for critic optimization.\n", + " Returns actions with added \"clipped noise\".\n", + " input:\n", + " states - PyTorch tensor, (batch_size x features)\n", + " output:\n", + " actions - PyTorch tensor, (batch_size x actions_dim)\n", + " '''\n", + " # no gradient computation is required here since we will use this only for interaction\n", + " with torch.no_grad():\n", + " actions = \n", + " \n", + " # actions can fly out of [-1, 1] range after added noise\n", + " return actions.clamp(-1, 1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For **SAC**, model *gaussian policy*. This means policy distribution is going to be multivariate normal with diagonal covariance. The policy head will predict the mean and covariance, and it should be guaranteed that covariance is non-negative. **Important:** the way you model covariance strongly influences optimization procedure, so here are some options: let $f_{\\theta}$ be the output of covariance head, then:\n", + "* use exponential function $\\sigma(s) = \\exp(f_{\\theta}(s))$\n", + "* transform output to $[-1, 1]$ using `tanh`, then project output to some interval $[m, M]$, where $m = -20$, $M = 2$ and then use exponential function. This will guarantee the range of modeled covariance is adequate. So, the resulting formula is:\n", + "$$\\sigma(s) = \\exp^{m + 0.5(M - m)(\\tanh(f_{\\theta}(s)) + 1)}$$\n", + "* `softplus` operation $\\sigma(s) = \\log(1 + \\exp^{f_{\\theta}(s)})$ seems to work poorly here. o_O\n", + "\n", + "**Note**: `torch.distributions.Normal` already has everything you will need to work with such policy after you modeled mean and covariance, i.e. sampling via reparametrization trick (see `rsample` method) and compute log probability (see `log_prob` method).\n", + "\n", + "There is one more problem with gaussian distribution. We need to force our actions to be in $[-1, 1]$ bound. To achieve this, model unbounded gaussian $\\mathcal{N}(\\mu_{\\theta}(s), \\sigma_{\\theta}(s)^2I)$, where $\\mu$ can be arbitrary. Then every time you have samples $u$ from this gaussian policy, squash it using $\\operatorname{tanh}$ function to get a sample from $[-1, 1]$:\n", + "$$u \\sim \\mathcal{N}(\\mu, \\sigma^2I)$$\n", + "$$a = \\operatorname{tanh}(u)$$\n", + "\n", + "**Important:** after that you are required to use change of variable formula every time you compute likelihood (see appendix C in [paper on SAC](https://arxiv.org/pdf/1801.01290.pdf) for details):\n", + "$$\\log p(a \\mid \\mu, \\sigma) = \\log p(u \\mid \\mu, \\sigma) - \\sum_{i = 1}^D \\log (1 - \\operatorname{tanh}^2(u_i)),$$\n", + "where $D$ is `action_dim`. In practice, add something like 1e-6 inside logarithm to protect from computational instabilities." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-09-16T18:41:06.246418Z", + "start_time": "2020-09-16T18:41:05.841255Z" + } + }, + "outputs": [], + "source": [ + "# template for SAC\n", + "from torch.distributions import Normal\n", + "\n", + "class SAC_Actor(nn.Module):\n", + " def __init__(self, state_dim, action_dim):\n", + " super().__init__() \n", + "\n", + " \n", + " \n", + " def apply(self, states):\n", + " '''\n", + " For given batch of states samples actions and also returns its log prob.\n", + " input:\n", + " states - PyTorch tensor, (batch_size x features)\n", + " output:\n", + " actions - PyTorch tensor, (batch_size x action_dim)\n", + " log_prob - PyTorch tensor, (batch_size)\n", + " '''\n", + " \n", + " \n", + " return actions, log_prob \n", + "\n", + " def get_action(self, states):\n", + " '''\n", + " Used to interact with environment by sampling actions from policy\n", + " input:\n", + " states - numpy, (batch_size x features)\n", + " output:\n", + " actions - numpy, (batch_size x actions_dim)\n", + " '''\n", + " # no gradient computation is required here since we will use this only for interaction\n", + " with torch.no_grad():\n", + " \n", + " # hint: you can use `apply` method here\n", + " actions = \n", + " \n", + " assert isinstance(actions, (list,np.ndarray)), \"convert actions to numpy to send into env\"\n", + " assert actions.max() <= 1. and actions.min() >= -1, \"actions must be in the range [-1, 1]\"\n", + " return actions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ReplayBuffer\n", + "\n", + "The same as in DQN. You can copy code from your DQN assignment, just check that it works fine with continuous actions (probably it is). \n", + "\n", + "Let's recall the interface:\n", + "* `exp_replay.add(obs, act, rw, next_obs, done)` - saves (s,a,r,s',done) tuple into the buffer\n", + "* `exp_replay.sample(batch_size)` - returns observations, actions, rewards, next_observations and is_done for `batch_size` random samples.\n", + "* `len(exp_replay)` - returns number of elements stored in replay buffer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class ReplayBuffer():\n", + " def __init__(self, size):\n", + " \"\"\"\n", + " Create Replay buffer.\n", + " Parameters\n", + " ----------\n", + " size: int\n", + " Max number of transitions to store in the buffer. When the buffer\n", + " overflows the old memories are dropped.\n", + "\n", + " Note: for this assignment you can pick any data structure you want.\n", + " If you want to keep it simple, you can store a list of tuples of (s, a, r, s') in self._storage\n", + " However you may find out there are faster and/or more memory-efficient ways to do so.\n", + " \"\"\"\n", + " self._storage = []\n", + " self._maxsize = size\n", + "\n", + " # OPTIONAL: YOUR CODE\n", + "\n", + " def __len__(self):\n", + " return len(self._storage)\n", + "\n", + " def add(self, obs_t, action, reward, obs_tp1, done):\n", + " '''\n", + " Make sure, _storage will not exceed _maxsize. \n", + " Make sure, FIFO rule is being followed: the oldest examples has to be removed earlier\n", + " ''' \n", + " data = (obs_t, action, reward, obs_tp1, done)\n", + " storage = self._storage\n", + " maxsize = self._maxsize\n", + " \n", + " # add data to storage\n", + "\n", + " def sample(self, batch_size):\n", + " \"\"\"Sample a batch of experiences.\n", + " Parameters\n", + " ----------\n", + " batch_size: int\n", + " How many transitions to sample.\n", + " Returns\n", + " -------\n", + " obs_batch: np.array\n", + " batch of observations\n", + " act_batch: np.array\n", + " batch of actions executed given obs_batch\n", + " rew_batch: np.array\n", + " rewards received as results of executing act_batch\n", + " next_obs_batch: np.array\n", + " next set of observations seen after executing act_batch\n", + " done_mask: np.array\n", + " done_mask[i] = 1 if executing act_batch[i] resulted in\n", + " the end of an episode and 0 otherwise.\n", + " \"\"\"\n", + " storage = self._storage\n", + " \n", + " # randomly generate batch_size integers\n", + " # to be used as indexes of samples\n", + " \n", + " \n", + " # collect for each index\n", + " \n", + " return \n", + " # , , , , " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "exp_replay = ReplayBuffer(10)\n", + "\n", + "for _ in range(30):\n", + " exp_replay.add(env.reset()[0], env.action_space.sample(),\n", + " 1.0, env.reset()[0], done=False)\n", + "\n", + "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(\n", + " 5)\n", + "\n", + "assert len(exp_replay) == 10, \"experience replay size should be 10 because that's what maximum capacity is\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def play_and_record(initial_state, agent, env, exp_replay, n_steps=1):\n", + " \"\"\"\n", + " Play the game for exactly n steps, record every (s,a,r,s', done) to replay buffer. \n", + " Whenever game ends, add record with done=True and reset the game.\n", + " It is guaranteed that env has done=False when passed to this function.\n", + "\n", + " :returns: return sum of rewards over time and the state in which the env stays\n", + " \"\"\"\n", + " s = initial_state\n", + " sum_rewards = 0\n", + "\n", + " # Play the game for n_steps as per instructions above\n", + " for t in range(n_steps):\n", + " \n", + " # select action using policy with exploration\n", + " a = \n", + " \n", + " ns, r, terminated, truncated, _ = env.step(a)\n", + " \n", + " exp_replay.add(s, a, r, ns, terminated)\n", + " \n", + " s = env.reset()[0] if terminated or truncated else ns\n", + " \n", + " sum_rewards += r \n", + "\n", + " return sum_rewards, s" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#testing your code.\n", + "exp_replay = ReplayBuffer(2000)\n", + "actor = (state_dim, action_dim).to(DEVICE)\n", + "\n", + "state, _ = env.reset()\n", + "play_and_record(state, actor, env, exp_replay, n_steps=1000)\n", + "\n", + "# if you're using your own experience replay buffer, some of those tests may need correction.\n", + "# just make sure you know what your code does\n", + "assert len(exp_replay) == 1000, \"play_and_record should have added exactly 1000 steps, \"\\\n", + " \"but instead added %i\" % len(exp_replay)\n", + "is_dones = list(zip(*exp_replay._storage))[-1]\n", + "\n", + "for _ in range(100):\n", + " obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(\n", + " 10)\n", + " assert obs_batch.shape == next_obs_batch.shape == (10,) + (state_dim,)\n", + " assert act_batch.shape == (\n", + " 10, action_dim), \"actions batch should have shape (10, 8) but is instead %s\" % str(act_batch.shape)\n", + " assert reward_batch.shape == (\n", + " 10,), \"rewards batch should have shape (10,) but is instead %s\" % str(reward_batch.shape)\n", + " assert is_done_batch.shape == (\n", + " 10,), \"is_done batch should have shape (10,) but is instead %s\" % str(is_done_batch.shape)\n", + " assert [int(i) in (0, 1)\n", + " for i in is_dones], \"is_done should be strictly True or False\"\n", + "\n", + "print(\"Well done!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialization\n", + "\n", + "Let's start initializing our algorithm. Here is our hyperparameters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gamma=0.99 # discount factor\n", + "max_buffer_size = 10**5 # size of experience replay\n", + "start_timesteps = 5000 # size of experience replay when start training\n", + "timesteps_per_epoch=1 # steps in environment per step of network updates\n", + "batch_size=128 # batch size for all optimizations\n", + "max_grad_norm=10 # max grad norm for all optimizations\n", + "tau=0.005 # speed of updating target networks\n", + "policy_update_freq=<> # frequency of actor update; vanilla choice is 2 for TD3 or 1 for SAC\n", + "alpha=0.1 # temperature for SAC\n", + "\n", + "# iterations passed\n", + "n_iterations = 0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here is our experience replay:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# experience replay\n", + "exp_replay = ReplayBuffer(max_buffer_size)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here is our models: *two* critics and one actor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# models to train\n", + "actor = (state_dim, action_dim).to(DEVICE)\n", + "critic1 = Critic(state_dim, action_dim).to(DEVICE)\n", + "critic2 = Critic(state_dim, action_dim).to(DEVICE)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To stabilize training, we will require **target networks** - slow updating copies of our models. In **TD3**, both critics and actor have their copies, in **SAC** it is assumed that only critics require target copies while actor is always used fresh." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# target networks: slow-updated copies of actor and two critics\n", + "target_critic1 = Critic(state_dim, action_dim).to(DEVICE)\n", + "target_critic2 = Critic(state_dim, action_dim).to(DEVICE)\n", + "target_actor = TD3_Actor(state_dim, action_dim).to(DEVICE) # comment this line if you chose SAC\n", + "\n", + "# initialize them as copies of original models\n", + "target_critic1.load_state_dict(critic1.state_dict())\n", + "target_critic2.load_state_dict(critic2.state_dict())\n", + "target_actor.load_state_dict(actor.state_dict()) # comment this line if you chose SAC " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In continuous control, target networks are usually updated using exponential smoothing:\n", + "$$\\theta^{-} \\leftarrow \\tau \\theta + (1 - \\tau) \\theta^{-},$$\n", + "where $\\theta^{-}$ are target network weights, $\\theta$ - fresh parameters, $\\tau$ - hyperparameter. This util function will do it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def update_target_networks(model, target_model):\n", + " for param, target_param in zip(model.parameters(), target_model.parameters()):\n", + " target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we will have three optimization procedures to train our three models, so let's welcome our three Adams:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# optimizers: for every model we have\n", + "opt_actor = torch.optim.Adam(actor.parameters(), lr=3e-4)\n", + "opt_critic1 = torch.optim.Adam(critic1.parameters(), lr=3e-4)\n", + "opt_critic2 = torch.optim.Adam(critic2.parameters(), lr=3e-4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# just to avoid writing this code three times\n", + "def optimize(name, model, optimizer, loss):\n", + " '''\n", + " Makes one step of SGD optimization, clips norm with max_grad_norm and \n", + " logs everything into tensorboard\n", + " '''\n", + " loss = loss.mean()\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)\n", + " optimizer.step()\n", + "\n", + " # logging\n", + " env.writer.add_scalar(name, loss.item(), n_iterations)\n", + " env.writer.add_scalar(name + \"_grad_norm\", grad_norm.item(), n_iterations)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Critic target computation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, let's discuss our losses for critic and actor.\n", + "\n", + "To train both critics we would like to minimize MSE using 1-step targets: for one sampled transition $(s, a, r, s')$ it should look something like this:\n", + "$$y(s, a) = r + \\gamma V(s').$$\n", + "\n", + "How do we evaluate next state and compute $V(s')$? Well, technically Monte-Carlo estimation looks simple:\n", + "$$V(s') \\approx Q(s', a')$$\n", + "where (important!) $a'$ is a sample from our current policy $\\pi(a' \\mid s')$.\n", + "\n", + "But out actor $\\pi$ will be actually trained to search for actions $a'$ where our critic gives big estimates, and this straightforward approach leads to serious overesimation issues. We require some hacks. First, we will use target networks for $Q$ (and **TD3** also uses target network for $\\pi$). Second, we will use *two* critics and take minimum across their estimations:\n", + "$$V(s') = \\min_{i = 1,2} Q^{-}_i(s', a'),$$\n", + "where $a'$ is sampled from target policy $\\pi^{-}(a' \\mid s')$ in **TD3** and from fresh policy $\\pi(a' \\mid s')$ in **SAC**.\n", + "\n", + "###### And the last but not the least:\n", + "* in **TD3** to compute $a'$ use *mode with clipped noise* that will prevent our policy from exploiting narrow peaks in our critic approximation;\n", + "* in **SAC** add (estimation of) entropy bonus in next state $s'$:\n", + "$$V(s') = \\min_{i = 1,2} Q^{-}_i(s', a') - \\alpha \\log \\pi (a' \\mid s')$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_critic_target(rewards, next_states, is_done):\n", + " '''\n", + " Important: use target networks for this method! Do not use \"fresh\" models except fresh policy in SAC!\n", + " input:\n", + " rewards - PyTorch tensor, (batch_size)\n", + " next_states - PyTorch tensor, (batch_size x features)\n", + " is_done - PyTorch tensor, (batch_size)\n", + " output:\n", + " critic target - PyTorch tensor, (batch_size)\n", + " '''\n", + " with torch.no_grad():\n", + " critic_target = \n", + " \n", + " assert not critic_target.requires_grad, \"target must not require grad.\"\n", + " assert len(critic_target.shape) == 1, \"dangerous extra dimension in target?\"\n", + "\n", + " return critic_target" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To train actor we want simply to maximize\n", + "$$\\mathbb{E}_{a \\sim \\pi(a \\mid s)} Q(s, a) \\to \\max_{\\pi}$$\n", + "\n", + "* in **TD3**, because of deterministic policy, the expectation reduces:\n", + "$$Q(s, \\pi(s)) \\to \\max_{\\pi}$$\n", + "* in **SAC**, use reparametrization trick to compute gradients and also do not forget to add entropy regularizer to motivate policy to be as stochastic as possible:\n", + "$$\\mathbb{E}_{a \\sim \\pi(a \\mid s)} Q(s, a) - \\alpha \\log \\pi(a \\mid s) \\to \\max_{\\pi}$$\n", + "\n", + "**Note:** We will use (fresh) critic1 here as Q-functon to \"exploit\". You can also use both critics and again take minimum across their estimations (this is done in original implementation of **SAC** and not done in **TD3**), but this seems to be not of high importance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_actor_loss(states):\n", + " '''\n", + " Returns actor loss on batch of states\n", + " input:\n", + " states - PyTorch tensor, (batch_size x features)\n", + " output:\n", + " actor loss - PyTorch tensor, (batch_size)\n", + " '''\n", + " # make sure you have gradients w.r.t. actor parameters\n", + " actions = \n", + " \n", + " assert actions.requires_grad, \"actions must be differentiable with respect to policy parameters\"\n", + " \n", + " # compute actor loss\n", + " actor_loss = \n", + " return actor_loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Pipeline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally combining all together and launching our algorithm. Your goal is to reach at least 1000 average reward during evaluation after training in this ant environment (*since this is a new hometask, this threshold might be updated, so at least just see if your ant learned to walk in the rendered simulation*).\n", + "\n", + "* rewards should rise more or less steadily in this environment. There can be some drops due to instabilities of algorithm, but it should eventually start rising after 100K-200K iterations. If no progress in reward is observed after these first 100K-200K iterations, there is a bug.\n", + "* gradient norm appears to be quite big for this task, it is ok if it reaches 100-200 (we handled it with clip_grad_norm). Consider everything exploded if it starts growing exponentially, then there is a bug." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "seed = \n", + "np.random.seed(seed)\n", + "env.unwrapped.seed(seed)\n", + "torch.manual_seed(seed);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm.notebook import trange\n", + "\n", + "interaction_state = env.reset()\n", + "random_actor = RandomActor()\n", + "\n", + "for n_iterations in trange(0, 1000000, timesteps_per_epoch):\n", + " # if experience replay is small yet, no training happens\n", + " # we also collect data using random policy to collect more diverse starting data\n", + " if len(exp_replay) < start_timesteps:\n", + " _, interaction_state = play_and_record(interaction_state, random_actor, env, exp_replay, timesteps_per_epoch)\n", + " continue\n", + " \n", + " # perform a step in environment and store it in experience replay\n", + " _, interaction_state = play_and_record(interaction_state, actor, env, exp_replay, timesteps_per_epoch)\n", + " \n", + " # sample a batch from experience replay\n", + " states, actions, rewards, next_states, is_done = exp_replay.sample(batch_size)\n", + " \n", + " # move everything to PyTorch tensors\n", + " states = torch.tensor(states, device=DEVICE, dtype=torch.float)\n", + " actions = torch.tensor(actions, device=DEVICE, dtype=torch.float)\n", + " rewards = torch.tensor(rewards, device=DEVICE, dtype=torch.float)\n", + " next_states = torch.tensor(next_states, device=DEVICE, dtype=torch.float)\n", + " is_done = torch.tensor(\n", + " is_done.astype('float32'),\n", + " device=DEVICE,\n", + " dtype=torch.float\n", + " )\n", + " \n", + " # losses\n", + " critic1_loss = \n", + " optimize(\"critic1\", critic1, opt_critic1, critic1_loss)\n", + "\n", + " critic2_loss = \n", + " optimize(\"critic2\", critic2, opt_critic2, critic2_loss)\n", + "\n", + " # actor update is less frequent in TD3\n", + " if n_iterations % policy_update_freq == 0:\n", + " actor_loss = \n", + " optimize(\"actor\", actor, opt_actor, actor_loss)\n", + "\n", + " # update target networks\n", + " update_target_networks(critic1, target_critic1)\n", + " update_target_networks(critic2, target_critic2)\n", + " update_target_networks(actor, target_actor) # comment this line if you chose SAC" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-09-16T18:41:47.560269Z", + "start_time": "2020-09-16T18:41:47.546277Z" + } + }, + "outputs": [], + "source": [ + "def evaluate(env, actor, n_games=1, t_max=1000):\n", + " '''\n", + " Plays n_games and returns rewards and rendered games\n", + " '''\n", + " rewards = []\n", + "\n", + " for _ in range(n_games):\n", + " s, _ = env.reset()\n", + "\n", + " R = 0\n", + " for _ in range(t_max):\n", + " # select action for final evaluation of your policy\n", + " action = \n", + "\n", + " assert (action.max() <= 1).all() and (action.min() >= -1).all()\n", + "\n", + " s, r, terminated, truncated, _ = env.step(action)\n", + "\n", + " R += r\n", + "\n", + " if terminated or truncated:\n", + " break\n", + "\n", + " rewards.append(R)\n", + " return np.array(rewards)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-09-16T18:38:45.130920Z", + "start_time": "2020-09-16T18:38:13.090472Z" + } + }, + "outputs": [], + "source": [ + "# evaluation will take some time!\n", + "sessions = evaluate(env, actor, n_games=20)\n", + "score = sessions.mean()\n", + "print(f\"Your score: {score}\")\n", + "\n", + "assert score >= 1000, \"Needs more training?\"\n", + "print(\"Well done!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Record" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-09-16T18:43:19.559507Z", + "start_time": "2020-09-16T18:43:19.522533Z" + } + }, + "outputs": [], + "source": [ + "from gymnasium.wrappers import RecordVideo\n", + "\n", + "# let's hope this will work\n", + "# don't forget to pray\n", + "with gym.make(\"Ant-v4\", render_mode=\"rgb_array\") as env, RecordVideo(\n", + " env=env, video_folder=\"./videos\"\n", + ") as env_monitor:\n", + " # note that t_max is 300, so collected reward will be smaller than 1000\n", + " evaluate(env_monitor, actor, n_games=1, t_max=300)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show video. This may not work in some setups. If it doesn't\n", + "# work for you, you can download the videos and view them locally.\n", + "\n", + "from pathlib import Path\n", + "from base64 import b64encode\n", + "from IPython.display import HTML\n", + "import sys\n", + "\n", + "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n", + "video_path = video_paths[-1] # You can also try other indices\n", + "\n", + "if 'google.colab' in sys.modules:\n", + " # https://stackoverflow.com/a/57378660/1214547\n", + " with video_path.open('rb') as fp:\n", + " mp4 = fp.read()\n", + " data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", + "else:\n", + " data_url = str(video_path)\n", + "\n", + "HTML(\"\"\"\n", + "\n", + "\"\"\".format(data_url))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Report\n", + "\n", + "We'd like to collect some statistics about computational resources you spent on this task. Please, report:\n", + "* which GPU or CPU you used: \n", + "* number of iterations you used for training: \n", + "* wall-clock time spent (on computation =D): " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/week09_policy_II/td3_and_sac/logger.py b/week09_policy_II/td3_and_sac/logger.py new file mode 100644 index 000000000..c393ad2cc --- /dev/null +++ b/week09_policy_II/td3_and_sac/logger.py @@ -0,0 +1,91 @@ +from collections import deque + +import gymnasium as gym +import numpy as np +from tensorboardX import SummaryWriter + + +class TensorboardSummaries(gym.Wrapper): + """Writes env summaries.""" + + def __init__(self, env, prefix=None, running_mean_size=100, step_var=None): + super().__init__(env) + self.episode_counter = 0 + self.prefix = prefix or self.env.spec.id + self.writer = SummaryWriter(f"logs/{self.prefix}") + self.step_var = 0 + + self.nenvs = getattr(self.env.unwrapped, "nenvs", 1) + self.rewards = np.zeros(self.nenvs) + self.had_ended_episodes = np.zeros(self.nenvs, dtype=bool) + self.episode_lengths = np.zeros(self.nenvs) + self.reward_queues = [ + deque([], maxlen=running_mean_size) for _ in range(self.nenvs) + ] + + def should_write_summaries(self): + """Returns true if it's time to write summaries.""" + return np.all(self.had_ended_episodes) + + def add_summaries(self): + """Writes summaries.""" + self.writer.add_scalar( + f"Episodes/total_reward", + np.mean([q[-1] for q in self.reward_queues]), + self.step_var, + ) + self.writer.add_scalar( + f"Episodes/reward_mean_{self.reward_queues[0].maxlen}", + np.mean([np.mean(q) for q in self.reward_queues]), + self.step_var, + ) + self.writer.add_scalar( + f"Episodes/episode_length", np.mean(self.episode_lengths), self.step_var + ) + if self.had_ended_episodes.size > 1: + self.writer.add_scalar( + f"Episodes/min_reward", + min(q[-1] for q in self.reward_queues), + self.step_var, + ) + self.writer.add_scalar( + f"Episodes/max_reward", + max(q[-1] for q in self.reward_queues), + self.step_var, + ) + self.episode_lengths.fill(0) + self.had_ended_episodes.fill(False) + + def step(self, action): + obs, rew, terminated, truncated, info = self.env.step(action) + self.rewards += rew + self.episode_lengths[~self.had_ended_episodes] += 1 + + info_collection = [info] if isinstance(info, dict) else info + terminated_collection = ( + [terminated] if isinstance(terminated, bool) else terminated + ) + truncated_collection = [truncated] if isinstance(truncated, bool) else truncated + done_indices = [ + i + for i, info in enumerate(info_collection) + if info.get( + "real_done", terminated_collection[i] or truncated_collection[i] + ) + ] + for i in done_indices: + if not self.had_ended_episodes[i]: + self.had_ended_episodes[i] = True + self.reward_queues[i].append(self.rewards[i]) + self.rewards[i] = 0 + + self.step_var += self.nenvs + if self.should_write_summaries(): + self.add_summaries() + return obs, rew, terminated, truncated, info + + def reset(self, **kwargs): + self.rewards.fill(0) + self.episode_lengths.fill(0) + self.had_ended_episodes.fill(False) + return self.env.reset(**kwargs) diff --git a/week10_planning/seminar_MCTS.ipynb b/week10_planning/seminar_MCTS.ipynb index 67e8b680e..0b11c0607 100644 --- a/week10_planning/seminar_MCTS.ipynb +++ b/week10_planning/seminar_MCTS.ipynb @@ -71,6 +71,7 @@ "import sys, os\n", "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", + " !pip install -q gymnasium\n", "\n", " !touch .setup_complete\n", "\n", @@ -98,7 +99,7 @@ "source": [ "---\n", "\n", - "But before we do that, we first need to make a wrapper for Gym environments to allow saving and loading game states to facilitate backtracking." + "But before we do that, we first need to make a wrapper for Gymnasium environments to allow saving and loading game states to facilitate backtracking." ] }, { @@ -107,8 +108,7 @@ "metadata": {}, "outputs": [], "source": [ - "import gym\n", - "from gym.core import Wrapper\n", + "import gymnasium as gym\n", "from pickle import dumps, loads\n", "from collections import namedtuple\n", "\n", @@ -117,7 +117,7 @@ " \"action_result\", (\"snapshot\", \"observation\", \"reward\", \"is_done\", \"info\"))\n", "\n", "\n", - "class WithSnapshots(Wrapper):\n", + "class WithSnapshots(gym.Wrapper):\n", " \"\"\"\n", " Creates a wrapper that supports saving and loading environemnt states.\n", " Required for planning algorithms.\n", @@ -128,8 +128,8 @@ " - ...\n", "\n", " You can also use reset() and step() directly for convenience.\n", - " - s = self.reset() # same as self.env.reset()\n", - " - s, r, done, _ = self.step(action) # same as self.env.step(action)\n", + " - s, _ = self.reset() # same as self.env.reset()\n", + " - s, r, terminated, truncated, _ = self.step(action) # same as self.env.step(action)\n", " \n", " Note that while you may use self.render(), it will spawn a window that cannot be pickled.\n", " Thus, you will need to call self.close() before pickling will work again.\n", @@ -153,9 +153,10 @@ " self.render() # close popup windows since we can't pickle them\n", " self.close()\n", " \n", - " if self.unwrapped.viewer is not None:\n", - " self.unwrapped.viewer.close()\n", - " self.unwrapped.viewer = None\n", + " self.unwrapped.screen = None\n", + " self.unwrapped.clock = None\n", + " self.unwrapped.surf = None\n", + "\n", " return dumps(self.env)\n", "\n", " def load_snapshot(self, snapshot, render=False):\n", @@ -181,7 +182,8 @@ "\n", " :returns: next snapshot, next_observation, reward, is_done, info\n", "\n", - " Basically it returns next snapshot and everything that env.step would have returned.\n", + " Basically it returns next snapshot and almost everything that env.step would have returned.\n", + " Note that is_done = terminated or truncated\n", " \"\"\"\n", "\n", " \n", @@ -210,7 +212,9 @@ "outputs": [], "source": [ "# make env\n", - "env = WithSnapshots(gym.make(\"CartPole-v0\"))\n", + "env = WithSnapshots(gym.make(\"CartPole-v1\",\n", + " render_mode=\"rgb_array\",\n", + " max_episode_steps=200))\n", "env.reset()\n", "\n", "n_actions = env.action_space.n" @@ -223,7 +227,7 @@ "outputs": [], "source": [ "print(\"initial_state:\")\n", - "plt.imshow(env.render('rgb_array'))\n", + "plt.imshow(env.render())\n", "env.close()\n", "\n", "# create first snapshot\n", @@ -238,13 +242,16 @@ "source": [ "# play without making snapshots (faster)\n", "while True:\n", - " is_done = env.step(env.action_space.sample())[2]\n", - " if is_done:\n", + " _, _, terminated, truncated, _ = env.step(env.action_space.sample())\n", + " if terminated:\n", " print(\"Whoops! We died!\")\n", " break\n", + " if truncated:\n", + " print(\"Time is over!\")\n", + " break\n", "\n", "print(\"final state:\")\n", - "plt.imshow(env.render('rgb_array'))\n", + "plt.imshow(env.render())\n", "env.close()" ] }, @@ -258,7 +265,7 @@ "env.load_snapshot(snap0)\n", "\n", "print(\"\\n\\nAfter loading snapshot\")\n", - "plt.imshow(env.render('rgb_array'))\n", + "plt.imshow(env.render())\n", "env.close()" ] }, @@ -524,8 +531,8 @@ "metadata": {}, "outputs": [], "source": [ - "env = WithSnapshots(gym.make(\"CartPole-v0\"))\n", - "root_observation = env.reset()\n", + "env = WithSnapshots(gym.make(\"CartPole-v1\", render_mode=\"rgb_array\", max_episode_steps=200))\n", + "root_observation, _ = env.reset()\n", "root_snapshot = env.get_snapshot()\n", "root = Root(root_snapshot, root_observation)" ] @@ -559,7 +566,6 @@ "source": [ "from IPython.display import clear_output\n", "from itertools import count\n", - "from gym.wrappers import Monitor\n", "\n", "total_reward = 0 # sum of rewards\n", "test_env = loads(root_snapshot) # env used to show progress\n", @@ -570,16 +576,16 @@ " best_child = \n", "\n", " # take action\n", - " s, r, done, _ = test_env.step(best_child.action)\n", + " s, r, terminated, truncated, _ = test_env.step(best_child.action)\n", "\n", " # show image\n", " clear_output(True)\n", " plt.title(\"step %i\" % i)\n", - " plt.imshow(test_env.render('rgb_array'))\n", + " plt.imshow(test_env.render())\n", " plt.show()\n", "\n", " total_reward += r\n", - " if done:\n", + " if terminated or truncated:\n", " print(\"Finished with reward = \", total_reward)\n", " break\n", "\n", @@ -624,6 +630,11 @@ "\n", "\"Build this\" assignment\n", "\n", + "Don't forget to run:\n", + "``` \n", + "pip install gymnasium[atari,accept-rom-license]\n", + "``` \n", + "\n", "Apply MCTS to play Atari games. In particular, let's start with ```gym.make(\"MsPacman-ramDeterministic-v0\")```.\n", "\n", "This requires two things:\n",