diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 2ea57e3e1c..9af9fb7a7a 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -18,7 +18,7 @@ Steps to reproduce the behavior: 4. See error **Console logs / stack traces** -Please wrap in [triple backticks (```)](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) to make it easier to read. +Please wrap in triple backticks (```) to make it easier to read. **Screenshots** If applicable, add screenshots to help explain your problem. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 6bf245b4cf..47c73ebf07 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -17,8 +17,8 @@ Describe the changes made in this PR. ### Checklist - [ ] Added tests that prove my fix is effective or that my feature works -- [ ] Updated the [changelog](https://github.com/Unity-Technologies/ml-agents/blob/main/com.unity.ml-agents/CHANGELOG.md) (if applicable) -- [ ] Updated the [documentation](https://github.com/Unity-Technologies/ml-agents/tree/main/docs) (if applicable) -- [ ] Updated the [migration guide](https://github.com/Unity-Technologies/ml-agents/blob/main/docs/Migrating.md) (if applicable) +- [ ] Updated the changelog (if applicable) +- [ ] Updated the documentation (if applicable) +- [ ] Updated the migration guide (if applicable) ### Other comments diff --git a/.github/stale.yml b/.github/stale.yml index 2328fea17c..88e2766248 100644 --- a/.github/stale.yml +++ b/.github/stale.yml @@ -10,15 +10,15 @@ only: issues # Issue specific configuration issues: limitPerRun: 5 - daysUntilStale: 28 - daysUntilClose: 14 + daysUntilStale: 90 + daysUntilClose: 30 markComment: > This issue has been automatically marked as stale because it has not had activity in the - last 28 days. It will be closed in the next 14 days if no further activity occurs. + last 90 days. It will be closed in the next 30 days if no further activity occurs. Thank you for your contributions. closeComment: > This issue has been automatically closed because it has not had activity in the - last 42 days. If this issue is still valid, please ping a maintainer. + last 120 days. If this issue is still valid, please ping a maintainer. Thank you for your contributions. exemptLabels: - request diff --git a/.github/workflows/colab.yml b/.github/workflows/colab.yml index 8c2236cb2c..8e07f3108a 100644 --- a/.github/workflows/colab.yml +++ b/.github/workflows/colab.yml @@ -8,7 +8,10 @@ on: - 'colab/**' - '.github/workflows/colab.yml' push: - branches: [main] + branches: + - main + - develop + - 'release/**' workflow_dispatch: jobs: diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 2fd09d2fe4..f44ec032e6 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -21,7 +21,7 @@ jobs: node-version: '12' - name: Install manual dependencies run: | - sudo npm install -g markdown-link-check + sudo npm install -g markdown-link-check@3.8.7 python -m pip install pre-commit pre-commit install - name: Run markdown checker @@ -43,13 +43,13 @@ jobs: # If one test in the matrix fails we still want to run the others. fail-fast: false matrix: - python-version: [3.7.x, 3.8.x, 3.9.x] + python-version: [3.8.x, 3.9.x, 3.10.x] include: - - python-version: 3.7.x - pip_constraints: test_constraints_min_version.txt - python-version: 3.8.x - pip_constraints: test_constraints_mid_version.txt + pip_constraints: test_constraints_min_version.txt - python-version: 3.9.x + pip_constraints: test_constraints_mid_version.txt + - python-version: 3.10.x pip_constraints: test_constraints_max_version.txt steps: - uses: actions/checkout@v2 @@ -87,7 +87,7 @@ jobs: run: | pytest --cov=ml-agents --cov=ml-agents-envs \ --cov-report=html --junitxml=junit/test-results-${{ matrix.python-version }}.xml \ - -p no:warnings -v -n auto + -p no:warnings -v -n 8 - name: Upload pytest test results uses: actions/upload-artifact@v2 with: diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index e64e672991..53d05952ce 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -3,7 +3,10 @@ name: pre-commit on: pull_request: push: - branches: [main] + branches: + - main + - develop + - 'release/**' workflow_dispatch: jobs: @@ -17,7 +20,7 @@ jobs: submodules: recursive - uses: actions/setup-python@v2 with: - python-version: 3.7.x + python-version: 3.8.x - uses: actions/setup-ruby@v1 env: ImageOS: ubuntu20 @@ -48,8 +51,9 @@ jobs: with: node-version: '12' - name: Install manual dependencies + # pin markdown-link-check version to support multi-level reference link run: | - sudo npm install -g markdown-link-check + sudo npm install -g markdown-link-check@3.8.7 python -m pip install pre-commit pre-commit install - name: Run markdown checker diff --git a/.github/workflows/publish_docs.yaml b/.github/workflows/publish_docs.yaml new file mode 100644 index 0000000000..0cec1bf586 --- /dev/null +++ b/.github/workflows/publish_docs.yaml @@ -0,0 +1,24 @@ +name: Publish HTML Docs + +on: + workflow_dispatch: + +jobs: + publish: + name: Publish Docs to GH Pages + runs-on: [self-hosted, Linux, X64] + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + token: ${{ secrets.PUBLIC_GH_TOKEN }} + - name: Setup Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Publish docs + run: | + pip install mkdocs + git remote add public git@github.com:Unity-Technologies/ml-agents.git + mkdocs gh-deply --clean -r public + diff --git a/.github/workflows/publish_pypi.yaml b/.github/workflows/publish_pypi.yaml index 294aed09cb..1ff4dd8fc9 100644 --- a/.github/workflows/publish_pypi.yaml +++ b/.github/workflows/publish_pypi.yaml @@ -20,10 +20,10 @@ jobs: steps: - uses: actions/checkout@main - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: pip install setuptools wheel twine --user - name: verify git tag vs. version diff --git a/.github/workflows/publish_pypi_python_api.yaml b/.github/workflows/publish_pypi_python_api.yaml index 465b05ac8d..c7adc7e45c 100644 --- a/.github/workflows/publish_pypi_python_api.yaml +++ b/.github/workflows/publish_pypi_python_api.yaml @@ -19,11 +19,11 @@ jobs: package-path: [ml-agents-envs] steps: - - uses: actions/checkout@main - - name: Set up Python 3.7 + - uses: actions/checkout@v2 + - name: Set up Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: pip install setuptools wheel twine --user - name: verify git tag vs. version @@ -39,6 +39,7 @@ jobs: if: startsWith(github.ref, 'refs/tags') && contains(github.ref, 'test') uses: actions/gh-action-pypi-publish@717ba43cfbb0387f6ce311b169a825772f54d295 with: + user: __token__ password: ${{ secrets.TEST_PYPI_PASSWORD }} repository_url: https://test.pypi.org/legacy/ packages_dir: ${{ matrix.package-path }}/dist/ @@ -46,5 +47,6 @@ jobs: if: startsWith(github.ref, 'refs/tags') && !contains(github.ref, 'test') uses: actions/gh-action-pypi-publish@717ba43cfbb0387f6ce311b169a825772f54d295 with: + user: __token__ password: ${{ secrets.PYPI_PASSWORD }} packages_dir: ${{ matrix.package-path }}/dist/ diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 23c8a77acc..9e50aaea5d 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -9,7 +9,10 @@ on: - 'test_requirements.txt' - '.github/workflows/pytest.yml' push: - branches: [main] + branches: + - main + - develop + - 'release/**' workflow_dispatch: inputs: pytest_markers: @@ -36,13 +39,13 @@ jobs: # If one test in the matrix fails we still want to run the others. fail-fast: false matrix: - python-version: [3.7.x, 3.8.x, 3.9.x] + python-version: [3.8.x, 3.9.x, 3.10.x] include: - - python-version: 3.7.x - pip_constraints: test_constraints_min_version.txt - python-version: 3.8.x - pip_constraints: test_constraints_mid_version.txt + pip_constraints: test_constraints_min_version.txt - python-version: 3.9.x + pip_constraints: test_constraints_mid_version.txt + - python-version: 3.10.x pip_constraints: test_constraints_max_version.txt steps: - uses: actions/checkout@v2 @@ -88,7 +91,7 @@ jobs: run: | pytest --cov=ml-agents --cov=ml-agents-envs \ --cov-report=html --junitxml=junit/test-results-${{ matrix.python-version }}.xml \ - -p no:warnings -v -m "${{ steps.pytest_marker.outputs.markers }}" -n auto + -p no:warnings -v -m "${{ steps.pytest_marker.outputs.markers }}" -n 8 - name: Upload pytest test results uses: actions/upload-artifact@v2 with: diff --git a/.gitmodules b/.gitmodules index cd43b6df8e..e69de29bb2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "com.unity.ml-agents"] - path = com.unity.ml-agents - url = ../com.unity.ml-agents.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ba9a84327d..8d15b038b8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/python/black - rev: 22.1.0 + rev: 22.3.0 hooks: - id: black exclude: > @@ -25,7 +25,7 @@ repos: exclude: ".*_pb2.py" args: [--ignore-missing-imports, --disallow-incomplete-defs, --no-strict-optional] additional_dependencies: [types-PyYAML, types-attrs, types-protobuf, types-setuptools, types-filelock] -- repo: https://gitlab.com/pycqa/flake8 +- repo: https://github.com/PyCQA/flake8 rev: 3.9.2 hooks: - id: flake8 @@ -128,6 +128,6 @@ repos: - id: generate-markdown-docs name: generate markdown docs language: python - entry: ./utils/generate_markdown_docs.py --package_dirs ml-agents-envs + entry: ./utils/generate_markdown_docs.py --package_dirs ml-agents-envs ml-agents pass_filenames: false additional_dependencies: [pyyaml, pydoc-markdown==3.10.1] diff --git a/.yamato/com.unity.ml-agents-coverage.yml b/.yamato/com.unity.ml-agents-coverage.yml index 09ef550417..f46612541b 100644 --- a/.yamato/com.unity.ml-agents-coverage.yml +++ b/.yamato/com.unity.ml-agents-coverage.yml @@ -12,11 +12,10 @@ test_coverage_{{ package.name }}_{{ platform.name }}_{{ editor.version }}: image: {{ platform.image }} flavor: {{ platform.flavor}} commands: - - git submodule update --init --recursive - - npm install upm-ci-utils@stable -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm + - npm install upm-ci-utils@1.27.0 -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm - upm-ci project test -u {{ editor.version }} --type project-tests --project-path {{ editor.testProject }} --package-filter {{ package.name }} {{ coverageOptions }} --extra-utr-arg "reruncount=2" - | - conda activate python3.7 + conda activate python3.8 python3 ml-agents/tests/yamato/check_coverage_percent.py upm-ci~/test-results/ {{ package.minCoveragePct }} artifacts: logs: @@ -29,6 +28,7 @@ test_coverage_{{ package.name }}_{{ platform.name }}_{{ editor.version }}: {% if platform.name == "linux" %} expression: | (pull_request.target eq "main" OR + pull_request.target eq "develop" OR pull_request.target match "release.+") AND NOT pull_request.draft AND (pull_request.changes.any match "com.unity.ml-agents/**" OR diff --git a/.yamato/com.unity.ml-agents-optional-dep-tests.yml b/.yamato/com.unity.ml-agents-optional-dep-tests.yml deleted file mode 100644 index c70b4cf250..0000000000 --- a/.yamato/com.unity.ml-agents-optional-dep-tests.yml +++ /dev/null @@ -1,66 +0,0 @@ -optional_deps: - - name: Analytics - project: "OptionalDepedencyTests/NoAnalyticsModule" - version: 2020.3 - - name: Physics - project: OptionalDepedencyTests/NoPhysicsModule - version: 2020.3 - - name: Physics2D - project: OptionalDepedencyTests/NoPhysics2DModule - version: 2020.3 ---- - - {% for optional_dep in optional_deps %} -OptionalDependencyTests_{{ optional_dep.name }}: - name : Test Optional Package Dependencies {{ optional_dep.name }} - agent: - type: Unity::VM - image: ml-agents/ml-agents-ubuntu-18.04:latest - flavor: b1.medium - commands: - - git submodule update --init --recursive - - | - curl -L https://artifactory.prd.it.unity3d.com/artifactory/api/gpg/key/public | sudo apt-key add - - sudo sh -c "echo 'deb https://artifactory.prd.it.unity3d.com/artifactory/unity-apt-local bionic main' > /etc/apt/sources.list.d/unity.list" - sudo apt update - sudo apt install -y unity-config - npm install upm-ci-utils@stable -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm - unity-config settings editor-path ./.Editor - unity-config project create opt-deps-test - unity-config project add dependency com.unity.ml-agents/ - unity-config project add testable com.unity.ml-agents - unity-config project add dependency com.unity.modules.imageconversion@1.0.0 - unity-config project add dependency com.unity.modules.jsonserialize@1.0.0 - {% unless optional_dep.name == "Physics" %} - unity-config project add dependency com.unity.modules.physics@1.0.0 - {% endunless %} - {% unless optional_dep.name == "Physics2D" %} - unity-config project add dependency com.unity.modules.physics2d@1.0.0 - {% endunless %} - {% unless optional_dep.name == "Analytics" %} - unity-config project add dependency com.unity.modules.unityanalytics@1.0.0 - {% endunless %} - upm-ci project test -u {{ optional_dep.version }} --type project-tests --project-path opt-deps-test --package-filter com.unity.ml-agents - artifacts: - logs: - paths: - - "upm-ci~/test-results/**/*" - dependencies: - - .yamato/com.unity.ml-agents-pack.yml#pack - {% for coverage_editor in coverage_test_editors %} - {% for coverage_platform in coverage_test_platforms %} - {% for coverage_package in coverage_test_packages %} - - .yamato/com.unity.ml-agents-coverage.yml#test_coverage_{{ coverage_package.name }}_{{ coverage_platform.name }}_{{ coverage_editor.version }} - {% endfor %} - {% endfor %} - {% endfor %} - triggers: - cancel_old_ci: true - expression: | - (pull_request.target eq "main" OR - pull_request.target match "release.+") AND - NOT pull_request.draft AND - (pull_request.changes.any match "com.unity.ml-agents/**" OR - pull_request.changes.any match ".yamato/com.unity.ml-agents-optional-dep-tests.yml") - {% endfor %} - diff --git a/.yamato/com.unity.ml-agents-pack.yml b/.yamato/com.unity.ml-agents-pack.yml index b72eb073ac..71e6d5d9af 100644 --- a/.yamato/com.unity.ml-agents-pack.yml +++ b/.yamato/com.unity.ml-agents-pack.yml @@ -5,12 +5,11 @@ pack: image: ml-agents/ml-agents-ubuntu-18.04:latest flavor: b1.small commands: - - git submodule update --init --recursive - | eval "$($HOME/anaconda/bin/conda shell.bash hook)" - conda activate python3.7 + conda activate python3.8 python3 -m pip install unity-downloader-cli --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple --upgrade - unity-downloader-cli -u 2020.3 -c editor --wait --fast + unity-downloader-cli -u 2021.3 -c editor --wait --fast ./.Editor/Unity -projectPath Project -batchMode -executeMethod Unity.MLAgents.SampleExporter.ExportCuratedSamples -logFile - npm install upm-ci-utils@stable -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm upm-ci project pack --project-path Project diff --git a/.yamato/com.unity.ml-agents-performance.yml b/.yamato/com.unity.ml-agents-performance.yml index 211381a818..c9f16cf2a3 100644 --- a/.yamato/com.unity.ml-agents-performance.yml +++ b/.yamato/com.unity.ml-agents-performance.yml @@ -1,6 +1,6 @@ test_editors: - - version: 2020.3 - - version: 2021.2 + - version: 2021.3 + - version: 2022.1 --- {% for editor in test_editors %} Run_Mac_Perfomance_Tests{{ editor.version }}: @@ -12,7 +12,6 @@ Run_Mac_Perfomance_Tests{{ editor.version }}: variables: UNITY_VERSION: {{ editor.version }} commands: - - git submodule update --init --recursive - python3 -m pip install unity-downloader-cli --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple --upgrade - unity-downloader-cli -u {{ editor.version }} -c editor --wait --fast - curl -s https://artifactory.prd.it.unity3d.com/artifactory/unity-tools-local/utr-standalone/utr --output utr @@ -21,7 +20,7 @@ Run_Mac_Perfomance_Tests{{ editor.version }}: triggers: cancel_old_ci: true recurring: - - branch: main + - branch: develop frequency: daily artifacts: logs: diff --git a/.yamato/com.unity.ml-agents-promotion.yml b/.yamato/com.unity.ml-agents-promotion.yml index 1a36644f0c..e4930acfea 100644 --- a/.yamato/com.unity.ml-agents-promotion.yml +++ b/.yamato/com.unity.ml-agents-promotion.yml @@ -1,5 +1,5 @@ test_editors: - - version: 2019.3 + - version: 2021.3 test_platforms: - name: win type: Unity::VM @@ -18,7 +18,7 @@ promotion_test_{{ platform.name }}_{{ editor.version }}: variables: UPMCI_PROMOTION: 1 commands: - - npm install upm-ci-utils@stable -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm + - npm install upm-ci-utils@1.27.0 -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm - upm-ci package test --unity-version {{ editor.version }} --package-path com.unity.ml-agents artifacts: logs: @@ -48,7 +48,7 @@ promote: variables: UPMCI_PROMOTION: 1 commands: - - npm install upm-ci-utils@stable -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm + - npm install upm-ci-utils@1.27.0 -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm - upm-ci package promote --package-path com.unity.ml-agents # triggers: # tags: diff --git a/.yamato/com.unity.ml-agents-publish.yml b/.yamato/com.unity.ml-agents-publish.yml index a87e83c35e..3f28322007 100644 --- a/.yamato/com.unity.ml-agents-publish.yml +++ b/.yamato/com.unity.ml-agents-publish.yml @@ -7,7 +7,7 @@ publish: variables: UPMCI_ENABLE_PACKAGE_SIGNING: 1 commands: - - npm install upm-ci-utils@stable -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm + - npm install upm-ci-utils@1.27.0 -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm - upm-ci package publish --package-path com.unity.ml-agents triggers: cancel_old_ci: true diff --git a/.yamato/com.unity.ml-agents-test.yml b/.yamato/com.unity.ml-agents-test.yml index b77dcfc3c2..54e334cc8d 100644 --- a/.yamato/com.unity.ml-agents-test.yml +++ b/.yamato/com.unity.ml-agents-test.yml @@ -1,10 +1,10 @@ {% metadata_file .yamato/coverage_tests.metafile %} test_editors: - - version: 2020.3 + - version: 2021.3 # We want some scene tests to run in the DevProject, but packages there only support 2020+ testProject: Project enableNoDefaultPackages: !!bool true - - version: 2021.2 + - version: 2022.1 testProject: DevProject enableNoDefaultPackages: !!bool true @@ -59,7 +59,7 @@ all_package_tests: triggers: cancel_old_ci: true recurring: - - branch: main + - branch: develop frequency: daily {% for package in packages %} @@ -79,8 +79,7 @@ test_{{ package.name }}_{{ platform.name }}_{{ editor.version }}: image: {{ platform.image }} flavor: {{ platform.flavor}} commands: - - git submodule update --init --recursive - - npm install upm-ci-utils@stable -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm + - npm install upm-ci-utils@1.27.0 -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm - upm-ci package test -u {{ editor.version }} --package-path {{ package.name }} {{ noDefaultPackagesOptions }} --warnings-as-errors --extra-utr-arg "reruncount=2" artifacts: logs: @@ -100,6 +99,7 @@ test_{{ package.name }}_{{ platform.name }}_{{ editor.version }}: {% if platform.name == "linux" %} expression: | (pull_request.target eq "main" OR + pull_request.target eq "develop" OR pull_request.target match "release.+") AND NOT pull_request.draft AND (pull_request.changes.any match "com.unity.ml-agents/**" OR @@ -124,14 +124,13 @@ test_{{ package.name }}_{{ platform.name }}_trunk: image: {{ platform.image }} flavor: {{ platform.flavor}} commands: - - git submodule update --init --recursive - | {% if platform.name == "linux" %} - conda activate python3.7 + conda activate python3.8 {% endif %} python -m pip install unity-downloader-cli --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple --upgrade unity-downloader-cli -u trunk -c editor --wait --fast - npm install upm-ci-utils@stable -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm + npm install upm-ci-utils@1.27.0 -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm upm-ci project test -u {{ editor.version }} --project-path {{ editor.testProject }} --package-filter {{ package.name }} --extra-create-project-arg="-upmNoDefaultPackages" --extra-utr-arg "reruncount=2" artifacts: logs: diff --git a/.yamato/compressed-sensor-test.yml b/.yamato/compressed-sensor-test.yml index a0745255c2..23e28a5533 100644 --- a/.yamato/compressed-sensor-test.yml +++ b/.yamato/compressed-sensor-test.yml @@ -10,10 +10,9 @@ test_compressed_obs_{{ editor.version }}: variables: UNITY_VERSION: {{ editor.version }} commands: - - git submodule update --init --recursive - | eval "$($HOME/anaconda/bin/conda shell.bash hook)" - conda activate python3.7 + conda activate python3.8 python -m pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple python -u -m ml-agents.tests.yamato.setup_venv python ml-agents/tests/yamato/scripts/run_compressed_sensor.py --env=artifacts/testPlayer-TestGridCompressed @@ -25,6 +24,7 @@ test_compressed_obs_{{ editor.version }}: {% if editor.extra_test == "sensor" %} expression: | (pull_request.target eq "main" OR + pull_request.target eq "develop" OR pull_request.target match "release.+") AND NOT pull_request.draft AND (pull_request.changes.any match "com.unity.ml-agents/**" OR diff --git a/.yamato/coverage_tests.metafile b/.yamato/coverage_tests.metafile index a12e8cd2ad..7f5aaad096 100644 --- a/.yamato/coverage_tests.metafile +++ b/.yamato/coverage_tests.metafile @@ -1,5 +1,5 @@ coverage_test_editors: - - version: 2020.3 + - version: 2021.3 testProject: DevProject coverage_test_platforms: diff --git a/.yamato/gym-interface-test.yml b/.yamato/gym-interface-test.yml index f71f67d3b0..a90a8ce9ec 100644 --- a/.yamato/gym-interface-test.yml +++ b/.yamato/gym-interface-test.yml @@ -10,10 +10,9 @@ test_gym_interface_{{ editor.version }}: variables: UNITY_VERSION: {{ editor.version }} commands: - - git submodule update --init --recursive - | eval "$($HOME/anaconda/bin/conda shell.bash hook)" - conda activate python3.7 + conda activate python3.8 python -m pip install wheel --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple python -m pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple python -u -m ml-agents.tests.yamato.setup_venv @@ -25,6 +24,7 @@ test_gym_interface_{{ editor.version }}: {% if editor.extra_test == "gym" %} expression: | (pull_request.target eq "main" OR + pull_request.target eq "develop" OR pull_request.target match "release.+") AND NOT pull_request.draft AND (pull_request.changes.any match "com.unity.ml-agents/**" OR diff --git a/.yamato/protobuf-generation-test.yml b/.yamato/protobuf-generation-test.yml index 2bc7f62b65..be490867e9 100644 --- a/.yamato/protobuf-generation-test.yml +++ b/.yamato/protobuf-generation-test.yml @@ -6,13 +6,12 @@ test_linux_protobuf_generation: flavor: b1.large variables: GRPC_VERSION: "1.14.1" - CS_PROTO_PATH: "Runtime/Grpc/CommunicatorObjects" + CS_PROTO_PATH: "com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects" commands: - - git submodule update --init --recursive - | sudo apt-get update && sudo apt-get install -y nuget eval "$($HOME/anaconda/bin/conda shell.bash hook)" - conda activate python3.7 + conda activate python3.8 nuget install Grpc.Tools -Version $GRPC_VERSION -OutputDirectory protobuf-definitions/ python3 -m pip install --upgrade pip --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple python3 -m pip install grpcio==1.28.1 grpcio-tools==1.13.0 protobuf==3.11.3 six==1.14.0 mypy-protobuf==1.16.0 --progress-bar=off --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple @@ -23,7 +22,6 @@ test_linux_protobuf_generation: popd mkdir -p artifacts touch artifacts/proto.patch - cd com.unity.ml-agents git diff --exit-code -- :/ ":(exclude,top)$CS_PROTO_PATH/*.meta" \ || { GIT_ERR=$?; echo "protobufs need to be regenerated, apply the patch uploaded to artifacts."; \ echo "Apply the patch with the command 'git apply proto.patch'"; \ @@ -32,6 +30,7 @@ test_linux_protobuf_generation: cancel_old_ci: true expression: | (pull_request.target eq "main" OR + pull_request.target eq "develop" OR pull_request.target match "release.+") AND NOT pull_request.draft AND (pull_request.changes.any match "protobuf-definitions/**" OR diff --git a/.yamato/pytest-gpu.yml b/.yamato/pytest-gpu.yml index b1824b8b10..99999ca224 100644 --- a/.yamato/pytest-gpu.yml +++ b/.yamato/pytest-gpu.yml @@ -5,19 +5,24 @@ pytest_gpu: image: ml-agents/ml-agents-ubuntu-18.04:latest flavor: b1.large commands: - - git submodule update --init --recursive - | eval "$($HOME/anaconda/bin/conda shell.bash hook)" - conda activate python3.7 + conda activate python3.8 python3 -m pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple python3 -u -m ml-agents.tests.yamato.setup_venv python3 -m pip install --progress-bar=off -r test_requirements.txt --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple python3 -m pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple - python3 -m pytest -m "not slow" -n auto --junitxml=junit/test-results.xml -p no:warnings + if python -c "exec('import torch \nif not torch.cuda.is_available(): raise')" &> /dev/null; then + echo 'all good' + else + exit 1 + fi + python3 -m pytest -m "not slow" --junitxml=junit/test-results.xml -p no:warnings triggers: cancel_old_ci: true expression: | (push.branch eq "main" OR + push.branch eq "develop" OR push.branch match "release.+") AND push.changes.any match "ml-agents/**" AND NOT push.changes.all match "**/*.md" diff --git a/.yamato/python-ll-api-test.yml b/.yamato/python-ll-api-test.yml index 51b8e269f2..cde94d31fb 100644 --- a/.yamato/python-ll-api-test.yml +++ b/.yamato/python-ll-api-test.yml @@ -10,10 +10,9 @@ test_linux_ll_api_{{ editor.version }}: variables: UNITY_VERSION: {{ editor.version }} commands: - - git submodule update --init --recursive - | eval "$($HOME/anaconda/bin/conda shell.bash hook)" - conda activate python3.7 + conda activate python3.8 python -m pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple python -u -m ml-agents.tests.yamato.setup_venv python ml-agents/tests/yamato/scripts/run_llapi.py @@ -27,6 +26,7 @@ test_linux_ll_api_{{ editor.version }}: {% if editor.extra_test == "llapi" %} expression: | (pull_request.target eq "main" OR + pull_request.target eq "develop" OR pull_request.target match "release.+") AND NOT pull_request.draft AND (pull_request.changes.any match "com.unity.ml-agents/**" OR diff --git a/.yamato/sonar-python-package.yml b/.yamato/sonar-python-package.yml new file mode 100644 index 0000000000..3e9403bc30 --- /dev/null +++ b/.yamato/sonar-python-package.yml @@ -0,0 +1,23 @@ +csharp: + name: Sonarqube Scan for ml-agents python repo + agent: + type: Unity::metal::macmini + image: package-ci/mac:v1.8.1-822785 + flavor: m1.mac + variables: + SONARQUBE_PROJECT_KEY: ai-ml-agents-toolkit + SONARQUBE_URL: https://sonarqube.internal.unity3d.com + SONARQUBE_LOGIN: a08467db099d82931708d480b8dbf428cf1921d5 + TARGET_BRANCH: develop + commands: + - npm install shellcheck --save-dev + - npm install upm-ci-utils@1.27.0 -g --registry https://artifactory.prd.it.unity3d.com/artifactory/api/npm/upm-npm + - curl https://binaries.sonarsource.com/Distribution/sonar-scanner-cli/sonar-scanner-cli-4.7.0.2747-macosx.zip -o sonar-scanner-cli-macosx.zip -L + - unzip sonar-scanner-cli-macosx.zip -d ~/sonar-scanner-cli + - ~/sonar-scanner-cli/sonar-scanner-4.7.0.2747-macosx/bin/sonar-scanner -Dsonar.projectKey=$SONARQUBE_PROJECT_KEY -Dsonar.sources=ml-agents-env -Dsonar.sources=ml-agents -Dsonar.sources=ml-agents-plugin-examples -Dsonar.sources=ml-agents-trainer-plugin -Dsonar.sources=utils -Dsonar.host.url=$SONARQUBE_URL -Dsonar.login=$SONARQUBE_LOGIN -Dsonar.branch.name=$TARGET_BRANCH -Dsonar.scm.provider=git + triggers: + cancel_old_ci: true + expression: | + ((pull_request.target eq "main" OR pull_request.target eq "develop" OR pull_request.target match "release.+") + AND NOT pull_request.push.changes.all match "**/*.md") OR + (push.branch eq "main" OR push.branch eq "develop") diff --git a/.yamato/standalone-build-test.yml b/.yamato/standalone-build-test.yml index c709d2a7b5..fac93943e9 100644 --- a/.yamato/standalone-build-test.yml +++ b/.yamato/standalone-build-test.yml @@ -10,10 +10,9 @@ test_linux_standalone_{{ editor.version }}: variables: UNITY_VERSION: {{ editor.version }} commands: - - git submodule update --init --recursive - | eval "$($HOME/anaconda/bin/conda shell.bash hook)" - conda activate python3.7 + conda activate python3.8 python3 -m pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple python3 -m pip install unity-downloader-cli --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple --upgrade unity-downloader-cli -u {{ editor.version }} -c editor --wait --fast @@ -27,6 +26,7 @@ test_linux_standalone_{{ editor.version }}: cancel_old_ci: true expression: | (pull_request.target eq "main" OR + pull_request.target eq "develop" OR pull_request.target match "release.+") AND NOT pull_request.draft AND (pull_request.changes.any match "com.unity.ml-agents/**" OR diff --git a/.yamato/standalone-build-webgl-test.yml b/.yamato/standalone-build-webgl-test.yml index 395d45f3f2..8d7a699768 100644 --- a/.yamato/standalone-build-webgl-test.yml +++ b/.yamato/standalone-build-webgl-test.yml @@ -1,4 +1,4 @@ -{% capture editor_version %}2020.3{% endcapture %} +{% capture editor_version %}2021.3{% endcapture %} test_webgl_standalone_{{ editor_version }}: name: Test WebGL Standalone {{ editor_version }} agent: @@ -8,10 +8,9 @@ test_webgl_standalone_{{ editor_version }}: variables: UNITY_VERSION: {{ editor_version }} commands: - - git submodule update --init --recursive - | eval "$($HOME/anaconda/bin/conda shell.bash hook)" - conda activate python3.7 + conda activate python3.8 python -m pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple python -m pip install unity-downloader-cli --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple --upgrade unity-downloader-cli -u {{ editor_version }} -c editor -c WebGL --wait --fast @@ -19,7 +18,7 @@ test_webgl_standalone_{{ editor_version }}: triggers: cancel_old_ci: true recurring: - - branch: main + - branch: develop frequency: weekly artifacts: logs: diff --git a/.yamato/test_versions.metafile b/.yamato/test_versions.metafile index 7cfdce7c03..8fdd62bb10 100644 --- a/.yamato/test_versions.metafile +++ b/.yamato/test_versions.metafile @@ -3,9 +3,9 @@ # For each "other" test, we only run it against a single version of the # editor to reduce the number of yamato jobs test_editors: - - version: 2020.3 + - version: 2021.3 extra_test: gym - - version: 2021.2 + - version: 2022.1 extra_test: sensor - version: trunk extra_test: llapi diff --git a/.yamato/training-backcompat-tests.yml b/.yamato/training-backcompat-tests.yml deleted file mode 100644 index f179c4ca5a..0000000000 --- a/.yamato/training-backcompat-tests.yml +++ /dev/null @@ -1,45 +0,0 @@ - -test_mac_backcompat_2019.4: - {% capture editor_version %}2019.4{% endcapture %} - {% capture csharp_backcompat_version %}1.0.0{% endcapture %} - # This test has to run on mac because it requires the custom build of tensorflow without AVX - # Test against 2020.1 because 2020.2 has to run against package version 1.2.0 - name: Test Mac Backcompat Training {{ editor_version }} - agent: - type: Unity::VM::osx - image: ml-agents/ml-agents-bokken-mac:0.1.5-853758 - flavor: b1.small - variables: - UNITY_VERSION: {{ editor_version }} - commands: - - git submodule update --init --recursive - - | - python -m venv venv && source venv/bin/activate - python -m pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple - python -m pip install unity-downloader-cli --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple --upgrade - unity-downloader-cli -u {{ editor_version }} -c editor --wait --fast - # Backwards-compatibility tests. - # If we make a breaking change to the communication protocol, these will need - # to be disabled until the next release. - python -u -m ml-agents.tests.yamato.standalone_build_tests --build-target=mac - python -u -m ml-agents.tests.yamato.training_int_tests --csharp {{ csharp_backcompat_version }} - - | - python -m venv venv_old && source venv_old/bin/activate - python -m pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple - python -u -m ml-agents.tests.yamato.training_int_tests --python 0.24.0 - triggers: - cancel_old_ci: true - recurring: - - branch: main - frequency: daily - artifacts: - logs: - paths: - - "artifacts/standalone_build.txt" - - "artifacts/inference.nn.txt" - - "artifacts/inference.onnx.txt" - - "artifacts/*.log" - standalonebuild: - paths: - - "artifacts/testPlayer*/**" - - "artifacts/models/**" diff --git a/.yamato/training-int-tests.yml b/.yamato/training-int-tests.yml index 860405e176..18f5f4f83d 100644 --- a/.yamato/training-int-tests.yml +++ b/.yamato/training-int-tests.yml @@ -10,10 +10,9 @@ test_linux_training_int_{{ editor.version }}: variables: UNITY_VERSION: {{ editor.version }} commands: - - git submodule update --init --recursive - | eval "$($HOME/anaconda/bin/conda shell.bash hook)" - conda activate python3.7 + conda activate python3.8 python -m pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple python -u -m ml-agents.tests.yamato.training_int_tests dependencies: @@ -22,6 +21,7 @@ test_linux_training_int_{{ editor.version }}: cancel_old_ci: true expression: | (pull_request.target eq "main" OR + pull_request.target eq "develop" OR pull_request.target match "release.+") AND NOT pull_request.draft AND (pull_request.changes.any match "com.unity.ml-agents/**" OR diff --git a/DevProject/Packages/packages-lock.json b/DevProject/Packages/packages-lock.json index 7eac0c8e6b..56804861c9 100644 --- a/DevProject/Packages/packages-lock.json +++ b/DevProject/Packages/packages-lock.json @@ -1,7 +1,7 @@ { "dependencies": { "com.unity.barracuda": { - "version": "2.3.1-preview", + "version": "3.0.0", "depth": 1, "source": "registry", "dependencies": { @@ -12,7 +12,7 @@ "url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates" }, "com.unity.burst": { - "version": "1.6.0", + "version": "1.6.6", "depth": 2, "source": "registry", "dependencies": { @@ -46,7 +46,7 @@ "url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates" }, "com.unity.mathematics": { - "version": "1.2.1", + "version": "1.2.6", "depth": 3, "source": "registry", "dependencies": {}, @@ -57,7 +57,7 @@ "depth": 0, "source": "local", "dependencies": { - "com.unity.barracuda": "2.3.1-preview", + "com.unity.barracuda": "3.0.0", "com.unity.modules.imageconversion": "1.0.0", "com.unity.modules.jsonserialize": "1.0.0" } @@ -72,16 +72,14 @@ } }, "com.unity.nuget.mono-cecil": { - "version": "0.1.6-preview.2", + "version": "1.10.1", "depth": 1, "source": "registry", - "dependencies": { - "nuget.mono-cecil": "0.1.6-preview" - }, + "dependencies": {}, "url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates" }, "com.unity.nuget.newtonsoft-json": { - "version": "2.0.0-preview", + "version": "3.0.2", "depth": 1, "source": "registry", "dependencies": {}, @@ -108,7 +106,7 @@ "url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates" }, "com.unity.settings-manager": { - "version": "1.0.1", + "version": "1.0.3", "depth": 1, "source": "registry", "dependencies": {}, @@ -159,13 +157,6 @@ "dependencies": {}, "url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates" }, - "nuget.mono-cecil": { - "version": "0.1.6-preview", - "depth": 2, - "source": "registry", - "dependencies": {}, - "url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates" - }, "nuget.moq": { "version": "1.0.0", "depth": 1, diff --git a/DevProject/ProjectSettings/EditorBuildSettings.asset b/DevProject/ProjectSettings/EditorBuildSettings.asset index 1823a7a08e..7512fc3e3a 100644 --- a/DevProject/ProjectSettings/EditorBuildSettings.asset +++ b/DevProject/ProjectSettings/EditorBuildSettings.asset @@ -9,5 +9,5 @@ EditorBuildSettings: path: Assets/ML-Agents/Scripts/Tests/Runtime/AcademyTest/AcademyStepperTestScene.unity guid: 9bafc50b1e55b43b2b1ae9620f1f8311 m_configObjects: - com.unity.ml-agents.settings: {fileID: 11400000, guid: 7017f4eb06bef4889a3608a54b1cc59e, + com.unity.ml-agents.settings: {fileID: 11400000, guid: 905d6ca857fdf4d028b93658cf00e271, type: 2} diff --git a/DevProject/ProjectSettings/ProjectVersion.txt b/DevProject/ProjectSettings/ProjectVersion.txt index 4c9401b919..8ea1b855ae 100644 --- a/DevProject/ProjectSettings/ProjectVersion.txt +++ b/DevProject/ProjectSettings/ProjectVersion.txt @@ -1,2 +1,2 @@ -m_EditorVersion: 2020.3.25f1 -m_EditorVersionWithRevision: 2020.3.25f1 (9b9180224418) +m_EditorVersion: 2021.3.11f1 +m_EditorVersionWithRevision: 2021.3.11f1 (0a5ca18544bf) diff --git a/Project/Assets/ML-Agents/Examples/3DBall/Scenes/Visual3DBall.unity b/Project/Assets/ML-Agents/Examples/3DBall/Scenes/Visual3DBall.unity index eee59cc5d6..0a57e315a0 100644 --- a/Project/Assets/ML-Agents/Examples/3DBall/Scenes/Visual3DBall.unity +++ b/Project/Assets/ML-Agents/Examples/3DBall/Scenes/Visual3DBall.unity @@ -43,7 +43,7 @@ RenderSettings: --- !u!157 &3 LightmapSettings: m_ObjectHideFlags: 0 - serializedVersion: 11 + serializedVersion: 12 m_GIWorkflowMode: 0 m_GISettings: serializedVersion: 2 @@ -54,7 +54,7 @@ LightmapSettings: m_EnableBakedLightmaps: 1 m_EnableRealtimeLightmaps: 1 m_LightmapEditorSettings: - serializedVersion: 10 + serializedVersion: 12 m_Resolution: 2 m_BakeResolution: 40 m_AtlasSize: 1024 @@ -62,6 +62,7 @@ LightmapSettings: m_AOMaxDistance: 1 m_CompAOExponent: 1 m_CompAOExponentDirect: 0 + m_ExtractAmbientOcclusion: 0 m_Padding: 2 m_LightmapParameters: {fileID: 0} m_LightmapsBakeMode: 1 @@ -76,10 +77,16 @@ LightmapSettings: m_PVRDirectSampleCount: 32 m_PVRSampleCount: 500 m_PVRBounces: 2 + m_PVREnvironmentSampleCount: 500 + m_PVREnvironmentReferencePointCount: 2048 + m_PVRFilteringMode: 2 + m_PVRDenoiserTypeDirect: 0 + m_PVRDenoiserTypeIndirect: 0 + m_PVRDenoiserTypeAO: 0 m_PVRFilterTypeDirect: 0 m_PVRFilterTypeIndirect: 0 m_PVRFilterTypeAO: 0 - m_PVRFilteringMode: 1 + m_PVREnvironmentMIS: 0 m_PVRCulling: 1 m_PVRFilteringGaussRadiusDirect: 1 m_PVRFilteringGaussRadiusIndirect: 5 @@ -87,9 +94,12 @@ LightmapSettings: m_PVRFilteringAtrousPositionSigmaDirect: 0.5 m_PVRFilteringAtrousPositionSigmaIndirect: 2 m_PVRFilteringAtrousPositionSigmaAO: 1 - m_ShowResolutionOverlay: 1 + m_ExportTrainingData: 0 + m_TrainingDataDestination: TrainingData + m_LightProbeSampleCountMultiplier: 4 m_LightingDataAsset: {fileID: 0} - m_UseShadowmask: 1 + m_LightingSettings: {fileID: 4890085278179872738, guid: 7480a805600d24847ab9e9df1dc971fe, + type: 2} --- !u!196 &4 NavMeshSettings: serializedVersion: 2 @@ -109,6 +119,8 @@ NavMeshSettings: manualTileSize: 0 tileSize: 256 accuratePlacement: 0 + maxJobWorkers: 0 + preserveTilesOutsideBounds: 0 debug: m_Flags: 0 m_NavMeshData: {fileID: 0} @@ -125,92 +137,92 @@ PrefabInstance: objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_LocalPosition.x + propertyPath: m_Pivot.x value: 0 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_LocalPosition.y + propertyPath: m_Pivot.y value: 0 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_LocalPosition.z - value: 0 + propertyPath: m_RootOrder + value: 1 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_LocalRotation.x + propertyPath: m_AnchorMax.x value: 0 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_LocalRotation.y + propertyPath: m_AnchorMax.y value: 0 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_LocalRotation.z + propertyPath: m_AnchorMin.x value: 0 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_LocalRotation.w - value: 1 + propertyPath: m_AnchorMin.y + value: 0 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_RootOrder - value: 1 + propertyPath: m_SizeDelta.x + value: 0 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_AnchoredPosition.x + propertyPath: m_SizeDelta.y value: 0 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_AnchoredPosition.y + propertyPath: m_LocalPosition.x value: 0 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_SizeDelta.x + propertyPath: m_LocalPosition.y value: 0 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_SizeDelta.y + propertyPath: m_LocalPosition.z value: 0 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_AnchorMin.x - value: 0 + propertyPath: m_LocalRotation.w + value: 1 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_AnchorMin.y + propertyPath: m_LocalRotation.x value: 0 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_AnchorMax.x + propertyPath: m_LocalRotation.y value: 0 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_AnchorMax.y + propertyPath: m_LocalRotation.z value: 0 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_Pivot.x + propertyPath: m_AnchoredPosition.x value: 0 objectReference: {fileID: 0} - target: {fileID: 224194346362733190, guid: 3ce107b4a79bc4eef83afde434932a68, type: 3} - propertyPath: m_Pivot.y + propertyPath: m_AnchoredPosition.y value: 0 objectReference: {fileID: 0} m_RemovedComponents: [] @@ -226,6 +238,14 @@ PrefabInstance: propertyPath: m_Name value: Visual3DBall (2) objectReference: {fileID: 0} + - target: {fileID: 1321468028730240, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_IsActive + value: 1 + objectReference: {fileID: 0} + - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_RootOrder + value: 7 + objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalPosition.x value: 40 @@ -238,6 +258,10 @@ PrefabInstance: propertyPath: m_LocalPosition.z value: 5 objectReference: {fileID: 0} + - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_LocalRotation.w + value: 1 + objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalRotation.x value: 0 @@ -250,14 +274,6 @@ PrefabInstance: propertyPath: m_LocalRotation.z value: 0 objectReference: {fileID: 0} - - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} - propertyPath: m_LocalRotation.w - value: 1 - objectReference: {fileID: 0} - - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} - propertyPath: m_RootOrder - value: 7 - objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalEulerAnglesHint.x value: 0 @@ -283,6 +299,14 @@ PrefabInstance: propertyPath: m_Name value: Visual3DBall (7) objectReference: {fileID: 0} + - target: {fileID: 1321468028730240, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_IsActive + value: 1 + objectReference: {fileID: 0} + - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_RootOrder + value: 12 + objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalPosition.x value: 60 @@ -295,6 +319,10 @@ PrefabInstance: propertyPath: m_LocalPosition.z value: 5 objectReference: {fileID: 0} + - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_LocalRotation.w + value: 1 + objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalRotation.x value: 0 @@ -307,14 +335,6 @@ PrefabInstance: propertyPath: m_LocalRotation.z value: 0 objectReference: {fileID: 0} - - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} - propertyPath: m_LocalRotation.w - value: 1 - objectReference: {fileID: 0} - - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} - propertyPath: m_RootOrder - value: 12 - objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalEulerAnglesHint.x value: 0 @@ -340,6 +360,14 @@ PrefabInstance: propertyPath: m_Name value: Visual3DBall (1) objectReference: {fileID: 0} + - target: {fileID: 1321468028730240, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_IsActive + value: 1 + objectReference: {fileID: 0} + - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_RootOrder + value: 6 + objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalPosition.x value: 20 @@ -352,6 +380,10 @@ PrefabInstance: propertyPath: m_LocalPosition.z value: 5 objectReference: {fileID: 0} + - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_LocalRotation.w + value: 1 + objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalRotation.x value: 0 @@ -364,14 +396,6 @@ PrefabInstance: propertyPath: m_LocalRotation.z value: 0 objectReference: {fileID: 0} - - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} - propertyPath: m_LocalRotation.w - value: 1 - objectReference: {fileID: 0} - - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} - propertyPath: m_RootOrder - value: 6 - objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalEulerAnglesHint.x value: 0 @@ -397,6 +421,14 @@ PrefabInstance: propertyPath: m_Name value: Visual3DBall (5) objectReference: {fileID: 0} + - target: {fileID: 1321468028730240, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_IsActive + value: 1 + objectReference: {fileID: 0} + - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_RootOrder + value: 10 + objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalPosition.x value: 20 @@ -409,6 +441,10 @@ PrefabInstance: propertyPath: m_LocalPosition.z value: 5 objectReference: {fileID: 0} + - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_LocalRotation.w + value: 1 + objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalRotation.x value: 0 @@ -421,14 +457,6 @@ PrefabInstance: propertyPath: m_LocalRotation.z value: 0 objectReference: {fileID: 0} - - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} - propertyPath: m_LocalRotation.w - value: 1 - objectReference: {fileID: 0} - - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} - propertyPath: m_RootOrder - value: 10 - objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalEulerAnglesHint.x value: 0 @@ -454,6 +482,14 @@ PrefabInstance: propertyPath: m_Name value: Visual3DBall (4) objectReference: {fileID: 0} + - target: {fileID: 1321468028730240, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_IsActive + value: 1 + objectReference: {fileID: 0} + - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_RootOrder + value: 9 + objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalPosition.x value: 0 @@ -466,6 +502,10 @@ PrefabInstance: propertyPath: m_LocalPosition.z value: 5 objectReference: {fileID: 0} + - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_LocalRotation.w + value: 1 + objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalRotation.x value: 0 @@ -478,14 +518,6 @@ PrefabInstance: propertyPath: m_LocalRotation.z value: 0 objectReference: {fileID: 0} - - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} - propertyPath: m_LocalRotation.w - value: 1 - objectReference: {fileID: 0} - - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} - propertyPath: m_RootOrder - value: 9 - objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalEulerAnglesHint.x value: 0 @@ -529,6 +561,7 @@ Transform: m_LocalRotation: {x: -0.069583125, y: 0.0049145464, z: 0.0702813, w: 0.99508524} m_LocalPosition: {x: 0, y: 0, z: 0} m_LocalScale: {x: 5, y: 0.19999993, z: 5} + m_ConstrainProportionsScale: 0 m_Children: [] m_Father: {fileID: 0} m_RootOrder: 4 @@ -578,9 +611,10 @@ MonoBehaviour: m_GameObject: {fileID: 1746325439} m_Enabled: 1 m_EditorHideFlags: 0 - m_Script: {fileID: 1077351063, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Script: {fileID: 11500000, guid: 4f231c4fb786f3946a6b90b886c48677, type: 3} m_Name: m_EditorClassIdentifier: + m_SendPointerHoverToParent: 1 m_HorizontalAxis: Horizontal m_VerticalAxis: Vertical m_SubmitButton: Submit @@ -597,7 +631,7 @@ MonoBehaviour: m_GameObject: {fileID: 1746325439} m_Enabled: 1 m_EditorHideFlags: 0 - m_Script: {fileID: -619905303, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Script: {fileID: 11500000, guid: 76c392e42b5098c458856cdf6ecaaaa1, type: 3} m_Name: m_EditorClassIdentifier: m_FirstSelected: {fileID: 0} @@ -613,6 +647,7 @@ Transform: m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} m_LocalPosition: {x: 0, y: 0, z: 0} m_LocalScale: {x: 1, y: 1, z: 1} + m_ConstrainProportionsScale: 0 m_Children: [] m_Father: {fileID: 0} m_RootOrder: 3 @@ -628,6 +663,14 @@ PrefabInstance: propertyPath: m_Name value: Visual3DBall (6) objectReference: {fileID: 0} + - target: {fileID: 1321468028730240, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_IsActive + value: 1 + objectReference: {fileID: 0} + - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_RootOrder + value: 11 + objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalPosition.x value: 40 @@ -640,6 +683,10 @@ PrefabInstance: propertyPath: m_LocalPosition.z value: 5 objectReference: {fileID: 0} + - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_LocalRotation.w + value: 1 + objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalRotation.x value: 0 @@ -652,14 +699,6 @@ PrefabInstance: propertyPath: m_LocalRotation.z value: 0 objectReference: {fileID: 0} - - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} - propertyPath: m_LocalRotation.w - value: 1 - objectReference: {fileID: 0} - - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} - propertyPath: m_RootOrder - value: 11 - objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalEulerAnglesHint.x value: 0 @@ -681,6 +720,10 @@ PrefabInstance: m_Modification: m_TransformParent: {fileID: 0} m_Modifications: + - target: {fileID: 4943719350691982, guid: 5889392e3f05b448a8a06c5def6c2dec, type: 3} + propertyPath: m_RootOrder + value: 2 + objectReference: {fileID: 0} - target: {fileID: 4943719350691982, guid: 5889392e3f05b448a8a06c5def6c2dec, type: 3} propertyPath: m_LocalPosition.x value: 0 @@ -693,6 +736,10 @@ PrefabInstance: propertyPath: m_LocalPosition.z value: 0 objectReference: {fileID: 0} + - target: {fileID: 4943719350691982, guid: 5889392e3f05b448a8a06c5def6c2dec, type: 3} + propertyPath: m_LocalRotation.w + value: 0.8681629 + objectReference: {fileID: 0} - target: {fileID: 4943719350691982, guid: 5889392e3f05b448a8a06c5def6c2dec, type: 3} propertyPath: m_LocalRotation.x value: 0.31598538 @@ -705,14 +752,6 @@ PrefabInstance: propertyPath: m_LocalRotation.z value: 0.13088542 objectReference: {fileID: 0} - - target: {fileID: 4943719350691982, guid: 5889392e3f05b448a8a06c5def6c2dec, type: 3} - propertyPath: m_LocalRotation.w - value: 0.8681629 - objectReference: {fileID: 0} - - target: {fileID: 4943719350691982, guid: 5889392e3f05b448a8a06c5def6c2dec, type: 3} - propertyPath: m_RootOrder - value: 2 - objectReference: {fileID: 0} - target: {fileID: 4943719350691982, guid: 5889392e3f05b448a8a06c5def6c2dec, type: 3} propertyPath: m_LocalEulerAnglesHint.y value: -45 @@ -730,6 +769,14 @@ PrefabInstance: propertyPath: m_Name value: Visual3DBall (3) objectReference: {fileID: 0} + - target: {fileID: 1321468028730240, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_IsActive + value: 1 + objectReference: {fileID: 0} + - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_RootOrder + value: 8 + objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalPosition.x value: 60 @@ -742,6 +789,10 @@ PrefabInstance: propertyPath: m_LocalPosition.z value: 5 objectReference: {fileID: 0} + - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_LocalRotation.w + value: 1 + objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalRotation.x value: 0 @@ -754,14 +805,6 @@ PrefabInstance: propertyPath: m_LocalRotation.z value: 0 objectReference: {fileID: 0} - - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} - propertyPath: m_LocalRotation.w - value: 1 - objectReference: {fileID: 0} - - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} - propertyPath: m_RootOrder - value: 8 - objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalEulerAnglesHint.x value: 0 @@ -805,9 +848,10 @@ Camera: m_ClearFlags: 2 m_BackGroundColor: {r: 0.46666667, g: 0.5647059, b: 0.60784316, a: 1} m_projectionMatrixMode: 1 + m_GateFitMode: 2 + m_FOVAxisMode: 0 m_SensorSize: {x: 36, y: 24} m_LensShift: {x: 0, y: 0} - m_GateFitMode: 2 m_FocalLength: 50 m_NormalizedViewPortRect: serializedVersion: 2 @@ -845,6 +889,7 @@ Transform: m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} m_LocalPosition: {x: 28.99, y: 14.09, z: -40.6} m_LocalScale: {x: 1, y: 1, z: 1} + m_ConstrainProportionsScale: 0 m_Children: [] m_Father: {fileID: 0} m_RootOrder: 0 @@ -860,6 +905,10 @@ PrefabInstance: propertyPath: m_Name value: Visual3DBall objectReference: {fileID: 0} + - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_RootOrder + value: 5 + objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalPosition.x value: 0 @@ -872,6 +921,10 @@ PrefabInstance: propertyPath: m_LocalPosition.z value: 5 objectReference: {fileID: 0} + - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} + propertyPath: m_LocalRotation.w + value: 1 + objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalRotation.x value: 0 @@ -884,14 +937,6 @@ PrefabInstance: propertyPath: m_LocalRotation.z value: 0 objectReference: {fileID: 0} - - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} - propertyPath: m_LocalRotation.w - value: 1 - objectReference: {fileID: 0} - - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} - propertyPath: m_RootOrder - value: 5 - objectReference: {fileID: 0} - target: {fileID: 4679453577574622, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} propertyPath: m_LocalEulerAnglesHint.x value: 0 @@ -904,5 +949,10 @@ PrefabInstance: propertyPath: m_LocalEulerAnglesHint.z value: 0 objectReference: {fileID: 0} + - target: {fileID: 7705253412956426214, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, + type: 3} + propertyPath: MaxStep + value: 500 + objectReference: {fileID: 0} m_RemovedComponents: [] m_SourcePrefab: {fileID: 100100000, guid: ec49a7b8b70a24ab48d7ca0bf5a063a6, type: 3} diff --git a/Project/Assets/ML-Agents/Examples/Pyramids/Prefabs/AreaPB.prefab b/Project/Assets/ML-Agents/Examples/Pyramids/Prefabs/AreaPB.prefab index e665a66219..3fc47a925a 100644 --- a/Project/Assets/ML-Agents/Examples/Pyramids/Prefabs/AreaPB.prefab +++ b/Project/Assets/ML-Agents/Examples/Pyramids/Prefabs/AreaPB.prefab @@ -1509,7 +1509,7 @@ GameObject: m_PrefabAsset: {fileID: 0} serializedVersion: 6 m_Component: - - component: {fileID: 4661241202043188} + - component: {fileID: 46612412021.3188} - component: {fileID: 33091290307090862} - component: {fileID: 23221028280036978} - component: {fileID: 65040597978940982} @@ -1520,7 +1520,7 @@ GameObject: m_NavMeshLayer: 0 m_StaticEditorFlags: 0 m_IsActive: 1 ---- !u!4 &4661241202043188 +--- !u!4 &46612412021.3188 Transform: m_ObjectHideFlags: 0 m_CorrespondingSourceObject: {fileID: 0} @@ -1623,7 +1623,7 @@ Transform: - {fileID: 4463282243387382} - {fileID: 4087387112911168} - {fileID: 4248851000827934} - - {fileID: 4661241202043188} + - {fileID: 46612412021.3188} - {fileID: 4676865447046996} - {fileID: 4404525033858484} - {fileID: 4912732971874350} diff --git a/Project/Assets/ML-Agents/Examples/SharedAssets/Materials/Textures/U_Logo_White_RGB.png b/Project/Assets/ML-Agents/Examples/SharedAssets/Materials/Textures/U_Logo_White_RGB.png new file mode 100644 index 0000000000..385a5558fa Binary files /dev/null and b/Project/Assets/ML-Agents/Examples/SharedAssets/Materials/Textures/U_Logo_White_RGB.png differ diff --git a/Project/Assets/ML-Agents/Examples/SharedAssets/Materials/Textures/UnityLogo.png.meta b/Project/Assets/ML-Agents/Examples/SharedAssets/Materials/Textures/U_Logo_White_RGB.png.meta similarity index 71% rename from Project/Assets/ML-Agents/Examples/SharedAssets/Materials/Textures/UnityLogo.png.meta rename to Project/Assets/ML-Agents/Examples/SharedAssets/Materials/Textures/U_Logo_White_RGB.png.meta index ab0294dd90..c339b6cd70 100644 --- a/Project/Assets/ML-Agents/Examples/SharedAssets/Materials/Textures/UnityLogo.png.meta +++ b/Project/Assets/ML-Agents/Examples/SharedAssets/Materials/Textures/U_Logo_White_RGB.png.meta @@ -1,9 +1,9 @@ fileFormatVersion: 2 -guid: 2e85738fe64714cffbf72f0f11de6307 +guid: ff9a4fb150ec44c1dae2f2c249a05286 TextureImporter: - fileIDToRecycleName: {} + internalIDToNameTable: [] externalObjects: {} - serializedVersion: 9 + serializedVersion: 11 mipmaps: mipMapMode: 0 enableMipMap: 0 @@ -23,6 +23,7 @@ TextureImporter: isReadable: 0 streamingMipmaps: 0 streamingMipmapsPriority: 0 + vTOnly: 0 grayScaleToAlpha: 0 generateCubemap: 6 cubemapConvolution: 0 @@ -31,12 +32,12 @@ TextureImporter: maxTextureSize: 2048 textureSettings: serializedVersion: 2 - filterMode: -1 - aniso: -1 - mipBias: -100 + filterMode: 1 + aniso: 1 + mipBias: 0 wrapU: 1 wrapV: 1 - wrapW: -1 + wrapW: 0 nPOTScale: 0 lightmap: 0 compressionQuality: 50 @@ -54,11 +55,15 @@ TextureImporter: textureType: 8 textureShape: 1 singleChannelComponent: 0 + flipbookRows: 1 + flipbookColumns: 1 maxTextureSizeSet: 0 compressionQualitySet: 0 textureFormatSet: 0 + ignorePngGamma: 0 + applyGammaDecoding: 0 platformSettings: - - serializedVersion: 2 + - serializedVersion: 3 buildTarget: DefaultTexturePlatform maxTextureSize: 2048 resizeAlgorithm: 0 @@ -69,7 +74,8 @@ TextureImporter: allowsAlphaSplitting: 0 overridden: 0 androidETC2FallbackOverride: 0 - - serializedVersion: 2 + forceMaximumCompressionQuality_BC6H_BC7: 0 + - serializedVersion: 3 buildTarget: Standalone maxTextureSize: 2048 resizeAlgorithm: 0 @@ -80,39 +86,20 @@ TextureImporter: allowsAlphaSplitting: 0 overridden: 0 androidETC2FallbackOverride: 0 - - serializedVersion: 2 - buildTarget: iPhone - maxTextureSize: 2048 - resizeAlgorithm: 0 - textureFormat: -1 - textureCompression: 1 - compressionQuality: 50 - crunchedCompression: 0 - allowsAlphaSplitting: 0 - overridden: 0 - androidETC2FallbackOverride: 0 - - serializedVersion: 2 - buildTarget: Android - maxTextureSize: 2048 - resizeAlgorithm: 0 - textureFormat: -1 - textureCompression: 1 - compressionQuality: 50 - crunchedCompression: 0 - allowsAlphaSplitting: 0 - overridden: 0 - androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 spriteSheet: serializedVersion: 2 sprites: [] outline: [] physicsShape: [] bones: [] - spriteID: b33f877521c6d4d8782f018141dc1d6a + spriteID: 5e97eb03825dee720800000000000000 + internalID: 0 vertices: [] indices: edges: [] weights: [] + secondaryTextures: [] spritePackingTag: pSDRemoveMatte: 0 pSDShowRemoveMatteOption: 0 diff --git a/Project/Assets/ML-Agents/Examples/SharedAssets/Materials/Textures/UnityLogo.png b/Project/Assets/ML-Agents/Examples/SharedAssets/Materials/Textures/UnityLogo.png deleted file mode 100644 index 7a955ee8ea..0000000000 Binary files a/Project/Assets/ML-Agents/Examples/SharedAssets/Materials/Textures/UnityLogo.png and /dev/null differ diff --git a/Project/Assets/ML-Agents/Examples/SharedAssets/Prefabs/Canvas_Watermark.prefab b/Project/Assets/ML-Agents/Examples/SharedAssets/Prefabs/Canvas_Watermark.prefab index 9dcc81245d..078274d921 100644 --- a/Project/Assets/ML-Agents/Examples/SharedAssets/Prefabs/Canvas_Watermark.prefab +++ b/Project/Assets/ML-Agents/Examples/SharedAssets/Prefabs/Canvas_Watermark.prefab @@ -1,22 +1,12 @@ %YAML 1.1 %TAG !u! tag:unity3d.com,2011: ---- !u!1001 &100100000 -Prefab: - m_ObjectHideFlags: 1 - serializedVersion: 2 - m_Modification: - m_TransformParent: {fileID: 0} - m_Modifications: [] - m_RemovedComponents: [] - m_ParentPrefab: {fileID: 0} - m_RootGameObject: {fileID: 1537641056927260} - m_IsPrefabParent: 1 --- !u!1 &1508578353888260 GameObject: m_ObjectHideFlags: 0 - m_PrefabParentObject: {fileID: 0} - m_PrefabInternal: {fileID: 100100000} - serializedVersion: 5 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + serializedVersion: 6 m_Component: - component: {fileID: 224796324260922368} - component: {fileID: 222875034646499690} @@ -28,76 +18,133 @@ GameObject: m_NavMeshLayer: 0 m_StaticEditorFlags: 0 m_IsActive: 1 ---- !u!1 &1537641056927260 -GameObject: +--- !u!224 &224796324260922368 +RectTransform: m_ObjectHideFlags: 0 - m_PrefabParentObject: {fileID: 0} - m_PrefabInternal: {fileID: 100100000} - serializedVersion: 5 - m_Component: - - component: {fileID: 224194346362733190} - - component: {fileID: 223703725700644330} - - component: {fileID: 114816648722094340} - - component: {fileID: 114595077744033850} - m_Layer: 5 - m_Name: Canvas_Watermark - m_TagString: Untagged - m_Icon: {fileID: 0} - m_NavMeshLayer: 0 - m_StaticEditorFlags: 0 - m_IsActive: 1 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 1508578353888260} + m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: 0} + m_LocalScale: {x: 0.3300893, y: 0.3300892, z: 0.3300892} + m_Children: [] + m_Father: {fileID: 224194346362733190} + m_RootOrder: 0 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} + m_AnchorMin: {x: 1, y: 1} + m_AnchorMax: {x: 1, y: 1} + m_AnchoredPosition: {x: -209, y: -116} + m_SizeDelta: {x: 715.7, y: 715.69995} + m_Pivot: {x: 0.5, y: 0.5} +--- !u!222 &222875034646499690 +CanvasRenderer: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 1508578353888260} + m_CullTransparentMesh: 1 --- !u!114 &114223610671736162 MonoBehaviour: - m_ObjectHideFlags: 1 - m_PrefabParentObject: {fileID: 0} - m_PrefabInternal: {fileID: 100100000} + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} m_GameObject: {fileID: 1508578353888260} m_Enabled: 1 m_EditorHideFlags: 0 - m_Script: {fileID: -765806418, guid: f70555f144d8491a825f0804e09c671c, type: 3} + m_Script: {fileID: 11500000, guid: fe87c0e1cc204ed48ad3b37840f39efc, type: 3} m_Name: m_EditorClassIdentifier: m_Material: {fileID: 0} m_Color: {r: 1, g: 1, b: 1, a: 1} m_RaycastTarget: 1 + m_RaycastPadding: {x: 0, y: 0, z: 0, w: 0} + m_Maskable: 1 m_OnCullStateChanged: m_PersistentCalls: m_Calls: [] - m_TypeName: UnityEngine.UI.MaskableGraphic+CullStateChangedEvent, UnityEngine.UI, - Version=1.0.0.0, Culture=neutral, PublicKeyToken=null - m_Sprite: {fileID: 21300000, guid: 2e85738fe64714cffbf72f0f11de6307, type: 3} + m_Sprite: {fileID: 21300000, guid: ff9a4fb150ec44c1dae2f2c249a05286, type: 3} m_Type: 0 - m_PreserveAspect: 0 + m_PreserveAspect: 1 m_FillCenter: 1 m_FillMethod: 4 m_FillAmount: 1 m_FillClockwise: 1 m_FillOrigin: 0 ---- !u!114 &114595077744033850 -MonoBehaviour: - m_ObjectHideFlags: 1 - m_PrefabParentObject: {fileID: 0} - m_PrefabInternal: {fileID: 100100000} + m_UseSpriteMesh: 0 + m_PixelsPerUnitMultiplier: 1 +--- !u!1 &1537641056927260 +GameObject: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + serializedVersion: 6 + m_Component: + - component: {fileID: 224194346362733190} + - component: {fileID: 223703725700644330} + - component: {fileID: 114816648722094340} + - component: {fileID: 114595077744033850} + m_Layer: 5 + m_Name: Canvas_Watermark + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!224 &224194346362733190 +RectTransform: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 1537641056927260} + m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: 0} + m_LocalScale: {x: 0, y: 0, z: 0} + m_Children: + - {fileID: 224796324260922368} + m_Father: {fileID: 0} + m_RootOrder: 0 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} + m_AnchorMin: {x: 0, y: 0} + m_AnchorMax: {x: 0, y: 0} + m_AnchoredPosition: {x: 0, y: 0} + m_SizeDelta: {x: 0, y: 0} + m_Pivot: {x: 0, y: 0} +--- !u!223 &223703725700644330 +Canvas: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} m_GameObject: {fileID: 1537641056927260} m_Enabled: 1 - m_EditorHideFlags: 0 - m_Script: {fileID: 1301386320, guid: f70555f144d8491a825f0804e09c671c, type: 3} - m_Name: - m_EditorClassIdentifier: - m_IgnoreReversedGraphics: 1 - m_BlockingObjects: 0 - m_BlockingMask: - serializedVersion: 2 - m_Bits: 4294967295 + serializedVersion: 3 + m_RenderMode: 0 + m_Camera: {fileID: 0} + m_PlaneDistance: 100 + m_PixelPerfect: 0 + m_ReceivesEvents: 1 + m_OverrideSorting: 0 + m_OverridePixelPerfect: 0 + m_SortingBucketNormalizedSize: 0 + m_AdditionalShaderChannelsFlag: 0 + m_SortingLayerID: 0 + m_SortingOrder: 0 + m_TargetDisplay: 0 --- !u!114 &114816648722094340 MonoBehaviour: - m_ObjectHideFlags: 1 - m_PrefabParentObject: {fileID: 0} - m_PrefabInternal: {fileID: 100100000} + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} m_GameObject: {fileID: 1537641056927260} m_Enabled: 1 m_EditorHideFlags: 0 - m_Script: {fileID: 1980459831, guid: f70555f144d8491a825f0804e09c671c, type: 3} + m_Script: {fileID: 11500000, guid: 0cd44c1031e13a943bb63640046fad76, type: 3} m_Name: m_EditorClassIdentifier: m_UiScaleMode: 1 @@ -110,66 +157,21 @@ MonoBehaviour: m_FallbackScreenDPI: 96 m_DefaultSpriteDPI: 96 m_DynamicPixelsPerUnit: 1 ---- !u!222 &222875034646499690 -CanvasRenderer: - m_ObjectHideFlags: 1 - m_PrefabParentObject: {fileID: 0} - m_PrefabInternal: {fileID: 100100000} - m_GameObject: {fileID: 1508578353888260} ---- !u!223 &223703725700644330 -Canvas: - m_ObjectHideFlags: 1 - m_PrefabParentObject: {fileID: 0} - m_PrefabInternal: {fileID: 100100000} + m_PresetInfoIsWorld: 0 +--- !u!114 &114595077744033850 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} m_GameObject: {fileID: 1537641056927260} m_Enabled: 1 - serializedVersion: 3 - m_RenderMode: 0 - m_Camera: {fileID: 0} - m_PlaneDistance: 100 - m_PixelPerfect: 0 - m_ReceivesEvents: 1 - m_OverrideSorting: 0 - m_OverridePixelPerfect: 0 - m_SortingBucketNormalizedSize: 0 - m_AdditionalShaderChannelsFlag: 0 - m_SortingLayerID: 0 - m_SortingOrder: 0 - m_TargetDisplay: 0 ---- !u!224 &224194346362733190 -RectTransform: - m_ObjectHideFlags: 1 - m_PrefabParentObject: {fileID: 0} - m_PrefabInternal: {fileID: 100100000} - m_GameObject: {fileID: 1537641056927260} - m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} - m_LocalPosition: {x: 0, y: 0, z: 0} - m_LocalScale: {x: 0, y: 0, z: 0} - m_Children: - - {fileID: 224796324260922368} - m_Father: {fileID: 0} - m_RootOrder: 0 - m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} - m_AnchorMin: {x: 0, y: 0} - m_AnchorMax: {x: 0, y: 0} - m_AnchoredPosition: {x: 0, y: 0} - m_SizeDelta: {x: 0, y: 0} - m_Pivot: {x: 0, y: 0} ---- !u!224 &224796324260922368 -RectTransform: - m_ObjectHideFlags: 1 - m_PrefabParentObject: {fileID: 0} - m_PrefabInternal: {fileID: 100100000} - m_GameObject: {fileID: 1508578353888260} - m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} - m_LocalPosition: {x: 0, y: 0, z: 0} - m_LocalScale: {x: 0.5588671, y: 0.558867, z: 0.558867} - m_Children: [] - m_Father: {fileID: 224194346362733190} - m_RootOrder: 0 - m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} - m_AnchorMin: {x: 1, y: 1} - m_AnchorMax: {x: 1, y: 1} - m_AnchoredPosition: {x: -209, y: -116} - m_SizeDelta: {x: 715.7, y: 715.69995} - m_Pivot: {x: 0.5, y: 0.5} + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: dc42784cf147c0c48a680349fa168899, type: 3} + m_Name: + m_EditorClassIdentifier: + m_IgnoreReversedGraphics: 1 + m_BlockingObjects: 0 + m_BlockingMask: + serializedVersion: 2 + m_Bits: 4294967295 diff --git a/Project/Packages/manifest.json b/Project/Packages/manifest.json index ea720f3454..7395e3d6c9 100644 --- a/Project/Packages/manifest.json +++ b/Project/Packages/manifest.json @@ -6,7 +6,7 @@ "com.unity.ml-agents.extensions": "file:../../com.unity.ml-agents.extensions", "com.unity.nuget.newtonsoft-json": "2.0.0", "com.unity.test-framework": "1.1.29", - "com.unity.toolchain.macos-x86_64-linux-x86_64": "0.1.20-preview", + "com.unity.toolchain.macos-x86_64-linux-x86_64": "2.0.3", "com.unity.ugui": "1.0.0", "com.unity.modules.imageconversion": "1.0.0", "com.unity.modules.jsonserialize": "1.0.0", diff --git a/Project/Packages/packages-lock.json b/Project/Packages/packages-lock.json index 7d765c0063..b92d60ef54 100644 --- a/Project/Packages/packages-lock.json +++ b/Project/Packages/packages-lock.json @@ -1,7 +1,7 @@ { "dependencies": { "com.unity.barracuda": { - "version": "2.3.1-preview", + "version": "3.0.0", "depth": 1, "source": "registry", "dependencies": { @@ -12,7 +12,7 @@ "url": "https://packages.unity.com" }, "com.unity.burst": { - "version": "1.6.0", + "version": "1.6.6", "depth": 2, "source": "registry", "dependencies": { @@ -44,7 +44,7 @@ "url": "https://packages.unity.com" }, "com.unity.mathematics": { - "version": "1.2.1", + "version": "1.2.6", "depth": 3, "source": "registry", "dependencies": {}, @@ -55,7 +55,7 @@ "depth": 0, "source": "local", "dependencies": { - "com.unity.barracuda": "2.3.1-preview", + "com.unity.barracuda": "3.0.0", "com.unity.modules.imageconversion": "1.0.0", "com.unity.modules.jsonserialize": "1.0.0" } @@ -77,18 +77,18 @@ "url": "https://packages.unity.com" }, "com.unity.sysroot": { - "version": "0.1.19-preview", + "version": "2.0.4", "depth": 1, "source": "registry", "dependencies": {}, "url": "https://packages.unity.com" }, "com.unity.sysroot.linux-x86_64": { - "version": "0.1.14-preview", + "version": "2.0.3", "depth": 1, "source": "registry", "dependencies": { - "com.unity.sysroot": "0.1.18-preview" + "com.unity.sysroot": "2.0.4" }, "url": "https://packages.unity.com" }, @@ -104,12 +104,12 @@ "url": "https://packages.unity.com" }, "com.unity.toolchain.macos-x86_64-linux-x86_64": { - "version": "0.1.20-preview", + "version": "2.0.3", "depth": 0, "source": "registry", "dependencies": { - "com.unity.sysroot": "0.1.19-preview", - "com.unity.sysroot.linux-x86_64": "0.1.14-preview" + "com.unity.sysroot": "2.0.4", + "com.unity.sysroot.linux-x86_64": "2.0.3" }, "url": "https://packages.unity.com" }, diff --git a/Project/ProjectSettings/ProjectSettings.asset b/Project/ProjectSettings/ProjectSettings.asset index 5895e58e3d..0685899de9 100644 --- a/Project/ProjectSettings/ProjectSettings.asset +++ b/Project/ProjectSettings/ProjectSettings.asset @@ -3,7 +3,7 @@ --- !u!129 &1 PlayerSettings: m_ObjectHideFlags: 0 - serializedVersion: 22 + serializedVersion: 23 productGUID: cd7e9a0e0d1d14312ad9e89757262f3b AndroidProfiler: 0 AndroidFilterTouchesWhenObscured: 0 @@ -145,23 +145,25 @@ PlayerSettings: enable360StereoCapture: 0 isWsaHolographicRemotingEnabled: 0 enableFrameTimingStats: 0 + enableOpenGLProfilerGPURecorders: 1 useHDRDisplay: 0 D3DHDRBitDepth: 0 m_ColorGamuts: 00000000 targetPixelDensity: 30 resolutionScalingMode: 0 + resetResolutionOnWindowResize: 0 androidSupportedAspectRatio: 1 androidMaxAspectRatio: 2.1 applicationIdentifier: Android: com.Company.ProductName - Standalone: com.UnityTechnologies.UnityEnvironment + Standalone: com.Unity-Technologies.UnityEnvironment buildNumber: Standalone: 0 iPhone: 0 tvOS: 0 overrideDefaultApplicationIdentifier: 0 AndroidBundleVersionCode: 1 - AndroidMinSdkVersion: 19 + AndroidMinSdkVersion: 22 AndroidTargetSdkVersion: 0 AndroidPreferredInstallLocation: 1 aotOptions: nimt-trampolines=1024 @@ -217,6 +219,7 @@ PlayerSettings: iOSLaunchScreeniPadCustomStoryboardPath: iOSDeviceRequirements: [] iOSURLSchemes: [] + macOSURLSchemes: [] iOSBackgroundModes: 0 iOSMetalForceHardShadows: 0 metalEditorSupport: 1 @@ -311,6 +314,9 @@ PlayerSettings: - m_BuildTarget: iOSSupport m_APIs: 10000000 m_Automatic: 1 + - m_BuildTarget: AndroidPlayer + m_APIs: 0b00000008000000 + m_Automatic: 0 m_BuildTargetVRSettings: [] openGLRequireES31: 0 openGLRequireES31AEP: 0 @@ -329,6 +335,7 @@ PlayerSettings: m_EncodingQuality: 1 m_BuildTargetGroupLightmapSettings: [] m_BuildTargetNormalMapEncoding: [] + m_BuildTargetDefaultTextureCompressionFormat: [] playModeTestRunnerEnabled: 0 runPlayModeTestAsEditModeTest: 0 actionOnDotNetUnhandledException: 1 @@ -347,6 +354,7 @@ PlayerSettings: switchScreenResolutionBehavior: 2 switchUseCPUProfiler: 0 switchUseGOLDLinker: 0 + switchLTOSetting: 0 switchApplicationID: 0x0005000C10000001 switchNSODependencies: switchTitleNames_0: @@ -477,7 +485,9 @@ PlayerSettings: switchPlayerConnectionEnabled: 1 switchUseNewStyleFilepaths: 0 switchUseMicroSleepForYield: 1 + switchEnableRamDiskSupport: 0 switchMicroSleepForYieldTime: 25 + switchRamDiskSpaceSize: 12 ps4NPAgeRating: 12 ps4NPTitleSecret: ps4NPTrophyPackPath: @@ -574,18 +584,15 @@ PlayerSettings: webGLThreadsSupport: 0 webGLDecompressionFallback: 0 scriptingDefineSymbols: - 1: - 7: UNITY_POST_PROCESSING_STACK_V2 - 13: UNITY_POST_PROCESSING_STACK_V2 - 14: UNITY_POST_PROCESSING_STACK_V2 - 17: UNITY_POST_PROCESSING_STACK_V2 - 18: UNITY_POST_PROCESSING_STACK_V2 - 19: UNITY_POST_PROCESSING_STACK_V2 - 21: UNITY_POST_PROCESSING_STACK_V2 - 23: UNITY_POST_PROCESSING_STACK_V2 - 25: UNITY_POST_PROCESSING_STACK_V2 - 26: UNITY_POST_PROCESSING_STACK_V2 - 27: UNITY_POST_PROCESSING_STACK_V2 + : UNITY_POST_PROCESSING_STACK_V2 + Android: UNITY_POST_PROCESSING_STACK_V2 + Nintendo Switch: UNITY_POST_PROCESSING_STACK_V2 + PS4: UNITY_POST_PROCESSING_STACK_V2 + Standalone: + WebGL: UNITY_POST_PROCESSING_STACK_V2 + Windows Store Apps: UNITY_POST_PROCESSING_STACK_V2 + XboxOne: UNITY_POST_PROCESSING_STACK_V2 + tvOS: UNITY_POST_PROCESSING_STACK_V2 additionalCompilerArguments: {} platformArchitecture: {} scriptingBackend: {} @@ -595,7 +602,6 @@ PlayerSettings: suppressCommonWarnings: 1 allowUnsafeCode: 0 useDeterministicCompilation: 1 - useReferenceAssemblies: 1 enableRoslynAnalyzers: 1 additionalIl2CppArgs: scriptingRuntimeVersion: 1 @@ -633,6 +639,7 @@ PlayerSettings: metroFTAName: metroFTAFileTypes: [] metroProtocolName: + vcxProjDefaultLanguage: XboxOneProductId: XboxOneUpdateKey: XboxOneSandboxId: @@ -682,4 +689,6 @@ PlayerSettings: organizationId: cloudEnabled: 0 legacyClampBlendShapeWeights: 1 + playerDataPath: + forceSRGBBlit: 1 virtualTexturingSupportEnabled: 0 diff --git a/Project/ProjectSettings/ProjectVersion.txt b/Project/ProjectSettings/ProjectVersion.txt index 4c9401b919..8ea1b855ae 100644 --- a/Project/ProjectSettings/ProjectVersion.txt +++ b/Project/ProjectSettings/ProjectVersion.txt @@ -1,2 +1,2 @@ -m_EditorVersion: 2020.3.25f1 -m_EditorVersionWithRevision: 2020.3.25f1 (9b9180224418) +m_EditorVersion: 2021.3.11f1 +m_EditorVersionWithRevision: 2021.3.11f1 (0a5ca18544bf) diff --git a/README.md b/README.md deleted file mode 100644 index 4227a3424d..0000000000 --- a/README.md +++ /dev/null @@ -1,189 +0,0 @@ - - -# Unity ML-Agents Toolkit - -[![docs badge](https://img.shields.io/badge/docs-reference-blue.svg)](https://github.com/Unity-Technologies/ml-agents/tree/release_19_docs/docs/) - -[![license badge](https://img.shields.io/badge/license-Apache--2.0-green.svg)](LICENSE.md) - -([latest release](https://github.com/Unity-Technologies/ml-agents/releases/tag/latest_release)) -([all releases](https://github.com/Unity-Technologies/ml-agents/releases)) - -**The Unity Machine Learning Agents Toolkit** (ML-Agents) is an open-source -project that enables games and simulations to serve as environments for -training intelligent agents. We provide implementations (based on PyTorch) -of state-of-the-art algorithms to enable game developers and hobbyists to easily -train intelligent agents for 2D, 3D and VR/AR games. Researchers can also use the -provided simple-to-use Python API to train Agents using reinforcement learning, -imitation learning, neuroevolution, or any other methods. These trained agents can be -used for multiple purposes, including controlling NPC behavior (in a variety of -settings such as multi-agent and adversarial), automated testing of game builds -and evaluating different game design decisions pre-release. The ML-Agents -Toolkit is mutually beneficial for both game developers and AI researchers as it -provides a central platform where advances in AI can be evaluated on Unity’s -rich environments and then made accessible to the wider research and game -developer communities. - -## Features - -- 18+ [example Unity environments](docs/Learning-Environment-Examples.md) -- Support for multiple environment configurations and training scenarios -- Flexible Unity SDK that can be integrated into your game or custom Unity scene -- Support for training single-agent, multi-agent cooperative, and multi-agent - competitive scenarios via several Deep Reinforcement Learning algorithms (PPO, SAC, MA-POCA, self-play). -- Support for learning from demonstrations through two Imitation Learning algorithms (BC and GAIL). -- Easily definable Curriculum Learning scenarios for complex tasks -- Train robust agents using environment randomization -- Flexible agent control with On Demand Decision Making -- Train using multiple concurrent Unity environment instances -- Utilizes the [Unity Inference Engine](docs/Unity-Inference-Engine.md) to - provide native cross-platform support -- Unity environment [control from Python](docs/Python-LLAPI.md) -- Wrap Unity learning environments as a [gym](docs/Python-Gym-API.md) - -See our [ML-Agents Overview](docs/ML-Agents-Overview.md) page for detailed -descriptions of all these features. - -## Releases & Documentation - -**Our latest, stable release is `Release 19`. Click -[here](https://github.com/Unity-Technologies/ml-agents/tree/release_19_docs/docs/Readme.md) -to get started with the latest release of ML-Agents.** - -The table below lists all our releases, including our `main` branch which is -under active development and may be unstable. A few helpful guidelines: -- The [Versioning page](docs/Versioning.md) overviews how we manage our GitHub - releases and the versioning process for each of the ML-Agents components. -- The [Releases page](https://github.com/Unity-Technologies/ml-agents/releases) - contains details of the changes between releases. -- The [Migration page](docs/Migrating.md) contains details on how to upgrade - from earlier releases of the ML-Agents Toolkit. -- The **Documentation** links in the table below include installation and usage - instructions specific to each release. Remember to always use the - documentation that corresponds to the release version you're using. -- The `com.unity.ml-agents` package is [verified](https://docs.unity3d.com/2020.1/Documentation/Manual/pack-safe.html) - for Unity 2020.1 and later. Verified packages releases are numbered 1.0.x. - -| **Version** | **Release Date** | **Source** | **Documentation** | **Download** | **Python Package** | **Unity Package** | -|:--------------------------:|:--------------------:|:--------------------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------:|:-------------------------------------------------------:|:----------------------------------------------------------------------------------------:| -| **main (unstable)** | -- | [source](https://github.com/Unity-Technologies/ml-agents/tree/main) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/main/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/main.zip) | -- | -- | -| **Release 19** | **January 14, 2022** | **[source](https://github.com/Unity-Technologies/ml-agents/tree/release_19)** | **[docs](https://github.com/Unity-Technologies/ml-agents/tree/release_19_docs/docs/Readme.md)** | **[download](https://github.com/Unity-Technologies/ml-agents/archive/release_19.zip)** | **[0.28.0](https://pypi.org/project/mlagents/0.28.0/)** | -- | -| **Release 18** | **June 09, 2021** | **[source](https://github.com/Unity-Technologies/ml-agents/tree/release_18)** | **[docs](https://github.com/Unity-Technologies/ml-agents/tree/release_18_docs/docs/Readme.md)** | **[download](https://github.com/Unity-Technologies/ml-agents/archive/release_18.zip)** | **[0.27.0](https://pypi.org/project/mlagents/0.27.0/)** | **[2.1.0](https://docs.unity3d.com/Packages/com.unity.ml-agents@2.1/manual/index.html)** | -| **Verified Package 1.0.8** | **May 26, 2021** | **[source](https://github.com/Unity-Technologies/ml-agents/tree/com.unity.ml-agents_1.0.8)** | **[docs](https://github.com/Unity-Technologies/ml-agents/blob/release_2_verified_docs/docs/Readme.md)** | **[download](https://github.com/Unity-Technologies/ml-agents/archive/com.unity.ml-agents_1.0.8.zip)** | **[0.16.1](https://pypi.org/project/mlagents/0.16.1/)** | **[1.0.8](https://docs.unity3d.com/Packages/com.unity.ml-agents@1.0/manual/index.html)** | - -If you are a researcher interested in a discussion of Unity as an AI platform, -see a pre-print of our -[reference paper on Unity and the ML-Agents Toolkit](https://arxiv.org/abs/1809.02627). - -If you use Unity or the ML-Agents Toolkit to conduct research, we ask that you -cite the following paper as a reference: - -Juliani, A., Berges, V., Teng, E., Cohen, A., Harper, J., Elion, C., Goy, C., -Gao, Y., Henry, H., Mattar, M., Lange, D. (2020). Unity: A General Platform for -Intelligent Agents. _arXiv preprint -[arXiv:1809.02627](https://arxiv.org/abs/1809.02627)._ -https://github.com/Unity-Technologies/ml-agents. - -## Additional Resources - -We have a Unity Learn course, -[ML-Agents: Hummingbirds](https://learn.unity.com/course/ml-agents-hummingbirds), -that provides a gentle introduction to Unity and the ML-Agents Toolkit. - -We've also partnered with -[CodeMonkeyUnity](https://www.youtube.com/c/CodeMonkeyUnity) to create a -[series of tutorial videos](https://www.youtube.com/playlist?list=PLzDRvYVwl53vehwiN_odYJkPBzcqFw110) -on how to implement and use the ML-Agents Toolkit. - -We have also published a series of blog posts that are relevant for ML-Agents: - -- (July 12, 2021) - [ML-Agents plays Dodgeball](https://blog.unity.com/technology/ml-agents-plays-dodgeball) -- (May 5, 2021) - [ML-Agents v2.0 release: Now supports training complex cooperative behaviors](https://blogs.unity3d.com/2021/05/05/ml-agents-v2-0-release-now-supports-training-complex-cooperative-behaviors/) -- (December 28, 2020) - [Happy holidays from the Unity ML-Agents team!](https://blogs.unity3d.com/2020/12/28/happy-holidays-from-the-unity-ml-agents-team/) -- (November 20, 2020) - [How Eidos-Montréal created Grid Sensors to improve observations for training agents](https://blogs.unity3d.com/2020/11/20/how-eidos-montreal-created-grid-sensors-to-improve-observations-for-training-agents/) -- (November 11, 2020) - [2020 AI@Unity interns shoutout](https://blogs.unity3d.com/2020/11/11/2020-aiunity-interns-shoutout/) -- (May 12, 2020) - [Announcing ML-Agents Unity Package v1.0!](https://blogs.unity3d.com/2020/05/12/announcing-ml-agents-unity-package-v1-0/) -- (February 28, 2020) - [Training intelligent adversaries using self-play with ML-Agents](https://blogs.unity3d.com/2020/02/28/training-intelligent-adversaries-using-self-play-with-ml-agents/) -- (November 11, 2019) - [Training your agents 7 times faster with ML-Agents](https://blogs.unity3d.com/2019/11/11/training-your-agents-7-times-faster-with-ml-agents/) -- (October 21, 2019) - [The AI@Unity interns help shape the world](https://blogs.unity3d.com/2019/10/21/the-aiunity-interns-help-shape-the-world/) -- (April 15, 2019) - [Unity ML-Agents Toolkit v0.8: Faster training on real games](https://blogs.unity3d.com/2019/04/15/unity-ml-agents-toolkit-v0-8-faster-training-on-real-games/) -- (March 1, 2019) - [Unity ML-Agents Toolkit v0.7: A leap towards cross-platform inference](https://blogs.unity3d.com/2019/03/01/unity-ml-agents-toolkit-v0-7-a-leap-towards-cross-platform-inference/) -- (December 17, 2018) - [ML-Agents Toolkit v0.6: Improved usability of Brains and Imitation Learning](https://blogs.unity3d.com/2018/12/17/ml-agents-toolkit-v0-6-improved-usability-of-brains-and-imitation-learning/) -- (October 2, 2018) - [Puppo, The Corgi: Cuteness Overload with the Unity ML-Agents Toolkit](https://blogs.unity3d.com/2018/10/02/puppo-the-corgi-cuteness-overload-with-the-unity-ml-agents-toolkit/) -- (September 11, 2018) - [ML-Agents Toolkit v0.5, new resources for AI researchers available now](https://blogs.unity3d.com/2018/09/11/ml-agents-toolkit-v0-5-new-resources-for-ai-researchers-available-now/) -- (June 26, 2018) - [Solving sparse-reward tasks with Curiosity](https://blogs.unity3d.com/2018/06/26/solving-sparse-reward-tasks-with-curiosity/) -- (June 19, 2018) - [Unity ML-Agents Toolkit v0.4 and Udacity Deep Reinforcement Learning Nanodegree](https://blogs.unity3d.com/2018/06/19/unity-ml-agents-toolkit-v0-4-and-udacity-deep-reinforcement-learning-nanodegree/) -- (May 24, 2018) - [Imitation Learning in Unity: The Workflow](https://blogs.unity3d.com/2018/05/24/imitation-learning-in-unity-the-workflow/) -- (March 15, 2018) - [ML-Agents Toolkit v0.3 Beta released: Imitation Learning, feedback-driven features, and more](https://blogs.unity3d.com/2018/03/15/ml-agents-v0-3-beta-released-imitation-learning-feedback-driven-features-and-more/) -- (December 11, 2017) - [Using Machine Learning Agents in a real game: a beginner’s guide](https://blogs.unity3d.com/2017/12/11/using-machine-learning-agents-in-a-real-game-a-beginners-guide/) -- (December 8, 2017) - [Introducing ML-Agents Toolkit v0.2: Curriculum Learning, new environments, and more](https://blogs.unity3d.com/2017/12/08/introducing-ml-agents-v0-2-curriculum-learning-new-environments-and-more/) -- (September 19, 2017) - [Introducing: Unity Machine Learning Agents Toolkit](https://blogs.unity3d.com/2017/09/19/introducing-unity-machine-learning-agents/) -- Overviewing reinforcement learning concepts - ([multi-armed bandit](https://blogs.unity3d.com/2017/06/26/unity-ai-themed-blog-entries/) - and - [Q-learning](https://blogs.unity3d.com/2017/08/22/unity-ai-reinforcement-learning-with-q-learning/)) - -### More from Unity - -- [Unity Robotics](https://github.com/Unity-Technologies/Unity-Robotics-Hub) -- [Unity Computer Vision](https://unity.com/computer-vision) -- [Unity Game Simulation](https://unity.com/products/game-simulation) - -## Community and Feedback - -The ML-Agents Toolkit is an open-source project and we encourage and welcome -contributions. If you wish to contribute, be sure to review our -[contribution guidelines](com.unity.ml-agents/CONTRIBUTING.md) and -[code of conduct](CODE_OF_CONDUCT.md). - -For problems with the installation and setup of the ML-Agents Toolkit, or -discussions about how to best setup or train your agents, please create a new -thread on the -[Unity ML-Agents forum](https://forum.unity.com/forums/ml-agents.453/) and make -sure to include as much detail as possible. If you run into any other problems -using the ML-Agents Toolkit or have a specific feature request, please -[submit a GitHub issue](https://github.com/Unity-Technologies/ml-agents/issues). - -Please tell us which samples you would like to see shipped with the ML-Agents Unity -package by replying to -[this forum thread](https://forum.unity.com/threads/feedback-wanted-shipping-sample-s-with-the-ml-agents-package.1073468/). - - -Your opinion matters a great deal to us. Only by hearing your thoughts on the -Unity ML-Agents Toolkit can we continue to improve and grow. Please take a few -minutes to -[let us know about it](https://unitysoftware.co1.qualtrics.com/jfe/form/SV_55pQKCZ578t0kbc). - -For any other questions or feedback, connect directly with the ML-Agents team at -ml-agents@unity3d.com. - -## Privacy - -In order to improve the developer experience for Unity ML-Agents Toolkit, we have added in-editor analytics. -Please refer to "Information that is passively collected by Unity" in the -[Unity Privacy Policy](https://unity3d.com/legal/privacy-policy). - -## License - -[Apache License 2.0](LICENSE.md) diff --git a/SURVEY.md b/SURVEY.md index 8523a19d81..1eb6bb1b7b 100644 --- a/SURVEY.md +++ b/SURVEY.md @@ -2,6 +2,6 @@ Your opinion matters a great deal to us. Only by hearing your thoughts on the Unity ML-Agents Toolkit can we continue to improve and grow. Please take a few -minutes to let us know about it. +minutes to let us know about it. Please email us at [ml-agents@unity3d.com](mailto:ml-agents@unity3d.com). -[Fill out the survey](https://goo.gl/forms/qFMYSYr5TlINvG6f1) + diff --git a/colab/Colab_UnityEnvironment_1_Run.ipynb b/colab/Colab_UnityEnvironment_1_Run.ipynb index f27fc3ff5d..d431c3d14f 100644 --- a/colab/Colab_UnityEnvironment_1_Run.ipynb +++ b/colab/Colab_UnityEnvironment_1_Run.ipynb @@ -145,7 +145,7 @@ " import mlagents\n", " print(\"ml-agents already installed\")\n", "except ImportError:\n", - " !python -m pip install -q mlagents==0.28.0\n", + " !python -m pip install -q mlagents==0.29.0\n", " print(\"Installed ml-agents\")" ], "execution_count": null, @@ -500,4 +500,4 @@ "outputs": [] } ] -} \ No newline at end of file +} diff --git a/colab/Colab_UnityEnvironment_2_Train.ipynb b/colab/Colab_UnityEnvironment_2_Train.ipynb index 2c6cb798a1..dda37a4e18 100644 --- a/colab/Colab_UnityEnvironment_2_Train.ipynb +++ b/colab/Colab_UnityEnvironment_2_Train.ipynb @@ -135,7 +135,7 @@ " import mlagents\n", " print(\"ml-agents already installed\")\n", "except ImportError:\n", - " !python -m pip install -q mlagents==0.28.0\n", + " !python -m pip install -q mlagents==0.29.0\n", " print(\"Installed ml-agents\")" ], "execution_count": null, @@ -686,4 +686,4 @@ "outputs": [] } ] -} \ No newline at end of file +} diff --git a/colab/Colab_UnityEnvironment_3_SideChannel.ipynb b/colab/Colab_UnityEnvironment_3_SideChannel.ipynb index 735a23380a..c053fff759 100644 --- a/colab/Colab_UnityEnvironment_3_SideChannel.ipynb +++ b/colab/Colab_UnityEnvironment_3_SideChannel.ipynb @@ -136,7 +136,7 @@ " import mlagents\n", " print(\"ml-agents already installed\")\n", "except ImportError:\n", - " !python -m pip install -q mlagents==0.28.0\n", + " !python -m pip install -q mlagents==0.29.0\n", " print(\"Installed ml-agents\")" ], "execution_count": null, @@ -290,4 +290,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/colab/Colab_UnityEnvironment_4_SB3VectorEnv.ipynb b/colab/Colab_UnityEnvironment_4_SB3VectorEnv.ipynb index 24685404be..8c2e671c05 100644 --- a/colab/Colab_UnityEnvironment_4_SB3VectorEnv.ipynb +++ b/colab/Colab_UnityEnvironment_4_SB3VectorEnv.ipynb @@ -1,20 +1,4 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Colab-UnityEnvironment-1-Run.ipynb", - "private_outputs": true, - "provenance": [], - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "name": "python3", - "language": "python", - "display_name": "Python 3" - } - }, "cells": [ { "cell_type": "markdown", @@ -38,6 +22,11 @@ { "cell_type": "code", "execution_count": null, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "#@title Install Rendering Dependencies { display-mode: \"form\" }\n", @@ -112,13 +101,7 @@ " !bash frame-buffer start\n", " os.environ[\"DISPLAY\"] = \":1\"\n", "pro_bar.update(progress(100, 100))" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "markdown", @@ -131,22 +114,22 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "N8yfQqkbebQ5", "pycharm": { "is_executing": true } }, + "outputs": [], "source": [ "try:\n", " import mlagents\n", " print(\"ml-agents already installed\")\n", "except ImportError:\n", - " !python -m pip install -q mlagents==0.28.0\n", + " !python -m pip install -q mlagents==0.29.0\n", " print(\"Installed ml-agents\")" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -168,38 +151,161 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "YSf-WhxbqtLw" }, + "outputs": [], "source": [ - "from math import ceil\n", + "from dataclasses import dataclass\n", + "from pathlib import Path\n", + "from typing import Callable, Any\n", + "\n", + "import gym\n", + "from gym import Env\n", "\n", "from stable_baselines3 import PPO\n", - "from stable_baselines3.common.vec_env import VecMonitor\n", + "from stable_baselines3.common.vec_env import VecMonitor, VecEnv, SubprocVecEnv\n", + "from supersuit import observation_lambda_v0\n", "\n", - "from mlagents_envs.envs.unity_vec_env import make_mla_sb3_env, LimitedConfig\n", "\n", - "# 250K should train to a reward ~= 0.90 for the \"Basic\" environment.\n", - "# We set the value lower here to demonstrate just a small amount of trianing.\n", - "TOTAL_TAINING_STEPS_GOAL = 40 * 1000\n", - "NUM_ENVS = 12\n", - "STEPS_PER_UPDATE = 2048" - ], - "execution_count": 29, - "outputs": [] + "from mlagents_envs.environment import UnityEnvironment\n", + "from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper\n", + "from mlagents_envs.registry import UnityEnvRegistry, default_registry\n", + "from mlagents_envs.side_channel.engine_configuration_channel import (\n", + " EngineConfig,\n", + " EngineConfigurationChannel,\n", + ")\n", + "\n", + "NUM_ENVS = 8" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Environment and Engine Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Default values from CLI (See cli_utils.py)\n", + "DEFAULT_ENGINE_CONFIG = EngineConfig(\n", + " width=84,\n", + " height=84,\n", + " quality_level=4,\n", + " time_scale=20,\n", + " target_frame_rate=-1,\n", + " capture_frame_rate=60,\n", + ")\n", + "\n", + "# Some config subset of an actual config.yaml file for MLA.\n", + "@dataclass\n", + "class LimitedConfig:\n", + " # The local path to a Unity executable or the name of an entry in the registry.\n", + " env_path_or_name: str\n", + " base_port: int\n", + " base_seed: int = 0\n", + " num_env: int = 1\n", + " engine_config: EngineConfig = DEFAULT_ENGINE_CONFIG\n", + " visual_obs: bool = False\n", + " # TODO: Decide if we should just tell users to always use MultiInputPolicy so we can simplify the user workflow.\n", + " # WARNING: Make sure to use MultiInputPolicy if you turn this on.\n", + " allow_multiple_obs: bool = False\n", + " env_registry: UnityEnvRegistry = default_registry" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Unity Environment SB3 Factory" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def _unity_env_from_path_or_registry(\n", + " env: str, registry: UnityEnvRegistry, **kwargs: Any\n", + ") -> UnityEnvironment:\n", + " if Path(env).expanduser().absolute().exists():\n", + " return UnityEnvironment(file_name=env, **kwargs)\n", + " elif env in registry:\n", + " return registry.get(env).make(**kwargs)\n", + " else:\n", + " raise ValueError(f\"Environment '{env}' wasn't a local path or registry entry\")\n", + " \n", + "def make_mla_sb3_env(config: LimitedConfig, **kwargs: Any) -> VecEnv:\n", + " def handle_obs(obs, space):\n", + " if isinstance(space, gym.spaces.Tuple):\n", + " if len(space) == 1:\n", + " return obs[0]\n", + " # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).\n", + " return {str(i): v for i, v in enumerate(obs)}\n", + " return obs\n", + "\n", + " def handle_obs_space(space):\n", + " if isinstance(space, gym.spaces.Tuple):\n", + " if len(space) == 1:\n", + " return space[0]\n", + " # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).\n", + " return gym.spaces.Dict({str(i): v for i, v in enumerate(space)})\n", + " return space\n", + "\n", + " def create_env(env: str, worker_id: int) -> Callable[[], Env]:\n", + " def _f() -> Env:\n", + " engine_configuration_channel = EngineConfigurationChannel()\n", + " engine_configuration_channel.set_configuration(config.engine_config)\n", + " kwargs[\"side_channels\"] = kwargs.get(\"side_channels\", []) + [\n", + " engine_configuration_channel\n", + " ]\n", + " unity_env = _unity_env_from_path_or_registry(\n", + " env=env,\n", + " registry=config.env_registry,\n", + " worker_id=worker_id,\n", + " base_port=config.base_port,\n", + " seed=config.base_seed + worker_id,\n", + " **kwargs,\n", + " )\n", + " new_env = UnityToGymWrapper(\n", + " unity_env=unity_env,\n", + " uint8_visual=config.visual_obs,\n", + " allow_multiple_obs=config.allow_multiple_obs,\n", + " )\n", + " new_env = observation_lambda_v0(new_env, handle_obs, handle_obs_space)\n", + " return new_env\n", + "\n", + " return _f\n", + "\n", + " env_facts = [\n", + " create_env(config.env_path_or_name, worker_id=x) for x in range(config.num_env)\n", + " ]\n", + " return SubprocVecEnv(env_facts)" + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ "### Start Environment from the registry" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "pycharm": { + "is_executing": true, + "name": "#%%\n" + } + }, "outputs": [], "source": [ "# -----------------\n", @@ -220,78 +326,84 @@ " ),\n", " no_graphics=True, # Set to false if you are running locally and want to watch the environments move around as they train.\n", ")" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n", - "is_executing": true - } - } + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ "### Create the model" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "pycharm": { + "is_executing": true, + "name": "#%%\n" + } + }, "outputs": [], "source": [ + "# 250K should train to a reward ~= 0.90 for the \"Basic\" environment.\n", + "# We set the value lower here to demonstrate just a small amount of trianing.\n", + "BATCH_SIZE = 32\n", + "BUFFER_SIZE = 256\n", + "UPDATES = 50\n", + "TOTAL_TAINING_STEPS_GOAL = BUFFER_SIZE * UPDATES\n", + "BETA = 0.0005\n", + "N_EPOCHS = 3 \n", + "STEPS_PER_UPDATE = BUFFER_SIZE / NUM_ENVS\n", + "\n", "# Helps gather stats for our eval() calls later so we can see reward stats.\n", "env = VecMonitor(env)\n", - "# Attempt to approximate settings from 3DBall.yaml\n", + "\n", + "#Policy and Value function with 2 layers of 128 units each and no shared layers.\n", + "policy_kwargs = {\"net_arch\" : [{\"pi\": [32,32], \"vf\": [32,32]}]}\n", + "\n", "model = PPO(\n", " \"MlpPolicy\",\n", " env,\n", " verbose=1,\n", - " learning_rate=lambda prog: 0.0003 * (1.0 - prog),\n", + " learning_rate=lambda progress: 0.0003 * (1.0 - progress),\n", + " clip_range=lambda progress: 0.2 * (1.0 - progress),\n", + " clip_range_vf=lambda progress: 0.2 * (1.0 - progress),\n", " # Uncomment this if you want to log tensorboard results when running this notebook locally.\n", " # tensorboard_log=\"results\",\n", + " policy_kwargs=policy_kwargs,\n", " n_steps=int(STEPS_PER_UPDATE),\n", + " batch_size=BATCH_SIZE,\n", + " n_epochs=N_EPOCHS,\n", + " ent_coef=BETA,\n", ")" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n", - "is_executing": true - } - } + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ "### Train the model" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "pycharm": { + "is_executing": true, + "name": "#%%\n" + } + }, "outputs": [], "source": [ - "training_rounds = ceil(TOTAL_TAINING_STEPS_GOAL / int(STEPS_PER_UPDATE * NUM_ENVS))\n", - "for i in range(training_rounds):\n", - " print(f\"Training round {i + 1}/{training_rounds}\")\n", + "# 0.93 is considered solved for the Basic environment\n", + "for i in range(UPDATES):\n", + " print(f\"Training round {i + 1}/{UPDATES}\")\n", " # NOTE: rest_num_timesteps should only happen the first time so that tensorboard logs are consistent.\n", - " model.learn(total_timesteps=6000, reset_num_timesteps=(i == 0))\n", + " model.learn(total_timesteps=BUFFER_SIZE, reset_num_timesteps=(i == 0))\n", " model.policy.eval()" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n", - "is_executing": true - } - } + ] }, { "cell_type": "markdown", @@ -305,6 +417,7 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "vdWG6_SqtNtv", "pycharm": { @@ -312,12 +425,46 @@ "name": "#%%\n" } }, + "outputs": [], "source": [ "env.close()\n", "print(\"Closed environment\")\n" - ], + ] + }, + { + "cell_type": "code", "execution_count": null, - "outputs": [] + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "Colab-UnityEnvironment-1-Run.ipynb", + "private_outputs": true, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.8" } - ] -} \ No newline at end of file + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/com.unity.ml-agents b/com.unity.ml-agents deleted file mode 160000 index 24ce875a0c..0000000000 --- a/com.unity.ml-agents +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 24ce875a0ce256bffd06d7c80e944459c5b9ba8d diff --git a/com.unity.ml-agents.extensions/package.json b/com.unity.ml-agents.extensions/package.json index 636102c730..436ba05296 100644 --- a/com.unity.ml-agents.extensions/package.json +++ b/com.unity.ml-agents.extensions/package.json @@ -2,10 +2,10 @@ "name": "com.unity.ml-agents.extensions", "displayName": "ML Agents Extensions", "version": "0.6.1-preview", - "unity": "2020.3", + "unity": "2021.3", "description": "A source-only package for new features based on ML-Agents", "dependencies": { - "com.unity.ml-agents": "2.2.1-exp.1", + "com.unity.ml-agents": "2.3.0-exp.3", "com.unity.modules.physics": "1.0.0" } } diff --git a/com.unity.ml-agents/.gitignore b/com.unity.ml-agents/.gitignore new file mode 100755 index 0000000000..b40e78e61a --- /dev/null +++ b/com.unity.ml-agents/.gitignore @@ -0,0 +1,30 @@ +artifacts/** +build/** +.build_script/** +node_modules/** +.DS_Store +.npmrc +!Documentation~ +!.Documentation +npm-debug.log +build.sh.meta +build.bat.meta +.idea/ +!Samples/*/*.unitypackage + +/[Ll]ibrary/ +/Logs/ +/[Tt]emp/ +/[Oo]bj/ +/[Bb]uild/ +/[Bb]uilds/ +/Assets/AssetStoreTools* +/Assets/Plugins* +/Assets/Demonstrations* +/csharp_timers.json + +# Visual Studio 2015 cache directory +/.vs/ + +*.api +*.api.meta diff --git a/com.unity.ml-agents/.npmignore b/com.unity.ml-agents/.npmignore new file mode 100755 index 0000000000..0f2d8d322c --- /dev/null +++ b/com.unity.ml-agents/.npmignore @@ -0,0 +1,20 @@ +artifacts/** +build/** +.build_script/** +node_modules/** +Documentation/ApiDocs/** +Documentation~/ApiDocs/** +.DS_Store +.npmrc +.npmignore +.gitignore +CONTRIBUTING.md +CONTRIBUTING.md.meta +QAReport.md +QAReport.md.meta +.gitlab-ci.yml +build.sh +build.sh.meta +build.bat +build.bat.meta +upm-ci.log diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md new file mode 100755 index 0000000000..45deb61b80 --- /dev/null +++ b/com.unity.ml-agents/CHANGELOG.md @@ -0,0 +1,902 @@ +# Changelog + +All notable changes to this package will be documented in this file. + +The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) +and this project adheres to +[Semantic Versioning](http://semver.org/spec/v2.0.0.html). + +## [2.3.0-exp.3] - 2022-11-21 +### Major Changes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- The minimum supported Unity version was updated to 2021.3. (#) + +#### ml-agents / ml-agents-envs +- Add your trainers to the package using Ml-Agents Custom Trainers plugin. (#) + - ML-Agents Custom Trainers plugin is an extensible plugin system to define new trainers based on the + High level trainer API, read more [here](../docs/Python-Custom-Trainer-Plugin.md). +- Refactored core modules to make ML-Agents internal classes more generalizable to various RL algorithms. (#) +- The minimum supported Python version for ML-agents has changed to 3.8.13. (#) +- The minimum supported version of PyTorch was changed to 1.8.0. (#) +- Add shared critic configurability for PPO. (#) +- We moved `UnityToGymWrapper` and `PettingZoo` API to `ml-agents-envs` package. All these environments will be +versioned under `ml-agents-envs` package in the future (#) + +### Minor Changes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- Added switch to RayPerceptionSensor to allow rays to be ordered left to right. (#26) + - Current alternating order is still the default but will be deprecated. +- Added suppport for enabling/disabling camera object attached to camera sensor in order to improve performance. (#31) + +#### ml-agents / ml-agents-envs +- Renaming the path that shadows torch with "mlagents/trainers/torch_entities" and update respective imports (#) + + +### Bug Fixes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +#### ml-agents / ml-agents-envs + + +## [2.3.0-exp.2] - 2022-03-28 +### Major Changes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +#### ml-agents / ml-agents-envs +- Refactored to support the new ML-Agents Pro package. +- The minimum supported Python version for ML-Agents-envs is changed to 3.7.2 (#) +- Added support for the PettingZoo multi-agent API (#) +- Refactored `gym-unity` into the `ml-agents-envs` package (#) + +### Minor Changes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- Upgrade barracuda dependency to 3.0.0 (#) +#### ml-agents / ml-agents-envs +- Added the new unity_vec_env file to the ml-agents-envs module +- Extended support to python 3.9.10 + +### Bug Fixes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +#### ml-agents / ml-agents-envs + +## [2.2.1-exp.1] - 2022-01-14 +### Major Changes + +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- The minimum supported Unity version was updated to 2020.3. (#5673) +- Added a new feature to replicate training areas dynamically during runtime. (#5568) +- Update Barracuda to 2.3.1-preview (#5591) +- Update Input System to 1.3.0 (#5661) + +#### ml-agents / ml-agents-envs / gym-unity (Python) + +### Minor Changes + +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- Added the capacity to initialize behaviors from any checkpoint and not just the latest one (#5525) +- Added the ability to get a read-only view of the stacked observations (#5523) + +#### ml-agents / ml-agents-envs / gym-unity (Python) +- Set gym version in gym-unity to gym release 0.20.0 (#5540) +- Added support for having `beta`, `epsilon`, and `learning rate` on separate schedules (affects only PPO and POCA). (#5538) +- Changed default behavior to restart crashed Unity environments rather than exiting. (#5553) + - Rate & lifetime limits on this are configurable via 3 new yaml options + 1. env_params.max_lifetime_restarts (--max-lifetime-restarts) [default=10] + 2. env_params.restarts_rate_limit_n (--restarts-rate-limit-n) [default=1] + 3. env_params.restarts_rate_limit_period_s (--restarts-rate-limit-period-s) [default=60] +- Deterministic action selection is now supported during training and inference(#5619) + - Added a new `--deterministic` cli flag to deterministically select the most probable actions in policy. The same thing can + be achieved by adding `deterministic: true` under `network_settings` of the run options configuration.(#5597) + - Extra tensors are now serialized to support deterministic action selection in onnx. (#5593) + - Support inference with deterministic action selection in editor (#5599) +- Added minimal analytics collection to LL-API (#5511) +- Update Colab notebooks for GridWorld example with DQN illustrating the use of the Python API and how to export to ONNX (#5643) + +### Bug Fixes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- Update gRPC native lib to universal for arm64 and x86_64. This change should enable ml-agents usage on mac M1 (#5283, #5519) +- Fixed a bug where ml-agents code wouldn't compile on platforms that didn't support analytics (PS4/5, XBoxOne) (#5628) + +#### ml-agents / ml-agents-envs / gym-unity (Python) +- Fixed a bug where the critics were not being normalized during training. (#5595) +- Fixed the bug where curriculum learning would crash because of the incorrect run_options parsing. (#5586) +- Fixed a bug in multi-agent cooperative training where agents might not receive all of the states of +terminated teammates. (#5441) +- Fixed wrong attribute name in argparser for torch device option (#5433)(#5467) +- Fixed conflicting CLI and yaml options regarding resume & initialize_from (#5495) +- Fixed failing tests for gym-unity due to gym 0.20.0 release (#5540) +- Fixed a bug in VAIL where the variational bottleneck was not properly passing gradients (#5546) +- Harden user PII protection logic and extend TrainingAnalytics to expose detailed configuration parameters. (#5512) + +## [2.1.0-exp.1] - 2021-06-09 +### Minor Changes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- update Barracuda to 2.0.0-pre.3. (#5385) +- Fixed NullReferenceException when adding Behavior Parameters with no Agent. (#5382) +- Add stacking option in Editor for `VectorSensorComponent`. (#5376) + +#### ml-agents / ml-agents-envs / gym-unity (Python) +- Lock cattrs dependency version to 1.6. (#5397) +- Added a fully connected visual encoder for environments with very small image inputs. (#5351) +- Colab notebooks illustrating the use of the Python API are now part of the repository. (#5399) + +### Bug Fixes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- RigidBodySensorComponent now displays a warning if it's used in a way that won't generate useful observations. (#5387) +- Update the documentation with a note saying that `GridSensor` does not work in 2D environments. (#5396) +- Fixed an error where sensors would not reset properly before collecting the last observation at the end of an +episode. (#5375) + +#### ml-agents / ml-agents-envs / gym-unity (Python) +- The calculation of the target entropy of SAC with continuous actions was incorrect and has been fixed. (#5372) +- Fixed an issue where the histogram stats would not be reported correctly in TensorBoard. (#5410) +- Fixed error when importing models which use the ResNet encoder. (#5358) + + +## [2.0.0-exp.1] - 2021-04-22 +### Major Changes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- The minimum supported Unity version was updated to 2019.4. (#5166) +- Several breaking interface changes were made. See the +[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_17_docs/docs/Migrating.md) for more +details. +- Some methods previously marked as `Obsolete` have been removed. If you were using these methods, you need to replace them with their supported counterpart. +- The interface for disabling discrete actions in `IDiscreteActionMask` has changed. +`WriteMask(int branch, IEnumerable actionIndices)` was replaced with +`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. (#5060) +- IActuator now implements IHeuristicProvider. (#5110) +- `ISensor.GetObservationShape()` was removed, and `GetObservationSpec()` was added. The `ITypedSensor` +and `IDimensionPropertiesSensor` interfaces were removed. (#5127) +- `ISensor.GetCompressionType()` was removed, and `GetCompressionSpec()` was added. The `ISparseChannelSensor` +interface was removed. (#5164) +- The abstract method `SensorComponent.GetObservationShape()` was no longer being called, so it has been removed. (#5172) +- `SensorComponent.CreateSensor()` was replaced with `SensorComponent.CreateSensors()`, which returns an `ISensor[]`. (#5181) +- `Match3Sensor` was refactored to produce cell and special type observations separately, and `Match3SensorComponent` now +produces two `Match3Sensor`s (unless there are no special types). Previously trained models will have different observation +sizes and will need to be retrained. (#5181) +- The `AbstractBoard` class for integration with Match-3 games was changed to make it easier to support boards with +different sizes using the same model. For a summary of the interface changes, please see the Migration Guide. (##5189) +- Updated the Barracuda package to version `1.4.0-preview`(#5236) +- `GridSensor` has been refactored and moved to main package, with changes to both sensor interfaces and behaviors. +Exsisting GridSensor created by extension package will not work in newer version. Previously trained models will +need to be retrained. Please see the Migration Guide for more details. (#5256) +- Models trained with 1.x versions of ML-Agents will no longer work at inference if they were trained using recurrent neural networks (#5254) + +### Minor Changes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- The `.onnx` models input names have changed. All input placeholders will now use the prefix `obs_` removing the distinction between visual and vector observations. In addition, the inputs and outputs of LSTM changed. Models created with this version will not be usable with previous versions of the package (#5080, #5236) +- The `.onnx` models discrete action output now contains the discrete actions values and not the logits. Models created with this version will not be usable with previous versions of the package (#5080) +- Added ML-Agents package settings. (#5027) +- Make com.unity.modules.unityanalytics an optional dependency. (#5109) +- Make com.unity.modules.physics and com.unity.modules.physics2d optional dependencies. (#5112) +- The default `InferenceDevice` is now `InferenceDevice.Default`, which is equivalent to `InferenceDevice.Burst`. If you +depend on the previous behavior, you can explicitly set the Agent's `InferenceDevice` to `InferenceDevice.CPU`. (#5175) +- Added support for `Goal Signal` as a type of observation. Trainers can now use HyperNetworks to process `Goal Signal`. Trainers with HyperNetworks are more effective at solving multiple tasks. (#5142, #5159, #5149) +- Modified the [GridWorld environment](https://github.com/Unity-Technologies/ml-agents/blob/main/docs/Learning-Environment-Examples.md#gridworld) to use the new `Goal Signal` feature. (#5193) +- `DecisionRequester.ShouldRequestDecision()` and `ShouldRequestAction()`methods were added. These are used to +determine whether `Agent.RequestDecision()` and `Agent.RequestAction()` are called (respectively). (#5223) +- `RaycastPerceptionSensor` now caches its raycast results; they can be accessed via `RayPerceptionSensor.RayPerceptionOutput`. (#5222) +- `ActionBuffers` are now reset to zero before being passed to `Agent.Heuristic()` and +`IHeuristicProvider.Heuristic()`. (#5227) +- `Agent` will now call `IDisposable.Dispose()` on all `ISensor`s that implement the `IDisposable` interface. (#5233) +- `CameraSensor`, `RenderTextureSensor`, and `Match3Sensor` will now reuse their `Texture2D`s, reducing the +amount of memory that needs to be allocated during runtime. (#5233) +- Optimzed `ObservationWriter.WriteTexture()` so that it doesn't call `Texture2D.GetPixels32()` for `RGB24` textures. +This results in much less memory being allocated during inference with `CameraSensor` and `RenderTextureSensor`. (#5233) +- The Match-3 integration utilities were moved from `com.unity.ml-agents.extensions` to `com.unity.ml-agents`. (#5259) + +#### ml-agents / ml-agents-envs / gym-unity (Python) +- Some console output have been moved from `info` to `debug` and will not be printed by default. If you want all messages to be printed, you can run `mlagents-learn` with the `--debug` option or add the line `debug: true` at the top of the yaml config file. (#5211) +- When using a configuration YAML, it is required to define all behaviors found in a Unity +executable in the trainer configuration YAML, or specify `default_settings`. (#5210) +- The embedding size of attention layers used when a BufferSensor is in the scene has been changed. It is now fixed to 128 units. It might be impossible to resume training from a checkpoint of a previous version. (#5272) + +### Bug Fixes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- Fixed a bug where sensors and actuators could get sorted inconsistently on different systems to different Culture +settings. Unfortunately, this may require retraining models if it changes the resulting order of the sensors +or actuators on your system. (#5194) +- Removed additional memory allocations that were occurring due to assert messages and iterating of DemonstrationRecorders. (#5246) +- Fixed a bug where agent trying to access unintialized fields when creating a new RayPerceptionSensorComponent on an agent. (#5261) +- Fixed a bug where the DemonstrationRecorder would throw a null reference exception if Num Steps To Record was > 0 and Record was turned off. (#5274) + +#### ml-agents / ml-agents-envs / gym-unity (Python) +- Fixed a bug where --results-dir has no effect. (#5269) +- Fixed a bug where old `.pt` checkpoints were not deleted during training. (#5271) +- The `UnityToGymWrapper` initializer now accepts an optional `action_space_seed` seed. If this is specified, it will +be used to set the random seed on the resulting action space. (#5303) + + +## [1.9.1-preview] - 2021-04-13 +### Major Changes +#### ml-agents / ml-agents-envs / gym-unity (Python) +- The `--resume` flag now supports resuming experiments with additional reward providers or + loading partial models if the network architecture has changed. See + [here](https://github.com/Unity-Technologies/ml-agents/blob/release_16_docs/docs/Training-ML-Agents.md#loading-an-existing-model) + for more details. (#5213) + +### Bug Fixes +#### com.unity.ml-agents (C#) +- Fixed erroneous warnings when using the Demonstration Recorder. (#5216) + +#### ml-agents / ml-agents-envs / gym-unity (Python) +- Fixed an issue which was causing increased variance when using LSTMs. Also fixed an issue with LSTM when used with POCA and `sequence_length` < `time_horizon`. (#5206) +- Fixed a bug where the SAC replay buffer would not be saved out at the end of a run, even if `save_replay_buffer` was enabled. (#5205) +- ELO now correctly resumes when loading from a checkpoint. (#5202) +- In the Python API, fixed `validate_action` to expect the right dimensions when `set_action_single_agent` is called. (#5208) +- In the `GymToUnityWrapper`, raise an appropriate warning if `step()` is called after an environment is done. (#5204) +- Fixed an issue where using one of the `gym` wrappers would override user-set log levels. (#5201) +## [1.9.0-preview] - 2021-03-17 +### Major Changes +#### com.unity.ml-agents (C#) +- The `BufferSensor` and `BufferSensorComponent` have been added. They allow the Agent to observe variable number of entities. For an example, see the [Sorter environment](https://github.com/Unity-Technologies/ml-agents/blob/release_15_docs/docs/Learning-Environment-Examples.md#sorter). (#4909) +- The `SimpleMultiAgentGroup` class and `IMultiAgentGroup` interface have been added. These allow Agents to be given rewards and + end episodes in groups. For examples, see the [Cooperative Push Block](https://github.com/Unity-Technologies/ml-agents/blob/release_15_docs/docs/Learning-Environment-Examples.md#cooperative-push-block), [Dungeon Escape](https://github.com/Unity-Technologies/ml-agents/blob/release_15_docs/docs/Learning-Environment-Examples.md#dungeon-escape) and [Soccer](https://github.com/Unity-Technologies/ml-agents/blob/release_15_docs/docs/Learning-Environment-Examples.md#soccer-twos) environments. (#4923) +#### ml-agents / ml-agents-envs / gym-unity (Python) +- The MA-POCA trainer has been added. This is a new trainer that enables Agents to learn how to work together in groups. Configure + `poca` as the trainer in the configuration YAML after instantiating a `SimpleMultiAgentGroup` to use this feature. (#5005) + +### Minor Changes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- Updated com.unity.barracuda to 1.3.2-preview. (#5084) +- Added 3D Ball to the `com.unity.ml-agents` samples. (#5077) +- Make com.unity.modules.unityanalytics an optional dependency. (#5109) + +#### ml-agents / ml-agents-envs / gym-unity (Python) +- The `encoding_size` setting for RewardSignals has been deprecated. Please use `network_settings` instead. (#4982) +- Sensor names are now passed through to `ObservationSpec.name`. (#5036) + +### Bug Fixes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +#### ml-agents / ml-agents-envs / gym-unity (Python) +- An issue that caused `GAIL` to fail for environments where agents can terminate episodes by self-sacrifice has been fixed. (#4971) +- Made the error message when observations of different shapes are sent to the trainer clearer. (#5030) +- An issue that prevented curriculums from incrementing with self-play has been fixed. (#5098) + +## [1.8.1-preview] - 2021-03-08 +### Minor Changes +#### ml-agents / ml-agents-envs / gym-unity (Python) +- The `cattrs` version dependency was updated to allow `>=1.1.0` on Python 3.8 or higher. (#4821) + +### Bug Fixes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- Fix an issue where queuing InputEvents overwrote data from previous events in the same frame. (#5034) + +## [1.8.0-preview] - 2021-02-17 +### Major Changes +#### com.unity.ml-agents (C#) +#### ml-agents / ml-agents-envs / gym-unity (Python) +- TensorFlow trainers have been removed, please use the Torch trainers instead. (#4707) +- A plugin system for `mlagents-learn` has been added. You can now define custom + `StatsWriter` implementations and register them to be called during training. + More types of plugins will be added in the future. (#4788) + +### Minor Changes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- The `ActionSpec` constructor is now public. Previously, it was not possible to create an + ActionSpec with both continuous and discrete actions from code. (#4896) +- `StatAggregationMethod.Sum` can now be passed to `StatsRecorder.Add()`. This + will result in the values being summed (instead of averaged) when written to + TensorBoard. Thanks to @brccabral for the contribution! (#4816) +- The upper limit for the time scale (by setting the `--time-scale` paramater in mlagents-learn) was + removed when training with a player. The Editor still requires it to be clamped to 100. (#4867) +- Added the IHeuristicProvider interface to allow IActuators as well as Agent implement the Heuristic function to generate actions. + Updated the Basic example and the Match3 Example to use Actuators. + Changed the namespace and file names of classes in com.unity.ml-agents.extensions. (#4849) +- Added `VectorSensor.AddObservation(IList)`. `VectorSensor.AddObservation(IEnumerable)` + is deprecated. The `IList` version is recommended, as it does not generate any + additional memory allocations. (#4887) +- Added `ObservationWriter.AddList()` and deprecated `ObservationWriter.AddRange()`. + `AddList()` is recommended, as it does not generate any additional memory allocations. (#4887) +- The Barracuda dependency was upgraded to 1.3.0. (#4898) +- Added `ActuatorComponent.CreateActuators`, and deprecate `ActuatorComponent.CreateActuator`. The + default implementation will wrap `ActuatorComponent.CreateActuator` in an array and return that. (#4899) +- `InferenceDevice.Burst` was added, indicating that Agent's model will be run using Barracuda's Burst backend. + This is the default for new Agents, but existing ones that use `InferenceDevice.CPU` should update to + `InferenceDevice.Burst`. (#4925) +- Add an InputActuatorComponent to allow the generation of Agent action spaces from an InputActionAsset. + Projects wanting to use this feature will need to add the + [Input System Package](https://docs.unity3d.com/Packages/com.unity.inputsystem@1.1/manual/index.html) + at version 1.1.0-preview.3 or later. (#4881) + +#### ml-agents / ml-agents-envs / gym-unity (Python) +- Tensorboard now logs the Environment Reward as both a scalar and a histogram. (#4878) +- Added a `--torch-device` commandline option to `mlagents-learn`, which sets the default + [`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device) used for training. (#4888) +- The `--cpu` commandline option had no effect and was removed. Use `--torch-device=cpu` to force CPU training. (#4888) +- The `mlagents_env` API has changed, `BehaviorSpec` now has a `observation_specs` property containing a list of `ObservationSpec`. For more information on `ObservationSpec` see [here](https://github.com/Unity-Technologies/ml-agents/blob/release_13_docs/docs/Python-API.md#behaviorspec). (#4763, #4825) + +### Bug Fixes +#### com.unity.ml-agents (C#) +- Fix a compile warning about using an obsolete enum in `GrpcExtensions.cs`. (#4812) +- CameraSensor now logs an error if the GraphicsDevice is null. (#4880) +- Removed unnecessary memory allocations in `ActuatorManager.UpdateActionArray()` (#4877) +- Removed unnecessary memory allocations in `SensorShapeValidator.ValidateSensors()` (#4879) +- Removed unnecessary memory allocations in `SideChannelManager.GetSideChannelMessage()` (#4886) +- Removed several memory allocations that happened during inference. On a test scene, this + reduced the amount of memory allocated by approximately 25%. (#4887) +- Removed several memory allocations that happened during inference with discrete actions. (#4922) +- Properly catch permission errors when writing timer files. (#4921) +- Unexpected exceptions during training initialization and shutdown are now logged. If you see + "noisy" logs, please let us know! (#4930, #4935) + +#### ml-agents / ml-agents-envs / gym-unity (Python) +- Fixed a bug that would cause an exception when `RunOptions` was deserialized via `pickle`. (#4842) +- Fixed a bug that can cause a crash if a behavior can appear during training in multi-environment training. (#4872) +- Fixed the computation of entropy for continuous actions. (#4869) +- Fixed a bug that would cause `UnityEnvironment` to wait the full timeout + period and report a misleading error message if the executable crashed + without closing the connection. It now periodically checks the process status + while waiting for a connection, and raises a better error message if it crashes. (#4880) +- Passing a `-logfile` option in the `--env-args` option to `mlagents-learn` is + no longer overwritten. (#4880) +- The `load_weights` function was being called unnecessarily often in the Ghost Trainer leading to training slowdowns. (#4934) + + +## [1.7.2-preview] - 2020-12-22 +### Bug Fixes +#### com.unity.ml-agents (C#) +- Add analytics package dependency to the package manifest. (#4794) +#### ml-agents / ml-agents-envs / gym-unity (Python) +- Fixed the docker build process. (#4791) + + +## [1.7.0-preview] - 2020-12-21 +### Major Changes +#### com.unity.ml-agents (C#) +#### ml-agents / ml-agents-envs / gym-unity (Python) +- PyTorch trainers now support training agents with both continuous and discrete action spaces. (#4702) +The `.onnx` models generated by the trainers of this release are incompatible with versions of Barracuda before `1.2.1-preview`. If you upgrade the trainers, you must upgrade the version of the Barracuda package as well (which can be done by upgrading the `com.unity.ml-agents` package). +### Minor Changes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- Agents with both continuous and discrete actions are now supported. You can specify +both continuous and discrete action sizes in Behavior Parameters. (#4702, #4718) +- In order to improve the developer experience for Unity ML-Agents Toolkit, we have added in-editor analytics. +Please refer to "Information that is passively collected by Unity" in the +[Unity Privacy Policy](https://unity3d.com/legal/privacy-policy). (#4677) +- The FoodCollector example environment now uses continuous actions for moving and +discrete actions for shooting. (#4746) +#### ml-agents / ml-agents-envs / gym-unity (Python) +- `ActionSpec.validate_action()` now enforces that `UnityEnvironment.set_action_for_agent()` receives a 1D `np.array`. (#4691) + +### Bug Fixes +#### com.unity.ml-agents (C#) +- Removed noisy warnings about API minor version mismatches in both the C# and python code. (#4688) +#### ml-agents / ml-agents-envs / gym-unity (Python) + + +## [1.6.0-preview] - 2020-11-18 +### Major Changes +#### com.unity.ml-agents (C#) +#### ml-agents / ml-agents-envs / gym-unity (Python) + - PyTorch trainers are now the default. See the + [installation docs](https://github.com/Unity-Technologies/ml-agents/blob/release_10_docs/docs/Installation.md) for + more information on installing PyTorch. For the time being, TensorFlow is still available; + you can use the TensorFlow backend by adding `--tensorflow` to the CLI, or + adding `framework: tensorflow` in the configuration YAML. (#4517) + +### Minor Changes +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- The Barracuda dependency was upgraded to 1.1.2 (#4571) +- Utilities were added to `com.unity.ml-agents.extensions` to make it easier to +integrate with match-3 games. See the [readme](https://github.com/Unity-Technologies/ml-agents/blob/release_10_docs/com.unity.ml-agents.extensions/Documentation~/Match3.md) +for more details. (#4515) +#### ml-agents / ml-agents-envs / gym-unity (Python) +- The `action_probs` node is no longer listed as an output in TensorFlow models (#4613). + +### Bug Fixes +#### com.unity.ml-agents (C#) +- `Agent.CollectObservations()` and `Agent.EndEpisode()` will now throw an exception +if they are called recursively (for example, if they call `Agent.EndEpisode()`). +Previously, this would result in an infinite loop and cause the editor to hang. (#4573) +#### ml-agents / ml-agents-envs / gym-unity (Python) +- Fixed an issue where runs could not be resumed when using TensorFlow and Ghost Training. (#4593) +- Change the tensor type of step count from int32 to int64 to address the overflow issue when step +goes larger than 2^31. Previous Tensorflow checkpoints will become incompatible and cannot be loaded. (#4607) +- Remove extra period after "Training" in console log. (#4674) + + +## [1.5.0-preview] - 2020-10-14 +### Major Changes +#### com.unity.ml-agents (C#) +#### ml-agents / ml-agents-envs / gym-unity (Python) + - Added the Random Network Distillation (RND) intrinsic reward signal to the Pytorch + trainers. To use RND, add a `rnd` section to the `reward_signals` section of your + yaml configuration file. [More information here](https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/Training-Configuration-File.md#rnd-intrinsic-reward) (#4473) +### Minor Changes +#### com.unity.ml-agents (C#) + - Stacking for compressed observations is now supported. An additional setting + option `Observation Stacks` is added in editor to sensor components that support + compressed observations. A new class `ISparseChannelSensor` with an + additional method `GetCompressedChannelMapping()`is added to generate a mapping + of the channels in compressed data to the actual channel after decompression, + for the python side to decompress correctly. (#4476) + - Added a new visual 3DBall environment. (#4513) +#### ml-agents / ml-agents-envs / gym-unity (Python) + - The Communication API was changed to 1.2.0 to indicate support for stacked + compressed observation. A new entry `compressed_channel_mapping` is added to the + proto to handle decompression correctly. Newer versions of the package that wish to + make use of this will also need a compatible version of the Python trainers. (#4476) + - In the `VisualFoodCollector` scene, a vector flag representing the frozen state of + the agent is added to the input observations in addition to the original first-person + camera frame. The scene is able to train with the provided default config file. (#4511) + - Added conversion to string for sampler classes to increase the verbosity of + the curriculum lesson changes. The lesson updates would now output the sampler + stats in addition to the lesson and parameter name to the console. (#4484) + - Localized documentation in Russian is added. Thanks to @SergeyMatrosov for + the contribution. (#4529) +### Bug Fixes +#### com.unity.ml-agents (C#) + - Fixed a bug where accessing the Academy outside of play mode would cause the + Academy to get stepped multiple times when in play mode. (#4532) +#### ml-agents / ml-agents-envs / gym-unity (Python) + + +## [1.4.0-preview] - 2020-09-16 +### Major Changes +#### com.unity.ml-agents (C#) +#### ml-agents / ml-agents-envs / gym-unity (Python) + +### Minor Changes +#### com.unity.ml-agents (C#) +- The `IActuator` interface and `ActuatorComponent` abstract class were added. +These are analogous to `ISensor` and `SensorComponent`, but for applying actions +for an Agent. They allow you to control the action space more programmatically +than defining the actions in the Agent's Behavior Parameters. See +[BasicActuatorComponent.cs](https://github.com/Unity-Technologies/ml-agents/blob/release_7_docs/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicActuatorComponent.cs) + for an example of how to use them. (#4297, #4315) +- Update Barracuda to 1.1.1-preview (#4482) +- Enabled C# formatting using `dotnet-format`. (#4362) +- GridSensor was added to the `com.unity.ml-agents.extensions` package. Thank you +to Jaden Travnik from Eidos Montreal for the contribution! (#4399) +- Added `Agent.EpisodeInterrupted()`, which can be used to reset the agent when +it has reached a user-determined maximum number of steps. This behaves similarly +to `Agent.EndEpsiode()` but has a slightly different effect on training (#4453). +#### ml-agents / ml-agents-envs / gym-unity (Python) +- Experimental PyTorch support has been added. Use `--torch` when running `mlagents-learn`, or add +`framework: pytorch` to your trainer configuration (under the behavior name) to enable it. +Note that PyTorch 1.6.0 or greater should be installed to use this feature; see +[the PyTorch website](https://pytorch.org/) for installation instructions and +[the relevant ML-Agents docs](https://github.com/Unity-Technologies/ml-agents/blob/release_7_docs/docs/Training-ML-Agents.md#using-pytorch-experimental) for usage. (#4335) +- The minimum supported version of TensorFlow was increased to 1.14.0. (#4411) +- Compressed visual observations with >3 channels are now supported. In +`ISensor.GetCompressedObservation()`, this can be done by writing 3 channels at a +time to a PNG and concatenating the resulting bytes. (#4399) +- The Communication API was changed to 1.1.0 to indicate support for concatenated PNGs +(see above). Newer versions of the package that wish to make use of this will also need +a compatible version of the trainer. (#4462) +- A CNN (`vis_encode_type: match3`) for smaller grids, e.g. board games, has been added. +(#4434) +- You can now again specify a default configuration for your behaviors. Specify `default_settings` in +your trainer configuration to do so. (#4448) +- Improved the executable detection logic for environments on Windows. (#4485) + +### Bug Fixes +#### com.unity.ml-agents (C#) +- Previously, `com.unity.ml-agents` was not declaring built-in packages as +dependencies in its package.json. The relevant dependencies are now listed. (#4384) +- Agents no longer try to send observations when they become disabled if the +Academy has been shut down. (#4489) +#### ml-agents / ml-agents-envs / gym-unity (Python) +- Fixed the sample code in the custom SideChannel example. (#4466) +- A bug in the observation normalizer that would cause rewards to decrease +when using `--resume` was fixed. (#4463) +- Fixed a bug in exporting Pytorch models when using multiple discrete actions. (#4491) + +## [1.3.0-preview] - 2020-08-12 + +### Major Changes +#### com.unity.ml-agents (C#) +#### ml-agents / ml-agents-envs / gym-unity (Python) +- The minimum supported Python version for ml-agents-envs was changed to 3.6.1. (#4244) +- The interaction between EnvManager and TrainerController was changed; EnvManager.advance() was split into to stages, +and TrainerController now uses the results from the first stage to handle new behavior names. This change speeds up +Python training by approximately 5-10%. (#4259) + +### Minor Changes +#### com.unity.ml-agents (C#) +- StatsSideChannel now stores multiple values per key. This means that multiple +calls to `StatsRecorder.Add()` with the same key in the same step will no +longer overwrite each other. (#4236) +#### ml-agents / ml-agents-envs / gym-unity (Python) +- The versions of `numpy` supported by ml-agents-envs were changed to disallow 1.19.0 or later. This was done to reflect +a similar change in TensorFlow's requirements. (#4274) +- Model checkpoints are now also saved as .nn files during training. (#4127) +- Model checkpoint info is saved in TrainingStatus.json after training is concluded (#4127) +- CSV statistics writer was removed (#4300). + +### Bug Fixes +#### com.unity.ml-agents (C#) +- Academy.EnvironmentStep() will now throw an exception if it is called +recursively (for example, by an Agent's CollectObservations method). +Previously, this would result in an infinite loop and cause the editor to hang. +(#4226) +#### ml-agents / ml-agents-envs / gym-unity (Python) +- The algorithm used to normalize observations was introducing NaNs if the initial observations were too large +due to incorrect initialization. The initialization was fixed and is now the observation means from the +first trajectory processed. (#4299) + +## [1.2.0-preview] - 2020-07-15 + +### Major Changes +#### ml-agents / ml-agents-envs / gym-unity (Python) +- The Parameter Randomization feature has been refactored to enable sampling of new parameters per episode to improve robustness. The + `resampling-interval` parameter has been removed and the config structure updated. More information [here](https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Training-ML-Agents.md). (#4065) +- The Parameter Randomization feature has been merged with the Curriculum feature. It is now possible to specify a sampler +in the lesson of a Curriculum. Curriculum has been refactored and is now specified at the level of the parameter, not the +behavior. More information +[here](https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Training-ML-Agents.md).(#4160) + +### Minor Changes +#### com.unity.ml-agents (C#) +- `SideChannelsManager` was renamed to `SideChannelManager`. The old name is still supported, but deprecated. (#4137) +- `RayPerceptionSensor.Perceive()` now additionally store the GameObject that was hit by the ray. (#4111) +- The Barracuda dependency was upgraded to 1.0.1 (#4188) +#### ml-agents / ml-agents-envs / gym-unity (Python) +- Added new Google Colab notebooks to show how to use `UnityEnvironment'. (#4117) + +### Bug Fixes +#### com.unity.ml-agents (C#) +- Fixed an issue where RayPerceptionSensor would raise an exception when the +list of tags was empty, or a tag in the list was invalid (unknown, null, or +empty string). (#4155) + +#### ml-agents / ml-agents-envs / gym-unity (Python) +- Fixed an error when setting `initialize_from` in the trainer confiiguration YAML to +`null`. (#4175) +- Fixed issue with FoodCollector, Soccer, and WallJump when playing with keyboard. (#4147, #4174) +- Fixed a crash in StatsReporter when using threaded trainers with very frequent summary writes +(#4201) +- `mlagents-learn` will now raise an error immediately if `--num-envs` is greater than 1 without setting the `--env` +argument. (#4203) + +## [1.1.0-preview] - 2020-06-10 +### Major Changes +#### com.unity.ml-agents (C#) +#### ml-agents / ml-agents-envs / gym-unity (Python) +- Added new Walker environments. Improved ragdoll stability/performance. (#4037) +- `max_step` in the `TerminalStep` and `TerminalSteps` objects was renamed `interrupted`. +- `beta` and `epsilon` in `PPO` are no longer decayed by default but follow the same schedule as learning rate. (#3940) +- `get_behavior_names()` and `get_behavior_spec()` on UnityEnvironment were replaced by the `behavior_specs` property. (#3946) +- The first version of the Unity Environment Registry (Experimental) has been released. More information [here](https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Unity-Environment-Registry.md)(#3967) +- `use_visual` and `allow_multiple_visual_obs` in the `UnityToGymWrapper` constructor +were replaced by `allow_multiple_obs` which allows one or more visual observations and +vector observations to be used simultaneously. (#3981) Thank you @shakenes ! +- Curriculum and Parameter Randomization configurations have been merged + into the main training configuration file. Note that this means training + configuration files are now environment-specific. (#3791) +- The format for trainer configuration has changed, and the "default" behavior has been deprecated. + See the [Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Migrating.md) for more details. (#3936) +- Training artifacts (trained models, summaries) are now found in the `results/` + directory. (#3829) +- When using Curriculum, the current lesson will resume if training is quit and resumed. As such, + the `--lesson` CLI option has been removed. (#4025) +### Minor Changes +#### com.unity.ml-agents (C#) +- `ObservableAttribute` was added. Adding the attribute to fields or properties on an Agent will allow it to generate + observations via reflection. (#3925, #4006) +#### ml-agents / ml-agents-envs / gym-unity (Python) +- Unity Player logs are now written out to the results directory. (#3877) +- Run configuration YAML files are written out to the results directory at the end of the run. (#3815) +- The `--save-freq` CLI option has been removed, and replaced by a `checkpoint_interval` option in the trainer configuration YAML. (#4034) +- When trying to load/resume from a checkpoint created with an earlier verison of ML-Agents, + a warning will be thrown. (#4035) +### Bug Fixes +- Fixed an issue where SAC would perform too many model updates when resuming from a + checkpoint, and too few when using `buffer_init_steps`. (#4038) +- Fixed a bug in the onnx export that would cause constants needed for inference to not be visible to some versions of + the Barracuda importer. (#4073) +#### com.unity.ml-agents (C#) +#### ml-agents / ml-agents-envs / gym-unity (Python) + + +## [1.0.2-preview] - 2020-05-20 +### Bug Fixes +#### com.unity.ml-agents (C#) +- Fix missing .meta file + + +## [1.0.1-preview] - 2020-05-19 +### Bug Fixes +#### com.unity.ml-agents (C#) +- A bug that would cause the editor to go into a loop when a prefab was selected was fixed. (#3949) +- BrainParameters.ToProto() no longer throws an exception if none of the fields have been set. (#3930) +- The Barracuda dependency was upgraded to 0.7.1-preview. (#3977) +#### ml-agents / ml-agents-envs / gym-unity (Python) +- An issue was fixed where using `--initialize-from` would resume from the past step count. (#3962) +- The gym wrapper error for the wrong number of agents now fires more consistently, and more details + were added to the error message when the input dimension is wrong. (#3963) + + +## [1.0.0-preview] - 2020-04-30 +### Major Changes +#### com.unity.ml-agents (C#) + +- The `MLAgents` C# namespace was renamed to `Unity.MLAgents`, and other nested + namespaces were similarly renamed. (#3843) +- The offset logic was removed from DecisionRequester. (#3716) +- The signature of `Agent.Heuristic()` was changed to take a float array as a + parameter, instead of returning the array. This was done to prevent a common + source of error where users would return arrays of the wrong size. (#3765) +- The communication API version has been bumped up to 1.0.0 and will use + [Semantic Versioning](https://semver.org/) to do compatibility checks for + communication between Unity and the Python process. (#3760) +- The obsolete `Agent` methods `GiveModel`, `Done`, `InitializeAgent`, + `AgentAction` and `AgentReset` have been removed. (#3770) +- The SideChannel API has changed: + - Introduced the `SideChannelManager` to register, unregister and access side + channels. (#3807) + - `Academy.FloatProperties` was replaced by `Academy.EnvironmentParameters`. + See the [Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_1_docs/docs/Migrating.md) + for more details on upgrading. (#3807) + - `SideChannel.OnMessageReceived` is now a protected method (was public) + - SideChannel IncomingMessages methods now take an optional default argument, + which is used when trying to read more data than the message contains. (#3751) + - Added a feature to allow sending stats from C# environments to TensorBoard + (and other python StatsWriters). To do this from your code, use + `Academy.Instance.StatsRecorder.Add(key, value)`. (#3660) +- `CameraSensorComponent.m_Grayscale` and + `RenderTextureSensorComponent.m_Grayscale` were changed from `public` to + `private`. These can still be accessed via their corresponding properties. + (#3808) +- Public fields and properties on several classes were renamed to follow Unity's + C# style conventions. All public fields and properties now use "PascalCase" + instead of "camelCase"; for example, `Agent.maxStep` was renamed to + `Agent.MaxStep`. For a full list of changes, see the pull request. (#3828) +- `WriteAdapter` was renamed to `ObservationWriter`. If you have a custom + `ISensor` implementation, you will need to change the signature of its + `Write()` method. (#3834) +- The Barracuda dependency was upgraded to 0.7.0-preview (which has breaking + namespace and assembly name changes). (#3875) + +#### ml-agents / ml-agents-envs / gym-unity (Python) + +- The `--load` and `--train` command-line flags have been deprecated. Training + now happens by default, and use `--resume` to resume training instead of + `--load`. (#3705) +- The Jupyter notebooks have been removed from the repository. (#3704) +- The multi-agent gym option was removed from the gym wrapper. For multi-agent + scenarios, use the [Low Level Python API](https://github.com/Unity-Technologies/ml-agents/blob/release_1_docs/docs/Python-API.md). (#3681) +- The low level Python API has changed. You can look at the document + [Low Level Python API](https://github.com/Unity-Technologies/ml-agents/blob/release_1_docs/docs/Python-API.md) + documentation for more information. If you use `mlagents-learn` for training, this should be a + transparent change. (#3681) +- Added ability to start training (initialize model weights) from a previous run + ID. (#3710) +- The GhostTrainer has been extended to support asymmetric games and the + asymmetric example environment Strikers Vs. Goalie has been added. (#3653) +- The `UnityEnv` class from the `gym-unity` package was renamed + `UnityToGymWrapper` and no longer creates the `UnityEnvironment`. Instead, the + `UnityEnvironment` must be passed as input to the constructor of + `UnityToGymWrapper` (#3812) + +### Minor Changes + +#### com.unity.ml-agents (C#) + +- Added new 3-joint Worm ragdoll environment. (#3798) +- `StackingSensor` was changed from `internal` visibility to `public`. (#3701) +- The internal event `Academy.AgentSetStatus` was renamed to + `Academy.AgentPreStep` and made public. (#3716) +- Academy.InferenceSeed property was added. This is used to initialize the + random number generator in ModelRunner, and is incremented for each + ModelRunner. (#3823) +- `Agent.GetObservations()` was added, which returns a read-only view of the + observations added in `CollectObservations()`. (#3825) +- `UnityRLCapabilities` was added to help inform users when RL features are + mismatched between C# and Python packages. (#3831) + +#### ml-agents / ml-agents-envs / gym-unity (Python) + +- Format of console output has changed slightly and now matches the name of the + model/summary directory. (#3630, #3616) +- Renamed 'Generalization' feature to 'Environment Parameter Randomization'. + (#3646) +- Timer files now contain a dictionary of metadata, including things like the + package version numbers. (#3758) +- The way that UnityEnvironment decides the port was changed. If no port is + specified, the behavior will depend on the `file_name` parameter. If it is + `None`, 5004 (the editor port) will be used; otherwise 5005 (the base + environment port) will be used. (#3673) +- Running `mlagents-learn` with the same `--run-id` twice will no longer + overwrite the existing files. (#3705) +- Model updates can now happen asynchronously with environment steps for better + performance. (#3690) +- `num_updates` and `train_interval` for SAC were replaced with + `steps_per_update`. (#3690) +- The maximum compatible version of tensorflow was changed to allow tensorflow + 2.1 and 2.2. This will allow use with python 3.8 using tensorflow 2.2.0rc3. + (#3830) +- `mlagents-learn` will no longer set the width and height of the executable + window to 84x84 when no width nor height arguments are given. (#3867) + +### Bug Fixes + +#### com.unity.ml-agents (C#) + +- Fixed a display bug when viewing Demonstration files in the inspector. The + shapes of the observations in the file now display correctly. (#3771) + +#### ml-agents / ml-agents-envs / gym-unity (Python) + +- Fixed an issue where exceptions from environments provided a return code of 0. + (#3680) +- Self-Play team changes will now trigger a full environment reset. This + prevents trajectories in progress during a team change from getting into the + buffer. (#3870) + +## [0.15.1-preview] - 2020-03-30 + +### Bug Fixes + +- Raise the wall in CrawlerStatic scene to prevent Agent from falling off. + (#3650) +- Fixed an issue where specifying `vis_encode_type` was required only for SAC. + (#3677) +- Fixed the reported entropy values for continuous actions (#3684) +- Fixed an issue where switching models using `SetModel()` during training would + use an excessive amount of memory. (#3664) +- Environment subprocesses now close immediately on timeout or wrong API + version. (#3679) +- Fixed an issue in the gym wrapper that would raise an exception if an Agent + called EndEpisode multiple times in the same step. (#3700) +- Fixed an issue where logging output was not visible; logging levels are now + set consistently. (#3703) + +## [0.15.0-preview] - 2020-03-18 + +### Major Changes + +- `Agent.CollectObservations` now takes a VectorSensor argument. (#3352, #3389) +- Added `Agent.CollectDiscreteActionMasks` virtual method with a + `DiscreteActionMasker` argument to specify which discrete actions are + unavailable to the Agent. (#3525) +- Beta support for ONNX export was added. If the `tf2onnx` python package is + installed, models will be saved to `.onnx` as well as `.nn` format. Note that + Barracuda 0.6.0 or later is required to import the `.onnx` files properly +- Multi-GPU training and the `--multi-gpu` option has been removed temporarily. + (#3345) +- All Sensor related code has been moved to the namespace `MLAgents.Sensors`. +- All SideChannel related code has been moved to the namespace + `MLAgents.SideChannels`. +- `BrainParameters` and `SpaceType` have been removed from the public API +- `BehaviorParameters` have been removed from the public API. +- The following methods in the `Agent` class have been deprecated and will be + removed in a later release: + - `InitializeAgent()` was renamed to `Initialize()` + - `AgentAction()` was renamed to `OnActionReceived()` + - `AgentReset()` was renamed to `OnEpisodeBegin()` + - `Done()` was renamed to `EndEpisode()` + - `GiveModel()` was renamed to `SetModel()` + +### Minor Changes + +- Monitor.cs was moved to Examples. (#3372) +- Automatic stepping for Academy is now controlled from the + AutomaticSteppingEnabled property. (#3376) +- The GetEpisodeCount, GetStepCount, GetTotalStepCount and methods of Academy + were changed to EpisodeCount, StepCount, TotalStepCount properties + respectively. (#3376) +- Several classes were changed from public to internal visibility. (#3390) +- Academy.RegisterSideChannel and UnregisterSideChannel methods were added. + (#3391) +- A tutorial on adding custom SideChannels was added (#3391) +- The stepping logic for the Agent and the Academy has been simplified (#3448) +- Update Barracuda to 0.6.1-preview + +* The interface for `RayPerceptionSensor.PerceiveStatic()` was changed to take + an input class and write to an output class, and the method was renamed to + `Perceive()`. + +- The checkpoint file suffix was changed from `.cptk` to `.ckpt` (#3470) +- The command-line argument used to determine the port that an environment will + listen on was changed from `--port` to `--mlagents-port`. +- `DemonstrationRecorder` can now record observations outside of the editor. +- `DemonstrationRecorder` now has an optional path for the demonstrations. This + will default to `Application.dataPath` if not set. +- `DemonstrationStore` was changed to accept a `Stream` for its constructor, and + was renamed to `DemonstrationWriter` +- The method `GetStepCount()` on the Agent class has been replaced with the + property getter `StepCount` +- `RayPerceptionSensorComponent` and related classes now display the debug + gizmos whenever the Agent is selected (not just Play mode). +- Most fields on `RayPerceptionSensorComponent` can now be changed while the + editor is in Play mode. The exceptions to this are fields that affect the + number of observations. +- Most fields on `CameraSensorComponent` and `RenderTextureSensorComponent` were + changed to private and replaced by properties with the same name. +- Unused static methods from the `Utilities` class (ShiftLeft, ReplaceRange, + AddRangeNoAlloc, and GetSensorFloatObservationSize) were removed. +- The `Agent` class is no longer abstract. +- SensorBase was moved out of the package and into the Examples directory. +- `AgentInfo.actionMasks` has been renamed to `AgentInfo.discreteActionMasks`. +- `DecisionRequester` has been made internal (you can still use the + DecisionRequesterComponent from the inspector). `RepeatAction` was renamed + `TakeActionsBetweenDecisions` for clarity. (#3555) +- The `IFloatProperties` interface has been removed. +- Fix #3579. +- Improved inference performance for models with multiple action branches. + (#3598) +- Fixed an issue when using GAIL with less than `batch_size` number of + demonstrations. (#3591) +- The interfaces to the `SideChannel` classes (on C# and python) have changed to + use new `IncomingMessage` and `OutgoingMessage` classes. These should make + reading and writing data to the channel easier. (#3596) +- Updated the ExpertPyramid.demo example demonstration file (#3613) +- Updated project version for example environments to 2018.4.18f1. (#3618) +- Changed the Product Name in the example environments to remove spaces, so that + the default build executable file doesn't contain spaces. (#3612) + +## [0.14.1-preview] - 2020-02-25 + +### Bug Fixes + +- Fixed an issue which caused self-play training sessions to consume a lot of + memory. (#3451) +- Fixed an IndexError when using GAIL or behavioral cloning with demonstrations + recorded with 0.14.0 or later (#3464) +- Updated the `gail_config.yaml` to work with per-Agent steps (#3475) +- Fixed demonstration recording of experiences when the Agent is done. (#3463) +- Fixed a bug with the rewards of multiple Agents in the gym interface (#3471, + #3496) + +## [0.14.0-preview] - 2020-02-13 + +### Major Changes + +- A new self-play mechanism for training agents in adversarial scenarios was + added (#3194) +- Tennis and Soccer environments were refactored to enable training with + self-play (#3194, #3331) +- UnitySDK folder was split into a Unity Package (com.unity.ml-agents) and our + examples were moved to the Project folder (#3267) +- Academy is now a singleton and is no longer abstract (#3210, #3184) +- In order to reduce the size of the API, several classes and methods were + marked as internal or private. Some public fields on the Agent were trimmed + (#3342, #3353, #3269) +- Decision Period and on-demand decision checkboxes were removed from the Agent. + on-demand decision is now the default (#3243) +- Calling Done() on the Agent will reset it immediately and call the AgentReset + virtual method (#3291, #3242) +- The "Reset on Done" setting in AgentParameters was removed; this is now always + true. AgentOnDone virtual method on the Agent was removed (#3311, #3222) +- Trainer steps are now counted per-Agent, not per-environment as in previous + versions. For instance, if you have 10 Agents in the scene, 20 environment + steps now correspond to 200 steps as printed in the terminal and in + Tensorboard (#3113) + +### Minor Changes + +- Barracuda was updated to 0.5.0-preview (#3329) +- --num-runs option was removed from mlagents-learn (#3155) +- Curriculum config files are now YAML formatted and all curricula for a + training run are combined into a single file (#3186) +- ML-Agents components, such as BehaviorParameters and various Sensor + implementations, now appear in the Components menu (#3231) +- Exceptions are now raised in Unity (in debug mode only) if NaN observations or + rewards are passed (#3221) +- RayPerception MonoBehavior, which was previously deprecated, was removed + (#3304) +- Uncompressed visual (i.e. 3d float arrays) observations are now supported. + CameraSensorComponent and RenderTextureSensor now have an option to write + uncompressed observations (#3148) +- Agent’s handling of observations during training was improved so that an extra + copy of the observations is no longer maintained (#3229) +- Error message for missing trainer config files was improved to include the + absolute path (#3230) +- Support for 2017.4 LTS was dropped (#3121, #3168) +- Some documentation improvements were made (#3296, #3292, #3295, #3281) + +### Bug Fixes + +- Numpy warning when stats don’t exist (#3251) +- A bug that caused RayPerceptionSensor to behave inconsistently with transforms + that have non-1 scale was fixed (#3321) +- Some small bugfixes to tensorflow_to_barracuda.py were backported from the + barracuda release (#3341) +- Base port in the jupyter notebook example was updated to use the same port + that the editor uses (#3283) + +## [0.13.0-preview] - 2020-01-24 + +### This is the first release of _Unity Package ML-Agents_. + +_Short description of this release_ diff --git a/com.unity.ml-agents/CHANGELOG.md.meta b/com.unity.ml-agents/CHANGELOG.md.meta new file mode 100644 index 0000000000..6331df01c1 --- /dev/null +++ b/com.unity.ml-agents/CHANGELOG.md.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: e19737407870a49abaaa1a90dae1a334 +TextScriptImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/CODEOWNERS b/com.unity.ml-agents/CODEOWNERS new file mode 100644 index 0000000000..52cb7ada8f --- /dev/null +++ b/com.unity.ml-agents/CODEOWNERS @@ -0,0 +1,2 @@ +# see https://docs.github.com/en/github/creating-cloning-and-archiving-repositories/about-code-owners for more information +* @unity/behavior-authoring \ No newline at end of file diff --git a/com.unity.ml-agents/CODEOWNERS.meta b/com.unity.ml-agents/CODEOWNERS.meta new file mode 100644 index 0000000000..f288a23537 --- /dev/null +++ b/com.unity.ml-agents/CODEOWNERS.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: 5de323c2110f44676ba99dc49409363c +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/CONTRIBUTING.md b/com.unity.ml-agents/CONTRIBUTING.md new file mode 100644 index 0000000000..65b01e775b --- /dev/null +++ b/com.unity.ml-agents/CONTRIBUTING.md @@ -0,0 +1,97 @@ +# Contribution Guidelines + +Thank you for your interest in contributing to the ML-Agents Toolkit! We are +incredibly excited to see how members of our community will use and extend the +ML-Agents Toolkit. To facilitate your contributions, we've outlined a brief set +of guidelines to ensure that your extensions can be easily integrated. + +## Communication + +First, please read through our +[code of conduct](https://github.com/Unity-Technologies/ml-agents/blob/main/CODE_OF_CONDUCT.md), +as we expect all our contributors to follow it. + +Second, before starting on a project that you intend to contribute to the +ML-Agents Toolkit (whether environments or modifications to the codebase), we +**strongly** recommend posting on our +[Issues page](https://github.com/Unity-Technologies/ml-agents/issues) and +briefly outlining the changes you plan to make. This will enable us to provide +some context that may be helpful for you. This could range from advice and +feedback on how to optimally perform your changes or reasons for not doing it. + +Lastly, if you're looking for input on what to contribute, feel free to reach +out to us directly at ml-agents@unity3d.com and/or browse the GitHub issues with +the `Requests` or `Bug` label. + +## Git Branches + +The main branch corresponds to the most recent version of the project. Note +that this may be newer that the +[latest release](https://github.com/Unity-Technologies/ml-agents/releases/tag/latest_release). + +When contributing to the project, please make sure that your Pull Request (PR) +contains the following: + +- Detailed description of the changes performed +- Corresponding changes to documentation, unit tests and sample environments (if + applicable) +- Summary of the tests performed to validate your changes +- Issue numbers that the PR resolves (if any) + +## Environments + +We are currently not accepting environment contributions directly into ML-Agents. +However, we believe community created enviornments have a lot of value to the +community. If you have an interesting enviornment and are willing to share, +feel free to showcase it and share any relevant files in the +[ML-Agents forum](https://forum.unity.com/forums/ml-agents.453/). + +## Continuous Integration (CI) + +We run continuous integration on all PRs; all tests must be passing before the PR is merged. + +Several static checks are run on the codebase using the +[pre-commit framework](https://pre-commit.com/) during CI. To execute the same +checks locally, run: +```bash +pip install pre-commit>=2.8.0 +pip install identify>==2.1.3 +pre-commit run --all-files +``` + +Some hooks (for example, `black`) will output the corrected version of the code; +others (like `mypy`) may require more effort to fix. You can optionally run +`pre-commit install` to install it as a git hook; after this it will run on all +commits that you make. + +### Code style + +All python code should be formatted with +[`black`](https://github.com/psf/black). + +C# code is formatted using [`dotnet-format`](https://github.com/dotnet/format). +You must have [dotnet](https://dotnet.microsoft.com/download) installed first +(but don't need to install `dotnet-format` - `pre-commit` will do that for you). + +### Python type annotations + +We use [`mypy`](http://mypy-lang.org/) to perform static type checking on python +code. Currently not all code is annotated but we will increase coverage over +time. If you are adding or refactoring code, please + +1. Add type annotations to the new or refactored code. +2. Make sure that code calling or called by the modified code also has type + annotations. + +The +[type hint cheat sheet](https://mypy.readthedocs.io/en/stable/cheat_sheet_py3.html) +provides a good introduction to adding type hints. + +## Contributor License Agreements + +When you open a pull request, you will be asked to acknolwedge our Contributor +License Agreement. We allow both individual contributions and contributions made +on behalf of companies. We use an open source tool called CLA assistant. If you +have any questions on our CLA, please +[submit an issue](https://github.com/Unity-Technologies/ml-agents/issues) or +email us at ml-agents@unity3d.com. diff --git a/com.unity.ml-agents/CONTRIBUTING.md.meta b/com.unity.ml-agents/CONTRIBUTING.md.meta new file mode 100644 index 0000000000..acf109e975 --- /dev/null +++ b/com.unity.ml-agents/CONTRIBUTING.md.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: 60b8c21afae8d449ebcdd512e85e97ac +TextScriptImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Documentation~/com.unity.ml-agents.md b/com.unity.ml-agents/Documentation~/com.unity.ml-agents.md new file mode 100644 index 0000000000..7e90deef39 --- /dev/null +++ b/com.unity.ml-agents/Documentation~/com.unity.ml-agents.md @@ -0,0 +1,161 @@ +# About ML-Agents package (`com.unity.ml-agents`) + +The _ML-Agents_ package contains the primary C# SDK for the [Unity ML-Agents +Toolkit]. + +The package allows you to convert any Unity scene into a learning environment +and train character behaviors using a variety of machine learning algorithms. +Additionally, it allows you to embed these trained behaviors back into Unity +scenes to control your characters. More specifically, the package provides the +following core functionalities: + +- Define Agents: entities, or characters, whose behavior will be learned. Agents + are entities that generate observations (through sensors), take actions, and + receive rewards from the environment. +- Define Behaviors: entities that specify how an agent should act. Multiple + agents can share the same Behavior and a scene may have multiple Behaviors. +- Record demonstrations of an agent within the Editor. You can use + demonstrations to help train a behavior for that agent. +- Embedding a trained behavior into the scene via the [Unity Inference Engine]. + Embedded behaviors allow you to switch an Agent between learning and + inference. + +Note that the _ML-Agents_ package does not contain the machine learning +algorithms for training behaviors. The _ML-Agents_ package only supports +instrumenting a Unity scene, setting it up for training, and then embedding the +trained model back into your Unity scene. The machine learning algorithms that +orchestrate training are part of the companion [Python package]. + +Note that we also provide an _ML-Agents Extensions_ package +(`com.unity.ml-agents.extensions`) that contains early/experimental features +that you may find useful. This package is only available from the [ML-Agents +GitHub repo]. + +## Package contents + +The following table describes the package folder structure: + +| **Location** | **Description** | +| ---------------------- | ----------------------------------------------------------------------- | +| _Documentation~_ | Contains the documentation for the Unity package. | +| _Editor_ | Contains utilities for Editor windows and drawers. | +| _Plugins_ | Contains third-party DLLs. | +| _Runtime_ | Contains core C# APIs for integrating ML-Agents into your Unity scene. | +| _Runtime/Integrations_ | Contains utilities for integrating ML-Agents into specific game genres. | +| _Tests_ | Contains the unit tests for the package. | + + + +## Installation + +To install this _ML-Agents_ package, follow the instructions in the [Package +Manager documentation]. + +To install the companion Python package to enable training behaviors, follow the +[installation instructions] on our [GitHub repository]. + +### Advanced Installation +With the changes to Unity Package Manager in 2021, experimental packages will not show up in the package list and have to be installed manually. There are two recommended ways to install the package manually: + +#### Github via Package Manager + +In Unity 2019.4 or later, open the Package Manager, hit the "+" button, and select "Add package from git URL". + +![Package Manager git URL](https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/images/unity_package_manager_git_url.png) + +In the dialog that appears, enter + ``` +git+https://github.com/Unity-Technologies/ml-agents.git?path=com.unity.ml-agents#release_19 +``` + +You can also edit your project's `manifest.json` directly and add the following line to the `dependencies` +section: +``` +"com.unity.ml-agents": "git+https://github.com/Unity-Technologies/ml-agents.git?path=com.unity.ml-agents#release_19", +``` +See [Git dependencies](https://docs.unity3d.com/Manual/upm-git.html#subfolder) for more information. Note that this +may take several minutes to resolve the packages the first time that you add it. + +#### Local Installation for Development + +[Clone the repository](https://github.com/Unity-Technologies/ml-agents/tree/release_19_docs/docs/Installation.md#clone-the-ml-agents-toolkit-repository-optional) and follow the +[Local Installation for Development](https://github.com/Unity-Technologies/ml-agents/tree/release_19_docs/docs/Installation.md#advanced-local-installation-for-development-1) +directions. + +## Requirements + +This version of the Unity ML-Agents package is compatible with the following +versions of the Unity Editor: + +- 2019.4 and later + +## Known Limitations + +### Training + +Training is limited to the Unity Editor and Standalone builds on Windows, MacOS, +and Linux with the Mono scripting backend. Currently, training does not work +with the IL2CPP scripting backend. Your environment will default to inference +mode if training is not supported or is not currently running. + +### Inference + +Inference is executed via the +[Unity Inference Engine](https://docs.unity3d.com/Packages/com.unity.barracuda@latest/index.html). + +**CPU** + +All platforms supported. + +**GPU** + +All platforms supported except: + +- WebGL and GLES 3/2 on Android / iPhone + +**NOTE:** Mobile platform support includes: + +- Vulkan for Android +- Metal for iOS. + +### Headless Mode + +If you enable Headless mode, you will not be able to collect visual observations +from your agents. + +### Rendering Speed and Synchronization + +Currently the speed of the game physics can only be increased to 100x real-time. +The Academy also moves in time with FixedUpdate() rather than Update(), so game +behavior implemented in Update() may be out of sync with the agent decision +making. See [Execution Order of Event Functions] for more information. + +You can control the frequency of Academy stepping by calling +`Academy.Instance.DisableAutomaticStepping()`, and then calling +`Academy.Instance.EnvironmentStep()` + +### Unity Inference Engine Models + +Currently, only models created with our trainers are supported for running +ML-Agents with a neural network behavior. + +## Helpful links + +If you are new to the Unity ML-Agents package, or have a question after reading +the documentation, you can checkout our [GitHub Repository], which also includes +a number of ways to [connect with us] including our [ML-Agents Forum]. + +In order to improve the developer experience for Unity ML-Agents Toolkit, we have added in-editor analytics. +Please refer to "Information that is passively collected by Unity" in the +[Unity Privacy Policy](https://unity3d.com/legal/privacy-policy). + +[unity ML-Agents Toolkit]: https://github.com/Unity-Technologies/ml-agents +[unity inference engine]: https://docs.unity3d.com/Packages/com.unity.barracuda@latest/index.html +[package manager documentation]: https://docs.unity3d.com/Manual/upm-ui-install.html +[installation instructions]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/Installation.md +[github repository]: https://github.com/Unity-Technologies/ml-agents +[python package]: https://github.com/Unity-Technologies/ml-agents +[execution order of event functions]: https://docs.unity3d.com/Manual/ExecutionOrder.html +[connect with us]: https://github.com/Unity-Technologies/ml-agents#community-and-feedback +[ml-agents forum]: https://forum.unity.com/forums/ml-agents.453/ +[ML-Agents GitHub repo]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/com.unity.ml-agents.extensions diff --git a/com.unity.ml-agents/Documentation~/filter.yml b/com.unity.ml-agents/Documentation~/filter.yml new file mode 100755 index 0000000000..ce144daf61 --- /dev/null +++ b/com.unity.ml-agents/Documentation~/filter.yml @@ -0,0 +1,14 @@ +apiRules: +- exclude: + uidRegex: .*Test.* + type: Namespace +- exclude: + uidRegex: ^Unity.MLAgents\.CommunicatorObjects$ + type: Namespace +- exclude: + uidRegex: ^Unity.MLAgents\.Editor$ + type: Namespace +- exclude: + uidRegex: ^Unity.MLAgentsExamples$ + type: Namespace + diff --git a/com.unity.ml-agents/Editor.meta b/com.unity.ml-agents/Editor.meta new file mode 100644 index 0000000000..89d980b088 --- /dev/null +++ b/com.unity.ml-agents/Editor.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: e9de88a64ac5c4d2eb8955836199d61b +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/AgentEditor.cs b/com.unity.ml-agents/Editor/AgentEditor.cs new file mode 100644 index 0000000000..fdebf1cb15 --- /dev/null +++ b/com.unity.ml-agents/Editor/AgentEditor.cs @@ -0,0 +1,31 @@ +using UnityEngine; +using UnityEditor; + +namespace Unity.MLAgents.Editor +{ + /* + This code is meant to modify the behavior of the inspector on Agent Components. + */ + [CustomEditor(typeof(Agent), true)] + [CanEditMultipleObjects] + internal class AgentEditor : UnityEditor.Editor + { + public override void OnInspectorGUI() + { + var serializedAgent = serializedObject; + serializedAgent.Update(); + + var maxSteps = serializedAgent.FindProperty("MaxStep"); + + EditorGUILayout.PropertyField( + maxSteps, + new GUIContent("Max Step", "The per-agent maximum number of steps.") + ); + + serializedAgent.ApplyModifiedProperties(); + + EditorGUILayout.LabelField("", GUI.skin.horizontalSlider); + base.OnInspectorGUI(); + } + } +} diff --git a/com.unity.ml-agents/Editor/AgentEditor.cs.meta b/com.unity.ml-agents/Editor/AgentEditor.cs.meta new file mode 100755 index 0000000000..66bc325f8b --- /dev/null +++ b/com.unity.ml-agents/Editor/AgentEditor.cs.meta @@ -0,0 +1,12 @@ +fileFormatVersion: 2 +guid: c3b291e1cd0c64781861652b579d0ac1 +timeCreated: 1503270350 +licenseType: Free +MonoImporter: + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs new file mode 100644 index 0000000000..a95b2846f3 --- /dev/null +++ b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs @@ -0,0 +1,192 @@ +using System.Collections.Generic; +using UnityEditor; +using Unity.Barracuda; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Policies; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; +using CheckTypeEnum = Unity.MLAgents.Inference.BarracudaModelParamLoader.FailedCheck.CheckTypeEnum; + +namespace Unity.MLAgents.Editor +{ + /* + This code is meant to modify the behavior of the inspector on Agent Components. + */ + [CustomEditor(typeof(BehaviorParameters))] + [CanEditMultipleObjects] + internal class BehaviorParametersEditor : UnityEditor.Editor + { + const float k_TimeBetweenModelReloads = 2f; + // Time since the last reload of the model + float m_TimeSinceModelReload; + // Whether or not the model needs to be reloaded + bool m_RequireReload; + const string k_BehaviorName = "m_BehaviorName"; + const string k_BrainParametersName = "m_BrainParameters"; + const string k_ModelName = "m_Model"; + const string k_InferenceDeviceName = "m_InferenceDevice"; + const string k_DeterministicInference = "m_DeterministicInference"; + const string k_BehaviorTypeName = "m_BehaviorType"; + const string k_TeamIdName = "TeamId"; + const string k_UseChildSensorsName = "m_UseChildSensors"; + const string k_ObservableAttributeHandlingName = "m_ObservableAttributeHandling"; + + public override void OnInspectorGUI() + { + var so = serializedObject; + so.Update(); + bool needPolicyUpdate; // Whether the name, model, inference device, or BehaviorType changed. + + var behaviorParameters = (BehaviorParameters)target; + var agent = behaviorParameters.gameObject.GetComponent(); + if (agent == null) + { + EditorGUILayout.HelpBox( + "No Agent is associated with this Behavior Parameters. Attach an Agent to " + + "this GameObject to configure your Agent with these behavior parameters.", + MessageType.Warning); + } + + // Drawing the Behavior Parameters + EditorGUI.indentLevel++; + EditorGUI.BeginChangeCheck(); // global + + EditorGUI.BeginChangeCheck(); + { + EditorGUILayout.PropertyField(so.FindProperty(k_BehaviorName)); + } + needPolicyUpdate = EditorGUI.EndChangeCheck(); + + EditorGUI.BeginChangeCheck(); + EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); + { + EditorGUILayout.PropertyField(so.FindProperty(k_BrainParametersName), true); + } + EditorGUI.EndDisabledGroup(); + + EditorGUI.BeginChangeCheck(); + { + EditorGUILayout.PropertyField(so.FindProperty(k_ModelName), true); + EditorGUI.indentLevel++; + EditorGUILayout.PropertyField(so.FindProperty(k_InferenceDeviceName), true); + EditorGUILayout.PropertyField(so.FindProperty(k_DeterministicInference), true); + EditorGUI.indentLevel--; + } + needPolicyUpdate = needPolicyUpdate || EditorGUI.EndChangeCheck(); + + EditorGUI.BeginChangeCheck(); + { + EditorGUILayout.PropertyField(so.FindProperty(k_BehaviorTypeName)); + } + needPolicyUpdate = needPolicyUpdate || EditorGUI.EndChangeCheck(); + + EditorGUILayout.PropertyField(so.FindProperty(k_TeamIdName)); + EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); + { + EditorGUILayout.PropertyField(so.FindProperty(k_UseChildSensorsName), true); + EditorGUILayout.PropertyField(so.FindProperty(k_ObservableAttributeHandlingName), true); + } + EditorGUI.EndDisabledGroup(); + + EditorGUI.indentLevel--; + m_RequireReload = EditorGUI.EndChangeCheck(); + DisplayFailedModelChecks(); + so.ApplyModifiedProperties(); + + if (needPolicyUpdate) + { + UpdateAgentPolicy(); + } + } + + /// + /// Must be called within OnEditorGUI() + /// + void DisplayFailedModelChecks() + { + if (m_RequireReload && m_TimeSinceModelReload > k_TimeBetweenModelReloads) + { + m_RequireReload = false; + m_TimeSinceModelReload = 0; + } + // Display all failed checks + D.logEnabled = false; + Model barracudaModel = null; + var model = (NNModel)serializedObject.FindProperty(k_ModelName).objectReferenceValue; + var behaviorParameters = (BehaviorParameters)target; + + // Grab the sensor components, since we need them to determine the observation sizes. + // TODO make these methods of BehaviorParameters + var agent = behaviorParameters.gameObject.GetComponent(); + if (agent == null) + { + return; + } + agent.sensors = new List(); + agent.InitializeSensors(); + var sensors = agent.sensors.ToArray(); + + ActuatorComponent[] actuatorComponents; + if (behaviorParameters.UseChildActuators) + { + actuatorComponents = behaviorParameters.GetComponentsInChildren(); + } + else + { + actuatorComponents = behaviorParameters.GetComponents(); + } + + // Get the total size of the sensors generated by ObservableAttributes. + // If there are any errors (e.g. unsupported type, write-only properties), display them too. + int observableAttributeSensorTotalSize = 0; + if (agent != null && behaviorParameters.ObservableAttributeHandling != ObservableAttributeOptions.Ignore) + { + List observableErrors = new List(); + observableAttributeSensorTotalSize = ObservableAttribute.GetTotalObservationSize(agent, false, observableErrors); + foreach (var check in observableErrors) + { + EditorGUILayout.HelpBox(check, MessageType.Warning); + } + } + + var brainParameters = behaviorParameters.BrainParameters; + if (model != null) + { + barracudaModel = ModelLoader.Load(model); + } + if (brainParameters != null) + { + var failedChecks = Inference.BarracudaModelParamLoader.CheckModel( + barracudaModel, brainParameters, sensors, actuatorComponents, + observableAttributeSensorTotalSize, behaviorParameters.BehaviorType, behaviorParameters.DeterministicInference + ); + foreach (var check in failedChecks) + { + if (check != null) + { + switch (check.CheckType) + { + case CheckTypeEnum.Info: + EditorGUILayout.HelpBox(check.Message, MessageType.Info); + break; + case CheckTypeEnum.Warning: + EditorGUILayout.HelpBox(check.Message, MessageType.Warning); + break; + case CheckTypeEnum.Error: + EditorGUILayout.HelpBox(check.Message, MessageType.Error); + break; + default: + break; + } + } + } + } + } + + void UpdateAgentPolicy() + { + var behaviorParameters = (BehaviorParameters)target; + behaviorParameters.UpdateAgentPolicy(); + } + } +} diff --git a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs.meta b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs.meta new file mode 100644 index 0000000000..6eb612f3e3 --- /dev/null +++ b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 72b0b21a2d4ee4bc2be0530fd134720d +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/BrainParametersDrawer.cs b/com.unity.ml-agents/Editor/BrainParametersDrawer.cs new file mode 100644 index 0000000000..52f40e20d3 --- /dev/null +++ b/com.unity.ml-agents/Editor/BrainParametersDrawer.cs @@ -0,0 +1,172 @@ +using UnityEngine; +using UnityEditor; +using Unity.MLAgents.Policies; + +namespace Unity.MLAgents.Editor +{ + /// + /// PropertyDrawer for BrainParameters. Defines how BrainParameters are displayed in the + /// Inspector. + /// + [CustomPropertyDrawer(typeof(BrainParameters))] + internal class BrainParametersDrawer : PropertyDrawer + { + // The height of a line in the Unity Inspectors + const float k_LineHeight = 17f; + const int k_VecObsNumLine = 3; + const string k_ActionSpecName = "m_ActionSpec"; + const string k_ContinuousActionSizeName = "m_NumContinuousActions"; + const string k_DiscreteBranchSizeName = "BranchSizes"; + const string k_ActionDescriptionPropName = "VectorActionDescriptions"; + const string k_VecObsPropName = "VectorObservationSize"; + const string k_NumVecObsPropName = "NumStackedVectorObservations"; + + /// + public override float GetPropertyHeight(SerializedProperty property, GUIContent label) + { + return GetHeightDrawVectorObservation() + + GetHeightDrawVectorAction(property); + } + + /// + public override void OnGUI(Rect position, SerializedProperty property, GUIContent label) + { + var indent = EditorGUI.indentLevel; + EditorGUI.indentLevel = 0; + position.height = k_LineHeight; + EditorGUI.BeginProperty(position, label, property); + EditorGUI.indentLevel++; + + // Vector Observations + DrawVectorObservation(position, property); + position.y += GetHeightDrawVectorObservation(); + + // Vector Action + DrawVectorAction(position, property); + position.y += GetHeightDrawVectorAction(property); + + EditorGUI.EndProperty(); + EditorGUI.indentLevel = indent; + } + + /// + /// Draws the Vector Observations for the Brain Parameters + /// + /// Rectangle on the screen to use for the property GUI. + /// The SerializedProperty of the BrainParameters + /// to make the custom GUI for. + static void DrawVectorObservation(Rect position, SerializedProperty property) + { + EditorGUI.LabelField(position, "Vector Observation"); + position.y += k_LineHeight; + + EditorGUI.indentLevel++; + EditorGUI.PropertyField(position, + property.FindPropertyRelative(k_VecObsPropName), + new GUIContent("Space Size", + "Length of state " + + "vector for brain (In Continuous state space)." + + "Or number of possible values (in Discrete state space).")); + position.y += k_LineHeight; + + EditorGUI.PropertyField(position, + property.FindPropertyRelative(k_NumVecObsPropName), + new GUIContent("Stacked Vectors", + "Number of states that will be stacked before " + + "being fed to the neural network.")); + position.y += k_LineHeight; + EditorGUI.indentLevel--; + } + + /// + /// The Height required to draw the Vector Observations paramaters + /// + /// The height of the drawer of the Vector Observations + static float GetHeightDrawVectorObservation() + { + return k_VecObsNumLine * k_LineHeight; + } + + /// + /// Draws the Vector Actions parameters for the Brain Parameters + /// + /// Rectangle on the screen to use for the property GUI. + /// The SerializedProperty of the BrainParameters + /// to make the custom GUI for. + static void DrawVectorAction(Rect position, SerializedProperty property) + { + EditorGUI.LabelField(position, "Actions"); + position.y += k_LineHeight; + EditorGUI.indentLevel++; + var actionSpecProperty = property.FindPropertyRelative(k_ActionSpecName); + DrawContinuousVectorAction(position, actionSpecProperty); + position.y += k_LineHeight; + DrawDiscreteVectorAction(position, actionSpecProperty); + } + + /// + /// Draws the Continuous Vector Actions parameters for the Brain Parameters + /// + /// Rectangle on the screen to use for the property GUI. + /// The SerializedProperty of the BrainParameters + /// to make the custom GUI for. + static void DrawContinuousVectorAction(Rect position, SerializedProperty property) + { + var continuousActionSize = property.FindPropertyRelative(k_ContinuousActionSizeName); + EditorGUI.PropertyField( + position, + continuousActionSize, + new GUIContent("Continuous Actions", "Number of continuous actions.")); + } + + /// + /// Draws the Discrete Vector Actions parameters for the Brain Parameters + /// + /// Rectangle on the screen to use for the property GUI. + /// The SerializedProperty of the BrainParameters + /// to make the custom GUI for. + static void DrawDiscreteVectorAction(Rect position, SerializedProperty property) + { + var branchSizes = property.FindPropertyRelative(k_DiscreteBranchSizeName); + var newSize = EditorGUI.IntField( + position, "Discrete Branches", branchSizes.arraySize); + + // This check is here due to: + // https://fogbugz.unity3d.com/f/cases/1246524/ + // If this case has been resolved, please remove this if condition. + if (newSize != branchSizes.arraySize) + { + branchSizes.arraySize = newSize; + } + + position.y += k_LineHeight; + position.x += 20; + position.width -= 20; + for (var branchIndex = 0; + branchIndex < branchSizes.arraySize; + branchIndex++) + { + var branchActionSize = + branchSizes.GetArrayElementAtIndex(branchIndex); + + EditorGUI.PropertyField( + position, + branchActionSize, + new GUIContent("Branch " + branchIndex + " Size", + "Number of possible actions for the branch number " + branchIndex + ".")); + position.y += k_LineHeight; + } + } + + /// + /// The Height required to draw the Vector Action parameters. + /// + /// The height of the drawer of the Vector Action. + static float GetHeightDrawVectorAction(SerializedProperty property) + { + var actionSpecProperty = property.FindPropertyRelative(k_ActionSpecName); + var numActionLines = 3 + actionSpecProperty.FindPropertyRelative(k_DiscreteBranchSizeName).arraySize; + return numActionLines * k_LineHeight; + } + } +} diff --git a/com.unity.ml-agents/Editor/BrainParametersDrawer.cs.meta b/com.unity.ml-agents/Editor/BrainParametersDrawer.cs.meta new file mode 100644 index 0000000000..9379a5f0eb --- /dev/null +++ b/com.unity.ml-agents/Editor/BrainParametersDrawer.cs.meta @@ -0,0 +1,12 @@ +fileFormatVersion: 2 +guid: b060ae8e687cf49bcae88b24db17bfa6 +timeCreated: 1517291065 +licenseType: Free +MonoImporter: + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/BufferSensorComponentEditor.cs b/com.unity.ml-agents/Editor/BufferSensorComponentEditor.cs new file mode 100644 index 0000000000..f41edb2740 --- /dev/null +++ b/com.unity.ml-agents/Editor/BufferSensorComponentEditor.cs @@ -0,0 +1,31 @@ +using UnityEditor; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Editor +{ + [CustomEditor(typeof(BufferSensorComponent), editorForChildClasses: true)] + [CanEditMultipleObjects] + internal class BufferSensorComponentEditor : UnityEditor.Editor + { + public override void OnInspectorGUI() + { + var so = serializedObject; + so.Update(); + + // Drawing the BufferSensorComponent + + EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); + { + // These fields affect the sensor order or observation size, + // So can't be changed at runtime. + EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_ObservableSize"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_MaxNumObservables"), true); + } + EditorGUI.EndDisabledGroup(); + + so.ApplyModifiedProperties(); + } + + } +} diff --git a/com.unity.ml-agents/Editor/BufferSensorComponentEditor.cs.meta b/com.unity.ml-agents/Editor/BufferSensorComponentEditor.cs.meta new file mode 100644 index 0000000000..62de961c7e --- /dev/null +++ b/com.unity.ml-agents/Editor/BufferSensorComponentEditor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: b042fe65027f94c1eb38a2ee1362d38d +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs b/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs new file mode 100644 index 0000000000..1df66ee3c9 --- /dev/null +++ b/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs @@ -0,0 +1,49 @@ +using UnityEditor; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Editor +{ + [CustomEditor(typeof(CameraSensorComponent), editorForChildClasses: true)] + [CanEditMultipleObjects] + internal class CameraSensorComponentEditor : UnityEditor.Editor + { + public override void OnInspectorGUI() + { + var so = serializedObject; + so.Update(); + + // Drawing the CameraSensorComponent + EditorGUI.BeginChangeCheck(); + + EditorGUILayout.PropertyField(so.FindProperty("m_Camera"), true); + EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); + { + // These fields affect the sensor order or observation size, + // So can't be changed at runtime. + EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_Width"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_Height"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true); + } + EditorGUI.EndDisabledGroup(); + EditorGUILayout.PropertyField(so.FindProperty("m_RuntimeCameraEnable"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_Compression"), true); + + var requireSensorUpdate = EditorGUI.EndChangeCheck(); + so.ApplyModifiedProperties(); + + if (requireSensorUpdate) + { + UpdateSensor(); + } + } + + void UpdateSensor() + { + var sensorComponent = serializedObject.targetObject as CameraSensorComponent; + sensorComponent?.UpdateSensor(); + } + } +} diff --git a/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs.meta b/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs.meta new file mode 100644 index 0000000000..70b1e31432 --- /dev/null +++ b/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: fdda773c024894cf0ae47d1b1396c38d +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/DemonstrationDrawer.cs b/com.unity.ml-agents/Editor/DemonstrationDrawer.cs new file mode 100644 index 0000000000..d15baa9a1d --- /dev/null +++ b/com.unity.ml-agents/Editor/DemonstrationDrawer.cs @@ -0,0 +1,137 @@ +using System.Collections.Generic; +using System.Text; +using UnityEditor; +using Unity.MLAgents.Demonstrations; + +namespace Unity.MLAgents.Editor +{ + /// + /// Renders a custom UI for DemonstrationSummary ScriptableObject. + /// + [CustomEditor(typeof(DemonstrationSummary))] + [CanEditMultipleObjects] + internal class DemonstrationEditor : UnityEditor.Editor + { + SerializedProperty m_BrainParameters; + SerializedProperty m_DemoMetaData; + SerializedProperty m_ObservationShapes; + const string k_BrainParametersName = "brainParameters"; + const string k_MetaDataName = "metaData"; + const string k_ObservationSummariesName = "observationSummaries"; + const string k_DemonstrationName = "demonstrationName"; + const string k_NumberStepsName = "numberSteps"; + const string k_NumberEpisodesName = "numberEpisodes"; + const string k_MeanRewardName = "meanReward"; + const string k_ActionSpecName = "m_ActionSpec"; + const string k_NumContinuousActionsName = "m_NumContinuousActions"; + const string k_NumDiscreteActionsName = "BranchSizes"; + const string k_ShapeName = "shape"; + + + void OnEnable() + { + m_BrainParameters = serializedObject.FindProperty(k_BrainParametersName); + m_DemoMetaData = serializedObject.FindProperty(k_MetaDataName); + m_ObservationShapes = serializedObject.FindProperty(k_ObservationSummariesName); + } + + /// + /// Renders Inspector UI for Demonstration metadata. + /// + void MakeMetaDataProperty(SerializedProperty property) + { + var nameProp = property.FindPropertyRelative(k_DemonstrationName); + var experiencesProp = property.FindPropertyRelative(k_NumberStepsName); + var episodesProp = property.FindPropertyRelative(k_NumberEpisodesName); + var rewardsProp = property.FindPropertyRelative(k_MeanRewardName); + + var nameLabel = nameProp.displayName + ": " + nameProp.stringValue; + var experiencesLabel = experiencesProp.displayName + ": " + experiencesProp.intValue; + var episodesLabel = episodesProp.displayName + ": " + episodesProp.intValue; + var rewardsLabel = rewardsProp.displayName + ": " + rewardsProp.floatValue; + + EditorGUILayout.LabelField(nameLabel); + EditorGUILayout.LabelField(experiencesLabel); + EditorGUILayout.LabelField(episodesLabel); + EditorGUILayout.LabelField(rewardsLabel); + } + + /// + /// Constructs label for a serialized integer array. + /// + static string BuildIntArrayLabel(SerializedProperty actionSizeProperty) + { + var actionSize = actionSizeProperty.arraySize; + var actionLabel = new StringBuilder("[ "); + for (var i = 0; i < actionSize; i++) + { + actionLabel.Append(actionSizeProperty.GetArrayElementAtIndex(i).intValue); + if (i < actionSize - 1) + { + actionLabel.Append(", "); + } + } + + actionLabel.Append(" ]"); + return actionLabel.ToString(); + } + + /// + /// Renders Inspector UI for BrainParameters of a DemonstrationSummary. + /// Only the Action size and type are used from the BrainParameters. + /// + void MakeActionsProperty(SerializedProperty property) + { + var actSpecProperty = property.FindPropertyRelative(k_ActionSpecName); + var continuousSizeProperty = actSpecProperty.FindPropertyRelative(k_NumContinuousActionsName); + var discreteSizeProperty = actSpecProperty.FindPropertyRelative(k_NumDiscreteActionsName); + var continuousSizeLabel = "Continuous Actions: " + continuousSizeProperty.intValue; + var discreteSizeLabel = "Discrete Action Branches: "; + discreteSizeLabel += discreteSizeProperty == null ? "[]" : BuildIntArrayLabel(discreteSizeProperty); + EditorGUILayout.LabelField(continuousSizeLabel); + EditorGUILayout.LabelField(discreteSizeLabel); + } + + /// + /// Render the observation shapes of a DemonstrationSummary. + /// + /// + void MakeObservationsProperty(SerializedProperty obsSummariesProperty) + { + var shapesLabels = new List(); + var numObservations = obsSummariesProperty.arraySize; + for (var i = 0; i < numObservations; i++) + { + var summary = obsSummariesProperty.GetArrayElementAtIndex(i); + var shapeProperty = summary.FindPropertyRelative(k_ShapeName); + shapesLabels.Add(BuildIntArrayLabel(shapeProperty)); + } + + var shapeLabel = $"Shapes: {string.Join(", ", shapesLabels)}"; + EditorGUILayout.LabelField(shapeLabel); + + } + + public override void OnInspectorGUI() + { + serializedObject.Update(); + + EditorGUILayout.LabelField("Meta Data", EditorStyles.boldLabel); + EditorGUI.indentLevel++; + MakeMetaDataProperty(m_DemoMetaData); + EditorGUI.indentLevel--; + + EditorGUILayout.LabelField("Observations", EditorStyles.boldLabel); + EditorGUI.indentLevel++; + MakeObservationsProperty(m_ObservationShapes); + EditorGUI.indentLevel--; + + EditorGUILayout.LabelField("Actions", EditorStyles.boldLabel); + EditorGUI.indentLevel++; + MakeActionsProperty(m_BrainParameters); + EditorGUI.indentLevel--; + + serializedObject.ApplyModifiedProperties(); + } + } +} diff --git a/com.unity.ml-agents/Editor/DemonstrationDrawer.cs.meta b/com.unity.ml-agents/Editor/DemonstrationDrawer.cs.meta new file mode 100644 index 0000000000..57c0681302 --- /dev/null +++ b/com.unity.ml-agents/Editor/DemonstrationDrawer.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 84f9cd83f56c74790a51444a6cfe4945 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/DemonstrationImporter.cs b/com.unity.ml-agents/Editor/DemonstrationImporter.cs new file mode 100644 index 0000000000..767f5556f9 --- /dev/null +++ b/com.unity.ml-agents/Editor/DemonstrationImporter.cs @@ -0,0 +1,75 @@ +using System; +using System.Collections.Generic; +using System.IO; +using Unity.MLAgents.CommunicatorObjects; +using UnityEditor; +using UnityEngine; +#if UNITY_2020_2_OR_NEWER +using UnityEditor.AssetImporters; +#else +using UnityEditor.Experimental.AssetImporters; +#endif +using Unity.MLAgents.Demonstrations; + +namespace Unity.MLAgents.Editor +{ + /// + /// Asset Importer used to parse demonstration files. + /// + [ScriptedImporter(1, new[] { "demo" })] + internal class DemonstrationImporter : ScriptedImporter + { + const string k_IconPath = "Packages/com.unity.ml-agents/Editor/Icons/DemoIcon.png"; + + public override void OnImportAsset(AssetImportContext ctx) + { + var inputType = Path.GetExtension(ctx.assetPath); + if (inputType == null) + { + throw new Exception("Demonstration import error."); + } + + try + { + // Read first three proto objects containing metadata, brain parameters, and observations. + Stream reader = File.OpenRead(ctx.assetPath); + + var metaDataProto = DemonstrationMetaProto.Parser.ParseDelimitedFrom(reader); + var metaData = metaDataProto.ToDemonstrationMetaData(); + + reader.Seek(DemonstrationWriter.MetaDataBytes + 1, 0); + var brainParamsProto = BrainParametersProto.Parser.ParseDelimitedFrom(reader); + var brainParameters = brainParamsProto.ToBrainParameters(); + + // Read the first AgentInfoActionPair so that we can get the observation sizes. + List observationSummaries; + try + { + var agentInfoActionPairProto = AgentInfoActionPairProto.Parser.ParseDelimitedFrom(reader); + observationSummaries = agentInfoActionPairProto.GetObservationSummaries(); + } + catch + { + // Just in case there weren't any AgentInfoActionPair or they couldn't be read. + observationSummaries = new List(); + } + + reader.Close(); + + var demonstrationSummary = ScriptableObject.CreateInstance(); + demonstrationSummary.Initialize(brainParameters, metaData, observationSummaries); + userData = demonstrationSummary.ToString(); + + var texture = (Texture2D) + AssetDatabase.LoadAssetAtPath(k_IconPath, typeof(Texture2D)); + + ctx.AddObjectToAsset(ctx.assetPath, demonstrationSummary, texture); + ctx.SetMainObject(demonstrationSummary); + } + catch + { + // ignored + } + } + } +} diff --git a/com.unity.ml-agents/Editor/DemonstrationImporter.cs.meta b/com.unity.ml-agents/Editor/DemonstrationImporter.cs.meta new file mode 100644 index 0000000000..bbdca977a3 --- /dev/null +++ b/com.unity.ml-agents/Editor/DemonstrationImporter.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 7bd65ce151aaa4a41a45312543c56be1 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/EditorUtilities.cs b/com.unity.ml-agents/Editor/EditorUtilities.cs new file mode 100644 index 0000000000..8ef266f259 --- /dev/null +++ b/com.unity.ml-agents/Editor/EditorUtilities.cs @@ -0,0 +1,19 @@ +using UnityEngine; + +namespace Unity.MLAgents.Editor +{ + /// + /// A static helper class for the Editor components of the ML-Agents SDK. + /// + public static class EditorUtilities + { + /// + /// Whether or not properties that affect the model can be updated at the current time. + /// + /// + public static bool CanUpdateModelProperties() + { + return !Application.isPlaying; + } + } +} diff --git a/com.unity.ml-agents/Editor/EditorUtilities.cs.meta b/com.unity.ml-agents/Editor/EditorUtilities.cs.meta new file mode 100644 index 0000000000..2e58d7e8a3 --- /dev/null +++ b/com.unity.ml-agents/Editor/EditorUtilities.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 840f5a76642c24b789ee312f0aa8e33b +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/GridSensorComponentEditor.cs b/com.unity.ml-agents/Editor/GridSensorComponentEditor.cs new file mode 100644 index 0000000000..fa9208e274 --- /dev/null +++ b/com.unity.ml-agents/Editor/GridSensorComponentEditor.cs @@ -0,0 +1,109 @@ +using UnityEditor; +using UnityEngine; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Editor +{ + [CustomEditor(typeof(GridSensorComponent), editorForChildClasses: true)] + [CanEditMultipleObjects] + internal class GridSensorComponentEditor : UnityEditor.Editor + { + public override void OnInspectorGUI() + { +#if !MLA_UNITY_PHYSICS_MODULE + EditorGUILayout.HelpBox("The Physics Module is not currently present. " + + "Please add it to your project in order to use the GridSensor APIs in the " + + $"{nameof(GridSensorComponent)}", MessageType.Warning); +#endif + + var so = serializedObject; + so.Update(); + + // Drawing the GridSensorComponent + EditorGUI.BeginChangeCheck(); + + EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); + { + // These fields affect the sensor order or observation size, + // So can't be changed at runtime. + EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_SensorName)), true); + + EditorGUILayout.LabelField("Grid Settings", EditorStyles.boldLabel); + EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_CellScale)), true); + // We only supports 2D GridSensor now so lock gridSize.y to 1 + var gridSize = so.FindProperty(nameof(GridSensorComponent.m_GridSize)); + var gridSize2d = new Vector3Int(gridSize.vector3IntValue.x, 1, gridSize.vector3IntValue.z); + var newGridSize = EditorGUILayout.Vector3IntField("Grid Size", gridSize2d); + gridSize.vector3IntValue = new Vector3Int(newGridSize.x, 1, newGridSize.z); + } + EditorGUI.EndDisabledGroup(); + EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_AgentGameObject)), true); + EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_RotateWithAgent)), true); + + EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); + { + // detectable tags + var detectableTags = so.FindProperty(nameof(GridSensorComponent.m_DetectableTags)); + var newSize = EditorGUILayout.IntField("Detectable Tags", detectableTags.arraySize); + if (newSize != detectableTags.arraySize) + { + detectableTags.arraySize = newSize; + } + EditorGUI.indentLevel++; + for (var i = 0; i < detectableTags.arraySize; i++) + { + var objectTag = detectableTags.GetArrayElementAtIndex(i); + EditorGUILayout.PropertyField(objectTag, new GUIContent("Tag " + i), true); + } + EditorGUI.indentLevel--; + } + EditorGUI.EndDisabledGroup(); + EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_ColliderMask)), true); + EditorGUILayout.LabelField("Sensor Settings", EditorStyles.boldLabel); + EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_ObservationStacks)), true); + EditorGUI.EndDisabledGroup(); + EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_CompressionType)), true); + EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); + { + EditorGUILayout.LabelField("Collider and Buffer", EditorStyles.boldLabel); + EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_InitialColliderBufferSize)), true); + EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_MaxColliderBufferSize)), true); + } + EditorGUI.EndDisabledGroup(); + + EditorGUILayout.LabelField("Debug Gizmo", EditorStyles.boldLabel); + EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_ShowGizmos)), true); + EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_GizmoYOffset)), true); + + // detectable objects + var debugColors = so.FindProperty(nameof(GridSensorComponent.m_DebugColors)); + var detectableObjectSize = so.FindProperty(nameof(GridSensorComponent.m_DetectableTags)).arraySize; + if (detectableObjectSize != debugColors.arraySize) + { + debugColors.arraySize = detectableObjectSize; + } + EditorGUILayout.LabelField("Debug Colors"); + EditorGUI.indentLevel++; + for (var i = 0; i < debugColors.arraySize; i++) + { + var debugColor = debugColors.GetArrayElementAtIndex(i); + EditorGUILayout.PropertyField(debugColor, new GUIContent("Tag " + i + " Color"), true); + } + EditorGUI.indentLevel--; + + var requireSensorUpdate = EditorGUI.EndChangeCheck(); + so.ApplyModifiedProperties(); + + if (requireSensorUpdate) + { + UpdateSensor(); + } + } + + void UpdateSensor() + { + var sensorComponent = serializedObject.targetObject as GridSensorComponent; + sensorComponent?.UpdateSensor(); + } + } +} diff --git a/com.unity.ml-agents/Editor/GridSensorComponentEditor.cs.meta b/com.unity.ml-agents/Editor/GridSensorComponentEditor.cs.meta new file mode 100644 index 0000000000..c27459abce --- /dev/null +++ b/com.unity.ml-agents/Editor/GridSensorComponentEditor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 584686b36fcb2435c8be47d70c332ed0 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/Icons.meta b/com.unity.ml-agents/Editor/Icons.meta new file mode 100644 index 0000000000..6071205cd3 --- /dev/null +++ b/com.unity.ml-agents/Editor/Icons.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: e6f6d464e3884bf883137660dee8aebf +timeCreated: 1581721596 \ No newline at end of file diff --git a/com.unity.ml-agents/Editor/Icons/DemoIcon.png b/com.unity.ml-agents/Editor/Icons/DemoIcon.png new file mode 100644 index 0000000000..ddc91181bf Binary files /dev/null and b/com.unity.ml-agents/Editor/Icons/DemoIcon.png differ diff --git a/com.unity.ml-agents/Editor/Icons/DemoIcon.png.meta b/com.unity.ml-agents/Editor/Icons/DemoIcon.png.meta new file mode 100644 index 0000000000..37831fb256 --- /dev/null +++ b/com.unity.ml-agents/Editor/Icons/DemoIcon.png.meta @@ -0,0 +1,86 @@ +fileFormatVersion: 2 +guid: 3352a0e8d253b4a4ea3782a6d7e09d9b +TextureImporter: + fileIDToRecycleName: {} + externalObjects: {} + serializedVersion: 4 + mipmaps: + mipMapMode: 0 + enableMipMap: 1 + sRGBTexture: 1 + linearTexture: 0 + fadeOut: 0 + borderMipMap: 0 + mipMapsPreserveCoverage: 0 + alphaTestReferenceValue: 0.5 + mipMapFadeDistanceStart: 1 + mipMapFadeDistanceEnd: 3 + bumpmap: + convertToNormalMap: 0 + externalNormalMap: 0 + heightScale: 0.25 + normalMapFilter: 0 + isReadable: 0 + grayScaleToAlpha: 0 + generateCubemap: 6 + cubemapConvolution: 0 + seamlessCubemap: 0 + textureFormat: 1 + maxTextureSize: 2048 + textureSettings: + serializedVersion: 2 + filterMode: -1 + aniso: -1 + mipBias: -1 + wrapU: -1 + wrapV: -1 + wrapW: -1 + nPOTScale: 1 + lightmap: 0 + compressionQuality: 50 + spriteMode: 0 + spriteExtrude: 1 + spriteMeshType: 1 + alignment: 0 + spritePivot: {x: 0.5, y: 0.5} + spritePixelsToUnits: 100 + spriteBorder: {x: 0, y: 0, z: 0, w: 0} + spriteGenerateFallbackPhysicsShape: 1 + alphaUsage: 1 + alphaIsTransparency: 1 + spriteTessellationDetail: -1 + textureType: 0 + textureShape: 1 + maxTextureSizeSet: 0 + compressionQualitySet: 0 + textureFormatSet: 0 + platformSettings: + - buildTarget: DefaultTexturePlatform + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + androidETC2FallbackOverride: 0 + - buildTarget: Standalone + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + androidETC2FallbackOverride: 0 + spriteSheet: + serializedVersion: 2 + sprites: [] + outline: [] + physicsShape: [] + spritePackingTag: + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/MLAgentsSettingsBuildProvider.cs b/com.unity.ml-agents/Editor/MLAgentsSettingsBuildProvider.cs new file mode 100644 index 0000000000..a91c07ecd7 --- /dev/null +++ b/com.unity.ml-agents/Editor/MLAgentsSettingsBuildProvider.cs @@ -0,0 +1,69 @@ +using System.Linq; +using UnityEngine; +using UnityEditor; +using UnityEditor.Build; +using UnityEditor.Build.Reporting; + + +namespace Unity.MLAgents.Editor +{ + internal class MLAgentsSettingsBuildProvider : IPreprocessBuildWithReport, IPostprocessBuildWithReport + { + private MLAgentsSettings m_SettingsAddedToPreloadedAssets; + + public int callbackOrder => 0; + + public void OnPreprocessBuild(BuildReport report) + { + var wasDirty = IsPlayerSettingsDirty(); + m_SettingsAddedToPreloadedAssets = null; + + var preloadedAssets = PlayerSettings.GetPreloadedAssets().ToList(); + if (!preloadedAssets.Contains(MLAgentsSettingsManager.Settings)) + { + m_SettingsAddedToPreloadedAssets = MLAgentsSettingsManager.Settings; + preloadedAssets.Add(m_SettingsAddedToPreloadedAssets); + PlayerSettings.SetPreloadedAssets(preloadedAssets.ToArray()); + } + + if (!wasDirty) + ClearPlayerSettingsDirtyFlag(); + } + + public void OnPostprocessBuild(BuildReport report) + { + if (m_SettingsAddedToPreloadedAssets == null) + return; + + var wasDirty = IsPlayerSettingsDirty(); + + var preloadedAssets = PlayerSettings.GetPreloadedAssets().ToList(); + if (preloadedAssets.Contains(m_SettingsAddedToPreloadedAssets)) + { + preloadedAssets.Remove(m_SettingsAddedToPreloadedAssets); + PlayerSettings.SetPreloadedAssets(preloadedAssets.ToArray()); + } + + m_SettingsAddedToPreloadedAssets = null; + + if (!wasDirty) + ClearPlayerSettingsDirtyFlag(); + } + + + private static bool IsPlayerSettingsDirty() + { + var settings = Resources.FindObjectsOfTypeAll(); + if (settings != null && settings.Length > 0) + return EditorUtility.IsDirty(settings[0]); + return false; + } + + private static void ClearPlayerSettingsDirtyFlag() + { + var settings = Resources.FindObjectsOfTypeAll(); + if (settings != null && settings.Length > 0) + EditorUtility.ClearDirty(settings[0]); + } + } +} diff --git a/com.unity.ml-agents/Editor/MLAgentsSettingsBuildProvider.cs.meta b/com.unity.ml-agents/Editor/MLAgentsSettingsBuildProvider.cs.meta new file mode 100644 index 0000000000..214ea9863f --- /dev/null +++ b/com.unity.ml-agents/Editor/MLAgentsSettingsBuildProvider.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: bd59ff34305fa4259a2735e08afdb424 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/MLAgentsSettingsProvider.cs b/com.unity.ml-agents/Editor/MLAgentsSettingsProvider.cs new file mode 100644 index 0000000000..4077927a4e --- /dev/null +++ b/com.unity.ml-agents/Editor/MLAgentsSettingsProvider.cs @@ -0,0 +1,194 @@ +using System; +using System.Linq; +using System.IO; +using System.Runtime.CompilerServices; +using UnityEngine; +using UnityEditor; +using UnityEngine.UIElements; + +[assembly: InternalsVisibleTo("Unity.ML-Agents.DevTests.Editor")] +namespace Unity.MLAgents.Editor +{ + internal class MLAgentsSettingsProvider : SettingsProvider, IDisposable + { + const string k_SettingsPath = "Project/ML-Agents"; + private static MLAgentsSettingsProvider s_Instance; + private string[] m_AvailableSettingsAssets; + private int m_CurrentSelectedSettingsAsset; + private SerializedObject m_SettingsObject; + [SerializeField] + private MLAgentsSettings m_Settings; + + + private MLAgentsSettingsProvider(string path, SettingsScope scope = SettingsScope.Project) + : base(path, scope) + { + s_Instance = this; + } + + [SettingsProvider] + public static SettingsProvider CreateMLAgentsSettingsProvider() + { + return new MLAgentsSettingsProvider(k_SettingsPath, SettingsScope.Project); + } + + public override void OnActivate(string searchContext, VisualElement rootElement) + { + base.OnActivate(searchContext, rootElement); + MLAgentsSettingsManager.OnSettingsChange += Reinitialize; + } + + public override void OnDeactivate() + { + base.OnDeactivate(); + MLAgentsSettingsManager.OnSettingsChange -= Reinitialize; + } + + public void Dispose() + { + m_SettingsObject?.Dispose(); + } + + public override void OnTitleBarGUI() + { + if (EditorGUILayout.DropdownButton(EditorGUIUtility.IconContent("_Popup"), FocusType.Passive, EditorStyles.label)) + { + var menu = new GenericMenu(); + for (var i = 0; i < m_AvailableSettingsAssets.Length; i++) + { + menu.AddItem(ExtractDisplayName(m_AvailableSettingsAssets[i]), m_CurrentSelectedSettingsAsset == i, (path) => + { + MLAgentsSettingsManager.Settings = AssetDatabase.LoadAssetAtPath((string)path); + }, m_AvailableSettingsAssets[i]); + } + menu.AddSeparator(""); + menu.AddItem(new GUIContent("New Settings Asset…"), false, CreateNewSettingsAsset); + menu.ShowAsContext(); + Event.current.Use(); + } + } + + private GUIContent ExtractDisplayName(string name) + { + if (name.StartsWith("Assets/")) + name = name.Substring("Assets/".Length); + if (name.EndsWith(".asset")) + name = name.Substring(0, name.Length - ".asset".Length); + if (name.EndsWith(".mlagents.settings")) + name = name.Substring(0, name.Length - ".mlagents.settings".Length); + + // Ugly hack: GenericMenu interprets "/" as a submenu path. But luckily, "/" is not the only slash we have in Unicode. + return new GUIContent(name.Replace("/", "\u29f8")); + } + + private void CreateNewSettingsAsset() + { + // Asset database always use forward slashes. Use forward slashes for all the paths. + var projectName = PlayerSettings.productName; + var path = EditorUtility.SaveFilePanel("Create ML-Agents Settings File", "Assets", + projectName + ".mlagents.settings", "asset"); + if (string.IsNullOrEmpty(path)) + { + return; + } + + path = path.Replace("\\", "/"); // Make sure we only get '/' separators. + var assetPath = Application.dataPath + "/"; + if (!path.StartsWith(assetPath, StringComparison.CurrentCultureIgnoreCase)) + { + Debug.LogError(string.Format( + "Settings must be stored in Assets folder of the project (got: '{0}')", path)); + return; + } + + var extension = Path.GetExtension(path); + if (string.Compare(extension, ".asset", StringComparison.InvariantCultureIgnoreCase) != 0) + { + path += ".asset"; + } + var relativePath = "Assets/" + path.Substring(assetPath.Length); + CreateNewSettingsAsset(relativePath); + } + + private static void CreateNewSettingsAsset(string relativePath) + { + var settings = ScriptableObject.CreateInstance(); + AssetDatabase.CreateAsset(settings, relativePath); + EditorGUIUtility.PingObject(settings); + // Install the settings. This will lead to an MLAgentsManager.OnSettingsChange event + // which in turn will cause this Provider to reinitialize + MLAgentsSettingsManager.Settings = settings; + } + + public override void OnGUI(string searchContext) + { + if (m_Settings == null) + { + InitializeWithCurrentSettings(); + } + + if (m_AvailableSettingsAssets.Length == 0) + { + EditorGUILayout.HelpBox( + "Click the button below to create a settings asset you can edit.", + MessageType.Info); + if (GUILayout.Button("Create settings asset", GUILayout.Height(30))) + CreateNewSettingsAsset(); + GUILayout.Space(20); + } + + using (new EditorGUI.DisabledScope(m_AvailableSettingsAssets.Length == 0)) + { + EditorGUI.BeginChangeCheck(); + EditorGUILayout.LabelField("Trainer Settings", EditorStyles.boldLabel); + EditorGUI.indentLevel++; + EditorGUILayout.PropertyField(m_SettingsObject.FindProperty("m_ConnectTrainer"), new GUIContent("Connect to Trainer")); + EditorGUILayout.PropertyField(m_SettingsObject.FindProperty("m_EditorPort"), new GUIContent("Editor Training Port")); + EditorGUI.indentLevel--; + if (EditorGUI.EndChangeCheck()) + m_SettingsObject.ApplyModifiedProperties(); + } + } + + internal void InitializeWithCurrentSettings() + { + m_AvailableSettingsAssets = FindSettingsInProject(); + + m_Settings = MLAgentsSettingsManager.Settings; + var currentSettingsPath = AssetDatabase.GetAssetPath(m_Settings); + if (string.IsNullOrEmpty(currentSettingsPath)) + { + if (m_AvailableSettingsAssets.Length > 0) + { + m_CurrentSelectedSettingsAsset = 0; + m_Settings = AssetDatabase.LoadAssetAtPath(m_AvailableSettingsAssets[0]); + MLAgentsSettingsManager.Settings = m_Settings; + } + } + else + { + var settingsList = m_AvailableSettingsAssets.ToList(); + m_CurrentSelectedSettingsAsset = settingsList.IndexOf(currentSettingsPath); + + EditorBuildSettings.AddConfigObject(MLAgentsSettingsManager.EditorBuildSettingsConfigKey, m_Settings, true); + } + + m_SettingsObject = new SerializedObject(m_Settings); + } + + private static string[] FindSettingsInProject() + { + var guids = AssetDatabase.FindAssets("t:MLAgentsSettings"); + return guids.Select(guid => AssetDatabase.GUIDToAssetPath(guid)).ToArray(); + } + + private void Reinitialize() + { + if (m_Settings != null && MLAgentsSettingsManager.Settings != m_Settings) + { + InitializeWithCurrentSettings(); + } + Repaint(); + } + } +} diff --git a/com.unity.ml-agents/Editor/MLAgentsSettingsProvider.cs.meta b/com.unity.ml-agents/Editor/MLAgentsSettingsProvider.cs.meta new file mode 100644 index 0000000000..09eaa72b4e --- /dev/null +++ b/com.unity.ml-agents/Editor/MLAgentsSettingsProvider.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 162489862d7f64a40990a0c06bb73bd0 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/Match3ActuatorComponentEditor.cs b/com.unity.ml-agents/Editor/Match3ActuatorComponentEditor.cs new file mode 100644 index 0000000000..0b072c5b52 --- /dev/null +++ b/com.unity.ml-agents/Editor/Match3ActuatorComponentEditor.cs @@ -0,0 +1,46 @@ +using UnityEditor; +using Unity.MLAgents.Integrations.Match3; +namespace Unity.MLAgents.Editor +{ + [CustomEditor(typeof(Match3ActuatorComponent), editorForChildClasses: true)] + [CanEditMultipleObjects] + internal class Match3ActuatorComponentEditor : UnityEditor.Editor + { + public override void OnInspectorGUI() + { + var so = serializedObject; + so.Update(); + + var component = (Match3ActuatorComponent)target; + var board = component.GetComponent(); + if (board == null) + { + EditorGUILayout.HelpBox("You must provide an implementation of an AbstractBoard.", MessageType.Warning); + return; + } + + // Drawing the RenderTextureComponent + EditorGUI.BeginChangeCheck(); + + EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); + { + EditorGUILayout.PropertyField(so.FindProperty("m_ActuatorName"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_RandomSeed"), true); + } + EditorGUI.EndDisabledGroup(); + EditorGUILayout.PropertyField(so.FindProperty("m_ForceHeuristic"), true); + + var requireSensorUpdate = EditorGUI.EndChangeCheck(); + so.ApplyModifiedProperties(); + + if (requireSensorUpdate) + { + UpdateActuator(); + } + } + + void UpdateActuator() + { + } + } +} diff --git a/com.unity.ml-agents/Editor/Match3ActuatorComponentEditor.cs.meta b/com.unity.ml-agents/Editor/Match3ActuatorComponentEditor.cs.meta new file mode 100644 index 0000000000..ce515a1234 --- /dev/null +++ b/com.unity.ml-agents/Editor/Match3ActuatorComponentEditor.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: b545474cca77481bbc3c6c161dd6bbc3 +timeCreated: 1618441761 \ No newline at end of file diff --git a/com.unity.ml-agents/Editor/Match3SensorComponentEditor.cs b/com.unity.ml-agents/Editor/Match3SensorComponentEditor.cs new file mode 100644 index 0000000000..857bedeee8 --- /dev/null +++ b/com.unity.ml-agents/Editor/Match3SensorComponentEditor.cs @@ -0,0 +1,45 @@ +using UnityEditor; +using Unity.MLAgents.Integrations.Match3; +namespace Unity.MLAgents.Editor +{ + [CustomEditor(typeof(Match3SensorComponent), editorForChildClasses: true)] + [CanEditMultipleObjects] + internal class Match3SensorComponentEditor : UnityEditor.Editor + { + public override void OnInspectorGUI() + { + var so = serializedObject; + so.Update(); + + var component = (Match3SensorComponent)target; + var board = component.GetComponent(); + if (board == null) + { + EditorGUILayout.HelpBox("You must provide an implementation of an AbstractBoard.", MessageType.Warning); + return; + } + + // Drawing the RenderTextureComponent + EditorGUI.BeginChangeCheck(); + + EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); + { + EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true); + } + EditorGUI.EndDisabledGroup(); + + var requireSensorUpdate = EditorGUI.EndChangeCheck(); + so.ApplyModifiedProperties(); + + if (requireSensorUpdate) + { + UpdateSensor(); + } + } + + void UpdateSensor() + { + } + } +} diff --git a/com.unity.ml-agents/Editor/Match3SensorComponentEditor.cs.meta b/com.unity.ml-agents/Editor/Match3SensorComponentEditor.cs.meta new file mode 100644 index 0000000000..82a80140ed --- /dev/null +++ b/com.unity.ml-agents/Editor/Match3SensorComponentEditor.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: ab55bf118d03479bb797c0037989c308 +timeCreated: 1618440499 \ No newline at end of file diff --git a/com.unity.ml-agents/Editor/RayPerceptionSensorComponentBaseEditor.cs b/com.unity.ml-agents/Editor/RayPerceptionSensorComponentBaseEditor.cs new file mode 100644 index 0000000000..4231c2776e --- /dev/null +++ b/com.unity.ml-agents/Editor/RayPerceptionSensorComponentBaseEditor.cs @@ -0,0 +1,111 @@ +using UnityEngine; +using UnityEditor; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Editor +{ + internal class RayPerceptionSensorComponentBaseEditor : UnityEditor.Editor + { + bool m_RequireSensorUpdate; + + protected void OnRayPerceptionInspectorGUI(bool is3d) + { +#if !MLA_UNITY_PHYSICS_MODULE + if (is3d) + { + EditorGUILayout.HelpBox("The Physics Module is not currently present. " + + "Please add it to your project in order to use the Ray Perception APIs in the " + + $"{nameof(RayPerceptionSensorComponent3D)}", MessageType.Warning); + } +#endif +#if !MLA_UNITY_PHYSICS2D_MODULE + if (!is3d) + { + EditorGUILayout.HelpBox("The Physics2D Module is not currently present. " + + "Please add it to your project in order to use the Ray Perception APIs in the " + + $"{nameof(RayPerceptionSensorComponent3D)}", MessageType.Warning); + } +#endif + var so = serializedObject; + so.Update(); + + // Drawing the RayPerceptionSensorComponent + EditorGUI.BeginChangeCheck(); + EditorGUI.indentLevel++; + + // Don't allow certain fields to be modified during play mode. + // * SensorName affects the ordering of the Agent's observations + // * The number of tags and rays affects the size of the observations. + EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); + { + EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_DetectableTags"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_RaysPerDirection"), true); + } + EditorGUI.EndDisabledGroup(); + + EditorGUILayout.PropertyField(so.FindProperty("m_MaxRayDegrees"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_SphereCastRadius"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_RayLength"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_RayLayerMask"), true); + + // Because the number of observation stacks affects the observation shape, + // it is not editable during play mode. + EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); + { + EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), new GUIContent("Stacked Raycasts"), true); + } + EditorGUI.EndDisabledGroup(); + + if (is3d) + { + EditorGUILayout.PropertyField(so.FindProperty("m_StartVerticalOffset"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_EndVerticalOffset"), true); + } + + EditorGUILayout.PropertyField(so.FindProperty("m_AlternatingRayOrder"), true); + + EditorGUILayout.PropertyField(so.FindProperty("rayHitColor"), true); + EditorGUILayout.PropertyField(so.FindProperty("rayMissColor"), true); + + EditorGUI.indentLevel--; + if (EditorGUI.EndChangeCheck()) + { + m_RequireSensorUpdate = true; + } + + so.ApplyModifiedProperties(); + UpdateSensorIfDirty(); + } + + void UpdateSensorIfDirty() + { + if (m_RequireSensorUpdate) + { + var sensorComponent = serializedObject.targetObject as RayPerceptionSensorComponentBase; + sensorComponent?.UpdateSensor(); + m_RequireSensorUpdate = false; + } + } + } + + [CustomEditor(typeof(RayPerceptionSensorComponent2D), editorForChildClasses: true)] + [CanEditMultipleObjects] + internal class RayPerceptionSensorComponent2DEditor : RayPerceptionSensorComponentBaseEditor + { + public override void OnInspectorGUI() + { + OnRayPerceptionInspectorGUI(false); + } + } + + [CustomEditor(typeof(RayPerceptionSensorComponent3D), editorForChildClasses: true)] + [CanEditMultipleObjects] + internal class RayPerceptionSensorComponent3DEditor : RayPerceptionSensorComponentBaseEditor + { + public override void OnInspectorGUI() + { + OnRayPerceptionInspectorGUI(true); + } + } +} diff --git a/com.unity.ml-agents/Editor/RayPerceptionSensorComponentBaseEditor.cs.meta b/com.unity.ml-agents/Editor/RayPerceptionSensorComponentBaseEditor.cs.meta new file mode 100644 index 0000000000..c02a8ca2a9 --- /dev/null +++ b/com.unity.ml-agents/Editor/RayPerceptionSensorComponentBaseEditor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: c0182483e53c24d0e9f264f711ed89a9 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs b/com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs new file mode 100644 index 0000000000..8e5fd892d3 --- /dev/null +++ b/com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs @@ -0,0 +1,43 @@ +using UnityEditor; +using Unity.MLAgents.Sensors; +namespace Unity.MLAgents.Editor +{ + [CustomEditor(typeof(RenderTextureSensorComponent), editorForChildClasses: true)] + [CanEditMultipleObjects] + internal class RenderTextureSensorComponentEditor : UnityEditor.Editor + { + public override void OnInspectorGUI() + { + var so = serializedObject; + so.Update(); + + // Drawing the RenderTextureComponent + EditorGUI.BeginChangeCheck(); + + EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); + { + EditorGUILayout.PropertyField(so.FindProperty("m_RenderTexture"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true); + } + EditorGUI.EndDisabledGroup(); + + EditorGUILayout.PropertyField(so.FindProperty("m_Compression"), true); + + var requireSensorUpdate = EditorGUI.EndChangeCheck(); + so.ApplyModifiedProperties(); + + if (requireSensorUpdate) + { + UpdateSensor(); + } + } + + void UpdateSensor() + { + var sensorComponent = serializedObject.targetObject as RenderTextureSensorComponent; + sensorComponent?.UpdateSensor(); + } + } +} diff --git a/com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs.meta b/com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs.meta new file mode 100644 index 0000000000..fd7a57d05e --- /dev/null +++ b/com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: dab309e01d2964f0792de3ef914ca6b9 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/Unity.ML-Agents.Editor.asmdef b/com.unity.ml-agents/Editor/Unity.ML-Agents.Editor.asmdef new file mode 100755 index 0000000000..27f67d0d1c --- /dev/null +++ b/com.unity.ml-agents/Editor/Unity.ML-Agents.Editor.asmdef @@ -0,0 +1,30 @@ +{ + "name": "Unity.ML-Agents.Editor", + "references": [ + "Unity.ML-Agents", + "Unity.Barracuda", + "Unity.ML-Agents.CommunicatorObjects" + ], + "optionalUnityReferences": [], + "includePlatforms": [ + "Editor" + ], + "excludePlatforms": [], + "allowUnsafeCode": false, + "overrideReferences": false, + "precompiledReferences": [], + "autoReferenced": true, + "defineConstraints": [], + "versionDefines": [ + { + "name": "com.unity.modules.physics", + "expression": "1.0.0", + "define": "MLA_UNITY_PHYSICS_MODULE" + }, + { + "name": "com.unity.modules.physics2d", + "expression": "1.0.0", + "define": "MLA_UNITY_PHYSICS2D_MODULE" + } + ] +} diff --git a/com.unity.ml-agents/Editor/Unity.ML-Agents.Editor.asmdef.meta b/com.unity.ml-agents/Editor/Unity.ML-Agents.Editor.asmdef.meta new file mode 100644 index 0000000000..0c031cb370 --- /dev/null +++ b/com.unity.ml-agents/Editor/Unity.ML-Agents.Editor.asmdef.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: 42675ddec8c314cf08d17ee0f6f5e5a5 +AssemblyDefinitionImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/UnityColors.colors b/com.unity.ml-agents/Editor/UnityColors.colors new file mode 100644 index 0000000000..af773637c7 --- /dev/null +++ b/com.unity.ml-agents/Editor/UnityColors.colors @@ -0,0 +1,64 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!114 &1 +MonoBehaviour: + m_ObjectHideFlags: 52 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 0} + m_Enabled: 1 + m_EditorHideFlags: 1 + m_Script: {fileID: 12323, guid: 0000000000000000e000000000000000, type: 0} + m_Name: UnityColors + m_EditorClassIdentifier: + m_Presets: + - m_Name: + m_Color: {r: 0.12941177, g: 0.5882353, b: 0.9529412, a: 1} + - m_Name: + m_Color: {r: 0, g: 0.34117648, b: 0.6039216, a: 1} + - m_Name: + m_Color: {r: 0.2627451, g: 0.7019608, b: 0.9019608, a: 1} + - m_Name: + m_Color: {r: 0.92156863, g: 0.25490198, b: 0.47843137, a: 1} + - m_Name: + m_Color: {r: 0.92941177, g: 0.3254902, b: 0.31764707, a: 1} + - m_Name: + m_Color: {r: 0.3647059, g: 0.41568628, b: 0.69411767, a: 1} + - m_Name: + m_Color: {r: 0.46666667, g: 0.5647059, b: 0.60784316, a: 1} + - m_Name: + m_Color: {r: 0.74509805, g: 0.7372549, b: 0.7411765, a: 1} + - m_Name: + m_Color: {r: 0.9254902, g: 0.9372549, b: 0.9411765, a: 1} + - m_Name: + m_Color: {r: 0.6039216, g: 0.31764707, b: 0.627451, a: 1} + - m_Name: + m_Color: {r: 0.2901961, g: 0.1764706, b: 0.5254902, a: 1} + - m_Name: + m_Color: {r: 0.4627451, g: 0.35686275, b: 0.654902, a: 1} + - m_Name: + m_Color: {r: 0.6039216, g: 0.31764707, b: 0.627451, a: 1} + - m_Name: + m_Color: {r: 0.20392157, g: 0.75686276, b: 0.8392157, a: 1} + - m_Name: + m_Color: {r: 0.1254902, g: 0.6509804, b: 0.60784316, a: 1} + - m_Name: + m_Color: {r: 0.39609292, g: 0.49962592, b: 0.6509434, a: 0} + - m_Name: + m_Color: {r: 0.40392157, g: 0.7372549, b: 0.41960785, a: 1} + - m_Name: + m_Color: {r: 0.60784316, g: 0.8039216, b: 0.39607844, a: 1} + - m_Name: + m_Color: {r: 0.8235294, g: 0.8784314, b: 0.34901962, a: 1} + - m_Name: + m_Color: {r: 1, g: 0.79607844, b: 0.15294118, a: 1} + - m_Name: + m_Color: {r: 1, g: 0.93333334, b: 0.34509805, a: 1} + - m_Name: + m_Color: {r: 0.98039216, g: 0.6509804, b: 0.16078432, a: 1} + - m_Name: + m_Color: {r: 0.9529412, g: 0.4392157, b: 0.27450982, a: 1} + - m_Name: + m_Color: {r: 0.74509805, g: 0.22745098, b: 0.15294118, a: 1} + - m_Name: + m_Color: {r: 0.9529412, g: 0.4392157, b: 0.27450982, a: 1} diff --git a/com.unity.ml-agents/Editor/UnityColors.colors.meta b/com.unity.ml-agents/Editor/UnityColors.colors.meta new file mode 100644 index 0000000000..34519c21ce --- /dev/null +++ b/com.unity.ml-agents/Editor/UnityColors.colors.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: b20b0226063034686a6cf92ade284285 +NativeFormatImporter: + externalObjects: {} + mainObjectFileID: 0 + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs b/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs new file mode 100644 index 0000000000..aae6fd796f --- /dev/null +++ b/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs @@ -0,0 +1,31 @@ +using UnityEditor; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Editor +{ + [CustomEditor(typeof(VectorSensorComponent), editorForChildClasses: true)] + [CanEditMultipleObjects] + internal class VectorSensorComponentEditor : UnityEditor.Editor + { + public override void OnInspectorGUI() + { + var so = serializedObject; + so.Update(); + + // Drawing the VectorSensorComponent + + EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); + { + // These fields affect the sensor order or observation size, + // So can't be changed at runtime. + EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_ObservationSize"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true); + } + EditorGUI.EndDisabledGroup(); + + so.ApplyModifiedProperties(); + } + } +} diff --git a/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs.meta b/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs.meta new file mode 100644 index 0000000000..9862a23944 --- /dev/null +++ b/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: aa0230c3402f04921acdbbdb61f6ff00 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/LICENSE.md b/com.unity.ml-agents/LICENSE.md new file mode 100644 index 0000000000..42863a2c98 --- /dev/null +++ b/com.unity.ml-agents/LICENSE.md @@ -0,0 +1,202 @@ +com.unity.ml-agents copyright © 2017 Unity Technologies + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. diff --git a/com.unity.ml-agents/LICENSE.md.meta b/com.unity.ml-agents/LICENSE.md.meta new file mode 100644 index 0000000000..0497eff447 --- /dev/null +++ b/com.unity.ml-agents/LICENSE.md.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: 3b008ccfd571c4bc08e5ae283e73db3f +TextScriptImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins.meta b/com.unity.ml-agents/Plugins.meta new file mode 100644 index 0000000000..e4a9de128a --- /dev/null +++ b/com.unity.ml-agents/Plugins.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 694794cb53c6c4bfc9b84ca5022f4ae2 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/IL2CPP.DL.Stubs.c b/com.unity.ml-agents/Plugins/IL2CPP.DL.Stubs.c new file mode 100644 index 0000000000..843dafb645 --- /dev/null +++ b/com.unity.ml-agents/Plugins/IL2CPP.DL.Stubs.c @@ -0,0 +1,10 @@ +// These stubs fix an issue compiling GRPC on Windows with IL2CPP. +// For the moment, only Inference works. (training doesn't) + +void * dlopen(const char *filename, int flags) { + return 0; +} + +void * dlsym(void *handle, const char *symbol) { + return 0; +} \ No newline at end of file diff --git a/com.unity.ml-agents/Plugins/IL2CPP.DL.Stubs.c.meta b/com.unity.ml-agents/Plugins/IL2CPP.DL.Stubs.c.meta new file mode 100644 index 0000000000..9f2b819fa0 --- /dev/null +++ b/com.unity.ml-agents/Plugins/IL2CPP.DL.Stubs.c.meta @@ -0,0 +1,89 @@ +fileFormatVersion: 2 +guid: 3509a8908cf600c4f914a0705123a363 +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + defineConstraints: [] + isPreloaded: 1 + isOverridable: 0 + isExplicitlyReferenced: 0 + validateReferences: 1 + platformData: + - first: + : Any + second: + enabled: 0 + settings: + Exclude Editor: 1 + Exclude Linux: 1 + Exclude Linux64: 1 + Exclude LinuxUniversal: 1 + Exclude OSXUniversal: 1 + Exclude Win: 0 + Exclude Win64: 0 + - first: + Any: + second: + enabled: 0 + settings: {} + - first: + Editor: Editor + second: + enabled: 0 + settings: + CPU: AnyCPU + DefaultValueInitialized: true + OS: AnyOS + - first: + Facebook: Win + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Facebook: Win64 + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Standalone: Linux + second: + enabled: 0 + settings: + CPU: None + - first: + Standalone: Linux64 + second: + enabled: 0 + settings: + CPU: None + - first: + Standalone: LinuxUniversal + second: + enabled: 0 + settings: + CPU: None + - first: + Standalone: OSXUniversal + second: + enabled: 0 + settings: + CPU: None + - first: + Standalone: Win + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: Win64 + second: + enabled: 1 + settings: + CPU: AnyCPU + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer.meta b/com.unity.ml-agents/Plugins/ProtoBuffer.meta new file mode 100644 index 0000000000..af0fdcb105 --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer.meta @@ -0,0 +1,10 @@ +fileFormatVersion: 2 +guid: e44343d7e31b04d47bd5f7329c918ffe +folderAsset: yes +timeCreated: 1521839636 +licenseType: Free +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/Grpc.Core.dll b/com.unity.ml-agents/Plugins/ProtoBuffer/Grpc.Core.dll new file mode 100644 index 0000000000..601f87c27a Binary files /dev/null and b/com.unity.ml-agents/Plugins/ProtoBuffer/Grpc.Core.dll differ diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/Grpc.Core.dll.meta b/com.unity.ml-agents/Plugins/ProtoBuffer/Grpc.Core.dll.meta new file mode 100644 index 0000000000..1e82ae4a51 --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer/Grpc.Core.dll.meta @@ -0,0 +1,118 @@ +fileFormatVersion: 2 +guid: cbf24ddeec4054edc9ad4c8295556878 +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + defineConstraints: [] + isPreloaded: 0 + isOverridable: 0 + isExplicitlyReferenced: 0 + validateReferences: 1 + platformData: + - first: + : Any + second: + enabled: 0 + settings: + Exclude Android: 1 + Exclude CloudRendering: 1 + Exclude Editor: 0 + Exclude Linux: 0 + Exclude Linux64: 0 + Exclude LinuxUniversal: 0 + Exclude OSXUniversal: 0 + Exclude WebGL: 1 + Exclude Win: 0 + Exclude Win64: 0 + Exclude iOS: 1 + - first: + Android: Android + second: + enabled: 0 + settings: + CPU: ARMv7 + - first: + Any: + second: + enabled: 0 + settings: {} + - first: + CloudRendering: CloudRendering + second: + enabled: 0 + settings: {} + - first: + Editor: Editor + second: + enabled: 1 + settings: + CPU: AnyCPU + DefaultValueInitialized: true + OS: AnyOS + - first: + Facebook: Win + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Facebook: Win64 + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Standalone: Linux + second: + enabled: 1 + settings: + CPU: x86 + - first: + Standalone: Linux64 + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: LinuxUniversal + second: + enabled: 1 + settings: {} + - first: + Standalone: OSXUniversal + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: Win + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: Win64 + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Windows Store Apps: WindowsStoreApps + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + iPhone: iOS + second: + enabled: 0 + settings: + AddToEmbeddedBinaries: false + CPU: AnyCPU + CompileFlags: + FrameworkDependencies: + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/System.Interactive.Async.dll b/com.unity.ml-agents/Plugins/ProtoBuffer/System.Interactive.Async.dll new file mode 100755 index 0000000000..48efea419e Binary files /dev/null and b/com.unity.ml-agents/Plugins/ProtoBuffer/System.Interactive.Async.dll differ diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/System.Interactive.Async.dll.meta b/com.unity.ml-agents/Plugins/ProtoBuffer/System.Interactive.Async.dll.meta new file mode 100644 index 0000000000..969150b326 --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer/System.Interactive.Async.dll.meta @@ -0,0 +1,33 @@ +fileFormatVersion: 2 +guid: 9502ce7e38c5947dba996570732b6e9f +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + defineConstraints: [] + isPreloaded: 0 + isOverridable: 0 + isExplicitlyReferenced: 0 + validateReferences: 1 + platformData: + - first: + Any: + second: + enabled: 1 + settings: {} + - first: + Editor: Editor + second: + enabled: 0 + settings: + DefaultValueInitialized: true + - first: + Windows Store Apps: WindowsStoreApps + second: + enabled: 0 + settings: + CPU: AnyCPU + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/link.xml b/com.unity.ml-agents/Plugins/ProtoBuffer/link.xml new file mode 100644 index 0000000000..857dfdd0a4 --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer/link.xml @@ -0,0 +1,10 @@ + + + + + diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/link.xml.meta b/com.unity.ml-agents/Plugins/ProtoBuffer/link.xml.meta new file mode 100644 index 0000000000..872460e078 --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer/link.xml.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: f94355fa6eab94c2d8529747b92ca3e1 +TextScriptImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes.meta b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes.meta new file mode 100644 index 0000000000..6995400aec --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes.meta @@ -0,0 +1,10 @@ +fileFormatVersion: 2 +guid: b8022add2e5264884a117894eeaf9809 +folderAsset: yes +timeCreated: 1521595360 +licenseType: Free +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux.meta b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux.meta new file mode 100644 index 0000000000..97848b1297 --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux.meta @@ -0,0 +1,10 @@ +fileFormatVersion: 2 +guid: 50c3602c6f6244621861928757e31463 +folderAsset: yes +timeCreated: 1521595360 +licenseType: Free +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux/native.meta b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux/native.meta new file mode 100644 index 0000000000..a8b33def01 --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux/native.meta @@ -0,0 +1,10 @@ +fileFormatVersion: 2 +guid: ba192b1e561564e1583e0a87334f8682 +folderAsset: yes +timeCreated: 1521595360 +licenseType: Free +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so new file mode 100755 index 0000000000..9bf86dc2d7 Binary files /dev/null and b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so differ diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so.meta b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so.meta new file mode 100644 index 0000000000..cf508374c6 --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so.meta @@ -0,0 +1,113 @@ +fileFormatVersion: 2 +guid: c9d901caf522f4dc5815786fa764a5da +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + defineConstraints: [] + isPreloaded: 0 + isOverridable: 0 + isExplicitlyReferenced: 0 + validateReferences: 1 + platformData: + - first: + : Any + second: + enabled: 0 + settings: + Exclude Android: 1 + Exclude CloudRendering: 1 + Exclude Editor: 0 + Exclude Linux: 1 + Exclude Linux64: 0 + Exclude LinuxUniversal: 0 + Exclude OSXUniversal: 1 + Exclude WebGL: 1 + Exclude Win: 0 + Exclude Win64: 0 + Exclude iOS: 1 + - first: + Android: Android + second: + enabled: 0 + settings: + CPU: ARMv7 + - first: + Any: + second: + enabled: 0 + settings: {} + - first: + CloudRendering: CloudRendering + second: + enabled: 0 + settings: {} + - first: + Editor: Editor + second: + enabled: 1 + settings: + CPU: x86_64 + DefaultValueInitialized: true + OS: Linux + - first: + Facebook: Win + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Facebook: Win64 + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Standalone: Linux + second: + enabled: 0 + settings: + CPU: None + - first: + Standalone: Linux64 + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: LinuxUniversal + second: + enabled: 1 + settings: + CPU: x86_64 + - first: + Standalone: OSXUniversal + second: + enabled: 0 + settings: + CPU: None + - first: + Standalone: Win + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: Win64 + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + iPhone: iOS + second: + enabled: 0 + settings: + AddToEmbeddedBinaries: false + CPU: AnyCPU + CompileFlags: + FrameworkDependencies: + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so new file mode 100755 index 0000000000..fce3041689 Binary files /dev/null and b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so differ diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so.meta b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so.meta new file mode 100644 index 0000000000..a3592911d6 --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so.meta @@ -0,0 +1,113 @@ +fileFormatVersion: 2 +guid: 7dfb52431a6d941c89758cf0a217e3ab +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + defineConstraints: [] + isPreloaded: 0 + isOverridable: 0 + isExplicitlyReferenced: 0 + validateReferences: 1 + platformData: + - first: + : Any + second: + enabled: 0 + settings: + Exclude Android: 1 + Exclude CloudRendering: 1 + Exclude Editor: 0 + Exclude Linux: 0 + Exclude Linux64: 0 + Exclude LinuxUniversal: 0 + Exclude OSXUniversal: 1 + Exclude WebGL: 1 + Exclude Win: 0 + Exclude Win64: 0 + Exclude iOS: 1 + - first: + Android: Android + second: + enabled: 0 + settings: + CPU: ARMv7 + - first: + Any: + second: + enabled: 0 + settings: {} + - first: + CloudRendering: CloudRendering + second: + enabled: 0 + settings: {} + - first: + Editor: Editor + second: + enabled: 1 + settings: + CPU: x86 + DefaultValueInitialized: true + OS: Linux + - first: + Facebook: Win + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Facebook: Win64 + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Standalone: Linux + second: + enabled: 1 + settings: + CPU: x86 + - first: + Standalone: Linux64 + second: + enabled: 1 + settings: + CPU: None + - first: + Standalone: LinuxUniversal + second: + enabled: 1 + settings: + CPU: x86 + - first: + Standalone: OSXUniversal + second: + enabled: 0 + settings: + CPU: None + - first: + Standalone: Win + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: Win64 + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + iPhone: iOS + second: + enabled: 0 + settings: + AddToEmbeddedBinaries: false + CPU: AnyCPU + CompileFlags: + FrameworkDependencies: + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/osx.meta b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/osx.meta new file mode 100644 index 0000000000..69cbe8ef60 --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/osx.meta @@ -0,0 +1,10 @@ +fileFormatVersion: 2 +guid: f43fa6e62fb4c4105b270be1ae7bbbfd +folderAsset: yes +timeCreated: 1521595360 +licenseType: Free +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/osx/native.meta b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/osx/native.meta new file mode 100644 index 0000000000..24fab959db --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/osx/native.meta @@ -0,0 +1,10 @@ +fileFormatVersion: 2 +guid: 55aee008fb6a3411aa96f2f9911f9207 +folderAsset: yes +timeCreated: 1521595360 +licenseType: Free +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle new file mode 100755 index 0000000000..440d2b9e33 Binary files /dev/null and b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle differ diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle.meta b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle.meta new file mode 100644 index 0000000000..2a1f0df2f5 --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle.meta @@ -0,0 +1,137 @@ +fileFormatVersion: 2 +guid: 7eeb863bd08ba4388829c23da03a714f +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + defineConstraints: [] + isPreloaded: 0 + isOverridable: 0 + isExplicitlyReferenced: 0 + validateReferences: 1 + platformData: + - first: + : Any + second: + enabled: 0 + settings: + Exclude Android: 1 + Exclude CloudRendering: 1 + Exclude Editor: 0 + Exclude Linux: 1 + Exclude Linux64: 1 + Exclude LinuxUniversal: 1 + Exclude OSXIntel: 0 + Exclude OSXIntel64: 0 + Exclude OSXUniversal: 0 + Exclude WebGL: 1 + Exclude Win: 1 + Exclude Win64: 1 + Exclude iOS: 1 + - first: + : OSXIntel + second: + enabled: 1 + settings: {} + - first: + : OSXIntel64 + second: + enabled: 1 + settings: {} + - first: + Android: Android + second: + enabled: 0 + settings: + CPU: ARMv7 + - first: + Any: + second: + enabled: 0 + settings: {} + - first: + CloudRendering: CloudRendering + second: + enabled: 0 + settings: {} + - first: + Editor: Editor + second: + enabled: 1 + settings: + CPU: x86_64 + DefaultValueInitialized: true + OS: OSX + - first: + Facebook: Win + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Facebook: Win64 + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Standalone: Linux + second: + enabled: 0 + settings: + CPU: x86 + - first: + Standalone: Linux64 + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Standalone: LinuxUniversal + second: + enabled: 0 + settings: + CPU: None + - first: + Standalone: OSXIntel + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: OSXIntel64 + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: OSXUniversal + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: Win + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Standalone: Win64 + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + iPhone: iOS + second: + enabled: 0 + settings: + AddToEmbeddedBinaries: false + CPU: AnyCPU + CompileFlags: + FrameworkDependencies: + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win.meta b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win.meta new file mode 100644 index 0000000000..b1e54c9a48 --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win.meta @@ -0,0 +1,10 @@ +fileFormatVersion: 2 +guid: a961485c3484a4002ac4961a8481f6cc +folderAsset: yes +timeCreated: 1521595360 +licenseType: Free +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win/native.meta b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win/native.meta new file mode 100644 index 0000000000..42e4968ae5 --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win/native.meta @@ -0,0 +1,10 @@ +fileFormatVersion: 2 +guid: af9f9f367bbc543b8ba41e58dcdd6e66 +folderAsset: yes +timeCreated: 1521595360 +licenseType: Free +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll new file mode 100755 index 0000000000..b2e48711b8 Binary files /dev/null and b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll differ diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll.meta b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll.meta new file mode 100644 index 0000000000..56500dc9b6 --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll.meta @@ -0,0 +1,105 @@ +fileFormatVersion: 2 +guid: f4d9429fe43154fbd9d158c129e0ff33 +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + defineConstraints: [] + isPreloaded: 0 + isOverridable: 0 + isExplicitlyReferenced: 0 + validateReferences: 1 + platformData: + - first: + : Any + second: + enabled: 0 + settings: + Exclude Android: 1 + Exclude Editor: 0 + Exclude Linux: 0 + Exclude Linux64: 0 + Exclude LinuxUniversal: 0 + Exclude OSXUniversal: 0 + Exclude Win: 1 + Exclude Win64: 0 + Exclude iOS: 1 + - first: + Android: Android + second: + enabled: 0 + settings: + CPU: ARMv7 + - first: + Any: + second: + enabled: 0 + settings: {} + - first: + Editor: Editor + second: + enabled: 1 + settings: + CPU: x86_64 + DefaultValueInitialized: true + OS: Windows + - first: + Facebook: Win + second: + enabled: 0 + settings: + CPU: None + - first: + Facebook: Win64 + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Standalone: Linux + second: + enabled: 1 + settings: + CPU: x86 + - first: + Standalone: Linux64 + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: LinuxUniversal + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: OSXUniversal + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: Win + second: + enabled: 0 + settings: + CPU: None + - first: + Standalone: Win64 + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + iPhone: iOS + second: + enabled: 0 + settings: + AddToEmbeddedBinaries: false + CompileFlags: + FrameworkDependencies: + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll new file mode 100755 index 0000000000..45d5c324a3 Binary files /dev/null and b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll differ diff --git a/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll.meta b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll.meta new file mode 100644 index 0000000000..77354acf46 --- /dev/null +++ b/com.unity.ml-agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll.meta @@ -0,0 +1,105 @@ +fileFormatVersion: 2 +guid: d74134114def74fb4ae781c015deaa95 +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + defineConstraints: [] + isPreloaded: 0 + isOverridable: 0 + isExplicitlyReferenced: 0 + validateReferences: 1 + platformData: + - first: + : Any + second: + enabled: 0 + settings: + Exclude Android: 1 + Exclude Editor: 0 + Exclude Linux: 0 + Exclude Linux64: 0 + Exclude LinuxUniversal: 0 + Exclude OSXUniversal: 0 + Exclude Win: 0 + Exclude Win64: 1 + Exclude iOS: 1 + - first: + Android: Android + second: + enabled: 0 + settings: + CPU: ARMv7 + - first: + Any: + second: + enabled: 0 + settings: {} + - first: + Editor: Editor + second: + enabled: 1 + settings: + CPU: x86 + DefaultValueInitialized: true + OS: Windows + - first: + Facebook: Win + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Facebook: Win64 + second: + enabled: 0 + settings: + CPU: None + - first: + Standalone: Linux + second: + enabled: 1 + settings: + CPU: x86 + - first: + Standalone: Linux64 + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: LinuxUniversal + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: OSXUniversal + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: Win + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: Win64 + second: + enabled: 0 + settings: + CPU: None + - first: + iPhone: iOS + second: + enabled: 0 + settings: + AddToEmbeddedBinaries: false + CompileFlags: + FrameworkDependencies: + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/System.IO.Abstractions.TestingHelpers.dll b/com.unity.ml-agents/Plugins/System.IO.Abstractions.TestingHelpers.dll new file mode 100755 index 0000000000..0d2b68f2e8 Binary files /dev/null and b/com.unity.ml-agents/Plugins/System.IO.Abstractions.TestingHelpers.dll differ diff --git a/com.unity.ml-agents/Plugins/System.IO.Abstractions.TestingHelpers.dll.meta b/com.unity.ml-agents/Plugins/System.IO.Abstractions.TestingHelpers.dll.meta new file mode 100644 index 0000000000..c6d910d8ee --- /dev/null +++ b/com.unity.ml-agents/Plugins/System.IO.Abstractions.TestingHelpers.dll.meta @@ -0,0 +1,33 @@ +fileFormatVersion: 2 +guid: 2d7ba4e1037b64de5b860bcbe15755b3 +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + defineConstraints: [] + isPreloaded: 0 + isOverridable: 0 + isExplicitlyReferenced: 0 + validateReferences: 1 + platformData: + - first: + Any: + second: + enabled: 1 + settings: {} + - first: + Editor: Editor + second: + enabled: 0 + settings: + DefaultValueInitialized: true + - first: + Windows Store Apps: WindowsStoreApps + second: + enabled: 0 + settings: + CPU: AnyCPU + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Plugins/System.IO.Abstractions.dll b/com.unity.ml-agents/Plugins/System.IO.Abstractions.dll new file mode 100755 index 0000000000..4fe6ccbf43 Binary files /dev/null and b/com.unity.ml-agents/Plugins/System.IO.Abstractions.dll differ diff --git a/com.unity.ml-agents/Plugins/System.IO.Abstractions.dll.meta b/com.unity.ml-agents/Plugins/System.IO.Abstractions.dll.meta new file mode 100644 index 0000000000..5432c24e8a --- /dev/null +++ b/com.unity.ml-agents/Plugins/System.IO.Abstractions.dll.meta @@ -0,0 +1,33 @@ +fileFormatVersion: 2 +guid: b01205587773841ad95e8ceda347e8bd +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + defineConstraints: [] + isPreloaded: 0 + isOverridable: 0 + isExplicitlyReferenced: 0 + validateReferences: 1 + platformData: + - first: + Any: + second: + enabled: 1 + settings: {} + - first: + Editor: Editor + second: + enabled: 0 + settings: + DefaultValueInitialized: true + - first: + Windows Store Apps: WindowsStoreApps + second: + enabled: 0 + settings: + CPU: AnyCPU + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/README.md b/com.unity.ml-agents/README.md new file mode 100644 index 0000000000..ae5c0e2c8e --- /dev/null +++ b/com.unity.ml-agents/README.md @@ -0,0 +1,15 @@ +# com.unity.ml-agents + +ML-Agents is a Unity package that allows users to use state-of-the-art machine learning to create intelligent character behaviors in any Unity environment (games, robotics, film, etc.). + +## Installation + +Please refer to the [ML-Agents github repo] for installation instructions. + +## Usage + +Please refer to the [ML-Agents documentation] page for usage guides. + + +[ML-Agents github repo]: https://github.com/Unity-Technologies/ml-agents +[ML-Agents documentation]: https://unity-technologies.github.io/ml-agents/ \ No newline at end of file diff --git a/com.unity.ml-agents/README.md.meta b/com.unity.ml-agents/README.md.meta new file mode 100644 index 0000000000..bbb2279ba2 --- /dev/null +++ b/com.unity.ml-agents/README.md.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: 940521c5d10354cde82c2d572d170c97 +TextScriptImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime.meta b/com.unity.ml-agents/Runtime.meta new file mode 100644 index 0000000000..b5ab5034ab --- /dev/null +++ b/com.unity.ml-agents/Runtime.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: a3a287cfa95bf4bdcad4997f7d48153b +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Academy.cs b/com.unity.ml-agents/Runtime/Academy.cs new file mode 100644 index 0000000000..515ccae2bf --- /dev/null +++ b/com.unity.ml-agents/Runtime/Academy.cs @@ -0,0 +1,695 @@ +using System; +using UnityEngine; +using System.Collections.Generic; +#if UNITY_EDITOR +using UnityEditor; +#endif +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Inference; +using Unity.MLAgents.Policies; +using Unity.MLAgents.SideChannels; +using Unity.Barracuda; + +/** + * Welcome to Unity Machine Learning Agents (ML-Agents). + * + * The ML-Agents toolkit contains four entities: Academy, Agent, Communicator and + * Python API. The academy and connected agents live within + * a learning environment (herein called Environment), while the communicator + * manages the communication between the learning environment and the Python + * API. For more information on each of these entities, in addition to how to + * set-up a learning environment and train the behavior of characters in a + * Unity scene, please browse our documentation pages on GitHub: + * https://github.com/Unity-Technologies/ml-agents/tree/release_19_docs/docs/ + */ + +namespace Unity.MLAgents +{ + /// + /// Helper class to step the Academy during FixedUpdate phase. + /// + internal class AcademyFixedUpdateStepper : MonoBehaviour + { + void FixedUpdate() + { + // Check if the stepper belongs to the current Academy and destroy it if it's not. + // This is to prevent from having leaked stepper from previous runs. + if (!Academy.IsInitialized || !Academy.Instance.IsStepperOwner(this)) + { + Destroy(this.gameObject); + } + else + { + Academy.Instance.EnvironmentStep(); + } + } + } + + /// + /// The Academy singleton manages agent training and decision making. + /// + /// + /// Access the Academy singleton through the + /// property. The Academy instance is initialized the first time it is accessed (which will + /// typically be by the first initialized in a scene). + /// + /// At initialization, the Academy attempts to connect to the Python training process through + /// the external communicator. If successful, the training process can train + /// instances. When you set an agent's setting + /// to , the agent exchanges data with the training process + /// to make decisions. If no training process is available, agents with the default behavior + /// fall back to inference or heuristic decisions. (You can also set agents to always use + /// inference or heuristics.) + /// + [HelpURL("https://github.com/Unity-Technologies/ml-agents/tree/release_19_docs/" + + "docs/Learning-Environment-Design.md")] + public class Academy : IDisposable + { + /// + /// Communication protocol version. + /// When connecting to python, this must be compatible with UnityEnvironment.API_VERSION. + /// We follow semantic versioning on the communication version, so existing + /// functionality will work as long the major versions match. + /// This should be changed whenever a change is made to the communication protocol. + /// + /// + /// History: + /// + /// + /// 1.0.0 + /// Initial version + /// + /// + /// 1.1.0 + /// Support concatenated PNGs for compressed observations. + /// + /// + /// 1.2.0 + /// Support compression mapping for stacked compressed observations. + /// + /// + /// 1.3.0 + /// Support both continuous and discrete actions. + /// + /// + /// 1.4.0 + /// Support training analytics sent from python trainer to the editor. + /// + /// + /// 1.5.0 + /// Support variable length observation training and multi-agent groups. + /// + /// + /// + const string k_ApiVersion = "1.5.0"; + + /// + /// Unity package version of com.unity.ml-agents. + /// This must match the version string in package.json and is checked in a unit test. + /// + internal const string k_PackageVersion = "2.3.0-exp.3"; + + const int k_EditorTrainingPort = 5004; + + const string k_PortCommandLineFlag = "--mlagents-port"; + + // Lazy initializer pattern, see https://csharpindepth.com/articles/singleton#lazy + static Lazy s_Lazy = new Lazy(() => new Academy()); + + /// + ///Reports whether the Academy has been initialized yet. + /// + /// True if the Academy is initialized, false otherwise. + public static bool IsInitialized + { + get { return s_Lazy.IsValueCreated; } + } + + /// + /// The singleton Academy object. + /// + /// Getting the instance initializes the Academy, if necessary. + public static Academy Instance { get { return s_Lazy.Value; } } + + // Fields not provided in the Inspector. + + /// + /// Reports whether or not the communicator is on. + /// + /// + /// + /// True, if communicator is on, false otherwise. + /// + public bool IsCommunicatorOn + { + get { return Communicator != null; } + } + + /// The number of episodes completed by the environment. Incremented + /// each time the environment is reset. + int m_EpisodeCount; + + /// The number of steps completed within the current episode. Incremented + /// each time a step is taken in the environment. Is reset to 0 during + /// . + int m_StepCount; + + /// The number of total number of steps completed during the whole simulation. Incremented + /// each time a step is taken in the environment. + int m_TotalStepCount; + + /// Pointer to the communicator currently in use by the Academy. + internal ICommunicator Communicator; + + bool m_Initialized; + List m_ModelRunners = new List(); + + // Flag used to keep track of the first time the Academy is reset. + bool m_HadFirstReset; + + // Detect an Academy step called by user code that is also called by the Academy. + private RecursionChecker m_StepRecursionChecker = new RecursionChecker("EnvironmentStep"); + + // Random seed used for inference. + int m_InferenceSeed; + + /// + /// Set the random seed used for inference. This should be set before any Agents are added + /// to the scene. The seed is passed to the ModelRunner constructor, and incremented each + /// time a new ModelRunner is created. + /// + public int InferenceSeed + { + set { m_InferenceSeed = value; } + } + + int m_NumAreas; + + /// + /// Number of training areas to instantiate. + /// + public int NumAreas => m_NumAreas; + + /// + /// Returns the RLCapabilities of the python client that the unity process is connected to. + /// + internal UnityRLCapabilities TrainerCapabilities { get; set; } + + + // The Academy uses a series of events to communicate with agents + // to facilitate synchronization. More specifically, it ensures + // that all the agents perform their steps in a consistent order (i.e. no + // agent can act based on a decision before another agent has had a chance + // to request a decision). + + // Signals to all the Agents at each environment step so they can use + // their Policy to decide on their next action. + internal event Action DecideAction; + + // Signals to all the listeners that the academy is being destroyed + internal event Action DestroyAction; + + // Signals to the Agent that a new step is about to start. + // This will mark the Agent as Done if it has reached its maxSteps. + internal event Action AgentIncrementStep; + + + /// + /// Signals to all of the s that their step is about to begin. + /// This is a good time for an to decide if it would like to + /// call or + /// for this step. Any other pre-step setup could be done during this event as well. + /// + public event Action AgentPreStep; + + // Signals to all the agents at each environment step so they can send + // their state to their Policy if they have requested a decision. + internal event Action AgentSendState; + + // Signals to all the agents at each environment step so they can act if + // they have requested a decision. + internal event Action AgentAct; + + // Signals to all the agents each time the Academy force resets. + internal event Action AgentForceReset; + + /// + /// Signals that the Academy has been reset by the training process. + /// + public event Action OnEnvironmentReset; + + AcademyFixedUpdateStepper m_FixedUpdateStepper; + GameObject m_StepperObject; + + + /// + /// Private constructor called the first time the Academy is used. + /// Academy uses this time to initialize internal data + /// structures, initialize the environment and check for the existence + /// of a communicator. + /// + protected Academy() + { + Application.quitting += Dispose; +#if UNITY_EDITOR || UNITY_STANDALONE + if (!CommunicatorFactory.CommunicatorRegistered) + { + Debug.Log("Registered Communicator in Academy."); + CommunicatorFactory.Register(RpcCommunicator.Create); + } +#endif + LazyInitialize(); + +#if UNITY_EDITOR + EditorApplication.playModeStateChanged += HandleOnPlayModeChanged; +#endif + } + +#if UNITY_EDITOR + /// + /// Clean up the Academy when switching from edit mode to play mode + /// + /// State. + void HandleOnPlayModeChanged(PlayModeStateChange state) + { + if (state == PlayModeStateChange.ExitingEditMode) + { + Dispose(); + } + } + +#endif + + /// + /// Initialize the Academy if it hasn't already been initialized. + /// This method is always safe to call; it will have no effect if the Academy is already + /// initialized. + /// + internal void LazyInitialize() + { + if (!m_Initialized) + { + InitializeEnvironment(); + m_Initialized = true; + } + } + + /// + /// Enable stepping of the Academy during the FixedUpdate phase. This is done by creating + /// a temporary GameObject with a MonoBehaviour that calls Academy.EnvironmentStep(). + /// + void EnableAutomaticStepping() + { + if (m_FixedUpdateStepper != null) + { + return; + } + + m_StepperObject = new GameObject("AcademyFixedUpdateStepper"); + // Don't show this object in the hierarchy + m_StepperObject.hideFlags = HideFlags.HideInHierarchy; + m_FixedUpdateStepper = m_StepperObject.AddComponent(); + try + { + // This try-catch is because DontDestroyOnLoad cannot be used in Editor Tests + GameObject.DontDestroyOnLoad(m_StepperObject); + } + catch { } + } + + /// + /// Disable stepping of the Academy during the FixedUpdate phase. If this is called, the Academy must be + /// stepped manually by the user by calling Academy.EnvironmentStep(). + /// + void DisableAutomaticStepping() + { + if (m_FixedUpdateStepper == null) + { + return; + } + + m_FixedUpdateStepper = null; + if (Application.isEditor) + { + UnityEngine.Object.DestroyImmediate(m_StepperObject); + } + else + { + UnityEngine.Object.Destroy(m_StepperObject); + } + + m_StepperObject = null; + } + + /// + /// Determines whether or not the Academy is automatically stepped during the FixedUpdate phase. + /// + /// Set true to enable automatic stepping; false to disable. + public bool AutomaticSteppingEnabled + { + get { return m_FixedUpdateStepper != null; } + set + { + if (value) + { + EnableAutomaticStepping(); + } + else + { + DisableAutomaticStepping(); + } + } + } + + // Used to read Python-provided environment parameters + static int ReadPortFromArgs() + { + var args = Environment.GetCommandLineArgs(); + var inputPort = ""; + for (var i = 0; i < args.Length; i++) + { + if (args[i] == k_PortCommandLineFlag) + { + inputPort = args[i + 1]; + } + } + + try + { + return int.Parse(inputPort); + } + catch + { + // No arg passed, or malformed port number. +#if UNITY_EDITOR + // Try connecting on the default editor port + return MLAgentsSettingsManager.Settings.ConnectTrainer ? MLAgentsSettingsManager.Settings.EditorPort : -1; +#else + // This is an executable, so we don't try to connect. + return -1; +#endif + } + } + + EnvironmentParameters m_EnvironmentParameters; + StatsRecorder m_StatsRecorder; + + /// + /// Returns the instance. If training + /// features such as Curriculum Learning or Environment Parameter Randomization are used, + /// then the values of the parameters generated from the training process can be + /// retrieved here. + /// + /// + public EnvironmentParameters EnvironmentParameters + { + get { return m_EnvironmentParameters; } + } + + /// + /// Returns the instance. This instance can be used + /// to record any statistics from the Unity environment. + /// + /// + public StatsRecorder StatsRecorder + { + get { return m_StatsRecorder; } + } + + /// + /// Initializes the environment, configures it and initializes the Academy. + /// + void InitializeEnvironment() + { + TimerStack.Instance.AddMetadata("communication_protocol_version", k_ApiVersion); + TimerStack.Instance.AddMetadata("com.unity.ml-agents_version", k_PackageVersion); + + EnableAutomaticStepping(); + + SideChannelManager.RegisterSideChannel(new EngineConfigurationChannel()); + SideChannelManager.RegisterSideChannel(new TrainingAnalyticsSideChannel()); + m_EnvironmentParameters = new EnvironmentParameters(); + m_StatsRecorder = new StatsRecorder(); + + // Try to launch the communicator by using the arguments passed at launch + var port = ReadPortFromArgs(); + if (port > 0) + { + Communicator = CommunicatorFactory.Create(); + } + + if (Communicator == null && CommunicatorFactory.Enabled && port > 0) + { + Debug.Log("Communicator failed to start!"); + } + + if (Communicator != null) + { + // We try to exchange the first message with Python. If this fails, it means + // no Python Process is ready to train the environment. In this case, the + // environment must use Inference. + bool initSuccessful = false; + var communicatorInitParams = new CommunicatorInitParameters + { + port = port, + unityCommunicationVersion = k_ApiVersion, + unityPackageVersion = k_PackageVersion, + name = "AcademySingleton", + CSharpCapabilities = new UnityRLCapabilities() + }; + + try + { + initSuccessful = Communicator.Initialize( + communicatorInitParams, + out var unityRlInitParameters + ); + if (initSuccessful) + { + UnityEngine.Random.InitState(unityRlInitParameters.seed); + // We might have inference-only Agents, so set the seed for them too. + m_InferenceSeed = unityRlInitParameters.seed; + m_NumAreas = unityRlInitParameters.numAreas; + TrainerCapabilities = unityRlInitParameters.TrainerCapabilities; + TrainerCapabilities.WarnOnPythonMissingBaseRLCapabilities(); + } + else + { + Debug.Log($"Couldn't connect to trainer on port {port} using API version {k_ApiVersion}. Will perform inference instead."); + Communicator = null; + } + } + catch (Exception ex) + { + Debug.Log($"Unexpected exception when trying to initialize communication: {ex}\nWill perform inference instead."); + Communicator = null; + } + } + + if (Communicator != null) + { + Communicator.QuitCommandReceived += OnQuitCommandReceived; + Communicator.ResetCommandReceived += OnResetCommand; + } + + // If a communicator is enabled/provided, then we assume we are in + // training mode. In the absence of a communicator, we assume we are + // in inference mode. + + ResetActions(); + } + + void ResetActions() + { + DecideAction = () => { }; + DestroyAction = () => { }; + AgentPreStep = i => { }; + AgentSendState = () => { }; + AgentAct = () => { }; + AgentForceReset = () => { }; + OnEnvironmentReset = () => { }; + } + + static void OnQuitCommandReceived() + { +#if UNITY_EDITOR + EditorApplication.isPlaying = false; +#endif + Application.Quit(); + } + + void OnResetCommand() + { + ForcedFullReset(); + } + + /// + /// The current episode count. + /// + /// + /// Current episode number. + /// + public int EpisodeCount + { + get { return m_EpisodeCount; } + } + + /// + /// The current step count (within the current episode). + /// + /// + /// Current step count. + /// + public int StepCount + { + get { return m_StepCount; } + } + + /// + /// Returns the total step count. + /// + /// + /// Total step count. + /// + public int TotalStepCount + { + get { return m_TotalStepCount; } + } + + /// + /// Forces the full reset. The done flags are not affected. Is either + /// called the first reset at inference and every external reset + /// at training. + /// + void ForcedFullReset() + { + EnvironmentReset(); + AgentForceReset?.Invoke(); + m_HadFirstReset = true; + } + + /// + /// Performs a single environment update of the Academy and Agent + /// objects within the environment. + /// + public void EnvironmentStep() + { + using (m_StepRecursionChecker.Start()) + { + if (!m_HadFirstReset) + { + ForcedFullReset(); + } + + AgentPreStep?.Invoke(m_StepCount); + + m_StepCount += 1; + m_TotalStepCount += 1; + AgentIncrementStep?.Invoke(); + + using (TimerStack.Instance.Scoped("AgentSendState")) + { + AgentSendState?.Invoke(); + } + + using (TimerStack.Instance.Scoped("DecideAction")) + { + DecideAction?.Invoke(); + } + + // If the communicator is not on, we need to clear the SideChannel sending queue + if (!IsCommunicatorOn) + { + SideChannelManager.GetSideChannelMessage(); + } + + using (TimerStack.Instance.Scoped("AgentAct")) + { + AgentAct?.Invoke(); + } + } + } + + /// + /// Resets the environment, including the Academy. + /// + void EnvironmentReset() + { + m_StepCount = 0; + m_EpisodeCount++; + OnEnvironmentReset?.Invoke(); + } + + /// + /// Creates or retrieves an existing ModelRunner that uses the same + /// NNModel and the InferenceDevice as provided. + /// + /// The NNModel the ModelRunner must use. + /// Description of the actions for the Agent. + /// + /// The inference device (CPU or GPU) the ModelRunner will use. + /// + /// Inference only: set to true if the action selection from model should be + /// Deterministic. + /// The ModelRunner compatible with the input settings. + internal ModelRunner GetOrCreateModelRunner( + NNModel model, ActionSpec actionSpec, InferenceDevice inferenceDevice, bool deterministicInference = false) + { + var modelRunner = m_ModelRunners.Find(x => x.HasModel(model, inferenceDevice)); + if (modelRunner == null) + { + modelRunner = new ModelRunner(model, actionSpec, inferenceDevice, m_InferenceSeed, deterministicInference); + m_ModelRunners.Add(modelRunner); + m_InferenceSeed++; + } + return modelRunner; + } + + /// + /// Shut down the Academy. + /// + public void Dispose() + { + DisableAutomaticStepping(); + + // Signal to listeners that the academy is being destroyed now + DestroyAction?.Invoke(); + + Communicator?.Dispose(); + Communicator = null; + + m_EnvironmentParameters.Dispose(); + m_StatsRecorder.Dispose(); + SideChannelManager.UnregisterAllSideChannels(); // unregister custom side channels + + if (m_ModelRunners != null) + { + foreach (var mr in m_ModelRunners) + { + mr.Dispose(); + } + + m_ModelRunners = null; + } + + // Clear out the actions so we're not keeping references to any old objects + ResetActions(); + + // TODO - Pass worker ID or some other identifier, + // so that multiple envs won't overwrite each others stats. + TimerStack.Instance.SaveJsonTimers(); + m_Initialized = false; + + // Reset the Lazy instance + s_Lazy = new Lazy(() => new Academy()); + } + + /// + /// Check if the input AcademyFixedUpdateStepper belongs to this Academy. + /// + internal bool IsStepperOwner(AcademyFixedUpdateStepper stepper) + { + return GameObject.ReferenceEquals(stepper.gameObject, Academy.Instance.m_StepperObject); + } + } +} diff --git a/com.unity.ml-agents/Runtime/Academy.cs.meta b/com.unity.ml-agents/Runtime/Academy.cs.meta new file mode 100755 index 0000000000..b0a5b6ffc4 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Academy.cs.meta @@ -0,0 +1,12 @@ +fileFormatVersion: 2 +guid: b1fc0029fee784d9cb9854f8912bfd07 +timeCreated: 1503613254 +licenseType: Free +MonoImporter: + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Actuators.meta b/com.unity.ml-agents/Runtime/Actuators.meta new file mode 100644 index 0000000000..96bbfb99b3 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 26733e59183b6479e8f0e892a8bf09a4 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs b/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs new file mode 100644 index 0000000000..b3026eac39 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs @@ -0,0 +1,234 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; + +namespace Unity.MLAgents.Actuators +{ + /// + /// ActionSegment{T} is a data structure that allows access to a segment of an underlying array + /// in order to avoid the copying and allocation of sub-arrays. The segment is defined by + /// the offset into the original array, and an length. + /// + /// The type of object stored in the underlying + public readonly struct ActionSegment : IEnumerable, IEquatable> + where T : struct + { + /// + /// The zero-based offset into the original array at which this segment starts. + /// + public readonly int Offset; + + /// + /// The number of items this segment can access in the underlying array. + /// + public readonly int Length; + + /// + /// An Empty segment which has an offset of 0, a Length of 0, and it's underlying array + /// is also empty. + /// + public static ActionSegment Empty = new ActionSegment(System.Array.Empty(), 0, 0); + + static void CheckParameters(IReadOnlyCollection actionArray, int offset, int length) + { +#if DEBUG + if (offset + length > actionArray.Count) + { + throw new ArgumentOutOfRangeException(nameof(offset), + $"Arguments offset: {offset} and length: {length} " + + $"are out of bounds of actionArray: {actionArray.Count}."); + } +#endif + } + + /// + /// Construct an with just an actionArray. The will + /// be set to 0 and the will be set to `actionArray.Length`. + /// + /// The action array to use for the this segment. + public ActionSegment(T[] actionArray) + : this(actionArray ?? System.Array.Empty(), 0, actionArray?.Length ?? 0) { } + + /// + /// Construct an with an underlying array + /// and offset, and a length. + /// + /// The underlying array which this segment has a view into + /// The zero-based offset into the underlying array. + /// The length of the segment. + public ActionSegment(T[] actionArray, int offset, int length) + { +#if DEBUG + CheckParameters(actionArray ?? System.Array.Empty(), offset, length); +#endif + Array = actionArray ?? System.Array.Empty(); + Offset = offset; + Length = length; + } + + /// + /// Get the underlying of this segment. + /// + public T[] Array { get; } + + /// + /// Allows access to the underlying array using array syntax. + /// + /// The zero-based index of the segment. + /// Thrown when the index is less than 0 or + /// greater than or equal to + public T this[int index] + { + get + { + if (index < 0 || index > Length) + { + throw new IndexOutOfRangeException($"Index out of bounds, expected a number between 0 and {Length}"); + } + return Array[Offset + index]; + } + set + { + if (index < 0 || index > Length) + { + throw new IndexOutOfRangeException($"Index out of bounds, expected a number between 0 and {Length}"); + } + Array[Offset + index] = value; + } + } + + /// + /// Sets the segment of the backing array to all zeros. + /// + public void Clear() + { + System.Array.Clear(Array, Offset, Length); + } + + /// + /// Check if the segment is empty. + /// + /// Whether or not the segment is empty. + public bool IsEmpty() + { + return Array == null || Array.Length == 0; + } + + /// + /// Returns an enumerator that iterates through the ActionSegment. + /// + /// An IEnumerator object that can be used to iterate through the ActionSegment. + IEnumerator IEnumerable.GetEnumerator() + { + return new Enumerator(this); + } + + /// + /// Returns an enumerator that iterates through the ActionSegment. + /// + /// An IEnumerator object that can be used to iterate through the ActionSegment. + public IEnumerator GetEnumerator() + { + return new Enumerator(this); + } + + /// + /// Indicates whether the current ActionSegment is equal to another ActionSegment. + /// + /// An ActionSegment to compare with this ActionSegment. + /// true if the current ActionSegment is equal to the other parameter; otherwise, false. + public override bool Equals(object obj) + { + if (!(obj is ActionSegment)) + { + return false; + } + return Equals((ActionSegment)obj); + } + + /// + /// Indicates whether the current ActionSegment is equal to another ActionSegment. + /// + /// An ActionSegment to compare with this ActionSegment. + /// true if the current ActionSegment is equal to the other parameter; otherwise, false. + public bool Equals(ActionSegment other) + { + return Offset == other.Offset && Length == other.Length && Array.SequenceEqual(other.Array); + } + + /// + /// Computes the hash code of the ActionSegment. + /// + /// A hash code for the current ActionSegment. + public override int GetHashCode() + { + unchecked + { + var hashCode = Offset; + hashCode = (hashCode * 397) ^ Length; + hashCode = (hashCode * 397) ^ (Array != null ? Array.GetHashCode() : 0); + return hashCode; + } + } + + /// + /// A private for the value type which follows its + /// rules of being a view into an underlying . + /// + struct Enumerator : IEnumerator + { + readonly T[] m_Array; + readonly int m_Start; + readonly int m_End; // cache Offset + Count, since it's a little slow + int m_Current; + + internal Enumerator(ActionSegment arraySegment) + { + Debug.Assert(arraySegment.Array != null); + Debug.Assert(arraySegment.Offset >= 0); + Debug.Assert(arraySegment.Length >= 0); + Debug.Assert(arraySegment.Offset + arraySegment.Length <= arraySegment.Array.Length); + + m_Array = arraySegment.Array; + m_Start = arraySegment.Offset; + m_End = arraySegment.Offset + arraySegment.Length; + m_Current = arraySegment.Offset - 1; + } + + public bool MoveNext() + { + if (m_Current < m_End) + { + m_Current++; + return m_Current < m_End; + } + return false; + } + + public T Current + { + get + { + if (m_Current < m_Start) + throw new InvalidOperationException("Enumerator not started."); + if (m_Current >= m_End) + throw new InvalidOperationException("Enumerator has reached the end already."); + return m_Array[m_Current]; + } + } + + object IEnumerator.Current => Current; + + void IEnumerator.Reset() + { + m_Current = m_Start - 1; + } + + public void Dispose() + { + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs.meta b/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs.meta new file mode 100644 index 0000000000..8e08ed0a4a --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 4fa1432c1ba3460caaa84303a9011ef2 +timeCreated: 1595869823 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs b/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs new file mode 100644 index 0000000000..6b0a001a7d --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs @@ -0,0 +1,137 @@ +using System; +using System.Linq; +using UnityEngine; + +namespace Unity.MLAgents.Actuators +{ + /// + /// Defines the structure of the actions to be used by the Actuator system. + /// + [Serializable] + public struct ActionSpec + { + [SerializeField] + int m_NumContinuousActions; + + /// + /// An array of branch sizes for discrete actions. + /// + /// For an IActuator that uses discrete actions, the number of + /// branches is the Length of the Array and each index contains the branch size. + /// The cumulative sum of the total number of discrete actions can be retrieved + /// by the property. + /// + /// For an IActuator with a Continuous it will be null. + /// + public int[] BranchSizes; + + /// + /// The number of continuous actions that an Agent can take. + /// + public int NumContinuousActions { get { return m_NumContinuousActions; } set { m_NumContinuousActions = value; } } + + /// + /// The number of branches for discrete actions that an Agent can take. + /// + public int NumDiscreteActions { get { return BranchSizes == null ? 0 : BranchSizes.Length; } } + + /// + /// Get the total number of Discrete Actions that can be taken by calculating the Sum + /// of all of the Discrete Action branch sizes. + /// + public int SumOfDiscreteBranchSizes { get { return BranchSizes == null ? 0 : BranchSizes.Sum(); } } + + /// + /// Creates a Continuous with the number of actions available. + /// + /// The number of continuous actions available. + /// An Continuous ActionSpec initialized with the number of actions available. + public static ActionSpec MakeContinuous(int numActions) + { + var actuatorSpace = new ActionSpec(numActions, null); + return actuatorSpace; + } + + /// + /// Creates a Discrete with the array of branch sizes that + /// represents the action space. + /// + /// The array of branch sizes for the discrete actions. Each index + /// contains the number of actions available for that branch. + /// An Discrete ActionSpec initialized with the array of branch sizes. + public static ActionSpec MakeDiscrete(params int[] branchSizes) + { + var actuatorSpace = new ActionSpec(0, branchSizes); + return actuatorSpace; + } + + /// + /// Create an ActionSpec initialized with the specified action sizes. + /// + /// The number of continuous actions available. + /// The array of branch sizes for the discrete actions. Each index + /// contains the number of actions available for that branch. + public ActionSpec(int numContinuousActions = 0, int[] discreteBranchSizes = null) + { + m_NumContinuousActions = numContinuousActions; + BranchSizes = discreteBranchSizes ?? Array.Empty(); + } + + /// + /// Check that the ActionSpec uses either all continuous or all discrete actions. + /// This is only used when connecting to old versions of the trainer that don't support this. + /// + /// + internal void CheckAllContinuousOrDiscrete() + { + if (NumContinuousActions > 0 && NumDiscreteActions > 0) + { + throw new UnityAgentsException( + "Action spaces with both continuous and discrete actions are not supported by the trainer. " + + "ActionSpecs must be all continuous or all discrete." + ); + } + } + + /// + /// Combines a list of actions specs and allocates a new array of branch sizes if needed. + /// + /// The list of action specs to combine. + /// An ActionSpec which represents the aggregate of the ActionSpecs passed in. + public static ActionSpec Combine(params ActionSpec[] specs) + { + var numContinuous = 0; + var numDiscrete = 0; + for (var i = 0; i < specs.Length; i++) + { + var spec = specs[i]; + numContinuous += spec.NumContinuousActions; + numDiscrete += spec.NumDiscreteActions; + } + + if (numDiscrete <= 0) + { + return MakeContinuous(numContinuous); + } + + var branchSizes = new int[numDiscrete]; + var offset = 0; + for (var i = 0; i < specs.Length; i++) + { + var spec = specs[i]; + if (spec.BranchSizes.Length == 0) + { + continue; + } + var branchSizesLength = spec.BranchSizes.Length; + Array.Copy(spec.BranchSizes, + 0, + branchSizes, + offset, + branchSizesLength); + offset += branchSizesLength; + } + return new ActionSpec(numContinuous, branchSizes); + } + } +} diff --git a/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs.meta b/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs.meta new file mode 100644 index 0000000000..a442a91a5e --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: ecdd6deefba1416ca149fe09d2a5afd8 +timeCreated: 1595892361 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs b/com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs new file mode 100644 index 0000000000..af34bef3a3 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs @@ -0,0 +1,25 @@ +using UnityEngine; + +namespace Unity.MLAgents.Actuators +{ + /// + /// Editor components for creating Actuators. Generally an IActuator component should + /// have a corresponding ActuatorComponent. + /// + public abstract class ActuatorComponent : MonoBehaviour + { + /// + /// Create a collection of s. This is called by the during + /// initialization. + /// + /// A collection of s + public abstract IActuator[] CreateActuators(); + + /// + /// The specification of the possible actions for this ActuatorComponent. + /// This must produce the same results as the corresponding IActuator's ActionSpec. + /// + /// + public abstract ActionSpec ActionSpec { get; } + } +} diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs.meta b/com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs.meta new file mode 100644 index 0000000000..1b7a643ed1 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 77cefae5f6d841be9ff80b41293d271b +timeCreated: 1593017318 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs b/com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs new file mode 100644 index 0000000000..d44532b16f --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs @@ -0,0 +1,149 @@ +using System; +using System.Collections.Generic; + +namespace Unity.MLAgents.Actuators +{ + /// + /// Implementation of IDiscreteActionMask that allows writing to the action mask from an . + /// + internal class ActuatorDiscreteActionMask : IDiscreteActionMask + { + /// When using discrete control, is the starting indices of the actions + /// when all the branches are concatenated with each other. + int[] m_StartingActionIndices; + + int[] m_BranchSizes; + + bool[] m_CurrentMask; + + IList m_Actuators; + + readonly int m_SumOfDiscreteBranchSizes; + readonly int m_NumBranches; + + /// + /// The offset into the branches array that is used when actuators are writing to the action mask. + /// + public int CurrentBranchOffset { get; set; } + + internal ActuatorDiscreteActionMask(IList actuators, int sumOfDiscreteBranchSizes, int numBranches, int[] branchSizes = null) + { + m_Actuators = actuators; + m_SumOfDiscreteBranchSizes = sumOfDiscreteBranchSizes; + m_NumBranches = numBranches; + m_BranchSizes = branchSizes; + } + + /// + public void SetActionEnabled(int branch, int actionIndex, bool isEnabled) + { + LazyInitialize(); +#if DEBUG + if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch]) + { + throw new UnityAgentsException( + "Invalid Action Masking: Action Mask is too large for specified branch."); + } +#endif + m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = !isEnabled; + } + + void LazyInitialize() + { + if (m_BranchSizes == null) + { + m_BranchSizes = new int[m_NumBranches]; + var start = 0; + for (var i = 0; i < m_Actuators.Count; i++) + { + var actuator = m_Actuators[i]; + var branchSizes = actuator.ActionSpec.BranchSizes; + Array.Copy(branchSizes, 0, m_BranchSizes, start, branchSizes.Length); + start += branchSizes.Length; + } + } + + // By default, the masks are null. If we want to specify a new mask, we initialize + // the actionMasks with trues. + if (m_CurrentMask == null) + { + m_CurrentMask = new bool[m_SumOfDiscreteBranchSizes]; + } + + // If this is the first time the masked actions are used, we generate the starting + // indices for each branch. + if (m_StartingActionIndices == null) + { + m_StartingActionIndices = Utilities.CumSum(m_BranchSizes); + } + } + + /// + /// Get the current mask for an agent. + /// + /// A mask for the agent. A boolean array of length equal to the total number of + /// actions. + internal bool[] GetMask() + { +#if DEBUG + if (m_CurrentMask != null) + { + AssertMask(); + } +#endif + return m_CurrentMask; + } + + /// + /// Makes sure that the current mask is usable. + /// + void AssertMask() + { +#if DEBUG + for (var branchIndex = 0; branchIndex < m_NumBranches; branchIndex++) + { + if (AreAllActionsMasked(branchIndex)) + { + throw new UnityAgentsException( + "Invalid Action Masking : All the actions of branch " + branchIndex + + " are masked."); + } + } +#endif + } + + /// + /// Resets the current mask for an agent. + /// + internal void ResetMask() + { + if (m_CurrentMask != null) + { + Array.Clear(m_CurrentMask, 0, m_CurrentMask.Length); + } + } + + /// + /// Checks if all the actions in the input branch are masked. + /// + /// The index of the branch to check. + /// True if all the actions of the branch are masked. + bool AreAllActionsMasked(int branch) + { + if (m_CurrentMask == null) + { + return false; + } + var start = m_StartingActionIndices[branch]; + var end = m_StartingActionIndices[branch + 1]; + for (var i = start; i < end; i++) + { + if (!m_CurrentMask[i]) + { + return false; + } + } + return true; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs.meta b/com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs.meta new file mode 100644 index 0000000000..09aa4784b0 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: d2a19e2f43fd4637a38d42b2a5f989f3 +timeCreated: 1595459316 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs new file mode 100644 index 0000000000..1ff35557d9 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs @@ -0,0 +1,500 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using UnityEngine; +using UnityEngine.Profiling; + +namespace Unity.MLAgents.Actuators +{ + /// + /// A class that manages the delegation of events, action buffers, and action mask for a list of IActuators. + /// + internal class ActuatorManager : IList + { + // IActuators managed by this object. + List m_Actuators; + + // An implementation of IDiscreteActionMask that allows for writing to it based on an offset. + ActuatorDiscreteActionMask m_DiscreteActionMask; + + ActionSpec m_CombinedActionSpec; + + /// + /// Flag used to check if our IActuators are ready for execution. + /// + /// + bool m_ReadyForExecution; + + /// + /// The sum of all of the discrete branches for all of the s in this manager. + /// + internal int SumOfDiscreteBranchSizes { get; private set; } + + /// + /// The number of the discrete branches for all of the s in this manager. + /// + internal int NumDiscreteActions { get; private set; } + + /// + /// The number of continuous actions for all of the s in this manager. + /// + internal int NumContinuousActions { get; private set; } + + /// + /// Returns the total actions which is calculated by + . + /// + public int TotalNumberOfActions => NumContinuousActions + NumDiscreteActions; + + /// + /// Gets the managed by this object. + /// + public ActuatorDiscreteActionMask DiscreteActionMask => m_DiscreteActionMask; + + /// + /// The currently stored object for the s managed by this class. + /// + public ActionBuffers StoredActions { get; private set; } + + /// + /// Create an ActuatorList with a preset capacity. + /// + /// The capacity of the list to create. + public ActuatorManager(int capacity = 0) + { + m_Actuators = new List(capacity); + } + + /// + /// + /// + void ReadyActuatorsForExecution() + { + ReadyActuatorsForExecution(m_Actuators, NumContinuousActions, SumOfDiscreteBranchSizes, + NumDiscreteActions); + } + + /// + /// This method validates that all s have unique names + /// if the `DEBUG` preprocessor macro is defined, and allocates the appropriate buffers to manage the actions for + /// all of the s that may live on a particular object. + /// + /// The list of actuators to validate and allocate buffers for. + /// The total number of continuous actions for all of the actuators. + /// The total sum of the discrete branches for all of the actuators in order + /// to be able to allocate an . + /// The number of discrete branches for all of the actuators. + internal void ReadyActuatorsForExecution(IList actuators, int numContinuousActions, int sumOfDiscreteBranches, int numDiscreteBranches) + { + if (m_ReadyForExecution) + { + return; + } +#if DEBUG + // Make sure the names are actually unique + ValidateActuators(); +#endif + + // Sort the Actuators by name to ensure determinism + SortActuators(m_Actuators); + var continuousActions = numContinuousActions == 0 ? ActionSegment.Empty : + new ActionSegment(new float[numContinuousActions]); + var discreteActions = numDiscreteBranches == 0 ? ActionSegment.Empty : new ActionSegment(new int[numDiscreteBranches]); + + StoredActions = new ActionBuffers(continuousActions, discreteActions); + m_CombinedActionSpec = CombineActionSpecs(actuators); + m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches, m_CombinedActionSpec.BranchSizes); + m_ReadyForExecution = true; + } + + internal static ActionSpec CombineActionSpecs(IList actuators) + { + int numContinuousActions = 0; + int numDiscreteActions = 0; + + foreach (var actuator in actuators) + { + numContinuousActions += actuator.ActionSpec.NumContinuousActions; + numDiscreteActions += actuator.ActionSpec.NumDiscreteActions; + } + + int[] combinedBranchSizes; + if (numDiscreteActions == 0) + { + combinedBranchSizes = Array.Empty(); + } + else + { + combinedBranchSizes = new int[numDiscreteActions]; + var start = 0; + for (var i = 0; i < actuators.Count; i++) + { + var branchSizes = actuators[i].ActionSpec.BranchSizes; + if (branchSizes != null) + { + Array.Copy(branchSizes, 0, combinedBranchSizes, start, branchSizes.Length); + start += branchSizes.Length; + } + } + } + + return new ActionSpec(numContinuousActions, combinedBranchSizes); + } + + /// + /// Returns an ActionSpec representing the concatenation of all IActuator's ActionSpecs + /// + /// + public ActionSpec GetCombinedActionSpec() + { + ReadyActuatorsForExecution(); + return m_CombinedActionSpec; + } + + /// + /// Updates the local action buffer with the action buffer passed in. If the buffer + /// passed in is null, the local action buffer will be cleared. + /// + /// The object which contains all of the + /// actions for the IActuators in this list. + public void UpdateActions(ActionBuffers actions) + { + Profiler.BeginSample("ActuatorManager.UpdateActions"); + ReadyActuatorsForExecution(); + UpdateActionArray(actions.ContinuousActions, StoredActions.ContinuousActions); + UpdateActionArray(actions.DiscreteActions, StoredActions.DiscreteActions); + Profiler.EndSample(); + } + + static void UpdateActionArray(ActionSegment sourceActionBuffer, ActionSegment destination) + where T : struct + { + if (sourceActionBuffer.Length <= 0) + { + destination.Clear(); + } + else + { + if (sourceActionBuffer.Length != destination.Length) + { + Debug.AssertFormat(sourceActionBuffer.Length == destination.Length, + "sourceActionBuffer: {0} is a different size than destination: {1}.", + sourceActionBuffer.Length, + destination.Length); + } + + Array.Copy(sourceActionBuffer.Array, + sourceActionBuffer.Offset, + destination.Array, + destination.Offset, + destination.Length); + } + } + + /// + /// This method will trigger the writing to the by all of the actuators + /// managed by this object. + /// + public void WriteActionMask() + { + ReadyActuatorsForExecution(); + m_DiscreteActionMask.ResetMask(); + var offset = 0; + for (var i = 0; i < m_Actuators.Count; i++) + { + var actuator = m_Actuators[i]; + if (actuator.ActionSpec.NumDiscreteActions > 0) + { + m_DiscreteActionMask.CurrentBranchOffset = offset; + actuator.WriteDiscreteActionMask(m_DiscreteActionMask); + offset += actuator.ActionSpec.NumDiscreteActions; + } + } + } + + /// + /// Iterates through all of the IActuators in this list and calls their + /// method on them, if implemented, with the appropriate + /// s depending on their . + /// + public void ApplyHeuristic(in ActionBuffers actionBuffersOut) + { + Profiler.BeginSample("ActuatorManager.ApplyHeuristic"); + var continuousStart = 0; + var discreteStart = 0; + for (var i = 0; i < m_Actuators.Count; i++) + { + var actuator = m_Actuators[i]; + var numContinuousActions = actuator.ActionSpec.NumContinuousActions; + var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions; + + if (numContinuousActions == 0 && numDiscreteActions == 0) + { + continue; + } + + var continuousActions = ActionSegment.Empty; + if (numContinuousActions > 0) + { + continuousActions = new ActionSegment(actionBuffersOut.ContinuousActions.Array, + continuousStart, + numContinuousActions); + } + + var discreteActions = ActionSegment.Empty; + if (numDiscreteActions > 0) + { + discreteActions = new ActionSegment(actionBuffersOut.DiscreteActions.Array, + discreteStart, + numDiscreteActions); + } + actuator.Heuristic(new ActionBuffers(continuousActions, discreteActions)); + continuousStart += numContinuousActions; + discreteStart += numDiscreteActions; + } + Profiler.EndSample(); + } + + /// + /// Iterates through all of the IActuators in this list and calls their + /// method on them with the appropriate + /// s depending on their . + /// + public void ExecuteActions() + { + Profiler.BeginSample("ActuatorManager.ExecuteActions"); + ReadyActuatorsForExecution(); + var continuousStart = 0; + var discreteStart = 0; + for (var i = 0; i < m_Actuators.Count; i++) + { + var actuator = m_Actuators[i]; + var numContinuousActions = actuator.ActionSpec.NumContinuousActions; + var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions; + + if (numContinuousActions == 0 && numDiscreteActions == 0) + { + continue; + } + + var continuousActions = ActionSegment.Empty; + if (numContinuousActions > 0) + { + continuousActions = new ActionSegment(StoredActions.ContinuousActions.Array, + continuousStart, + numContinuousActions); + } + + var discreteActions = ActionSegment.Empty; + if (numDiscreteActions > 0) + { + discreteActions = new ActionSegment(StoredActions.DiscreteActions.Array, + discreteStart, + numDiscreteActions); + } + + actuator.OnActionReceived(new ActionBuffers(continuousActions, discreteActions)); + continuousStart += numContinuousActions; + discreteStart += numDiscreteActions; + } + Profiler.EndSample(); + } + + /// + /// Resets the to be all + /// zeros and calls on each managed by this object. + /// + public void ResetData() + { + if (!m_ReadyForExecution) + { + return; + } + StoredActions.Clear(); + for (var i = 0; i < m_Actuators.Count; i++) + { + m_Actuators[i].ResetData(); + } + m_DiscreteActionMask.ResetMask(); + } + + /// + /// Sorts the s according to their value. + /// + internal static void SortActuators(List actuators) + { + actuators.Sort((x, y) => string.Compare(x.Name, y.Name, StringComparison.InvariantCulture)); + } + + /// + /// Validates that the IActuators managed by this object have unique names. + /// Each Actuator needs to have a unique name in order for this object to ensure that the storage of action + /// buffers, and execution of Actuators remains deterministic across different sessions of running. + /// + void ValidateActuators() + { + for (var i = 0; i < m_Actuators.Count - 1; i++) + { + Debug.Assert( + !m_Actuators[i].Name.Equals(m_Actuators[i + 1].Name), + "Actuator names must be unique."); + } + } + + /// + /// Helper method to update bookkeeping items around buffer management for actuators added to this object. + /// + /// The IActuator to keep bookkeeping for. + void AddToBufferSizes(IActuator actuatorItem) + { + if (actuatorItem == null) + { + return; + } + + NumContinuousActions += actuatorItem.ActionSpec.NumContinuousActions; + NumDiscreteActions += actuatorItem.ActionSpec.NumDiscreteActions; + SumOfDiscreteBranchSizes += actuatorItem.ActionSpec.SumOfDiscreteBranchSizes; + } + + /// + /// Helper method to update bookkeeping items around buffer management for actuators removed from this object. + /// + /// The IActuator to keep bookkeeping for. + void SubtractFromBufferSize(IActuator actuatorItem) + { + if (actuatorItem == null) + { + return; + } + + NumContinuousActions -= actuatorItem.ActionSpec.NumContinuousActions; + NumDiscreteActions -= actuatorItem.ActionSpec.NumDiscreteActions; + SumOfDiscreteBranchSizes -= actuatorItem.ActionSpec.SumOfDiscreteBranchSizes; + } + + /// + /// Sets all of the bookkeeping items back to 0. + /// + void ClearBufferSizes() + { + NumContinuousActions = NumDiscreteActions = SumOfDiscreteBranchSizes = 0; + } + + /// + /// Add an array of s at once. + /// + /// The array of s to add. + public void AddActuators(IActuator[] actuators) + { + for (var i = 0; i < actuators.Length; i++) + { + Add(actuators[i]); + } + } + + /********************************************************************************* + * IList implementation that delegates to m_Actuators List. * + *********************************************************************************/ + + /// + public IEnumerator GetEnumerator() + { + return m_Actuators.GetEnumerator(); + } + + /// + IEnumerator IEnumerable.GetEnumerator() + { + return ((IEnumerable)m_Actuators).GetEnumerator(); + } + + /// + public void Add(IActuator item) + { + Debug.Assert(m_ReadyForExecution == false, + "Cannot add to the ActuatorManager after its buffers have been initialized"); + m_Actuators.Add(item); + AddToBufferSizes(item); + } + + /// + public void Clear() + { + Debug.Assert(m_ReadyForExecution == false, + "Cannot clear the ActuatorManager after its buffers have been initialized"); + m_Actuators.Clear(); + ClearBufferSizes(); + } + + /// + public bool Contains(IActuator item) + { + return m_Actuators.Contains(item); + } + + /// + public void CopyTo(IActuator[] array, int arrayIndex) + { + m_Actuators.CopyTo(array, arrayIndex); + } + + /// + public bool Remove(IActuator item) + { + Debug.Assert(m_ReadyForExecution == false, + "Cannot remove from the ActuatorManager after its buffers have been initialized"); + if (m_Actuators.Remove(item)) + { + SubtractFromBufferSize(item); + return true; + } + return false; + } + + /// + public int Count => m_Actuators.Count; + + /// + public bool IsReadOnly => false; + + /// + public int IndexOf(IActuator item) + { + return m_Actuators.IndexOf(item); + } + + /// + public void Insert(int index, IActuator item) + { + Debug.Assert(m_ReadyForExecution == false, + "Cannot insert into the ActuatorManager after its buffers have been initialized"); + m_Actuators.Insert(index, item); + AddToBufferSizes(item); + } + + /// + public void RemoveAt(int index) + { + Debug.Assert(m_ReadyForExecution == false, + "Cannot remove from the ActuatorManager after its buffers have been initialized"); + var actuator = m_Actuators[index]; + SubtractFromBufferSize(actuator); + m_Actuators.RemoveAt(index); + } + + /// + public IActuator this[int index] + { + get => m_Actuators[index]; + set + { + Debug.Assert(m_ReadyForExecution == false, + "Cannot modify the ActuatorManager after its buffers have been initialized"); + var old = m_Actuators[index]; + SubtractFromBufferSize(old); + m_Actuators[index] = value; + AddToBufferSizes(value); + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs.meta b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs.meta new file mode 100644 index 0000000000..aa56b5ca9f --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 7bb5b1e3779d4342a8e70f6e3c1d67cc +timeCreated: 1593031463 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs new file mode 100644 index 0000000000..1586e6215b --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs @@ -0,0 +1,192 @@ +using System; +using System.Linq; +using UnityEngine; + +namespace Unity.MLAgents.Actuators +{ + /// + /// A structure that wraps the s for a particular and is + /// used when is called. + /// + public readonly struct ActionBuffers + { + /// + /// An empty action buffer. + /// + public static ActionBuffers Empty = new ActionBuffers(ActionSegment.Empty, ActionSegment.Empty); + + /// + /// Holds the Continuous to be used by an . + /// + public ActionSegment ContinuousActions { get; } + + /// + /// Holds the Discrete to be used by an . + /// + public ActionSegment DiscreteActions { get; } + + /// + /// Create an instance with discrete actions stored as a float array. This exists + /// to achieve backward compatibility with the former Agent methods which used a float array for both continuous + /// and discrete actions. + /// + /// The float array of discrete actions. + /// An instance initialized with a + /// initialized from a float array. + public static ActionBuffers FromDiscreteActions(float[] discreteActions) + { + return new ActionBuffers(ActionSegment.Empty, discreteActions == null ? ActionSegment.Empty + : new ActionSegment(Array.ConvertAll(discreteActions, + x => (int)x))); + } + + /// + /// Construct an instance with the continuous and discrete actions that will + /// be used. + /// /// + /// The continuous actions to send to an . + /// The discrete actions to send to an . + public ActionBuffers(float[] continuousActions, int[] discreteActions) + : this(new ActionSegment(continuousActions), new ActionSegment(discreteActions)) { } + + /// + /// Construct an instance with the continuous and discrete actions that will + /// be used. + /// + /// The continuous actions to send to an . + /// The discrete actions to send to an . + public ActionBuffers(ActionSegment continuousActions, ActionSegment discreteActions) + { + ContinuousActions = continuousActions; + DiscreteActions = discreteActions; + } + + /// + /// Construct an instance with . All values are initialized to zeros. + /// /// + /// The to send to an . + public ActionBuffers(ActionSpec actionSpec) + : this(new ActionSegment(new float[actionSpec.NumContinuousActions]), + new ActionSegment(new int[actionSpec.NumDiscreteActions])) + { } + + /// + /// Create an instance with ActionSpec and all actions stored as a float array. + /// + /// of the + /// The float array of all actions, including discrete and continuous actions. + /// An instance initialized with a and a float array. + internal static ActionBuffers FromActionSpec(ActionSpec actionSpec, float[] actions) + { + if (actions == null) + { + return ActionBuffers.Empty; + } + + Debug.Assert(actions.Length == actionSpec.NumContinuousActions + actionSpec.NumDiscreteActions, + $"The length of '{nameof(actions)}' does not match the total size of ActionSpec.\n" + + $"{nameof(actions)}.Length: {actions.Length}\n" + + $"{nameof(actionSpec)}: {actionSpec.NumContinuousActions + actionSpec.NumDiscreteActions}"); + + ActionSegment continuousActionSegment = ActionSegment.Empty; + ActionSegment discreteActionSegment = ActionSegment.Empty; + int offset = 0; + if (actionSpec.NumContinuousActions > 0) + { + continuousActionSegment = new ActionSegment(actions, 0, actionSpec.NumContinuousActions); + offset += actionSpec.NumContinuousActions; + } + if (actionSpec.NumDiscreteActions > 0) + { + int[] discreteActions = new int[actionSpec.NumDiscreteActions]; + for (var i = 0; i < actionSpec.NumDiscreteActions; i++) + { + discreteActions[i] = (int)actions[i + offset]; + } + discreteActionSegment = new ActionSegment(discreteActions); + } + + return new ActionBuffers(continuousActionSegment, discreteActionSegment); + } + + /// + /// Clear the and segments to be all zeros. + /// + public void Clear() + { + ContinuousActions.Clear(); + DiscreteActions.Clear(); + } + + /// + /// Check if the is empty. + /// + /// Whether the buffers are empty. + public bool IsEmpty() + { + return ContinuousActions.IsEmpty() && DiscreteActions.IsEmpty(); + } + + /// + /// Indicates whether the current ActionBuffers is equal to another ActionBuffers. + /// + /// An ActionBuffers to compare with this ActionBuffers. + /// true if the current ActionBuffers is equal to the other parameter; otherwise, false. + public override bool Equals(object obj) + { + if (!(obj is ActionBuffers)) + { + return false; + } + + var ab = (ActionBuffers)obj; + return ab.ContinuousActions.SequenceEqual(ContinuousActions) && + ab.DiscreteActions.SequenceEqual(DiscreteActions); + } + + /// + /// Computes the hash code of the ActionBuffers. + /// + /// A hash code for the current ActionBuffers. + public override int GetHashCode() + { + unchecked + { + return (ContinuousActions.GetHashCode() * 397) ^ DiscreteActions.GetHashCode(); + } + } + } + + /// + /// An interface that describes an object that can receive actions from a Reinforcement Learning network. + /// + public interface IActionReceiver + { + /// + /// Method called in order too allow object to execute actions based on the + /// contents. The structure of the contents in the + /// are defined by the . + /// + /// The data structure containing the action buffers for this object. + void OnActionReceived(ActionBuffers actionBuffers); + + /// + /// Implement `WriteDiscreteActionMask()` to modify the masks for discrete + /// actions. When using discrete actions, the agent will not perform the masked + /// action. + /// + /// + /// The action mask for the agent. + /// + /// + /// When using Discrete Control, you can prevent the Agent from using a certain + /// action by masking it with . + /// + /// See [Agents - Actions] for more information on masking actions. + /// + /// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/Learning-Environment-Design-Agents.md#actions + /// + /// + void WriteDiscreteActionMask(IDiscreteActionMask actionMask); + } +} diff --git a/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs.meta b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs.meta new file mode 100644 index 0000000000..b14a69d21c --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: b25a5b3027c9476ea1a310241be0f10f +timeCreated: 1594756775 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Actuators/IActuator.cs b/com.unity.ml-agents/Runtime/Actuators/IActuator.cs new file mode 100644 index 0000000000..aa2675905a --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/IActuator.cs @@ -0,0 +1,42 @@ +namespace Unity.MLAgents.Actuators +{ + /// + /// Abstraction that facilitates the execution of actions. + /// + public interface IActuator : IActionReceiver, IHeuristicProvider + { + /// + /// The specification of the actions for this IActuator. + /// + /// + ActionSpec ActionSpec { get; } + + /// + /// Gets the name of this IActuator which will be used to sort it. + /// + /// + string Name { get; } + + /// + /// Resets the internal state of the actuator. This is called at the end of an Agent's episode. + /// Most implementations can leave this empty. + /// + void ResetData(); + } + + /// + /// Helper methods to be shared by all classes that implement . + /// + public static class IActuatorExtensions + { + /// + /// Returns the number of discrete branches + the number of continuous actions. + /// + /// + /// + public static int TotalNumberOfActions(this IActuator actuator) + { + return actuator.ActionSpec.NumContinuousActions + actuator.ActionSpec.NumDiscreteActions; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Actuators/IActuator.cs.meta b/com.unity.ml-agents/Runtime/Actuators/IActuator.cs.meta new file mode 100644 index 0000000000..4fd0d172ca --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/IActuator.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 780d7f0a675f44bfa784b370025b51c3 +timeCreated: 1592848317 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Actuators/IBuiltInActuator.cs b/com.unity.ml-agents/Runtime/Actuators/IBuiltInActuator.cs new file mode 100644 index 0000000000..8b77672d17 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/IBuiltInActuator.cs @@ -0,0 +1,49 @@ +namespace Unity.MLAgents.Actuators +{ + /// + /// Identifiers for "built in" actuator types. + /// These are only used for analytics, and should not be used for any runtime decisions. + /// + /// NOTE: Do not renumber these, since the values are used for analytics. Renaming is allowed though. + /// + public enum BuiltInActuatorType + { + /// + /// Default Sensor type if it cannot be determined. + /// + Unknown = 0, + + /// + /// VectorActuator used by the Agent + /// + AgentVectorActuator = 1, + + /// + /// Corresponds to + /// + VectorActuator = 2, + + /// + /// Corresponds to the Match3Actuator in com.unity.ml-agents.extensions. + /// + Match3Actuator = 3, + + /// + /// Corresponds to the InputActionActuator in com.unity.ml-agents.extensions. + /// + InputActionActuator = 4, + } + + /// + /// Interface for actuators that are provided as part of ML-Agents. + /// User-implemented actuators don't need to use this interface. + /// + internal interface IBuiltInActuator + { + /// + /// Return the corresponding BuiltInActuatorType for the actuator. + /// + /// A BuiltInActuatorType corresponding to the actuator. + BuiltInActuatorType GetBuiltInActuatorType(); + } +} diff --git a/com.unity.ml-agents/Runtime/Actuators/IBuiltInActuator.cs.meta b/com.unity.ml-agents/Runtime/Actuators/IBuiltInActuator.cs.meta new file mode 100644 index 0000000000..da1d96f271 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/IBuiltInActuator.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: e3d7ef9a9a5043549cc5c0bbee520810 +timeCreated: 1613514041 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs b/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs new file mode 100644 index 0000000000..1a100b68e1 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs @@ -0,0 +1,26 @@ +namespace Unity.MLAgents.Actuators +{ + /// + /// Interface for writing a mask to disable discrete actions for agents for the next decision. + /// + public interface IDiscreteActionMask + { + /// + /// Set whether or not the action index for the given branch is allowed. + /// + /// + /// By default, all discrete actions are allowed. + /// If isEnabled is false, the agent will not be able to perform the actions passed as argument + /// at the next decision for the specified action branch. The actionIndex corresponds + /// to the action options the agent will be unable to perform. + /// + /// See [Agents - Actions] for more information on masking actions. + /// + /// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/Learning-Environment-Design-Agents.md#masking-discrete-actions + /// + /// The branch for which the actions will be masked. + /// Index of the action. + /// Whether the action is allowed or not. + void SetActionEnabled(int branch, int actionIndex, bool isEnabled); + } +} diff --git a/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs.meta b/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs.meta new file mode 100644 index 0000000000..ebfa10158f --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 1bc4e4b71bf4470789488fab2ee65388 +timeCreated: 1595369065 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs b/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs new file mode 100644 index 0000000000..b992361c83 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs @@ -0,0 +1,18 @@ +namespace Unity.MLAgents.Actuators +{ + /// + /// Interface that allows objects to fill out an data structure for controlling + /// behavior of Agents or Actuators. + /// + public interface IHeuristicProvider + { + /// + /// Method called on objects which are expected to fill out the data structure. + /// Object that implement this interface should be careful to be consistent in the placement of their actions + /// in the data structure. + /// + /// The data structure to be filled by the + /// object implementing this interface. + void Heuristic(in ActionBuffers actionBuffersOut); + } +} diff --git a/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs.meta b/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs.meta new file mode 100644 index 0000000000..ca8338a072 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: be90ffb28f39444a8fb02dfd4a82870c +timeCreated: 1610057456 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs b/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs new file mode 100644 index 0000000000..586058aad3 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs @@ -0,0 +1,105 @@ +using UnityEngine.Profiling; + +namespace Unity.MLAgents.Actuators +{ + /// + /// IActuator implementation that forwards calls to an and an . + /// + internal class VectorActuator : IActuator, IBuiltInActuator + { + IActionReceiver m_ActionReceiver; + IHeuristicProvider m_HeuristicProvider; + + ActionBuffers m_ActionBuffers; + internal ActionBuffers ActionBuffers + { + get => m_ActionBuffers; + private set => m_ActionBuffers = value; + } + + /// + /// Create a VectorActuator that forwards to the provided IActionReceiver. + /// + /// The used for OnActionReceived and WriteDiscreteActionMask. + /// If this parameter also implements it will be cast and used to forward calls to + /// . + /// + /// + public VectorActuator(IActionReceiver actionReceiver, + ActionSpec actionSpec, + string name = "VectorActuator") + : this(actionReceiver, actionReceiver as IHeuristicProvider, actionSpec, name) { } + + /// + /// Create a VectorActuator that forwards to the provided IActionReceiver. + /// + /// The used for OnActionReceived and WriteDiscreteActionMask. + /// The used to fill the + /// for Heuristic Policies. + /// + /// + public VectorActuator(IActionReceiver actionReceiver, + IHeuristicProvider heuristicProvider, + ActionSpec actionSpec, + string name = "VectorActuator") + { + m_ActionReceiver = actionReceiver; + m_HeuristicProvider = heuristicProvider; + ActionSpec = actionSpec; + string suffix; + if (actionSpec.NumContinuousActions == 0) + { + suffix = "-Discrete"; + } + else if (actionSpec.NumDiscreteActions == 0) + { + suffix = "-Continuous"; + } + else + { + suffix = $"-Continuous-{actionSpec.NumContinuousActions}-Discrete-{actionSpec.NumDiscreteActions}"; + } + Name = name + suffix; + } + + /// + public void ResetData() + { + m_ActionBuffers = ActionBuffers.Empty; + } + + /// + public void OnActionReceived(ActionBuffers actionBuffers) + { + Profiler.BeginSample("VectorActuator.OnActionReceived"); + m_ActionBuffers = actionBuffers; + m_ActionReceiver.OnActionReceived(m_ActionBuffers); + Profiler.EndSample(); + } + + public void Heuristic(in ActionBuffers actionBuffersOut) + { + Profiler.BeginSample("VectorActuator.Heuristic"); + m_HeuristicProvider?.Heuristic(actionBuffersOut); + Profiler.EndSample(); + } + + /// + public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) + { + m_ActionReceiver.WriteDiscreteActionMask(actionMask); + } + + /// + public ActionSpec ActionSpec { get; } + + /// + public string Name { get; } + + /// + public virtual BuiltInActuatorType GetBuiltInActuatorType() + { + return BuiltInActuatorType.VectorActuator; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs.meta b/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs.meta new file mode 100644 index 0000000000..6e9f68b913 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: ff7a3292c0b24b23b3f1c0eeb690ec4c +timeCreated: 1593023833 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs new file mode 100644 index 0000000000..8d5c6e79d2 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -0,0 +1,1427 @@ +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using UnityEngine; +using Unity.Barracuda; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; +using Unity.MLAgents.Demonstrations; +using Unity.MLAgents.Policies; +using UnityEngine.Serialization; + +namespace Unity.MLAgents +{ + /// + /// Struct that contains all the information for an Agent, including its + /// observations, actions and current status. + /// + public struct AgentInfo + { + /// + /// Keeps track of the last actions taken by the Brain. + /// + public ActionBuffers storedActions; + + /// + /// For discrete control, specifies the actions that the agent cannot take. + /// An element of the mask array is true if the action is prohibited. + /// + public bool[] discreteActionMasks; + + /// + /// The current agent reward. + /// + public float reward; + + /// + /// The current group reward received by the agent. + /// + public float groupReward; + + /// + /// Whether the agent is done or not. + /// + public bool done; + + /// + /// Whether the agent has reached its max step count for this episode. + /// + public bool maxStepReached; + + /// + /// Episode identifier each agent receives at every reset. It is used + /// to separate between different agents in the environment. + /// + public int episodeId; + + /// + /// MultiAgentGroup identifier. + /// + public int groupId; + + public void ClearActions() + { + storedActions.Clear(); + } + + public void CopyActions(ActionBuffers actionBuffers) + { + var continuousActions = storedActions.ContinuousActions; + for (var i = 0; i < actionBuffers.ContinuousActions.Length; i++) + { + continuousActions[i] = actionBuffers.ContinuousActions[i]; + } + var discreteActions = storedActions.DiscreteActions; + for (var i = 0; i < actionBuffers.DiscreteActions.Length; i++) + { + discreteActions[i] = actionBuffers.DiscreteActions[i]; + } + } + } + + /// + /// Simple wrapper around VectorActuator that overrides GetBuiltInActuatorType + /// so that it can be distinguished from a standard VectorActuator. + /// + internal class AgentVectorActuator : VectorActuator + { + public AgentVectorActuator(IActionReceiver actionReceiver, + IHeuristicProvider heuristicProvider, + ActionSpec actionSpec, + string name = "VectorActuator" + ) : base(actionReceiver, heuristicProvider, actionSpec, name) + { } + + public override BuiltInActuatorType GetBuiltInActuatorType() + { + return BuiltInActuatorType.AgentVectorActuator; + } + } + + /// + /// An agent is an actor that can observe its environment, decide on the + /// best course of action using those observations, and execute those actions + /// within the environment. + /// + /// + /// Use the Agent class as the subclass for implementing your own agents. Add + /// your Agent implementation to a [GameObject] in the [Unity scene] that serves + /// as the agent's environment. + /// + /// Agents in an environment operate in *steps*. At each step, an agent collects observations, + /// passes them to its decision-making policy, and receives an action vector in response. + /// + /// Agents make observations using implementations. The ML-Agents + /// API provides implementations for visual observations () + /// raycast observations (), and arbitrary + /// data observations (). You can add the + /// and or + /// components to an agent's [GameObject] to use + /// those sensor types. You can implement the + /// function in your Agent subclass to use a vector observation. The Agent class calls this + /// function before it uses the observation vector to make a decision. (If you only use + /// visual or raycast observations, you do not need to implement + /// .) + /// + /// Assign a decision making policy to an agent using a + /// component attached to the agent's [GameObject]. The setting + /// determines how decisions are made: + /// + /// * : decisions are made by the external process, + /// when connected. Otherwise, decisions are made using inference. If no inference model + /// is specified in the BehaviorParameters component, then heuristic decision + /// making is used. + /// * : decisions are always made using the trained + /// model specified in the component. + /// * : when a decision is needed, the agent's + /// function is called. Your implementation is responsible for + /// providing the appropriate action. + /// + /// To trigger an agent decision automatically, you can attach a + /// component to the Agent game object. You can also call the agent's + /// function manually. You only need to call when the agent is + /// in a position to act upon the decision. In many cases, this will be every [FixedUpdate] + /// callback, but could be less frequent. For example, an agent that hops around its environment + /// can only take an action when it touches the ground, so several frames might elapse between + /// one decision and the need for the next. + /// + /// Use the function to implement the actions your agent can take, + /// such as moving to reach a goal or interacting with its environment. + /// + /// When you call on an agent or the agent reaches its count, + /// its current episode ends. You can reset the agent -- or remove it from the + /// environment -- by implementing the function. An agent also + /// becomes done when the resets the environment, which only happens when + /// the receives a reset signal from an external process via the + /// . + /// + /// The Agent class extends the Unity [MonoBehaviour] class. You can implement the + /// standard [MonoBehaviour] functions as needed for your agent. Since an agent's + /// observations and actions typically take place during the [FixedUpdate] phase, you should + /// only use the [MonoBehaviour.Update] function for cosmetic purposes. If you override the [MonoBehaviour] + /// methods, [OnEnable()] or [OnDisable()], always call the base Agent class implementations. + /// + /// You can implement the function to specify agent actions using + /// your own heuristic algorithm. Implementing a heuristic function can be useful + /// for debugging. For example, you can use keyboard input to select agent actions in + /// order to manually control an agent's behavior. + /// + /// Note that you can change the inference model assigned to an agent at any step + /// by calling . + /// + /// See [Agents] and [Reinforcement Learning in Unity] in the [Unity ML-Agents Toolkit manual] for + /// more information on creating and training agents. + /// + /// For sample implementations of agent behavior, see the examples available in the + /// [Unity ML-Agents Toolkit] on Github. + /// + /// [MonoBehaviour]: https://docs.unity3d.com/ScriptReference/MonoBehaviour.html + /// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html + /// [Unity scene]: https://docs.unity3d.com/Manual/CreatingScenes.html + /// [FixedUpdate]: https://docs.unity3d.com/ScriptReference/MonoBehaviour.FixedUpdate.html + /// [MonoBehaviour.Update]: https://docs.unity3d.com/ScriptReference/MonoBehaviour.Update.html + /// [OnEnable()]: https://docs.unity3d.com/ScriptReference/MonoBehaviour.OnEnable.html + /// [OnDisable()]: https://docs.unity3d.com/ScriptReference/MonoBehaviour.OnDisable.html] + /// [OnBeforeSerialize()]: https://docs.unity3d.com/ScriptReference/MonoBehaviour.OnBeforeSerialize.html + /// [OnAfterSerialize()]: https://docs.unity3d.com/ScriptReference/MonoBehaviour.OnAfterSerialize.html + /// [Agents]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/Learning-Environment-Design-Agents.md + /// [Reinforcement Learning in Unity]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/Learning-Environment-Design.md + /// [Unity ML-Agents Toolkit]: https://github.com/Unity-Technologies/ml-agents + /// [Unity ML-Agents Toolkit manual]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/Readme.md + /// + /// + [HelpURL("https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/" + + "docs/Learning-Environment-Design-Agents.md")] + [Serializable] + [RequireComponent(typeof(BehaviorParameters))] + [DefaultExecutionOrder(-50)] + public partial class Agent : MonoBehaviour, ISerializationCallbackReceiver, IActionReceiver, IHeuristicProvider + { + IPolicy m_Brain; + BehaviorParameters m_PolicyFactory; + + /// This code is here to make the upgrade path for users using MaxStep + /// easier. We will hook into the Serialization code and make sure that + /// agentParameters.maxStep and this.maxStep are in sync. + [Serializable] + internal struct AgentParameters + { + public int maxStep; + } + + [SerializeField] + [HideInInspector] + internal AgentParameters agentParameters; + [SerializeField] + [HideInInspector] + internal bool hasUpgradedFromAgentParameters; + + /// + /// The maximum number of steps the agent takes before being done. + /// + /// The maximum steps for an agent to take before it resets; or 0 for + /// unlimited steps. + /// + /// The max step value determines the maximum length of an agent's episodes. + /// Set to a positive integer to limit the episode length to that many steps. + /// Set to 0 for unlimited episode length. + /// + /// When an episode ends and a new one begins, the Agent object's + /// function is called. You can implement + /// to reset the agent or remove it from the + /// environment. An agent's episode can also end if you call its + /// method or an external process resets the environment through the . + /// + /// Consider limiting the number of steps in an episode to avoid wasting time during + /// training. If you set the max step value to a reasonable estimate of the time it should + /// take to complete a task, then agents that haven’t succeeded in that time frame will + /// reset and start a new training episode rather than continue to fail. + /// + /// + /// To use a step limit when training while allowing agents to run without resetting + /// outside of training, you can set the max step to 0 in + /// if the is not connected to an external process. + /// + /// using Unity.MLAgents; + /// + /// public class MyAgent : Agent + /// { + /// public override void Initialize() + /// { + /// if (!Academy.Instance.IsCommunicatorOn) + /// { + /// this.MaxStep = 0; + /// } + /// } + /// } + /// + /// **Note:** in general, you should limit the differences between the code you execute + /// during training and the code you run during inference. + /// + [FormerlySerializedAs("maxStep")] + [HideInInspector] public int MaxStep; + + /// Current Agent information (message sent to Brain). + AgentInfo m_Info; + + /// Represents the reward the agent accumulated during the current step. + /// It is reset to 0 at the beginning of every step. + /// Should be set to a positive value when the agent performs a "good" + /// action that we wish to reinforce/reward, and set to a negative value + /// when the agent performs a "bad" action that we wish to punish/deter. + /// Additionally, the magnitude of the reward should not exceed 1.0 + float m_Reward; + + /// Represents the group reward the agent accumulated during the current step. + float m_GroupReward; + + /// Keeps track of the cumulative reward in this episode. + float m_CumulativeReward; + + /// Whether or not the agent requests an action. + bool m_RequestAction; + + /// Whether or not the agent requests a decision. + bool m_RequestDecision; + + /// Keeps track of the number of steps taken by the agent in this episode. + /// Note that this value is different for each agent, and may not overlap + /// with the step counter in the Academy, since agents reset based on + /// their own experience. + int m_StepCount; + + /// Number of times the Agent has completed an episode. + int m_CompletedEpisodes; + + /// Episode identifier each agent receives. It is used + /// to separate between different agents in the environment. + /// This Id will be changed every time the Agent resets. + int m_EpisodeId; + + /// Whether or not the Agent has been initialized already + bool m_Initialized; + + /// + /// Set of DemonstrationWriters that the Agent will write its step information to. + /// If you use a DemonstrationRecorder component, this will automatically register its DemonstrationWriter. + /// You can also add your own DemonstrationWriter by calling + /// DemonstrationRecorder.AddDemonstrationWriterToAgent() + /// + internal ISet DemonstrationWriters = new HashSet(); + + /// + /// List of sensors used to generate observations. + /// Currently generated from attached SensorComponents, and a legacy VectorSensor + /// + internal List sensors; + + /// + /// VectorSensor which is written to by AddVectorObs + /// + internal VectorSensor collectObservationsSensor; + + /// + /// StackingSensor which is written to by AddVectorObs + /// + internal StackingSensor stackedCollectObservationsSensor; + + private RecursionChecker m_CollectObservationsChecker = new RecursionChecker("CollectObservations"); + private RecursionChecker m_OnEpisodeBeginChecker = new RecursionChecker("OnEpisodeBegin"); + + /// + /// List of IActuators that this Agent will delegate actions to if any exist. + /// + ActuatorManager m_ActuatorManager; + + /// + /// VectorActuator which is used by default if no other sensors exist on this Agent. This VectorSensor will + /// delegate its actions to by default in order to keep backward compatibility + /// with the current behavior of Agent. + /// + IActuator m_VectorActuator; + + /// Currect MultiAgentGroup ID. Default to 0 (meaning no group) + int m_GroupId; + + /// Delegate for the agent to unregister itself from the MultiAgentGroup without cyclic reference + /// between agent and the group + internal event Action OnAgentDisabled; + + /// + /// Called when the Agent is being loaded (before OnEnable()). + /// + /// + /// This function registers the RpcCommunicator delegate if no delegate has been registered with CommunicatorFactory. + /// Always call the base Agent class version of this function if you implement `Awake()` in your + /// own Agent subclasses. + /// + /// + /// + /// protected override void Awake() + /// { + /// base.Awake(); + /// // additional Awake logic... + /// } + /// + /// + protected internal virtual void Awake() + { +#if UNITY_EDITOR || UNITY_STANDALONE + if (!CommunicatorFactory.CommunicatorRegistered) + { + Debug.Log("Registered Communicator in Agent."); + CommunicatorFactory.Register(RpcCommunicator.Create); + } +#endif + } + + /// + /// Called when the attached [GameObject] becomes enabled and active. + /// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html + /// + /// + /// This function initializes the Agent instance, if it hasn't been initialized yet. + /// Always call the base Agent class version of this function if you implement `OnEnable()` + /// in your own Agent subclasses. + /// + /// + /// + /// protected override void OnEnable() + /// { + /// base.OnEnable(); + /// // additional OnEnable logic... + /// } + /// + /// + protected virtual void OnEnable() + { + LazyInitialize(); + } + + /// + /// Called by Unity immediately before serializing this object. + /// + /// + /// The Agent class uses OnBeforeSerialize() for internal housekeeping. Call the + /// base class implementation if you need your own custom serialization logic. + /// + /// See [OnBeforeSerialize] for more information. + /// + /// [OnBeforeSerialize]: https://docs.unity3d.com/ScriptReference/ISerializationCallbackReceiver.OnAfterDeserialize.html + /// + /// + /// + /// public new void OnBeforeSerialize() + /// { + /// base.OnBeforeSerialize(); + /// // additional serialization logic... + /// } + /// + /// + public void OnBeforeSerialize() + { + // Manages a serialization upgrade issue from v0.13 to v0.14 where MaxStep moved + // from AgentParameters (since removed) to Agent + if (MaxStep == 0 && MaxStep != agentParameters.maxStep && !hasUpgradedFromAgentParameters) + { + MaxStep = agentParameters.maxStep; + } + hasUpgradedFromAgentParameters = true; + } + + /// + /// Called by Unity immediately after deserializing this object. + /// + /// + /// The Agent class uses OnAfterDeserialize() for internal housekeeping. Call the + /// base class implementation if you need your own custom deserialization logic. + /// + /// See [OnAfterDeserialize] for more information. + /// + /// [OnAfterDeserialize]: https://docs.unity3d.com/ScriptReference/ISerializationCallbackReceiver.OnAfterDeserialize.html + /// + /// + /// + /// public new void OnAfterDeserialize() + /// { + /// base.OnAfterDeserialize(); + /// // additional deserialization logic... + /// } + /// + /// + public void OnAfterDeserialize() + { + // Manages a serialization upgrade issue from v0.13 to v0.14 where MaxStep moved + // from AgentParameters (since removed) to Agent + if (MaxStep == 0 && MaxStep != agentParameters.maxStep && !hasUpgradedFromAgentParameters) + { + MaxStep = agentParameters.maxStep; + } + hasUpgradedFromAgentParameters = true; + } + + /// + /// Initializes the agent. Can be safely called multiple times. + /// + /// + /// This function calls your implementation, if one exists. + /// + public void LazyInitialize() + { + if (m_Initialized) + { + return; + } + m_Initialized = true; + + // Grab the "static" properties for the Agent. + m_EpisodeId = EpisodeIdCounter.GetEpisodeId(); + m_PolicyFactory = GetComponent(); + + m_Info = new AgentInfo(); + sensors = new List(); + + Academy.Instance.AgentIncrementStep += AgentIncrementStep; + Academy.Instance.AgentSendState += SendInfo; + Academy.Instance.DecideAction += DecideAction; + Academy.Instance.AgentAct += AgentStep; + Academy.Instance.AgentForceReset += _AgentReset; + + using (TimerStack.Instance.Scoped("InitializeActuators")) + { + InitializeActuators(); + } + + m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), m_ActuatorManager); + ResetData(); + Initialize(); + + using (TimerStack.Instance.Scoped("InitializeSensors")) + { + InitializeSensors(); + } + + m_Info.storedActions = new ActionBuffers( + new float[m_ActuatorManager.NumContinuousActions], + new int[m_ActuatorManager.NumDiscreteActions] + ); + + m_Info.groupId = m_GroupId; + + // The first time the Academy resets, all Agents in the scene will be + // forced to reset through the event. + // To avoid the Agent resetting twice, the Agents will not begin their + // episode when initializing until after the Academy had its first reset. + if (Academy.Instance.TotalStepCount != 0) + { + using (m_OnEpisodeBeginChecker.Start()) + { + OnEpisodeBegin(); + } + } + } + + /// + /// The reason that the Agent has been set to "done". + /// + enum DoneReason + { + /// + /// The episode was ended manually by calling . + /// + DoneCalled, + + /// + /// The max steps for the Agent were reached. + /// + MaxStepReached, + + /// + /// The Agent was disabled. + /// + Disabled, + } + + /// + /// Called when the attached [GameObject] becomes disabled and inactive. + /// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html + /// + /// + /// Always call the base Agent class version of this function if you implement `OnDisable()` + /// in your own Agent subclasses. + /// + /// + /// + /// protected override void OnDisable() + /// { + /// base.OnDisable(); + /// // additional OnDisable logic... + /// } + /// + /// + /// + protected virtual void OnDisable() + { + DemonstrationWriters.Clear(); + + // If Academy.Dispose has already been called, we don't need to unregister with it. + // We don't want to even try, because this will lazily create a new Academy! + if (Academy.IsInitialized) + { + Academy.Instance.AgentIncrementStep -= AgentIncrementStep; + Academy.Instance.AgentSendState -= SendInfo; + Academy.Instance.DecideAction -= DecideAction; + Academy.Instance.AgentAct -= AgentStep; + Academy.Instance.AgentForceReset -= _AgentReset; + NotifyAgentDone(DoneReason.Disabled); + } + + CleanupSensors(); + m_Brain?.Dispose(); + OnAgentDisabled?.Invoke(this); + m_Initialized = false; + } + + void NotifyAgentDone(DoneReason doneReason) + { + if (m_Info.done) + { + // The Agent was already marked as Done and should not be notified again + return; + } + m_Info.episodeId = m_EpisodeId; + m_Info.reward = m_Reward; + m_Info.groupReward = m_GroupReward; + m_Info.done = true; + m_Info.maxStepReached = doneReason == DoneReason.MaxStepReached; + m_Info.groupId = m_GroupId; + UpdateSensors(); + // Make sure the latest observations are being passed to training. + using (m_CollectObservationsChecker.Start()) + { + CollectObservations(collectObservationsSensor); + } + // Request the last decision with no callbacks + // We request a decision so Python knows the Agent is done immediately + m_Brain?.RequestDecision(m_Info, sensors); + + // We also have to write any to any DemonstationStores so that they get the "done" flag. + if (DemonstrationWriters.Count != 0) + { + foreach (var demoWriter in DemonstrationWriters) + { + demoWriter.Record(m_Info, sensors); + } + } + + ResetSensors(); + + if (doneReason != DoneReason.Disabled) + { + // We don't want to update the reward stats when the Agent is disabled, because this will make + // the rewards look lower than they actually are during shutdown. + m_CompletedEpisodes++; + UpdateRewardStats(); + } + + m_Reward = 0f; + m_GroupReward = 0f; + m_CumulativeReward = 0f; + m_RequestAction = false; + m_RequestDecision = false; + m_Info.storedActions.Clear(); + } + + /// + /// Updates the Model assigned to this Agent instance. + /// + /// + /// If the agent already has an assigned model, that model is replaced with the + /// the provided one. However, if you call this function with arguments that are + /// identical to the current parameters of the agent, then no changes are made. + /// + /// **Note:** the parameter is ignored when not training. + /// The and parameters + /// are ignored when not using inference. + /// + /// The identifier of the behavior. This + /// will categorize the agent when training. + /// + /// The model to use for inference. + /// Define the device on which the model + /// will be run. + public void SetModel( + string behaviorName, + NNModel model, + InferenceDevice inferenceDevice = InferenceDevice.Default) + { + if (behaviorName == m_PolicyFactory.BehaviorName && + model == m_PolicyFactory.Model && + inferenceDevice == m_PolicyFactory.InferenceDevice) + { + // If everything is the same, don't make any changes. + return; + } + NotifyAgentDone(DoneReason.Disabled); + m_PolicyFactory.Model = model; + m_PolicyFactory.InferenceDevice = inferenceDevice; + m_PolicyFactory.BehaviorName = behaviorName; + ReloadPolicy(); + } + + internal void ReloadPolicy() + { + if (!m_Initialized) + { + // If we haven't initialized yet, no need to make any changes now; they'll + // happen in LazyInitialize later. + return; + } + m_Brain?.Dispose(); + m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), m_ActuatorManager); + } + + /// + /// Returns the current step counter (within the current episode). + /// + /// + /// Current step count. + /// + public int StepCount + { + get { return m_StepCount; } + } + + /// + /// Returns the number of episodes that the Agent has completed (either + /// was called, or maxSteps was reached). + /// + /// + /// Current episode count. + /// + public int CompletedEpisodes + { + get { return m_CompletedEpisodes; } + } + + /// + /// Overrides the current step reward of the agent and updates the episode + /// reward accordingly. + /// + /// + /// This function replaces any rewards given to the agent during the current step. + /// Use to incrementally change the reward rather than + /// overriding it. + /// + /// Typically, you assign rewards in the Agent subclass's + /// implementation after carrying out the received action and evaluating its success. + /// + /// Rewards are used during reinforcement learning; they are ignored during inference. + /// + /// See [Agents - Rewards] for general advice on implementing rewards and [Reward Signals] + /// for information about mixing reward signals from curiosity and Generative Adversarial + /// Imitation Learning (GAIL) with rewards supplied through this method. + /// + /// [Agents - Rewards]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/Learning-Environment-Design-Agents.md#rewards + /// [Reward Signals]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/ML-Agents-Overview.md#a-quick-note-on-reward-signals + /// + /// The new value of the reward. + public void SetReward(float reward) + { + Utilities.DebugCheckNanAndInfinity(reward, nameof(reward), nameof(SetReward)); + m_CumulativeReward += (reward - m_Reward); + m_Reward = reward; + } + + /// + /// Increments the step and episode rewards by the provided value. + /// + /// Use a positive reward to reinforce desired behavior. You can use a + /// negative reward to penalize mistakes. Use to + /// set the reward assigned to the current step with a specific value rather than + /// increasing or decreasing it. + /// + /// Typically, you assign rewards in the Agent subclass's + /// implementation after carrying out the received action and evaluating its success. + /// + /// Rewards are used during reinforcement learning; they are ignored during inference. + /// + /// See [Agents - Rewards] for general advice on implementing rewards and [Reward Signals] + /// for information about mixing reward signals from curiosity and Generative Adversarial + /// Imitation Learning (GAIL) with rewards supplied through this method. + /// + /// [Agents - Rewards]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/Learning-Environment-Design-Agents.md#rewards + /// [Reward Signals]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/ML-Agents-Overview.md#a-quick-note-on-reward-signals + /// + /// Incremental reward value. + public void AddReward(float increment) + { + Utilities.DebugCheckNanAndInfinity(increment, nameof(increment), nameof(AddReward)); + m_Reward += increment; + m_CumulativeReward += increment; + } + + internal void SetGroupReward(float reward) + { + Utilities.DebugCheckNanAndInfinity(reward, nameof(reward), nameof(SetGroupReward)); + m_GroupReward = reward; + } + + internal void AddGroupReward(float increment) + { + Utilities.DebugCheckNanAndInfinity(increment, nameof(increment), nameof(AddGroupReward)); + m_GroupReward += increment; + } + + /// + /// Retrieves the episode reward for the Agent. + /// + /// The episode reward. + public float GetCumulativeReward() + { + return m_CumulativeReward; + } + + void UpdateRewardStats() + { + var gaugeName = $"{m_PolicyFactory.BehaviorName}.CumulativeReward"; + TimerStack.Instance.SetGauge(gaugeName, GetCumulativeReward()); + } + + /// + /// Sets the done flag to true and resets the agent. + /// + /// + /// This should be used when the episode can no longer continue, such as when the Agent + /// reaches the goal or fails at the task. + /// + /// + /// + public void EndEpisode() + { + EndEpisodeAndReset(DoneReason.DoneCalled); + } + + /// + /// Indicate that the episode is over but not due to the "fault" of the Agent. + /// This has the same end result as calling , but has a + /// slightly different effect on training. + /// + /// + /// This should be used when the episode could continue, but has gone on for + /// a sufficient number of steps. + /// + /// + /// + public void EpisodeInterrupted() + { + EndEpisodeAndReset(DoneReason.MaxStepReached); + } + + /// + /// Internal method to end the episode and reset the Agent. + /// + /// + void EndEpisodeAndReset(DoneReason reason) + { + NotifyAgentDone(reason); + _AgentReset(); + } + + /// + /// Requests a new decision for this agent. + /// + /// + /// Call `RequestDecision()` whenever an agent needs a decision. You often + /// want to request a decision every environment step. However, if an agent + /// cannot use the decision every step, then you can request a decision less + /// frequently. + /// + /// You can add a component to the agent's + /// [GameObject] to drive the agent's decision making. When you use this component, + /// do not call `RequestDecision()` separately. + /// + /// Note that this function calls ; you do not need to + /// call both functions at the same time. + /// + /// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html + /// + public void RequestDecision() + { + m_RequestDecision = true; + RequestAction(); + } + + /// + /// Requests an action for this agent. + /// + /// + /// Call `RequestAction()` to repeat the previous action returned by the agent's + /// most recent decision. A new decision is not requested. When you call this function, + /// the Agent instance invokes with the + /// existing action vector. + /// + /// You can use `RequestAction()` in situations where an agent must take an action + /// every update, but doesn't need to make a decision as often. For example, an + /// agent that moves through its environment might need to apply an action to keep + /// moving, but only needs to make a decision to change course or speed occasionally. + /// + /// You can add a component to the agent's + /// [GameObject] to drive the agent's decision making and action frequency. When you + /// use this component, do not call `RequestAction()` separately. + /// + /// Note that calls `RequestAction()`; you do not need to + /// call both functions at the same time. + /// + /// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html + /// + public void RequestAction() + { + m_RequestAction = true; + } + + /// Helper function that resets all the data structures associated with + /// the agent. Typically used when the agent is being initialized or reset + /// at the end of an episode. + void ResetData() + { + m_ActuatorManager?.ResetData(); + } + + /// + /// Implement `Initialize()` to perform one-time initialization or set up of the + /// Agent instance. + /// + /// + /// `Initialize()` is called once when the agent is first enabled. If, for example, + /// the Agent object needs references to other [GameObjects] in the scene, you + /// can collect and store those references here. + /// + /// Note that is called at the start of each of + /// the agent's "episodes". You can use that function for items that need to be reset + /// for each episode. + /// + /// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html + /// + public virtual void Initialize() { } + + /// + /// Implement to choose an action for this agent using a custom heuristic. + /// + /// + /// Implement this function to provide custom decision making logic or to support manual + /// control of an agent using keyboard, mouse, game controller input, or a script. + /// + /// Your heuristic implementation can use any decision making logic you specify. Assign decision + /// values to the and + /// arrays , passed to your function as a parameter. + /// The same array will be reused between steps. It is up to the user to initialize + /// the values on each call, for example by calling `Array.Clear(actionsOut, 0, actionsOut.Length);`. + /// Add values to the array at the same indexes as they are used in your + /// function, which receives this array and + /// implements the corresponding agent behavior. See [Actions] for more information + /// about agent actions. + /// Note : Do not create a new float array of action in the `Heuristic()` method, + /// as this will prevent writing floats to the original action array. + /// + /// An agent calls this `Heuristic()` function to make a decision when you set its behavior + /// type to . The agent also calls this function if + /// you set its behavior type to when the + /// is not connected to an external training process and you do not + /// assign a trained model to the agent. + /// + /// To perform imitation learning, implement manual control of the agent in the `Heuristic()` + /// function so that you can record the demonstrations required for the imitation learning + /// algorithms. (Attach a [Demonstration Recorder] component to the agent's [GameObject] to + /// record the demonstration session to a file.) + /// + /// Even when you don’t plan to use heuristic decisions for an agent or imitation learning, + /// implementing a simple heuristic function can aid in debugging agent actions and interactions + /// with its environment. + /// + /// [Demonstration Recorder]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/Learning-Environment-Design-Agents.md#recording-demonstrations + /// [Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/Learning-Environment-Design-Agents.md#actions + /// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html + /// + /// + /// The following example illustrates a `Heuristic()` function that provides WASD-style + /// keyboard control for an agent that can move in two dimensions as well as jump. See + /// [Input Manager] for more information about the built-in Unity input functions. + /// You can also use the [Input System package], which provides a more flexible and + /// configurable input system. + /// + /// public override void Heuristic(in ActionBuffers actionsOut) + /// { + /// var continuousActionsOut = actionsOut.ContinuousActions; + /// continuousActionsOut[0] = Input.GetAxis("Horizontal"); + /// continuousActionsOut[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f; + /// continuousActionsOut[2] = Input.GetAxis("Vertical"); + /// } + /// + /// [Input Manager]: https://docs.unity3d.com/Manual/class-InputManager.html + /// [Input System package]: https://docs.unity3d.com/Packages/com.unity.inputsystem@1.0/manual/index.html + /// + /// The which contain the continuous and + /// discrete action buffers to write to. + /// + public virtual void Heuristic(in ActionBuffers actionsOut) + { + Debug.LogWarning("Heuristic method called but not implemented. Returning placeholder actions."); + } + + /// + /// Set up the list of ISensors on the Agent. By default, this will select any + /// SensorComponent's attached to the Agent. + /// + internal void InitializeSensors() + { + if (m_PolicyFactory == null) + { + m_PolicyFactory = GetComponent(); + } + if (m_PolicyFactory.ObservableAttributeHandling != ObservableAttributeOptions.Ignore) + { + var excludeInherited = + m_PolicyFactory.ObservableAttributeHandling == ObservableAttributeOptions.ExcludeInherited; + using (TimerStack.Instance.Scoped("CreateObservableSensors")) + { + var observableSensors = ObservableAttribute.CreateObservableSensors(this, excludeInherited); + sensors.AddRange(observableSensors); + } + } + + // Get all attached sensor components + SensorComponent[] attachedSensorComponents; + if (m_PolicyFactory.UseChildSensors) + { + attachedSensorComponents = GetComponentsInChildren(); + } + else + { + attachedSensorComponents = GetComponents(); + } + + sensors.Capacity += attachedSensorComponents.Length; + foreach (var component in attachedSensorComponents) + { + sensors.AddRange(component.CreateSensors()); + } + + // Support legacy CollectObservations + var param = m_PolicyFactory.BrainParameters; + if (param.VectorObservationSize > 0) + { + collectObservationsSensor = new VectorSensor(param.VectorObservationSize); + if (param.NumStackedVectorObservations > 1) + { + stackedCollectObservationsSensor = new StackingSensor( + collectObservationsSensor, param.NumStackedVectorObservations); + sensors.Add(stackedCollectObservationsSensor); + } + else + { + sensors.Add(collectObservationsSensor); + } + } + + // Sort the Sensors by name to ensure determinism + SensorUtils.SortSensors(sensors); + +#if DEBUG + // Make sure the names are actually unique + + for (var i = 0; i < sensors.Count - 1; i++) + { + Debug.Assert( + !sensors[i].GetName().Equals(sensors[i + 1].GetName()), + "Sensor names must be unique."); + } +#endif + } + + void CleanupSensors() + { + // Dispose all attached sensor + for (var i = 0; i < sensors.Count; i++) + { + var sensor = sensors[i]; + if (sensor is IDisposable disposableSensor) + { + disposableSensor.Dispose(); + } + } + } + + void InitializeActuators() + { + ActuatorComponent[] attachedActuators; + if (m_PolicyFactory.UseChildActuators) + { + attachedActuators = GetComponentsInChildren(); + } + else + { + attachedActuators = GetComponents(); + } + + // Support legacy OnActionReceived + // TODO don't set this up if the sizes are 0? + var param = m_PolicyFactory.BrainParameters; + m_VectorActuator = new AgentVectorActuator(this, this, param.ActionSpec); + m_ActuatorManager = new ActuatorManager(attachedActuators.Length + 1); + + m_ActuatorManager.Add(m_VectorActuator); + + foreach (var actuatorComponent in attachedActuators) + { + m_ActuatorManager.AddActuators(actuatorComponent.CreateActuators()); + } + } + + /// + /// Sends the Agent info to the linked Brain. + /// + void SendInfoToBrain() + { + if (!m_Initialized) + { + throw new UnityAgentsException("Call to SendInfoToBrain when Agent hasn't been initialized." + + "Please ensure that you are calling 'base.OnEnable()' if you have overridden OnEnable."); + } + + if (m_Brain == null) + { + return; + } + + if (m_Info.done) + { + m_Info.ClearActions(); + } + else + { + m_Info.CopyActions(m_ActuatorManager.StoredActions); + } + + UpdateSensors(); + using (TimerStack.Instance.Scoped("CollectObservations")) + { + using (m_CollectObservationsChecker.Start()) + { + CollectObservations(collectObservationsSensor); + } + } + using (TimerStack.Instance.Scoped("WriteActionMask")) + { + m_ActuatorManager.WriteActionMask(); + } + + m_Info.discreteActionMasks = m_ActuatorManager.DiscreteActionMask?.GetMask(); + m_Info.reward = m_Reward; + m_Info.groupReward = m_GroupReward; + m_Info.done = false; + m_Info.maxStepReached = false; + m_Info.episodeId = m_EpisodeId; + m_Info.groupId = m_GroupId; + + using (TimerStack.Instance.Scoped("RequestDecision")) + { + m_Brain.RequestDecision(m_Info, sensors); + } + + // If we have any DemonstrationWriters, write the AgentInfo and sensors to them. + if (DemonstrationWriters.Count != 0) + { + foreach (var demoWriter in DemonstrationWriters) + { + demoWriter.Record(m_Info, sensors); + } + } + } + + void UpdateSensors() + { + foreach (var sensor in sensors) + { + sensor.Update(); + } + } + + void ResetSensors() + { + foreach (var sensor in sensors) + { + sensor.Reset(); + } + } + + /// + /// Implement `CollectObservations()` to collect the vector observations of + /// the agent for the step. The agent observation describes the current + /// environment from the perspective of the agent. + /// + /// + /// The vector observations for the agent. + /// + /// + /// An agent's observation is any environment information that helps + /// the agent achieve its goal. For example, for a fighting agent, its + /// observation could include distances to friends or enemies, or the + /// current level of ammunition at its disposal. + /// + /// You can use a combination of vector, visual, and raycast observations for an + /// agent. If you only use visual or raycast observations, you do not need to + /// implement a `CollectObservations()` function. + /// + /// Add vector observations to the parameter passed to + /// this method by calling the helper methods: + /// - + /// - + /// - + /// - + /// - + /// - + /// - + /// - + /// + /// You can use any combination of these helper functions to build the agent's + /// vector of observations. You must build the vector in the same order + /// each time `CollectObservations()` is called and the length of the vector + /// must always be the same. In addition, the length of the observation must + /// match the + /// attribute of the linked Brain, which is set in the Editor on the + /// **Behavior Parameters** component attached to the agent's [GameObject]. + /// + /// For more information about observations, see [Observations and Sensors]. + /// + /// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html + /// [Observations and Sensors]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/Learning-Environment-Design-Agents.md#observations-and-sensors + /// + public virtual void CollectObservations(VectorSensor sensor) + { + } + + /// + /// Returns a read-only view of the observations that were generated in + /// . This is mainly useful inside of a + /// method to avoid recomputing the observations. + /// + /// A read-only view of the observations list. + public ReadOnlyCollection GetObservations() + { + return collectObservationsSensor.GetObservations(); + } + + /// + /// Returns a read-only view of the stacked observations that were generated in + /// . This is mainly useful inside of a + /// method to avoid recomputing the observations. + /// + /// A read-only view of the stacked observations list. + public ReadOnlyCollection GetStackedObservations() + { + return stackedCollectObservationsSensor.GetStackedObservations(); + } + + /// + /// Implement `WriteDiscreteActionMask()` to collects the masks for discrete + /// actions. When using discrete actions, the agent will not perform the masked + /// action. + /// + /// + /// The action mask for the agent. + /// + /// + /// When using Discrete Control, you can prevent the Agent from using a certain + /// action by masking it with . + /// + /// See [Agents - Actions] for more information on masking actions. + /// + /// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/Learning-Environment-Design-Agents.md#actions + /// + /// + public virtual void WriteDiscreteActionMask(IDiscreteActionMask actionMask) { } + + /// + /// Implement `OnActionReceived()` to specify agent behavior at every step, based + /// on the provided action. + /// + /// + /// An action is passed to this function in the form of an . + /// Your implementation must use the array to direct the agent's behavior for the + /// current step. + /// + /// You decide how many elements you need in the ActionBuffers to control your + /// agent and what each element means. For example, if you want to apply a + /// force to move an agent around the environment, you can arbitrarily pick + /// three values in ActionBuffers.ContinuousActions array to use as the force components. + /// During training, the agent's policy learns to set those particular elements of + /// the array to maximize the training rewards the agent receives. (Of course, + /// if you implement a function, it must use the same + /// elements of the action array for the same purpose since there is no learning + /// involved.) + /// + /// An Agent can use continuous and/or discrete actions. Configure this along with the size + /// of the action array, in the of the agent's associated + /// component. + /// + /// When an agent uses continuous actions, the values in the ActionBuffers.ContinuousActions + /// array are floating point numbers. You should clamp the values to the range, + /// -1..1, to increase numerical stability during training. + /// + /// When an agent uses discrete actions, the values in the ActionBuffers.DiscreteActions array + /// are integers that each represent a specific, discrete action. For example, + /// you could define a set of discrete actions such as: + /// + /// + /// 0 = Do nothing + /// 1 = Move one space left + /// 2 = Move one space right + /// 3 = Move one space up + /// 4 = Move one space down + /// + /// + /// When making a decision, the agent picks one of the five actions and puts the + /// corresponding integer value in the ActionBuffers.DiscreteActions array. For example, if the agent + /// decided to move left, the ActionBuffers.DiscreteActions parameter would be an array with + /// a single element with the value 1. + /// + /// You can define multiple sets, or branches, of discrete actions to allow an + /// agent to perform simultaneous, independent actions. For example, you could + /// use one branch for movement and another branch for throwing a ball left, right, + /// up, or down, to allow the agent to do both in the same step. + /// + /// The ActionBuffers.DiscreteActions array of an agent with discrete actions contains one + /// element for each branch. The value of each element is the integer representing the + /// chosen action for that branch. The agent always chooses one action for each branch. + /// + /// When you use the discrete actions, you can prevent the training process + /// or the neural network model from choosing specific actions in a step by + /// implementing the + /// method. For example, if your agent is next to a wall, you could mask out any + /// actions that would result in the agent trying to move into the wall. + /// + /// For more information about implementing agent actions see [Agents - Actions]. + /// + /// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/Learning-Environment-Design-Agents.md#actions + /// + /// + /// Struct containing the buffers of actions to be executed at this step. + /// + public virtual void OnActionReceived(ActionBuffers actions) { } + + /// + /// Implement `OnEpisodeBegin()` to set up an Agent instance at the beginning + /// of an episode. + /// + /// + /// + public virtual void OnEpisodeBegin() { } + + /// + /// Gets the most recent ActionBuffer for this agent. + /// + /// The most recent ActionBuffer for this agent + public ActionBuffers GetStoredActionBuffers() + { + return m_ActuatorManager.StoredActions; + } + + /// + /// An internal reset method that updates internal data structures in + /// addition to calling . + /// + void _AgentReset() + { + ResetData(); + m_StepCount = 0; + using (m_OnEpisodeBeginChecker.Start()) + { + OnEpisodeBegin(); + } + } + + /// + /// Scales continuous action from [-1, 1] to arbitrary range. + /// + /// The input action value. + /// The minimum output value. + /// The maximum output value. + /// The scaled from [-1,1] to + /// [, ]. + protected static float ScaleAction(float rawAction, float min, float max) + { + var middle = (min + max) / 2; + var range = (max - min) / 2; + return rawAction * range + middle; + } + + /// + /// Signals the agent that it must send its decision to the brain. + /// + void SendInfo() + { + // If the Agent is done, it has just reset and thus requires a new decision + if (m_RequestDecision) + { + SendInfoToBrain(); + m_Reward = 0f; + m_GroupReward = 0f; + m_RequestDecision = false; + } + } + + void AgentIncrementStep() + { + m_StepCount += 1; + } + + /// Used by the brain to make the agent perform a step. + void AgentStep() + { + if ((m_RequestAction) && (m_Brain != null)) + { + m_RequestAction = false; + m_ActuatorManager.ExecuteActions(); + } + + if ((m_StepCount >= MaxStep) && (MaxStep > 0)) + { + NotifyAgentDone(DoneReason.MaxStepReached); + _AgentReset(); + } + } + + void DecideAction() + { + if (m_ActuatorManager.StoredActions.ContinuousActions.Array == null) + { + ResetData(); + } + var actions = m_Brain?.DecideAction() ?? new ActionBuffers(); + m_Info.CopyActions(actions); + m_ActuatorManager.UpdateActions(actions); + } + + internal void SetMultiAgentGroup(IMultiAgentGroup multiAgentGroup) + { + if (multiAgentGroup == null) + { + m_GroupId = 0; + } + else + { + var newGroupId = multiAgentGroup.GetId(); + if (m_GroupId == 0 || m_GroupId == newGroupId) + { + m_GroupId = newGroupId; + } + else + { + throw new UnityAgentsException("Agent is already registered with a group. Unregister it first."); + } + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Agent.cs.meta b/com.unity.ml-agents/Runtime/Agent.cs.meta new file mode 100644 index 0000000000..5463d244fb --- /dev/null +++ b/com.unity.ml-agents/Runtime/Agent.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 88b6042bc9a5d4aa58d931eae49442e5 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Analytics.meta b/com.unity.ml-agents/Runtime/Analytics.meta new file mode 100644 index 0000000000..260b85a9b3 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Analytics.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 8b12ac54c5224758af88c67e2af4a01e +timeCreated: 1604359666 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs b/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs new file mode 100644 index 0000000000..b206f6bd98 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs @@ -0,0 +1,66 @@ +using System; +using System.Text; +using System.Security.Cryptography; +using UnityEngine; + +namespace Unity.MLAgents.Analytics +{ + + internal static class AnalyticsUtils + { + /// + /// Conversion function from byte array to hex string + /// + /// + /// A byte array to be hex encoded. + private static string ToHexString(byte[] array) + { + StringBuilder hex = new StringBuilder(array.Length * 2); + foreach (byte b in array) + { + hex.AppendFormat("{0:x2}", b); + } + return hex.ToString(); + } + + /// + /// Hash a string to remove PII or secret info before sending to analytics + /// + /// + /// A string containing the key to be used for HMAC encoding. + /// + /// A string containing the value to be encoded. + public static string Hash(string key, string value) + { + string hash; + UTF8Encoding encoder = new UTF8Encoding(); + using (HMACSHA256 hmac = new HMACSHA256(encoder.GetBytes(key))) + { + Byte[] hmBytes = hmac.ComputeHash(encoder.GetBytes(value)); + hash = ToHexString(hmBytes); + } + return hash; + } + + internal static bool s_SendEditorAnalytics = true; + + /// + /// Helper class to temporarily disable sending analytics from unit tests. + /// + internal class DisableAnalyticsSending : IDisposable + { + private bool m_PreviousSendEditorAnalytics; + + public DisableAnalyticsSending() + { + m_PreviousSendEditorAnalytics = s_SendEditorAnalytics; + s_SendEditorAnalytics = false; + } + + public void Dispose() + { + s_SendEditorAnalytics = m_PreviousSendEditorAnalytics; + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs.meta b/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs.meta new file mode 100644 index 0000000000..b00fab1c90 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: af1ef3e70f1242938d7b39284b1a892b +timeCreated: 1610575760 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Analytics/Events.cs b/com.unity.ml-agents/Runtime/Analytics/Events.cs new file mode 100644 index 0000000000..4a34273c04 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Analytics/Events.cs @@ -0,0 +1,194 @@ +using System; +using System.Collections.Generic; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Analytics +{ + internal struct InferenceEvent + { + /// + /// Hash of the BehaviorName. + /// + public string BehaviorName; + public string BarracudaModelSource; + public string BarracudaModelVersion; + public string BarracudaModelProducer; + public string BarracudaPackageVersion; + /// + /// Whether inference is performed on CPU (0) or GPU (1). + /// + public int InferenceDevice; + public List ObservationSpecs; + public EventActionSpec ActionSpec; + public List ActuatorInfos; + public int MemorySize; + public long TotalWeightSizeBytes; + public string ModelHash; + } + + /// + /// Simplified version of ActionSpec struct for use in analytics + /// + [Serializable] + internal struct EventActionSpec + { + public int NumContinuousActions; + public int NumDiscreteActions; + public int[] BranchSizes; + + public static EventActionSpec FromActionSpec(ActionSpec actionSpec) + { + var branchSizes = actionSpec.BranchSizes ?? Array.Empty(); + return new EventActionSpec + { + NumContinuousActions = actionSpec.NumContinuousActions, + NumDiscreteActions = actionSpec.NumDiscreteActions, + BranchSizes = branchSizes, + }; + } + } + + /// + /// Information about an actuator. + /// + [Serializable] + internal struct EventActuatorInfo + { + public int BuiltInActuatorType; + public int NumContinuousActions; + public int NumDiscreteActions; + + public static EventActuatorInfo FromActuator(IActuator actuator) + { + BuiltInActuatorType builtInActuatorType = Actuators.BuiltInActuatorType.Unknown; + if (actuator is IBuiltInActuator builtInActuator) + { + builtInActuatorType = builtInActuator.GetBuiltInActuatorType(); + } + + var actionSpec = actuator.ActionSpec; + + return new EventActuatorInfo + { + BuiltInActuatorType = (int)builtInActuatorType, + NumContinuousActions = actionSpec.NumContinuousActions, + NumDiscreteActions = actionSpec.NumDiscreteActions + }; + } + } + + /// + /// Information about one dimension of an observation. + /// + [Serializable] + internal struct EventObservationDimensionInfo + { + public int Size; + public int Flags; + } + + /// + /// Simplified summary of Agent observations for use in analytics + /// + [Serializable] + internal struct EventObservationSpec + { + public string SensorName; + public string CompressionType; + public int BuiltInSensorType; + public int ObservationType; + public EventObservationDimensionInfo[] DimensionInfos; + + public static EventObservationSpec FromSensor(ISensor sensor) + { + var obsSpec = sensor.GetObservationSpec(); + var shape = obsSpec.Shape; + var dimProps = obsSpec.DimensionProperties; + var dimInfos = new EventObservationDimensionInfo[shape.Length]; + for (var i = 0; i < shape.Length; i++) + { + dimInfos[i].Size = shape[i]; + dimInfos[i].Flags = (int)dimProps[i]; + } + + var builtInSensorType = + (sensor as IBuiltInSensor)?.GetBuiltInSensorType() ?? Sensors.BuiltInSensorType.Unknown; + + return new EventObservationSpec + { + SensorName = sensor.GetName(), + CompressionType = sensor.GetCompressionSpec().SensorCompressionType.ToString(), + BuiltInSensorType = (int)builtInSensorType, + ObservationType = (int)obsSpec.ObservationType, + DimensionInfos = dimInfos, + }; + } + } + + internal struct RemotePolicyInitializedEvent + { + public string TrainingSessionGuid; + /// + /// Hash of the BehaviorName. + /// + public string BehaviorName; + public List ObservationSpecs; + public EventActionSpec ActionSpec; + public List ActuatorInfos; + + /// + /// This will be the same as TrainingEnvironmentInitializedEvent if available, but + /// TrainingEnvironmentInitializedEvent maybe not always be available with older trainers. + /// + public string MLAgentsEnvsVersion; + public string TrainerCommunicationVersion; + } + + internal struct TrainingEnvironmentInitializedEvent + { + public string TrainingSessionGuid; + + public string TrainerPythonVersion; + public string MLAgentsVersion; + public string MLAgentsEnvsVersion; + public string TorchVersion; + public string TorchDeviceType; + public int NumEnvironments; + public int NumEnvironmentParameters; + public string RunOptions; + } + + [Flags] + internal enum RewardSignals + { + Extrinsic = 1 << 0, + Gail = 1 << 1, + Curiosity = 1 << 2, + Rnd = 1 << 3, + } + + [Flags] + internal enum TrainingFeatures + { + BehavioralCloning = 1 << 0, + Recurrent = 1 << 1, + Threaded = 1 << 2, + SelfPlay = 1 << 3, + Curriculum = 1 << 4, + } + + internal struct TrainingBehaviorInitializedEvent + { + public string TrainingSessionGuid; + + public string BehaviorName; + public string TrainerType; + public RewardSignals RewardSignalFlags; + public TrainingFeatures TrainingFeatureFlags; + public string VisualEncoder; + public int NumNetworkLayers; + public int NumNetworkHiddenUnits; + public string Config; + } +} diff --git a/com.unity.ml-agents/Runtime/Analytics/Events.cs.meta b/com.unity.ml-agents/Runtime/Analytics/Events.cs.meta new file mode 100644 index 0000000000..347eebcd51 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Analytics/Events.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 0a0d7cda6d74425a80775769a9283ba6 +timeCreated: 1604359798 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs b/com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs new file mode 100644 index 0000000000..b7b466155a --- /dev/null +++ b/com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs @@ -0,0 +1,283 @@ +using System.Collections.Generic; +using System.Diagnostics; +using Unity.Barracuda; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Inference; +using Unity.MLAgents.Policies; +using Unity.MLAgents.Sensors; +using UnityEngine; + +#if MLA_UNITY_ANALYTICS_MODULE && ENABLE_CLOUD_SERVICES_ANALYTICS +using UnityEngine.Analytics; +#endif + + +#if UNITY_EDITOR +using UnityEditor; +#if MLA_UNITY_ANALYTICS_MODULE +using UnityEditor.Analytics; +#endif // MLA_UNITY_ANALYTICS_MODULE +#endif // UNITY_EDITOR + + +namespace Unity.MLAgents.Analytics +{ + internal class InferenceAnalytics + { + const string k_VendorKey = "unity.ml-agents"; + const string k_EventName = "ml_agents_inferencemodelset"; + const int k_EventVersion = 1; + + /// + /// Whether or not we've registered this particular event yet + /// + static bool s_EventRegistered; + + /// + /// Hourly limit for this event name + /// + const int k_MaxEventsPerHour = 1000; + + /// + /// Maximum number of items in this event. + /// + const int k_MaxNumberOfElements = 1000; + + +#if UNITY_EDITOR && MLA_UNITY_ANALYTICS_MODULE && ENABLE_CLOUD_SERVICES_ANALYTICS + /// + /// Models that we've already sent events for. + /// + private static HashSet s_SentModels; +#endif + + static bool EnableAnalytics() + { +#if UNITY_EDITOR && MLA_UNITY_ANALYTICS_MODULE && ENABLE_CLOUD_SERVICES_ANALYTICS + if (s_EventRegistered) + { + return true; + } + + AnalyticsResult result = EditorAnalytics.RegisterEventWithLimit(k_EventName, k_MaxEventsPerHour, k_MaxNumberOfElements, k_VendorKey, k_EventVersion); + if (result == AnalyticsResult.Ok) + { + s_EventRegistered = true; + } + if (s_EventRegistered && s_SentModels == null) + { + s_SentModels = new HashSet(); + } + +#else // no editor, no analytics + s_EventRegistered = false; +#endif + return s_EventRegistered; + } + + public static bool IsAnalyticsEnabled() + { +#if UNITY_EDITOR + return EditorAnalytics.enabled; +#else + return false; +#endif + } + + /// + /// Send an analytics event for the NNModel when it is set up for inference. + /// No events will be sent if analytics are disabled, and at most one event + /// will be sent per model instance. + /// + /// The NNModel being used for inference. + /// The BehaviorName of the Agent using the model + /// Whether inference is being performed on the CPU or GPU + /// List of ISensors for the Agent. Used to generate information about the observation space. + /// ActionSpec for the Agent. Used to generate information about the action space. + /// List of IActuators for the Agent. Used to generate information about the action space. + /// + [Conditional("MLA_UNITY_ANALYTICS_MODULE")] + public static void InferenceModelSet( + NNModel nnModel, + string behaviorName, + InferenceDevice inferenceDevice, + IList sensors, + ActionSpec actionSpec, + IList actuators + ) + { +#if UNITY_EDITOR && MLA_UNITY_ANALYTICS_MODULE && ENABLE_CLOUD_SERVICES_ANALYTICS + // The event shouldn't be able to report if this is disabled but if we know we're not going to report + // Lets early out and not waste time gathering all the data + if (!IsAnalyticsEnabled()) + return; + + if (!EnableAnalytics()) + return; + + var added = s_SentModels.Add(nnModel); + + if (!added) + { + // We previously added this model. Exit so we don't resend. + return; + } + + var data = GetEventForModel(nnModel, behaviorName, inferenceDevice, sensors, actionSpec, actuators); + // Note - to debug, use JsonUtility.ToJson on the event. + // Debug.Log(JsonUtility.ToJson(data, true)); + if (AnalyticsUtils.s_SendEditorAnalytics) + { + EditorAnalytics.SendEventWithLimit(k_EventName, data, k_EventVersion); + } +#endif + } + + /// + /// Generate an InferenceEvent for the model. + /// + /// + /// + /// + /// + /// + /// + /// + internal static InferenceEvent GetEventForModel( + NNModel nnModel, + string behaviorName, + InferenceDevice inferenceDevice, + IList sensors, + ActionSpec actionSpec, + IList actuators + ) + { + var barracudaModel = ModelLoader.Load(nnModel); + var inferenceEvent = new InferenceEvent(); + + // Hash the behavior name so that there's no concern about PII or "secret" data being leaked. + inferenceEvent.BehaviorName = AnalyticsUtils.Hash(k_VendorKey, behaviorName); + + inferenceEvent.BarracudaModelSource = barracudaModel.IrSource; + inferenceEvent.BarracudaModelVersion = barracudaModel.IrVersion; + inferenceEvent.BarracudaModelProducer = barracudaModel.ProducerName; + inferenceEvent.MemorySize = (int)barracudaModel.GetTensorByName(TensorNames.MemorySize)[0]; + inferenceEvent.InferenceDevice = (int)inferenceDevice; + + if (barracudaModel.ProducerName == "Script") + { + // .nn files don't have these fields set correctly. Assign some placeholder values. + inferenceEvent.BarracudaModelSource = "NN"; + inferenceEvent.BarracudaModelProducer = "tensorflow_to_barracuda.py"; + } + +#if UNITY_EDITOR + var barracudaPackageInfo = UnityEditor.PackageManager.PackageInfo.FindForAssembly(typeof(Tensor).Assembly); + inferenceEvent.BarracudaPackageVersion = barracudaPackageInfo.version; +#else + inferenceEvent.BarracudaPackageVersion = null; +#endif + + inferenceEvent.ActionSpec = EventActionSpec.FromActionSpec(actionSpec); + inferenceEvent.ObservationSpecs = new List(sensors.Count); + foreach (var sensor in sensors) + { + inferenceEvent.ObservationSpecs.Add(EventObservationSpec.FromSensor(sensor)); + } + + inferenceEvent.ActuatorInfos = new List(actuators.Count); + foreach (var actuator in actuators) + { + inferenceEvent.ActuatorInfos.Add(EventActuatorInfo.FromActuator(actuator)); + } + + inferenceEvent.TotalWeightSizeBytes = GetModelWeightSize(barracudaModel); + inferenceEvent.ModelHash = GetModelHash(barracudaModel); + return inferenceEvent; + } + + /// + /// Compute the total model weight size in bytes. + /// This corresponds to the "Total weight size" display in the Barracuda inspector, + /// and the calculations are the same. + /// + /// + /// + static long GetModelWeightSize(Model barracudaModel) + { + long totalWeightsSizeInBytes = 0; + for (var l = 0; l < barracudaModel.layers.Count; ++l) + { + for (var d = 0; d < barracudaModel.layers[l].datasets.Length; ++d) + { + totalWeightsSizeInBytes += barracudaModel.layers[l].datasets[d].length; + } + } + return totalWeightsSizeInBytes; + } + + /// + /// Wrapper around Hash128 that supports Append(float[], int, int) + /// + struct MLAgentsHash128 + { + private Hash128 m_Hash; + + public void Append(float[] values, int count) + { + if (values == null) + { + return; + } + + // Pre-2020 versions of Unity don't have Hash128.Append() (can only hash strings and scalars) + // For these versions, we'll hash element by element. +#if UNITY_2020_1_OR_NEWER + m_Hash.Append(values, 0, count); +#else + for (var i = 0; i < count; i++) + { + var tempHash = new Hash128(); + HashUtilities.ComputeHash128(ref values[i], ref tempHash); + HashUtilities.AppendHash(ref tempHash, ref m_Hash); + } +#endif + } + + public void Append(string value) + { + var tempHash = Hash128.Compute(value); + HashUtilities.AppendHash(ref tempHash, ref m_Hash); + } + + public override string ToString() + { + return m_Hash.ToString(); + } + } + + /// + /// Compute a hash of the model's layer data and return it as a string. + /// A subset of the layer weights are used for performance. + /// This increases the chance of a collision, but this should still be extremely rare. + /// + /// + /// + static string GetModelHash(Model barracudaModel) + { + var hash = new MLAgentsHash128(); + + // Limit the max number of float bytes that we hash for performance. + const int kMaxFloats = 256; + + foreach (var layer in barracudaModel.layers) + { + hash.Append(layer.name); + var numFloatsToHash = Mathf.Min(layer.weights.Length, kMaxFloats); + hash.Append(layer.weights, numFloatsToHash); + } + + return hash.ToString(); + } + } +} diff --git a/com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs.meta b/com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs.meta new file mode 100644 index 0000000000..e81b2ecbb6 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: ac4c40c2394d481ebf602caa600a32f3 +timeCreated: 1604359787 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs b/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs new file mode 100644 index 0000000000..08c205bfc6 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs @@ -0,0 +1,276 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Sensors; +using UnityEngine; +#if MLA_UNITY_ANALYTICS_MODULE + +#if ENABLE_CLOUD_SERVICES_ANALYTICS +using UnityEngine.Analytics; +#endif + +#if UNITY_EDITOR +using UnityEditor.Analytics; +#endif +#endif + +#if UNITY_EDITOR +using UnityEditor; +#endif + +namespace Unity.MLAgents.Analytics +{ + internal static class TrainingAnalytics + { + const string k_VendorKey = "unity.ml-agents"; + const string k_TrainingEnvironmentInitializedEventName = "ml_agents_training_environment_initialized"; + const string k_TrainingBehaviorInitializedEventName = "ml_agents_training_behavior_initialized"; + const string k_RemotePolicyInitializedEventName = "ml_agents_remote_policy_initialized"; + + private static readonly string[] s_EventNames = + { + k_TrainingEnvironmentInitializedEventName, + k_TrainingBehaviorInitializedEventName, + k_RemotePolicyInitializedEventName + }; + + /// + /// Hourly limit for this event name + /// + const int k_MaxEventsPerHour = 1000; + + /// + /// Maximum number of items in this event. + /// + const int k_MaxNumberOfElements = 1000; + + private static bool s_SentEnvironmentInitialized; + +#if UNITY_EDITOR && MLA_UNITY_ANALYTICS_MODULE && ENABLE_CLOUD_SERVICES_ANALYTICS + /// + /// Whether or not we've registered this particular event yet + /// + static bool s_EventsRegistered; + + /// + /// Behaviors that we've already sent events for. + /// + private static HashSet s_SentRemotePolicyInitialized; + private static HashSet s_SentTrainingBehaviorInitialized; +#endif + + private static Guid s_TrainingSessionGuid; + + // These are set when the RpcCommunicator connects + private static string s_TrainerPackageVersion = ""; + private static string s_TrainerCommunicationVersion = ""; + + internal static bool EnableAnalytics() + { +#if UNITY_EDITOR && MLA_UNITY_ANALYTICS_MODULE && ENABLE_CLOUD_SERVICES_ANALYTICS + if (s_EventsRegistered) + { + return true; + } + foreach (var eventName in s_EventNames) + { + AnalyticsResult result = EditorAnalytics.RegisterEventWithLimit(eventName, k_MaxEventsPerHour, k_MaxNumberOfElements, k_VendorKey); + if (result != AnalyticsResult.Ok) + { + return false; + } + } + s_EventsRegistered = true; + + if (s_SentRemotePolicyInitialized == null) + { + s_SentRemotePolicyInitialized = new HashSet(); + s_SentTrainingBehaviorInitialized = new HashSet(); + s_TrainingSessionGuid = Guid.NewGuid(); + } + + return s_EventsRegistered; +#else + return false; +#endif // MLA_UNITY_ANALYTICS_MODULE + } + + /// + /// Cache information about the trainer when it becomes available in the RpcCommunicator. + /// + /// + /// + [Conditional("MLA_UNITY_ANALYTICS_MODULE")] + public static void SetTrainerInformation(string packageVersion, string communicationVersion) + { + s_TrainerPackageVersion = packageVersion; + s_TrainerCommunicationVersion = communicationVersion; + } + + public static bool IsAnalyticsEnabled() + { +#if UNITY_EDITOR && MLA_UNITY_ANALYTICS_MODULE && ENABLE_CLOUD_SERVICES_ANALYTICS + return EditorAnalytics.enabled; +#else + return false; +#endif + } + + [Conditional("MLA_UNITY_ANALYTICS_MODULE")] + public static void TrainingEnvironmentInitialized(TrainingEnvironmentInitializedEvent tbiEvent) + { + if (!IsAnalyticsEnabled()) + return; + + if (!EnableAnalytics()) + return; + + if (s_SentEnvironmentInitialized) + { + // We already sent an TrainingEnvironmentInitializedEvent. Exit so we don't resend. + return; + } + + s_SentEnvironmentInitialized = true; + tbiEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString(); + + // Note - to debug, use JsonUtility.ToJson on the event. + // Debug.Log( + // $"Would send event {k_TrainingEnvironmentInitializedEventName} with body {JsonUtility.ToJson(tbiEvent, true)}" + // ); +#if UNITY_EDITOR && MLA_UNITY_ANALYTICS_MODULE && ENABLE_CLOUD_SERVICES_ANALYTICS + if (AnalyticsUtils.s_SendEditorAnalytics) + { + EditorAnalytics.SendEventWithLimit(k_TrainingEnvironmentInitializedEventName, tbiEvent); + } +#endif + } + + [Conditional("MLA_UNITY_ANALYTICS_MODULE")] + public static void RemotePolicyInitialized( + string fullyQualifiedBehaviorName, + IList sensors, + ActionSpec actionSpec, + IList actuators + ) + { +#if UNITY_EDITOR && MLA_UNITY_ANALYTICS_MODULE && ENABLE_CLOUD_SERVICES_ANALYTICS + if (!IsAnalyticsEnabled()) + return; + + if (!EnableAnalytics()) + return; + + // Extract base behavior name (no team ID) + var behaviorName = ParseBehaviorName(fullyQualifiedBehaviorName); + var added = s_SentRemotePolicyInitialized.Add(behaviorName); + + if (!added) + { + // We previously added this model. Exit so we don't resend. + return; + } + + var data = GetEventForRemotePolicy(behaviorName, sensors, actionSpec, actuators); + // Note - to debug, use JsonUtility.ToJson on the event. + // Debug.Log( + // $"Would send event {k_RemotePolicyInitializedEventName} with body {JsonUtility.ToJson(data, true)}" + // ); + if (AnalyticsUtils.s_SendEditorAnalytics) + { + EditorAnalytics.SendEventWithLimit(k_RemotePolicyInitializedEventName, data); + } +#endif + } + + internal static string ParseBehaviorName(string fullyQualifiedBehaviorName) + { + var lastQuestionIndex = fullyQualifiedBehaviorName.LastIndexOf("?"); + if (lastQuestionIndex < 0) + { + // Nothing to remove + return fullyQualifiedBehaviorName; + } + + return fullyQualifiedBehaviorName.Substring(0, lastQuestionIndex); + } + + internal static TrainingBehaviorInitializedEvent SanitizeTrainingBehaviorInitializedEvent(TrainingBehaviorInitializedEvent tbiEvent) + { + // Hash the behavior name if the message version is from an older version of ml-agents that doesn't do trainer-side hashing. + // We'll also, for extra safety, verify that the BehaviorName is the size of the expected SHA256 hash. + // Context: The config field was added at the same time as trainer side hashing, so messages including it should already be hashed. + if (tbiEvent.Config.Length == 0 || tbiEvent.BehaviorName.Length != 64) + { + tbiEvent.BehaviorName = AnalyticsUtils.Hash(k_VendorKey, tbiEvent.BehaviorName); + } + + return tbiEvent; + } + + [Conditional("MLA_UNITY_ANALYTICS_MODULE")] + public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent rawTbiEvent) + { +#if UNITY_EDITOR && MLA_UNITY_ANALYTICS_MODULE && ENABLE_CLOUD_SERVICES_ANALYTICS + if (!IsAnalyticsEnabled()) + return; + + if (!EnableAnalytics()) + return; + + var tbiEvent = SanitizeTrainingBehaviorInitializedEvent(rawTbiEvent); + var behaviorName = tbiEvent.BehaviorName; + var added = s_SentTrainingBehaviorInitialized.Add(behaviorName); + + if (!added) + { + // We previously added this model. Exit so we don't resend. + return; + } + + tbiEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString(); + + // Note - to debug, use JsonUtility.ToJson on the event. + // Debug.Log( + // $"Would send event {k_TrainingBehaviorInitializedEventName} with body {JsonUtility.ToJson(tbiEvent, true)}" + // ); + if (AnalyticsUtils.s_SendEditorAnalytics) + { + EditorAnalytics.SendEventWithLimit(k_TrainingBehaviorInitializedEventName, tbiEvent); + } +#endif + } + + internal static RemotePolicyInitializedEvent GetEventForRemotePolicy( + string behaviorName, + IList sensors, + ActionSpec actionSpec, + IList actuators + ) + { + var remotePolicyEvent = new RemotePolicyInitializedEvent(); + + // Hash the behavior name so that there's no concern about PII or "secret" data being leaked. + remotePolicyEvent.BehaviorName = AnalyticsUtils.Hash(k_VendorKey, behaviorName); + + remotePolicyEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString(); + remotePolicyEvent.ActionSpec = EventActionSpec.FromActionSpec(actionSpec); + remotePolicyEvent.ObservationSpecs = new List(sensors.Count); + foreach (var sensor in sensors) + { + remotePolicyEvent.ObservationSpecs.Add(EventObservationSpec.FromSensor(sensor)); + } + + remotePolicyEvent.ActuatorInfos = new List(actuators.Count); + foreach (var actuator in actuators) + { + remotePolicyEvent.ActuatorInfos.Add(EventActuatorInfo.FromActuator(actuator)); + } + + remotePolicyEvent.MLAgentsEnvsVersion = s_TrainerPackageVersion; + remotePolicyEvent.TrainerCommunicationVersion = s_TrainerCommunicationVersion; + return remotePolicyEvent; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs.meta b/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs.meta new file mode 100644 index 0000000000..9109c265a2 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 5ad0bc6b45614bb7929d25dd59d5ac38 +timeCreated: 1608168600 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Areas.meta b/com.unity.ml-agents/Runtime/Areas.meta new file mode 100644 index 0000000000..d00b0cf67c --- /dev/null +++ b/com.unity.ml-agents/Runtime/Areas.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 4774a04ed09a1405cb957aace235adcb +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Areas/TrainingAreaReplicator.cs b/com.unity.ml-agents/Runtime/Areas/TrainingAreaReplicator.cs new file mode 100644 index 0000000000..ef4a9d0633 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Areas/TrainingAreaReplicator.cs @@ -0,0 +1,114 @@ +using System; +using Unity.Mathematics; +using UnityEngine; + +namespace Unity.MLAgents.Areas +{ + /// + /// The Training Ares Replicator allows for a training area object group to be replicated dynamically during runtime. + /// + [DefaultExecutionOrder(-5)] + public class TrainingAreaReplicator : MonoBehaviour + { + /// + /// The base training area to be replicated. + /// + public GameObject baseArea; + + /// + /// The number of training areas to replicate. + /// + public int numAreas = 1; + + /// + /// The separation between each training area. + /// + public float separation = 10f; + + int3 m_GridSize = new int3(1, 1, 1); + int m_areaCount = 0; + string m_TrainingAreaName; + + /// + /// The size of the computed grid to pack the training areas into. + /// + public int3 GridSize => m_GridSize; + + /// + /// The name of the training area. + /// + public string TrainingAreaName => m_TrainingAreaName; + + /// + /// Called before the simulation begins to computed the grid size for distributing + /// the replicated training areas and set the area name. + /// + public void Awake() + { + // Computes the Grid Size on Awake + ComputeGridSize(); + // Sets the TrainingArea name to the name of the base area. + m_TrainingAreaName = baseArea.name; + } + + /// + /// Called after Awake and before the simulation begins and adds the training areas before + /// the Academy begins. + /// + public void OnEnable() + { + // Adds the training are replicas during OnEnable to ensure they are added before the Academy begins its work. + AddEnvironments(); + } + + /// + /// Computes the Grid Size for replicating the training area. + /// + void ComputeGridSize() + { + // check if running inference, if so, use the num areas set through the component, + // otherwise, pull it from the academy + if (Academy.Instance.Communicator != null) + numAreas = Academy.Instance.NumAreas; + + var rootNumAreas = Mathf.Pow(numAreas, 1.0f / 3.0f); + m_GridSize.x = Mathf.CeilToInt(rootNumAreas); + m_GridSize.y = Mathf.CeilToInt(rootNumAreas); + var zSize = Mathf.CeilToInt((float)numAreas / (m_GridSize.x * m_GridSize.y)); + m_GridSize.z = zSize == 0 ? 1 : zSize; + } + + /// + /// Adds replicas of the training area to the scene. + /// + /// + void AddEnvironments() + { + if (numAreas > m_GridSize.x * m_GridSize.y * m_GridSize.z) + { + throw new UnityAgentsException("The number of training areas that you have specified exceeds the size of the grid."); + } + + for (int z = 0; z < m_GridSize.z; z++) + { + for (int y = 0; y < m_GridSize.y; y++) + { + for (int x = 0; x < m_GridSize.x; x++) + { + if (m_areaCount == 0) + { + // Skip this first area since it already exists. + m_areaCount = 1; + } + else if (m_areaCount < numAreas) + { + m_areaCount++; + var area = Instantiate(baseArea, new Vector3(x * separation, y * separation, z * separation), Quaternion.identity); + area.name = m_TrainingAreaName; + } + } + } + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Areas/TrainingAreaReplicator.cs.meta b/com.unity.ml-agents/Runtime/Areas/TrainingAreaReplicator.cs.meta new file mode 100644 index 0000000000..84ac36d789 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Areas/TrainingAreaReplicator.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 7fc26c3bda6fe4937b2264ffe43190b7 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/AssemblyInfo.cs b/com.unity.ml-agents/Runtime/AssemblyInfo.cs new file mode 100644 index 0000000000..377c8b0870 --- /dev/null +++ b/com.unity.ml-agents/Runtime/AssemblyInfo.cs @@ -0,0 +1,14 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Tests")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Runtime.Sensor.Tests")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Runtime.Utils.Tests")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Runtime.Tests")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.Input")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.Tests")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Pro")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Pro.Tests")] + diff --git a/com.unity.ml-agents/Runtime/AssemblyInfo.cs.meta b/com.unity.ml-agents/Runtime/AssemblyInfo.cs.meta new file mode 100644 index 0000000000..1672ad458e --- /dev/null +++ b/com.unity.ml-agents/Runtime/AssemblyInfo.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: b433ecadea36c4af9a3dc65e359a3ca0 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Communicator.meta b/com.unity.ml-agents/Runtime/Communicator.meta new file mode 100644 index 0000000000..dc3a8bac9b --- /dev/null +++ b/com.unity.ml-agents/Runtime/Communicator.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 432bb08962b3944c6964c0db6af43669 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Communicator/CommunicatorFactory.cs b/com.unity.ml-agents/Runtime/Communicator/CommunicatorFactory.cs new file mode 100644 index 0000000000..02d1e4efbd --- /dev/null +++ b/com.unity.ml-agents/Runtime/Communicator/CommunicatorFactory.cs @@ -0,0 +1,43 @@ +using System; + +namespace Unity.MLAgents +{ + /// + /// Factory class for an ICommunicator instance. This is used to the at startup. + /// By default, on desktop platforms, an ICommunicator will be created and attempt to connect + /// to a trainer. This behavior can be prevented by setting to false + /// *before* the is initialized. + /// + public static class CommunicatorFactory + { + static Func s_Creator; + static bool s_Enabled = true; + + /// + /// Whether or not an ICommunicator instance will be created when the is initialized. + /// Changing this has no effect after the has already been initialized. + /// + public static bool Enabled + { + get => s_Enabled; + set => s_Enabled = value; + } + + public static bool CommunicatorRegistered => s_Creator != null; + + internal static ICommunicator Create() + { + return s_Enabled ? s_Creator() : null; + } + + public static void Register(Func creator) where T : ICommunicator + { + s_Creator = () => creator(); + } + + public static void ClearCreator() + { + s_Creator = null; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Communicator/CommunicatorFactory.cs.meta b/com.unity.ml-agents/Runtime/Communicator/CommunicatorFactory.cs.meta new file mode 100644 index 0000000000..1d208003e3 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Communicator/CommunicatorFactory.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 0b604cddc07e4484a2cdaba630a971ea +timeCreated: 1613617949 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs new file mode 100644 index 0000000000..e5a97cd167 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -0,0 +1,541 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Google.Protobuf; +using Unity.MLAgents.CommunicatorObjects; +using UnityEngine; +using System.Runtime.CompilerServices; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Demonstrations; +using Unity.MLAgents.Policies; + +using Unity.MLAgents.Analytics; + +[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Runtime.Utils.Tests")] + +namespace Unity.MLAgents +{ + internal static class GrpcExtensions + { + #region AgentInfo + /// + /// Static flag to make sure that we only fire the warning once. + /// + private static bool s_HaveWarnedTrainerCapabilitiesAgentGroup; + + /// + /// Converts a AgentInfo to a protobuf generated AgentInfoActionPairProto + /// + /// The protobuf version of the AgentInfoActionPairProto. + public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai) + { + var agentInfoProto = ai.ToAgentInfoProto(); + + var agentActionProto = new AgentActionProto(); + + if (!ai.storedActions.IsEmpty()) + { + if (!ai.storedActions.ContinuousActions.IsEmpty()) + { + agentActionProto.ContinuousActions.AddRange(ai.storedActions.ContinuousActions.Array); + } + if (!ai.storedActions.DiscreteActions.IsEmpty()) + { + agentActionProto.DiscreteActions.AddRange(ai.storedActions.DiscreteActions.Array); + } + } + + return new AgentInfoActionPairProto + { + AgentInfo = agentInfoProto, + ActionInfo = agentActionProto + }; + } + + /// + /// Converts a AgentInfo to a protobuf generated AgentInfoProto + /// + /// The protobuf version of the AgentInfo. + public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai) + { + if (ai.groupId > 0) + { + var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.MultiAgentGroups; + if (!trainerCanHandle) + { + if (!s_HaveWarnedTrainerCapabilitiesAgentGroup) + { + Debug.LogWarning( + $"Attached trainer doesn't support Multi Agent Groups; group rewards will be ignored." + + "Please find the versions that work best together from our release page: " + + "https://github.com/Unity-Technologies/ml-agents/releases" + ); + s_HaveWarnedTrainerCapabilitiesAgentGroup = true; + } + } + } + var agentInfoProto = new AgentInfoProto + { + Reward = ai.reward, + GroupReward = ai.groupReward, + MaxStepReached = ai.maxStepReached, + Done = ai.done, + Id = ai.episodeId, + GroupId = ai.groupId, + }; + + if (ai.discreteActionMasks != null) + { + agentInfoProto.ActionMask.AddRange(ai.discreteActionMasks); + } + + return agentInfoProto; + } + + /// + /// Get summaries for the observations in the AgentInfo part of the AgentInfoActionPairProto. + /// + /// + /// + public static List GetObservationSummaries(this AgentInfoActionPairProto infoActionPair) + { + List summariesOut = new List(); + var agentInfo = infoActionPair.AgentInfo; + foreach (var obs in agentInfo.Observations) + { + var summary = new ObservationSummary(); + summary.shape = obs.Shape.ToArray(); + summariesOut.Add(summary); + } + + return summariesOut; + } + + #endregion + + #region BrainParameters + /// + /// Converts a BrainParameters into to a BrainParametersProto so it can be sent. + /// + /// The BrainInfoProto generated. + /// The instance of BrainParameter to extend. + /// The name of the brain. + /// Whether or not the Brain is training. + public static BrainParametersProto ToProto(this BrainParameters bp, string name, bool isTraining) + { + // Disable deprecation warnings so we can set legacy fields +#pragma warning disable CS0618 + var brainParametersProto = new BrainParametersProto + { + VectorActionSpaceTypeDeprecated = (SpaceTypeProto)bp.VectorActionSpaceType, + BrainName = name, + IsTraining = isTraining, + ActionSpec = ToActionSpecProto(bp.ActionSpec), + }; + if (bp.VectorActionSize != null) + { + brainParametersProto.VectorActionSizeDeprecated.AddRange(bp.VectorActionSize); + } + if (bp.VectorActionDescriptions != null) + { + brainParametersProto.VectorActionDescriptionsDeprecated.AddRange(bp.VectorActionDescriptions); + } +#pragma warning restore CS0618 + return brainParametersProto; + } + + /// + /// Converts an ActionSpec into to a Protobuf BrainInfoProto so it can be sent. + /// + /// The BrainInfoProto generated. + /// Description of the actions for the Agent. + /// The name of the brain. + /// Whether or not the Brain is training. + public static BrainParametersProto ToBrainParametersProto(this ActionSpec actionSpec, string name, bool isTraining) + { + var brainParametersProto = new BrainParametersProto + { + BrainName = name, + IsTraining = isTraining, + ActionSpec = ToActionSpecProto(actionSpec), + }; + + var supportHybrid = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.HybridActions; + if (!supportHybrid) + { + actionSpec.CheckAllContinuousOrDiscrete(); + if (actionSpec.NumContinuousActions > 0) + { + brainParametersProto.VectorActionSizeDeprecated.Add(actionSpec.NumContinuousActions); + brainParametersProto.VectorActionSpaceTypeDeprecated = SpaceTypeProto.Continuous; + } + else if (actionSpec.NumDiscreteActions > 0) + { + brainParametersProto.VectorActionSizeDeprecated.AddRange(actionSpec.BranchSizes); + brainParametersProto.VectorActionSpaceTypeDeprecated = SpaceTypeProto.Discrete; + } + } + + // TODO handle ActionDescriptions? + return brainParametersProto; + } + + /// + /// Convert a BrainParametersProto to a BrainParameters struct. + /// + /// An instance of a brain parameters protobuf object. + /// A BrainParameters struct. + public static BrainParameters ToBrainParameters(this BrainParametersProto bpp) + { + ActionSpec actionSpec; + if (bpp.ActionSpec == null) + { + // Disable deprecation warnings so we can set legacy fields +#pragma warning disable CS0618 + var spaceType = (SpaceType)bpp.VectorActionSpaceTypeDeprecated; + if (spaceType == SpaceType.Continuous) + { + actionSpec = ActionSpec.MakeContinuous(bpp.VectorActionSizeDeprecated.ToArray()[0]); + } + else + { + actionSpec = ActionSpec.MakeDiscrete(bpp.VectorActionSizeDeprecated.ToArray()); + } +#pragma warning restore CS0618 + } + else + { + actionSpec = ToActionSpec(bpp.ActionSpec); + } + var bp = new BrainParameters + { + VectorActionDescriptions = bpp.VectorActionDescriptionsDeprecated.ToArray(), + ActionSpec = actionSpec, + }; + return bp; + } + + /// + /// Convert a ActionSpecProto to a ActionSpec struct. + /// + /// An instance of an action spec protobuf object. + /// An ActionSpec struct. + public static ActionSpec ToActionSpec(this ActionSpecProto actionSpecProto) + { + var actionSpec = new ActionSpec(actionSpecProto.NumContinuousActions); + if (actionSpecProto.DiscreteBranchSizes != null) + { + actionSpec.BranchSizes = actionSpecProto.DiscreteBranchSizes.ToArray(); + } + return actionSpec; + } + + /// + /// Convert a ActionSpec struct to a ActionSpecProto. + /// + /// An instance of an action spec struct. + /// An ActionSpecProto. + public static ActionSpecProto ToActionSpecProto(this ActionSpec actionSpec) + { + var actionSpecProto = new ActionSpecProto + { + NumContinuousActions = actionSpec.NumContinuousActions, + NumDiscreteActions = actionSpec.NumDiscreteActions, + }; + if (actionSpec.BranchSizes != null) + { + actionSpecProto.DiscreteBranchSizes.AddRange(actionSpec.BranchSizes); + } + return actionSpecProto; + } + + #endregion + + #region DemonstrationMetaData + /// + /// Convert metadata object to proto object. + /// + public static DemonstrationMetaProto ToProto(this DemonstrationMetaData dm) + { + var demonstrationName = dm.demonstrationName ?? ""; + var demoProto = new DemonstrationMetaProto + { + ApiVersion = DemonstrationMetaData.ApiVersion, + MeanReward = dm.meanReward, + NumberSteps = dm.numberSteps, + NumberEpisodes = dm.numberEpisodes, + DemonstrationName = demonstrationName + }; + return demoProto; + } + + /// + /// Initialize metadata values based on proto object. + /// + public static DemonstrationMetaData ToDemonstrationMetaData(this DemonstrationMetaProto demoProto) + { + var dm = new DemonstrationMetaData + { + numberEpisodes = demoProto.NumberEpisodes, + numberSteps = demoProto.NumberSteps, + meanReward = demoProto.MeanReward, + demonstrationName = demoProto.DemonstrationName + }; + if (demoProto.ApiVersion != DemonstrationMetaData.ApiVersion) + { + throw new Exception("API versions of demonstration are incompatible."); + } + return dm; + } + + #endregion + + public static UnityRLInitParameters ToUnityRLInitParameters(this UnityRLInitializationInputProto inputProto) + { + return new UnityRLInitParameters + { + seed = inputProto.Seed, + numAreas = inputProto.NumAreas, + pythonLibraryVersion = inputProto.PackageVersion, + pythonCommunicationVersion = inputProto.CommunicationVersion, + TrainerCapabilities = inputProto.Capabilities.ToRLCapabilities() + }; + } + + #region AgentAction + public static List ToAgentActionList(this UnityRLInputProto.Types.ListAgentActionProto proto) + { + var agentActions = new List(proto.Value.Count); + foreach (var ap in proto.Value) + { + agentActions.Add(ap.ToActionBuffers()); + } + return agentActions; + } + + public static ActionBuffers ToActionBuffers(this AgentActionProto proto) + { + return new ActionBuffers(proto.ContinuousActions.ToArray(), proto.DiscreteActions.ToArray()); + } + + #endregion + + #region Observations + /// + /// Static flag to make sure that we only fire the warning once. + /// + private static bool s_HaveWarnedTrainerCapabilitiesMultiPng; + private static bool s_HaveWarnedTrainerCapabilitiesMapping; + + /// + /// Generate an ObservationProto for the sensor using the provided ObservationWriter. + /// This is equivalent to producing an Observation and calling Observation.ToProto(), + /// but avoid some intermediate memory allocations. + /// + /// + /// + /// + public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter) + { + var obsSpec = sensor.GetObservationSpec(); + var shape = obsSpec.Shape; + ObservationProto observationProto = null; + var compressionSpec = sensor.GetCompressionSpec(); + var compressionType = compressionSpec.SensorCompressionType; + // Check capabilities if we need to concatenate PNGs + if (compressionType == SensorCompressionType.PNG && shape.Length == 3 && shape[2] > 3) + { + var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations; + if (!trainerCanHandle) + { + if (!s_HaveWarnedTrainerCapabilitiesMultiPng) + { + Debug.LogWarning( + $"Attached trainer doesn't support multiple PNGs. Switching to uncompressed observations for sensor {sensor.GetName()}. " + + "Please find the versions that work best together from our release page: " + + "https://github.com/Unity-Technologies/ml-agents/releases" + ); + s_HaveWarnedTrainerCapabilitiesMultiPng = true; + } + compressionType = SensorCompressionType.None; + } + } + // Check capabilities if we need mapping for compressed observations + if (compressionType != SensorCompressionType.None && shape.Length == 3 && shape[2] > 3) + { + var trainerCanHandleMapping = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.CompressedChannelMapping; + var isTrivialMapping = compressionSpec.IsTrivialMapping(); + if (!trainerCanHandleMapping && !isTrivialMapping) + { + if (!s_HaveWarnedTrainerCapabilitiesMapping) + { + Debug.LogWarning( + $"The sensor {sensor.GetName()} is using non-trivial mapping and " + + "the attached trainer doesn't support compression mapping. " + + "Switching to uncompressed observations. " + + "Please find the versions that work best together from our release page: " + + "https://github.com/Unity-Technologies/ml-agents/releases" + ); + s_HaveWarnedTrainerCapabilitiesMapping = true; + } + compressionType = SensorCompressionType.None; + } + } + + if (compressionType == SensorCompressionType.None) + { + var numFloats = sensor.ObservationSize(); + var floatDataProto = new ObservationProto.Types.FloatData(); + // Resize the float array + // TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530 + for (var i = 0; i < numFloats; i++) + { + floatDataProto.Data.Add(0.0f); + } + + observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationSpec(), 0); + sensor.Write(observationWriter); + + observationProto = new ObservationProto + { + FloatData = floatDataProto, + CompressionType = (CompressionTypeProto)SensorCompressionType.None, + }; + } + else + { + var compressedObs = sensor.GetCompressedObservation(); + if (compressedObs == null) + { + throw new UnityAgentsException( + $"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " + + "You must return a byte[]. If you don't want to use compressed observations, " + + "return CompressionSpec.Default() from GetCompressionSpec()." + ); + } + observationProto = new ObservationProto + { + CompressedData = ByteString.CopyFrom(compressedObs), + CompressionType = (CompressionTypeProto)sensor.GetCompressionSpec().SensorCompressionType, + }; + if (compressionSpec.CompressedChannelMapping != null) + { + observationProto.CompressedChannelMapping.AddRange(compressionSpec.CompressedChannelMapping); + } + } + + // Add the dimension properties to the observationProto + var dimensionProperties = obsSpec.DimensionProperties; + for (int i = 0; i < dimensionProperties.Length; i++) + { + observationProto.DimensionProperties.Add((int)dimensionProperties[i]); + } + + // Checking trainer compatibility with variable length observations + if (dimensionProperties == new InplaceArray(DimensionProperty.VariableSize, DimensionProperty.None)) + { + var trainerCanHandleVarLenObs = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.VariableLengthObservation; + if (!trainerCanHandleVarLenObs) + { + throw new UnityAgentsException("Variable Length Observations are not supported by the trainer"); + } + } + + for (var i = 0; i < shape.Length; i++) + { + observationProto.Shape.Add(shape[i]); + } + + var sensorName = sensor.GetName(); + if (!string.IsNullOrEmpty(sensorName)) + { + observationProto.Name = sensorName; + } + + observationProto.ObservationType = (ObservationTypeProto)obsSpec.ObservationType; + return observationProto; + } + + #endregion + + public static UnityRLCapabilities ToRLCapabilities(this UnityRLCapabilitiesProto proto) + { + return new UnityRLCapabilities + { + BaseRLCapabilities = proto.BaseRLCapabilities, + ConcatenatedPngObservations = proto.ConcatenatedPngObservations, + CompressedChannelMapping = proto.CompressedChannelMapping, + HybridActions = proto.HybridActions, + TrainingAnalytics = proto.TrainingAnalytics, + VariableLengthObservation = proto.VariableLengthObservation, + MultiAgentGroups = proto.MultiAgentGroups, + }; + } + + public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps) + { + return new UnityRLCapabilitiesProto + { + BaseRLCapabilities = rlCaps.BaseRLCapabilities, + ConcatenatedPngObservations = rlCaps.ConcatenatedPngObservations, + CompressedChannelMapping = rlCaps.CompressedChannelMapping, + HybridActions = rlCaps.HybridActions, + TrainingAnalytics = rlCaps.TrainingAnalytics, + VariableLengthObservation = rlCaps.VariableLengthObservation, + MultiAgentGroups = rlCaps.MultiAgentGroups, + }; + } + + #region Analytics + internal static TrainingEnvironmentInitializedEvent ToTrainingEnvironmentInitializedEvent( + this TrainingEnvironmentInitialized inputProto) + { + return new TrainingEnvironmentInitializedEvent + { + TrainerPythonVersion = inputProto.PythonVersion, + MLAgentsVersion = inputProto.MlagentsVersion, + MLAgentsEnvsVersion = inputProto.MlagentsEnvsVersion, + TorchVersion = inputProto.TorchVersion, + TorchDeviceType = inputProto.TorchDeviceType, + NumEnvironments = inputProto.NumEnvs, + NumEnvironmentParameters = inputProto.NumEnvironmentParameters, + RunOptions = inputProto.RunOptions, + }; + } + + internal static TrainingBehaviorInitializedEvent ToTrainingBehaviorInitializedEvent( + this TrainingBehaviorInitialized inputProto) + { + RewardSignals rewardSignals = 0; + rewardSignals |= inputProto.ExtrinsicRewardEnabled ? RewardSignals.Extrinsic : 0; + rewardSignals |= inputProto.GailRewardEnabled ? RewardSignals.Gail : 0; + rewardSignals |= inputProto.CuriosityRewardEnabled ? RewardSignals.Curiosity : 0; + rewardSignals |= inputProto.RndRewardEnabled ? RewardSignals.Rnd : 0; + + TrainingFeatures trainingFeatures = 0; + trainingFeatures |= inputProto.BehavioralCloningEnabled ? TrainingFeatures.BehavioralCloning : 0; + trainingFeatures |= inputProto.RecurrentEnabled ? TrainingFeatures.Recurrent : 0; + trainingFeatures |= inputProto.TrainerThreaded ? TrainingFeatures.Threaded : 0; + trainingFeatures |= inputProto.SelfPlayEnabled ? TrainingFeatures.SelfPlay : 0; + trainingFeatures |= inputProto.CurriculumEnabled ? TrainingFeatures.Curriculum : 0; + + + return new TrainingBehaviorInitializedEvent + { + BehaviorName = inputProto.BehaviorName, + TrainerType = inputProto.TrainerType, + RewardSignalFlags = rewardSignals, + TrainingFeatureFlags = trainingFeatures, + VisualEncoder = inputProto.VisualEncoder, + NumNetworkLayers = inputProto.NumNetworkLayers, + NumNetworkHiddenUnits = inputProto.NumNetworkHiddenUnits, + Config = inputProto.Config, + }; + } + + #endregion + } +} diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs.meta b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs.meta new file mode 100644 index 0000000000..31c109f8fa --- /dev/null +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 02e8742d8a124607bef3b5ff8b9dd3d0 +timeCreated: 1569444771 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs b/com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs new file mode 100644 index 0000000000..2036a2aa28 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs @@ -0,0 +1,173 @@ +using System; +using System.Collections.Generic; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents +{ + public struct CommunicatorInitParameters + { + /// + /// Port to listen for connections on. + /// + public int port; + + /// + /// The name of the environment. + /// + public string name; + + /// + /// The version of the Unity SDK. + /// + public string unityPackageVersion; + + /// + /// The version of the communication API. + /// + public string unityCommunicationVersion; + + /// + /// The RL capabilities of the C# codebase. + /// + public UnityRLCapabilities CSharpCapabilities; + } + public struct UnityRLInitParameters + { + /// + /// A random number generator (RNG) seed sent from the python process to Unity. + /// + public int seed; + + /// + /// The number of areas to replicate if Training Area Replication is used in the scene. + /// + public int numAreas; + + /// + /// The library version of the python process. + /// + public string pythonLibraryVersion; + + /// + /// The version of the communication API that python is using. + /// + public string pythonCommunicationVersion; + + /// + /// The RL capabilities of the Trainer codebase. + /// + public UnityRLCapabilities TrainerCapabilities; + } + internal struct UnityRLInputParameters + { + /// + /// Boolean sent back from python to indicate whether or not training is happening. + /// + public bool isTraining; + } + + /// + /// Delegate for handling quit events sent back from the communicator. + /// + public delegate void QuitCommandHandler(); + + /// + /// Delegate for handling reset parameter updates sent from the communicator. + /// + public delegate void ResetCommandHandler(); + + /// + /// Delegate to handle UnityRLInputParameters updates from the communicator. + /// + /// + internal delegate void RLInputReceivedHandler(UnityRLInputParameters inputParams); + + /** + This is the interface of the Communicators. + This does not need to be modified nor implemented to create a Unity environment. + + When the Unity Communicator is initialized, it will wait for the External Communicator + to be initialized as well. The two communicators will then exchange their first messages + that will usually contain information for initialization (information that does not need + to be resent at each new exchange). + + By convention a Unity input is from External to Unity and a Unity output is from Unity to + External. Inputs and outputs are relative to Unity. + + By convention, when the Unity Communicator and External Communicator call exchange, the + exchange is NOT simultaneous but sequential. This means that when a side of the + communication calls exchange, the other will receive the result of its previous + exchange call. + This is what happens when A calls exchange a single time: + A sends data_1 to B -> B receives data_1 -> B generates and sends data_2 -> A receives data_2 + When A calls exchange, it sends data_1 and receives data_2 + + Since the messages are sent back and forth with exchange and simultaneously when calling + initialize, External sends two messages at initialization. + + The structure of the messages is as follows: + UnityMessage + ...Header + ...UnityOutput + ......UnityRLOutput + ......UnityRLInitializationOutput + ...UnityInput + ......UnityRLInput + ......UnityRLInitializationInput + + UnityOutput and UnityInput can be extended to provide functionalities beyond RL + UnityRLOutput and UnityRLInput can be extended to provide new RL functionalities + */ + public interface ICommunicator : IDisposable + { + /// + /// Quit was received by the communicator. + /// + event QuitCommandHandler QuitCommandReceived; + + /// + /// Reset command sent back from the communicator. + /// + event ResetCommandHandler ResetCommandReceived; + + /// + /// Sends the academy parameters through the Communicator. + /// Is used by the academy to send the AcademyParameters to the communicator. + /// + /// Whether the connection was successful. + /// The Unity Initialization Parameters to be sent. + /// The External Initialization Parameters received + bool Initialize(CommunicatorInitParameters initParameters, out UnityRLInitParameters initParametersOut); + + /// + /// Registers a new Brain to the Communicator. + /// + /// The name or key uniquely identifying the Brain. + /// Description of the actions for the Agent. + void SubscribeBrain(string name, ActionSpec actionSpec); + + /// + /// Sends the observations of one Agent. + /// + /// Batch Key. + /// Agent info. + /// The list of ISensors of the Agent. + void PutObservations(string brainKey, AgentInfo info, List sensors); + + /// + /// Signals the ICommunicator that the Agents are now ready to receive their action + /// and that if the communicator has not yet received an action for one of the Agents + /// it needs to get one at this point. + /// + void DecideBatch(); + + /// + /// Gets the AgentActions based on the batching key. + /// + /// A key to identify which behavior actions to get. + /// A key to identify which Agent actions to get. + /// + ActionBuffers GetActions(string key, int agentId); + } +} diff --git a/com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs.meta b/com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs.meta new file mode 100644 index 0000000000..15f8a01eb3 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 53977f05e5684d4a9e2ef86f225934a2 +timeCreated: 1568395551 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs b/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs new file mode 100644 index 0000000000..24d8ae563e --- /dev/null +++ b/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs @@ -0,0 +1,608 @@ +#if UNITY_EDITOR || UNITY_STANDALONE +#define MLA_SUPPORTED_TRAINING_PLATFORM +#endif + +#if MLA_SUPPORTED_TRAINING_PLATFORM +using Grpc.Core; +#if UNITY_EDITOR +using UnityEditor; +#endif +using System; +using System.Collections.Generic; +using System.Linq; +using UnityEngine; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.CommunicatorObjects; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.SideChannels; +using Google.Protobuf; + +using Unity.MLAgents.Analytics; + +namespace Unity.MLAgents +{ + /// Responsible for communication with External using gRPC. + public class RpcCommunicator : ICommunicator + { + public event QuitCommandHandler QuitCommandReceived; + public event ResetCommandHandler ResetCommandReceived; + + /// If true, the communication is active. + bool m_IsOpen; + + List m_BehaviorNames = new List(); + bool m_NeedCommunicateThisStep; + ObservationWriter m_ObservationWriter = new ObservationWriter(); + Dictionary m_SensorShapeValidators = new Dictionary(); + Dictionary> m_OrderedAgentsRequestingDecisions = new Dictionary>(); + + /// The current UnityRLOutput to be sent when all the brains queried the communicator + UnityRLOutputProto m_CurrentUnityRlOutput = + new UnityRLOutputProto(); + + Dictionary> m_LastActionsReceived = + new Dictionary>(); + + // Brains that we have sent over the communicator with agents. + HashSet m_SentBrainKeys = new HashSet(); + Dictionary m_UnsentBrainKeys = new Dictionary(); + + + /// The Unity to External client. + UnityToExternalProto.UnityToExternalProtoClient m_Client; + Channel m_Channel; + + /// + /// Initializes a new instance of the RPCCommunicator class. + /// + protected RpcCommunicator() + { + } + + public static RpcCommunicator Create() + { +#if MLA_SUPPORTED_TRAINING_PLATFORM + return new RpcCommunicator(); +#else + return null; +#endif + } + +#region Initialization + + internal static bool CheckCommunicationVersionsAreCompatible( + string unityCommunicationVersion, + string pythonApiVersion + ) + { + var unityVersion = new Version(unityCommunicationVersion); + var pythonVersion = new Version(pythonApiVersion); + if (unityVersion.Major == 0) + { + if (unityVersion.Major != pythonVersion.Major || unityVersion.Minor != pythonVersion.Minor) + { + return false; + } + } + else if (unityVersion.Major != pythonVersion.Major) + { + return false; + } + else if (unityVersion.Minor != pythonVersion.Minor) + { + // If a feature is used in Unity but not supported in the trainer, + // we will warn at the point it's used. Don't warn here to avoid noise. + } + return true; + } + + /// + /// Sends the initialization parameters through the Communicator. + /// Is used by the academy to send initialization parameters to the communicator. + /// + /// Whether the connection was successful. + /// The Unity Initialization Parameters to be sent. + /// The External Initialization Parameters received. + public bool Initialize(CommunicatorInitParameters initParameters, out UnityRLInitParameters initParametersOut) + { +#if MLA_SUPPORTED_TRAINING_PLATFORM + var academyParameters = new UnityRLInitializationOutputProto + { + Name = initParameters.name, + PackageVersion = initParameters.unityPackageVersion, + CommunicationVersion = initParameters.unityCommunicationVersion, + Capabilities = initParameters.CSharpCapabilities.ToProto() + }; + + UnityInputProto input; + UnityInputProto initializationInput; + try + { + initializationInput = Initialize( + initParameters.port, + new UnityOutputProto + { + RlInitializationOutput = academyParameters + }, + out input + ); + } + catch (Exception ex) + { + if (ex is RpcException rpcException) + { + + switch (rpcException.Status.StatusCode) + { + case StatusCode.Unavailable: + // This is the common case where there's no trainer to connect to. + break; + case StatusCode.DeadlineExceeded: + // We don't currently set a deadline for connection, but likely will in the future. + break; + default: + Debug.Log($"Unexpected gRPC exception when trying to initialize communication: {rpcException}"); + break; + } + } + else + { + Debug.Log($"Unexpected exception when trying to initialize communication: {ex}"); + } + initParametersOut = new UnityRLInitParameters(); + NotifyQuitAndShutDownChannel(); + return false; + } + + var pythonPackageVersion = initializationInput.RlInitializationInput.PackageVersion; + var pythonCommunicationVersion = initializationInput.RlInitializationInput.CommunicationVersion; + TrainingAnalytics.SetTrainerInformation(pythonPackageVersion, pythonCommunicationVersion); + + var communicationIsCompatible = CheckCommunicationVersionsAreCompatible( + initParameters.unityCommunicationVersion, + pythonCommunicationVersion + ); + + // Initialization succeeded part-way. The most likely cause is a mismatch between the communicator + // API strings, so log an explicit warning if that's the case. + if (initializationInput != null && input == null) + { + if (!communicationIsCompatible) + { + Debug.LogWarningFormat( + "Communication protocol between python ({0}) and Unity ({1}) have different " + + "versions which make them incompatible. Python library version: {2}.", + pythonCommunicationVersion, initParameters.unityCommunicationVersion, + pythonPackageVersion + ); + } + else + { + Debug.LogWarningFormat( + "Unknown communication error between Python. Python communication protocol: {0}, " + + "Python library version: {1}.", + pythonCommunicationVersion, + pythonPackageVersion + ); + } + + initParametersOut = new UnityRLInitParameters(); + return false; + } + + UpdateEnvironmentWithInput(input.RlInput); + initParametersOut = initializationInput.RlInitializationInput.ToUnityRLInitParameters(); + // Be sure to shut down the grpc channel when the application is quitting. + Application.quitting += NotifyQuitAndShutDownChannel; + return true; +#else + initParametersOut = new UnityRLInitParameters(); + return false; +#endif + } + + /// + /// Adds the brain to the list of brains which will be sending information to External. + /// + /// Brain key. + /// Description of the actions for the Agent. + public void SubscribeBrain(string brainKey, ActionSpec actionSpec) + { + if (m_BehaviorNames.Contains(brainKey)) + { + return; + } + m_BehaviorNames.Add(brainKey); + m_CurrentUnityRlOutput.AgentInfos.Add( + brainKey, + new UnityRLOutputProto.Types.ListAgentInfoProto() + ); + + CacheActionSpec(brainKey, actionSpec); + } + + void UpdateEnvironmentWithInput(UnityRLInputProto rlInput) + { + SideChannelManager.ProcessSideChannelData(rlInput.SideChannel.ToArray()); + SendCommandEvent(rlInput.Command); + } + + UnityInputProto Initialize(int port, UnityOutputProto unityOutput, out UnityInputProto unityInput) + { + m_IsOpen = true; + m_Channel = new Channel($"localhost:{port}", ChannelCredentials.Insecure); + + m_Client = new UnityToExternalProto.UnityToExternalProtoClient(m_Channel); + var result = m_Client.Exchange(WrapMessage(unityOutput, 200)); + var inputMessage = m_Client.Exchange(WrapMessage(null, 200)); + unityInput = inputMessage.UnityInput; +#if UNITY_EDITOR + EditorApplication.playModeStateChanged += HandleOnPlayModeChanged; +#endif + if (result.Header.Status != 200 || inputMessage.Header.Status != 200) + { + m_IsOpen = false; + NotifyQuitAndShutDownChannel(); + } + return result.UnityInput; + } + + void NotifyQuitAndShutDownChannel() + { + QuitCommandReceived?.Invoke(); + try + { + m_Channel.ShutdownAsync().Wait(); + } + catch (Exception) + { + // do nothing + } + } + +#endregion + +#region Destruction + + /// + /// Close the communicator gracefully on both sides of the communication. + /// + public void Dispose() + { + if (!m_IsOpen) + { + return; + } + + try + { + m_Client.Exchange(WrapMessage(null, 400)); + m_IsOpen = false; + } + catch + { + // ignored + } + } + +#endregion + +#region Sending Events + + void SendCommandEvent(CommandProto command) + { + switch (command) + { + case CommandProto.Quit: + { + NotifyQuitAndShutDownChannel(); + return; + } + case CommandProto.Reset: + { + foreach (var brainName in m_OrderedAgentsRequestingDecisions.Keys) + { + m_OrderedAgentsRequestingDecisions[brainName].Clear(); + } + ResetCommandReceived?.Invoke(); + return; + } + default: + { + return; + } + } + } + +#endregion + +#region Sending and retreiving data + + public void DecideBatch() + { + if (!m_NeedCommunicateThisStep) + { + return; + } + m_NeedCommunicateThisStep = false; + + SendBatchedMessageHelper(); + } + + /// + /// Sends the observations of one Agent. + /// + /// Batch Key. + /// Agent info. + /// Sensors that will produce the observations + public void PutObservations(string behaviorName, AgentInfo info, List sensors) + { +#if DEBUG + if (!m_SensorShapeValidators.ContainsKey(behaviorName)) + { + m_SensorShapeValidators[behaviorName] = new SensorShapeValidator(); + } + m_SensorShapeValidators[behaviorName].ValidateSensors(sensors); +#endif + + using (TimerStack.Instance.Scoped("AgentInfo.ToProto")) + { + var agentInfoProto = info.ToAgentInfoProto(); + + using (TimerStack.Instance.Scoped("GenerateSensorData")) + { + foreach (var sensor in sensors) + { + var obsProto = sensor.GetObservationProto(m_ObservationWriter); + agentInfoProto.Observations.Add(obsProto); + } + } + m_CurrentUnityRlOutput.AgentInfos[behaviorName].Value.Add(agentInfoProto); + } + + m_NeedCommunicateThisStep = true; + if (!m_OrderedAgentsRequestingDecisions.ContainsKey(behaviorName)) + { + m_OrderedAgentsRequestingDecisions[behaviorName] = new List(); + } + if (!info.done) + { + m_OrderedAgentsRequestingDecisions[behaviorName].Add(info.episodeId); + } + if (!m_LastActionsReceived.ContainsKey(behaviorName)) + { + m_LastActionsReceived[behaviorName] = new Dictionary(); + } + m_LastActionsReceived[behaviorName][info.episodeId] = ActionBuffers.Empty; + if (info.done) + { + m_LastActionsReceived[behaviorName].Remove(info.episodeId); + } + } + + /// + /// Helper method that sends the current UnityRLOutput, receives the next UnityInput and + /// Applies the appropriate AgentAction to the agents. + /// + void SendBatchedMessageHelper() + { + var message = new UnityOutputProto + { + RlOutput = m_CurrentUnityRlOutput, + }; + var tempUnityRlInitializationOutput = GetTempUnityRlInitializationOutput(); + if (tempUnityRlInitializationOutput != null) + { + message.RlInitializationOutput = tempUnityRlInitializationOutput; + } + + byte[] messageAggregated = SideChannelManager.GetSideChannelMessage(); + message.RlOutput.SideChannel = ByteString.CopyFrom(messageAggregated); + + var input = Exchange(message); + UpdateSentActionSpec(tempUnityRlInitializationOutput); + + foreach (var k in m_CurrentUnityRlOutput.AgentInfos.Keys) + { + m_CurrentUnityRlOutput.AgentInfos[k].Value.Clear(); + } + + var rlInput = input?.RlInput; + + if (rlInput?.AgentActions == null) + { + return; + } + + UpdateEnvironmentWithInput(rlInput); + + foreach (var brainName in rlInput.AgentActions.Keys) + { + if (!m_OrderedAgentsRequestingDecisions[brainName].Any()) + { + continue; + } + + if (!rlInput.AgentActions[brainName].Value.Any()) + { + continue; + } + + var agentActions = rlInput.AgentActions[brainName].ToAgentActionList(); + var numAgents = m_OrderedAgentsRequestingDecisions[brainName].Count; + for (var i = 0; i < numAgents; i++) + { + var agentAction = agentActions[i]; + var agentId = m_OrderedAgentsRequestingDecisions[brainName][i]; + if (m_LastActionsReceived[brainName].ContainsKey(agentId)) + { + m_LastActionsReceived[brainName][agentId] = agentAction; + } + } + } + foreach (var brainName in m_OrderedAgentsRequestingDecisions.Keys) + { + m_OrderedAgentsRequestingDecisions[brainName].Clear(); + } + } + + public ActionBuffers GetActions(string behaviorName, int agentId) + { + if (m_LastActionsReceived.ContainsKey(behaviorName)) + { + if (m_LastActionsReceived[behaviorName].ContainsKey(agentId)) + { + return m_LastActionsReceived[behaviorName][agentId]; + } + } + return ActionBuffers.Empty; + } + + /// + /// Send a UnityOutput and receives a UnityInput. + /// + /// The next UnityInput. + /// The UnityOutput to be sent. + UnityInputProto Exchange(UnityOutputProto unityOutput) + { + if (!m_IsOpen) + { + return null; + } + + try + { + var message = m_Client.Exchange(WrapMessage(unityOutput, 200)); + if (message.Header.Status == 200) + { + return message.UnityInput; + } + + m_IsOpen = false; + // Not sure if the quit command is actually sent when a + // non 200 message is received. Notify that we are indeed + // quitting. + NotifyQuitAndShutDownChannel(); + return message.UnityInput; + } + catch (Exception ex) + { + if (ex is RpcException rpcException) + { + // Log more verbose errors if they're something the user can possibly do something about. + switch (rpcException.Status.StatusCode) + { + case StatusCode.Unavailable: + // This can happen when python disconnects. Ignore it to avoid noisy logs. + break; + case StatusCode.ResourceExhausted: + // This happens is the message body is too large. There's no way to + // gracefully handle this, but at least we can show the message and the + // user can try to reduce the number of agents or observation sizes. + Debug.LogError($"GRPC Exception: {rpcException.Message}. Disconnecting from trainer."); + break; + default: + // Other unknown errors. Log at INFO level. + Debug.Log($"GRPC Exception: {rpcException.Message}. Disconnecting from trainer."); + break; + } + } + else + { + // Fall-through for other error types + Debug.LogError($"Communication Exception: {ex.Message}. Disconnecting from trainer."); + } + + m_IsOpen = false; + NotifyQuitAndShutDownChannel(); + return null; + } + } + + /// + /// Wraps the UnityOutput into a message with the appropriate status. + /// + /// The UnityMessage corresponding. + /// The UnityOutput to be wrapped. + /// The status of the message. + static UnityMessageProto WrapMessage(UnityOutputProto content, int status) + { + return new UnityMessageProto + { + Header = new HeaderProto { Status = status }, + UnityOutput = content + }; + } + + void CacheActionSpec(string behaviorName, ActionSpec actionSpec) + { + if (m_SentBrainKeys.Contains(behaviorName)) + { + return; + } + + // TODO We should check that if m_unsentBrainKeys has brainKey, it equals actionSpec + m_UnsentBrainKeys[behaviorName] = actionSpec; + } + + UnityRLInitializationOutputProto GetTempUnityRlInitializationOutput() + { + UnityRLInitializationOutputProto output = null; + foreach (var behaviorName in m_UnsentBrainKeys.Keys) + { + if (m_CurrentUnityRlOutput.AgentInfos.ContainsKey(behaviorName)) + { + if (m_CurrentUnityRlOutput.AgentInfos[behaviorName].CalculateSize() > 0) + { + // Only send the actionSpec if there is a non empty list of + // AgentInfos ready to be sent. + // This is to ensure that The Python side will always have a first + // observation when receiving the ActionSpec + if (output == null) + { + output = new UnityRLInitializationOutputProto(); + } + + var actionSpec = m_UnsentBrainKeys[behaviorName]; + output.BrainParameters.Add(actionSpec.ToBrainParametersProto(behaviorName, true)); + } + } + } + + return output; + } + + void UpdateSentActionSpec(UnityRLInitializationOutputProto output) + { + if (output == null) + { + return; + } + + foreach (var brainProto in output.BrainParameters) + { + m_SentBrainKeys.Add(brainProto.BrainName); + m_UnsentBrainKeys.Remove(brainProto.BrainName); + } + } + +#endregion + +#if UNITY_EDITOR + /// + /// When the editor exits, the communicator must be closed + /// + /// State. + void HandleOnPlayModeChanged(PlayModeStateChange state) + { + // This method is run whenever the playmode state is changed. + if (state == PlayModeStateChange.ExitingPlayMode) + { + Dispose(); + } + } + +#endif + } +} +#endif // UNITY_EDITOR || UNITY_STANDALONE diff --git a/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs.meta b/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs.meta new file mode 100644 index 0000000000..d1903d74cc --- /dev/null +++ b/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs.meta @@ -0,0 +1,13 @@ +fileFormatVersion: 2 +guid: 57a3dc12d3b88408688bb490b65a838e +timeCreated: 1523046536 +licenseType: Free +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs b/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs new file mode 100644 index 0000000000..3dffcf52b7 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs @@ -0,0 +1,54 @@ +using UnityEngine; + +namespace Unity.MLAgents +{ + public class UnityRLCapabilities + { + public bool BaseRLCapabilities; + public bool ConcatenatedPngObservations; + public bool CompressedChannelMapping; + public bool HybridActions; + public bool TrainingAnalytics; + public bool VariableLengthObservation; + public bool MultiAgentGroups; + + /// + /// A class holding the capabilities flags for Reinforcement Learning across C# and the Trainer codebase. This + /// struct will be used to inform users if and when they are using C# / Trainer features that are mismatched. + /// + public UnityRLCapabilities( + bool baseRlCapabilities = true, + bool concatenatedPngObservations = true, + bool compressedChannelMapping = true, + bool hybridActions = true, + bool trainingAnalytics = true, + bool variableLengthObservation = true, + bool multiAgentGroups = true) + { + BaseRLCapabilities = baseRlCapabilities; + ConcatenatedPngObservations = concatenatedPngObservations; + CompressedChannelMapping = compressedChannelMapping; + HybridActions = hybridActions; + TrainingAnalytics = trainingAnalytics; + VariableLengthObservation = variableLengthObservation; + MultiAgentGroups = multiAgentGroups; + } + + /// + /// Will print a warning to the console if Python does not support base capabilities and will + /// return true if the warning was printed. + /// + /// + public bool WarnOnPythonMissingBaseRLCapabilities() + { + if (BaseRLCapabilities) + { + return false; + } + Debug.LogWarning("Unity has connected to a Training process that does not support" + + "Base Reinforcement Learning Capabilities. Please make sure you have the" + + " latest training codebase installed for this version of the ML-Agents package."); + return true; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs.meta b/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs.meta new file mode 100644 index 0000000000..6cdc57628e --- /dev/null +++ b/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: f95d271af72d4b75aa94d308222f79d8 +timeCreated: 1587670989 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Constants.cs b/com.unity.ml-agents/Runtime/Constants.cs new file mode 100644 index 0000000000..4be9eba042 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Constants.cs @@ -0,0 +1,12 @@ +namespace Unity.MLAgents +{ + /// + /// Grouping for use in AddComponentMenu (instead of nesting the menus). + /// + internal enum MenuGroup + { + Default = 0, + Sensors = 50, + Actuators = 100 + } +} diff --git a/com.unity.ml-agents/Runtime/Constants.cs.meta b/com.unity.ml-agents/Runtime/Constants.cs.meta new file mode 100644 index 0000000000..f963ba55aa --- /dev/null +++ b/com.unity.ml-agents/Runtime/Constants.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 0622d88401ec464d9d2cf2fb03ce17b5 +timeCreated: 1579215785 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/DecisionRequester.cs b/com.unity.ml-agents/Runtime/DecisionRequester.cs new file mode 100644 index 0000000000..49590c7be4 --- /dev/null +++ b/com.unity.ml-agents/Runtime/DecisionRequester.cs @@ -0,0 +1,123 @@ +using System; +using UnityEngine; +using UnityEngine.Serialization; + +namespace Unity.MLAgents +{ + /// + /// The DecisionRequester component automatically request decisions for an + /// instance at regular intervals. + /// + /// + /// Attach a DecisionRequester component to the same [GameObject] as the + /// component. + /// + /// The DecisionRequester component provides a convenient and flexible way to + /// trigger the agent decision making process. Without a DecisionRequester, + /// your implementation must manually call its + /// function. + /// + [AddComponentMenu("ML Agents/Decision Requester", (int)MenuGroup.Default)] + [RequireComponent(typeof(Agent))] + [DefaultExecutionOrder(-10)] + public class DecisionRequester : MonoBehaviour + { + /// + /// The frequency with which the agent requests a decision. A DecisionPeriod of 5 means + /// that the Agent will request a decision every 5 Academy steps. /// + [Range(1, 20)] + [Tooltip("The frequency with which the agent requests a decision. A DecisionPeriod " + + "of 5 means that the Agent will request a decision every 5 Academy steps.")] + public int DecisionPeriod = 5; + + /// + /// Indicates whether or not the agent will take an action during the Academy steps where + /// it does not request a decision. Has no effect when DecisionPeriod is set to 1. + /// + [Tooltip("Indicates whether or not the agent will take an action during the Academy " + + "steps where it does not request a decision. Has no effect when DecisionPeriod " + + "is set to 1.")] + [FormerlySerializedAs("RepeatAction")] + public bool TakeActionsBetweenDecisions = true; + + [NonSerialized] + Agent m_Agent; + + /// + /// Get the Agent attached to the DecisionRequester. + /// + public Agent Agent + { + get => m_Agent; + } + + internal void Awake() + { + m_Agent = gameObject.GetComponent(); + Debug.Assert(m_Agent != null, "Agent component was not found on this gameObject and is required."); + Academy.Instance.AgentPreStep += MakeRequests; + } + + void OnDestroy() + { + if (Academy.IsInitialized) + { + Academy.Instance.AgentPreStep -= MakeRequests; + } + } + + /// + /// Information about Academy step used to make decisions about whether to request a decision. + /// + public struct DecisionRequestContext + { + /// + /// The current step count of the Academy, equivalent to Academy.StepCount. + /// + public int AcademyStepCount; + } + + /// + /// Method that hooks into the Academy in order inform the Agent on whether or not it should request a + /// decision, and whether or not it should take actions between decisions. + /// + /// The current step count of the academy. + void MakeRequests(int academyStepCount) + { + var context = new DecisionRequestContext + { + AcademyStepCount = academyStepCount + }; + + if (ShouldRequestDecision(context)) + { + m_Agent?.RequestDecision(); + } + + if (ShouldRequestAction(context)) + { + m_Agent?.RequestAction(); + } + } + + /// + /// Whether Agent.RequestDecision should be called on this update step. + /// + /// + /// + protected virtual bool ShouldRequestDecision(DecisionRequestContext context) + { + return context.AcademyStepCount % DecisionPeriod == 0; + } + + /// + /// Whether Agent.RequestAction should be called on this update step. + /// + /// + /// + protected virtual bool ShouldRequestAction(DecisionRequestContext context) + { + return TakeActionsBetweenDecisions; + } + } +} diff --git a/com.unity.ml-agents/Runtime/DecisionRequester.cs.meta b/com.unity.ml-agents/Runtime/DecisionRequester.cs.meta new file mode 100644 index 0000000000..bdc416b94b --- /dev/null +++ b/com.unity.ml-agents/Runtime/DecisionRequester.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 3a5c9d521e5ef4759a8246a07d52221e +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Demonstrations.meta b/com.unity.ml-agents/Runtime/Demonstrations.meta new file mode 100644 index 0000000000..85288b5325 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Demonstrations.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 85e02c21d231b4f5fa0c5f87e5f907a2 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationMetaData.cs b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationMetaData.cs new file mode 100644 index 0000000000..42f67733df --- /dev/null +++ b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationMetaData.cs @@ -0,0 +1,20 @@ +using System; +using UnityEngine.Serialization; + +namespace Unity.MLAgents.Demonstrations +{ + /// + /// Demonstration meta-data. + /// Kept in a struct for easy serialization and deserialization. + /// + [Serializable] + internal class DemonstrationMetaData + { + [FormerlySerializedAs("numberExperiences")] + public int numberSteps; + public int numberEpisodes; + public float meanReward; + public string demonstrationName; + public const int ApiVersion = 1; + } +} diff --git a/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationMetaData.cs.meta b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationMetaData.cs.meta new file mode 100644 index 0000000000..8e6ff39275 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationMetaData.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: af5f3b4258a2d4ead90e733f30cfaa7a +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs new file mode 100644 index 0000000000..b6daeace32 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs @@ -0,0 +1,228 @@ +using System.IO.Abstractions; +using System.Text.RegularExpressions; +using UnityEngine; +using System.IO; +using Unity.MLAgents.Policies; +using UnityEngine.Serialization; + +namespace Unity.MLAgents.Demonstrations +{ + /// + /// The Demonstration Recorder component facilitates the recording of demonstrations + /// used for imitation learning. + /// + /// Add this component to the [GameObject] containing an + /// to enable recording the agent for imitation learning. You must implement the + /// function of the agent to provide manual control + /// in order to record demonstrations. + /// + /// See [Imitation Learning - Recording Demonstrations] for more information. + /// + /// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html + /// [Imitation Learning - Recording Demonstrations]: https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs//Learning-Environment-Design-Agents.md#recording-demonstrations + /// + [RequireComponent(typeof(Agent))] + [AddComponentMenu("ML Agents/Demonstration Recorder", (int)MenuGroup.Default)] + public class DemonstrationRecorder : MonoBehaviour + { + /// + /// Whether or not to record demonstrations. + /// + [FormerlySerializedAs("record")] + [Tooltip("Whether or not to record demonstrations.")] + public bool Record; + + /// + /// Number of steps to record. The editor will stop playing when it reaches this threshold. + /// Set to zero to record indefinitely. + /// + [Tooltip("Number of steps to record. The editor will stop playing when it reaches this threshold. " + + "Set to zero to record indefinitely.")] + public int NumStepsToRecord; + + /// + /// Base demonstration file name. If multiple files are saved, the additional filenames + /// will have a sequence of unique numbers appended. + /// + [FormerlySerializedAs("demonstrationName")] + [Tooltip("Base demonstration file name. If multiple files are saved, the additional " + + "filenames will have a unique number appended.")] + public string DemonstrationName; + + /// + /// Directory to save the demo files. Will default to a "Demonstrations/" folder in the + /// Application data path if not specified. + /// + [FormerlySerializedAs("demonstrationDirectory")] + [Tooltip("Directory to save the demo files. Will default to " + + "{Application.dataPath}/Demonstrations if not specified.")] + public string DemonstrationDirectory; + + DemonstrationWriter m_DemoWriter; + internal const int MaxNameLength = 16; + + const string k_ExtensionType = ".demo"; + const string k_DefaultDirectoryName = "Demonstrations"; + IFileSystem m_FileSystem; + + Agent m_Agent; + + void OnEnable() + { + m_Agent = GetComponent(); + } + + void Update() + { + if (!Record) + { + return; + } + + LazyInitialize(); + + // Quit when num steps to record is reached + if (NumStepsToRecord > 0 && m_DemoWriter.NumSteps >= NumStepsToRecord) + { + Application.Quit(0); +#if UNITY_EDITOR + UnityEditor.EditorApplication.isPlaying = false; +#endif + } + } + + /// + /// Creates demonstration store for use in recording. + /// Has no effect if the demonstration store was already created. + /// + internal DemonstrationWriter LazyInitialize(IFileSystem fileSystem = null) + { + if (m_DemoWriter != null) + { + return m_DemoWriter; + } + + if (m_Agent == null) + { + m_Agent = GetComponent(); + } + + m_FileSystem = fileSystem ?? new FileSystem(); + var behaviorParams = GetComponent(); + if (string.IsNullOrEmpty(DemonstrationName)) + { + DemonstrationName = behaviorParams.BehaviorName; + } + if (string.IsNullOrEmpty(DemonstrationDirectory)) + { + DemonstrationDirectory = Path.Combine(Application.dataPath, k_DefaultDirectoryName); + } + + DemonstrationName = SanitizeName(DemonstrationName, MaxNameLength); + var filePath = MakeDemonstrationFilePath(m_FileSystem, DemonstrationDirectory, DemonstrationName); + var stream = m_FileSystem.File.Create(filePath); + m_DemoWriter = new DemonstrationWriter(stream); + + AddDemonstrationWriterToAgent(m_DemoWriter); + + return m_DemoWriter; + } + + /// + /// Removes all characters except alphanumerics from demonstration name. + /// Shorten name if it is longer than the maxNameLength. + /// + internal static string SanitizeName(string demoName, int maxNameLength) + { + var rgx = new Regex("[^a-zA-Z0-9 -]"); + demoName = rgx.Replace(demoName, ""); + // If the string is too long, it will overflow the metadata. + if (demoName.Length > maxNameLength) + { + demoName = demoName.Substring(0, maxNameLength); + } + return demoName; + } + + /// + /// Gets a unique path for the DemonstrationName in the DemonstrationDirectory. + /// + /// + /// + /// + /// + internal static string MakeDemonstrationFilePath( + IFileSystem fileSystem, string demonstrationDirectory, string demonstrationName + ) + { + // Create the directory if it doesn't already exist + if (!fileSystem.Directory.Exists(demonstrationDirectory)) + { + fileSystem.Directory.CreateDirectory(demonstrationDirectory); + } + + var literalName = demonstrationName; + var filePath = Path.Combine(demonstrationDirectory, literalName + k_ExtensionType); + var uniqueNameCounter = 0; + while (fileSystem.File.Exists(filePath)) + { + // TODO should we use a timestamp instead of a counter here? This loops an increasing number of times + // as the number of demos increases. + literalName = demonstrationName + "_" + uniqueNameCounter; + filePath = Path.Combine(demonstrationDirectory, literalName + k_ExtensionType); + uniqueNameCounter++; + } + + return filePath; + } + + /// + /// Close the DemonstrationWriter and remove it from the Agent. + /// Has no effect if the DemonstrationWriter is already closed (or wasn't opened) + /// + public void Close() + { + if (m_DemoWriter != null) + { + RemoveDemonstrationWriterFromAgent(m_DemoWriter); + + m_DemoWriter.Close(); + m_DemoWriter = null; + } + } + + /// + /// Clean up the DemonstrationWriter when shutting down or destroying the Agent. + /// + void OnDestroy() + { + Close(); + } + + /// + /// Add additional DemonstrationWriter to the Agent. It is still up to the user to Close this + /// DemonstrationWriters when recording is done. + /// + /// + public void AddDemonstrationWriterToAgent(DemonstrationWriter demoWriter) + { + var behaviorParams = GetComponent(); + demoWriter.Initialize( + DemonstrationName, + behaviorParams.BrainParameters, + behaviorParams.FullyQualifiedBehaviorName + ); + m_Agent.DemonstrationWriters.Add(demoWriter); + } + + /// + /// Remove additional DemonstrationWriter to the Agent. It is still up to the user to Close this + /// DemonstrationWriters when recording is done. + /// + /// + public void RemoveDemonstrationWriterFromAgent(DemonstrationWriter demoWriter) + { + m_Agent.DemonstrationWriters.Remove(demoWriter); + } + } +} diff --git a/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs.meta b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs.meta new file mode 100644 index 0000000000..cde4db8f20 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: f2902496c0120472b90269f94a0aec7e +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationSummary.cs b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationSummary.cs new file mode 100644 index 0000000000..cb32409913 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationSummary.cs @@ -0,0 +1,37 @@ +using System; +using System.Collections.Generic; +using UnityEngine; +using Unity.MLAgents.Policies; + +namespace Unity.MLAgents.Demonstrations +{ + /// + /// Summary of a loaded Demonstration file. Only used for display in the Inspector. + /// + [Serializable] + internal class DemonstrationSummary : ScriptableObject + { + public DemonstrationMetaData metaData; + public BrainParameters brainParameters; + public List observationSummaries; + + public void Initialize(BrainParameters brainParams, + DemonstrationMetaData demonstrationMetaData, List obsSummaries) + { + brainParameters = brainParams; + metaData = demonstrationMetaData; + observationSummaries = obsSummaries; + } + } + + + /// + /// Summary of a loaded Observation. Currently only contains the shape of the Observation. + /// + /// This is necessary because serialization doesn't support nested containers or arrays. + [Serializable] + internal struct ObservationSummary + { + public int[] shape; + } +} diff --git a/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationSummary.cs.meta b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationSummary.cs.meta new file mode 100644 index 0000000000..91e53800d5 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationSummary.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: a5e0cbcbc514b473399c262dd37541ea +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs new file mode 100644 index 0000000000..c29bec2c40 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs @@ -0,0 +1,161 @@ +using System.IO; +using Google.Protobuf; +using System.Collections.Generic; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Policies; + +namespace Unity.MLAgents.Demonstrations +{ + /// + /// Responsible for writing demonstration data to stream (typically a file stream). + /// + /// + public class DemonstrationWriter + { + /// + /// Number of bytes reserved for the at the start of the demo file. + /// + internal const int MetaDataBytes = 32; + + DemonstrationMetaData m_MetaData; + Stream m_Writer; + float m_CumulativeReward; + ObservationWriter m_ObservationWriter = new ObservationWriter(); + + /// + /// Create a DemonstrationWriter that will write to the specified stream. + /// The stream must support writes and seeking. + /// + /// + public DemonstrationWriter(Stream stream) + { + m_Writer = stream; + } + + /// + /// Number of steps written so far. + /// + internal int NumSteps + { + get { return m_MetaData.numberSteps; } + } + + /// + /// Writes the initial data to the stream. + /// + /// Base name of the demonstration file(s). + /// The name of the Brain the agent is attached to. + /// The parameters of the Brain the agent is attached to. + internal void Initialize( + string demonstrationName, BrainParameters brainParameters, string brainName) + { + if (m_Writer == null) + { + // Already closed + return; + } + + m_MetaData = new DemonstrationMetaData { demonstrationName = demonstrationName }; + var metaProto = m_MetaData.ToProto(); + metaProto.WriteDelimitedTo(m_Writer); + + WriteBrainParameters(brainName, brainParameters); + } + + /// + /// Writes meta-data. Note that this is called at the *end* of recording, but writes to the + /// beginning of the file. + /// + void WriteMetadata() + { + if (m_Writer == null) + { + // Already closed + return; + } + + var metaProto = m_MetaData.ToProto(); + var metaProtoBytes = metaProto.ToByteArray(); + m_Writer.Write(metaProtoBytes, 0, metaProtoBytes.Length); + m_Writer.Seek(0, 0); + metaProto.WriteDelimitedTo(m_Writer); + } + + /// + /// Writes brain parameters to file. + /// + /// The name of the Brain the agent is attached to. + /// The parameters of the Brain the agent is attached to. + void WriteBrainParameters(string brainName, BrainParameters brainParameters) + { + if (m_Writer == null) + { + // Already closed + return; + } + + // Writes BrainParameters to file. + m_Writer.Seek(MetaDataBytes + 1, 0); + var brainProto = brainParameters.ToProto(brainName, false); + brainProto.WriteDelimitedTo(m_Writer); + } + + /// + /// Write AgentInfo experience to file. + /// + /// for the agent being recorded. + /// List of sensors to record for the agent. + internal void Record(AgentInfo info, List sensors) + { + if (m_Writer == null) + { + // Already closed + return; + } + + // Increment meta-data counters. + m_MetaData.numberSteps++; + m_CumulativeReward += info.reward; + if (info.done) + { + EndEpisode(); + } + + // Generate observations and add AgentInfo to file. + var agentProto = info.ToInfoActionPairProto(); + foreach (var sensor in sensors) + { + agentProto.AgentInfo.Observations.Add(sensor.GetObservationProto(m_ObservationWriter)); + } + + agentProto.WriteDelimitedTo(m_Writer); + } + + + /// + /// Performs all clean-up necessary. + /// + public void Close() + { + if (m_Writer == null) + { + // Already closed + return; + } + + EndEpisode(); + m_MetaData.meanReward = m_CumulativeReward / m_MetaData.numberEpisodes; + WriteMetadata(); + m_Writer.Close(); + m_Writer = null; + } + + /// + /// Performs necessary episode-completion steps. + /// + void EndEpisode() + { + m_MetaData.numberEpisodes += 1; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs.meta b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs.meta new file mode 100644 index 0000000000..f30f1b22c1 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: ebaf7878a8cc74ee3aae07daf9e1b6f2 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/EnvironmentParameters.cs b/com.unity.ml-agents/Runtime/EnvironmentParameters.cs new file mode 100644 index 0000000000..fc1c667cd6 --- /dev/null +++ b/com.unity.ml-agents/Runtime/EnvironmentParameters.cs @@ -0,0 +1,70 @@ +using System; +using System.Collections.Generic; +using Unity.MLAgents.SideChannels; + +namespace Unity.MLAgents +{ + /// + /// A container for the Environment Parameters that may be modified during training. + /// The keys for those parameters are defined in the trainer configurations and the + /// the values are generated from the training process in features such as Curriculum Learning + /// and Environment Parameter Randomization. + /// + /// One current assumption for all the environment parameters is that they are of type float. + /// + public sealed class EnvironmentParameters + { + /// + /// The side channel that is used to receive the new parameter values. + /// + readonly EnvironmentParametersChannel m_Channel; + + /// + /// Constructor. + /// + internal EnvironmentParameters() + { + m_Channel = new EnvironmentParametersChannel(); + SideChannelManager.RegisterSideChannel(m_Channel); + } + + /// + /// Returns the parameter value for the specified key. Returns the default value provided + /// if this parameter key does not have a value. Only returns a parameter value if it is + /// of type float. + /// + /// The parameter key + /// Default value for this parameter. + /// + public float GetWithDefault(string key, float defaultValue) + { + return m_Channel.GetWithDefault(key, defaultValue); + } + + /// + /// Registers a callback action for the provided parameter key. Will overwrite any + /// existing action for that parameter. The callback will be called whenever the parameter + /// receives a value from the training process. + /// + /// The parameter key + /// The callback action + public void RegisterCallback(string key, Action action) + { + m_Channel.RegisterCallback(key, action); + } + + /// + /// Returns a list of all the parameter keys that have received values. + /// + /// List of parameter keys. + public IList Keys() + { + return m_Channel.ListParameters(); + } + + internal void Dispose() + { + SideChannelManager.UnregisterSideChannel(m_Channel); + } + } +} diff --git a/com.unity.ml-agents/Runtime/EnvironmentParameters.cs.meta b/com.unity.ml-agents/Runtime/EnvironmentParameters.cs.meta new file mode 100644 index 0000000000..9e7a85f810 --- /dev/null +++ b/com.unity.ml-agents/Runtime/EnvironmentParameters.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 90ce0b26bef35484890eac0633b85eed +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/EpisodeIdCounter.cs b/com.unity.ml-agents/Runtime/EpisodeIdCounter.cs new file mode 100644 index 0000000000..735c7fff96 --- /dev/null +++ b/com.unity.ml-agents/Runtime/EpisodeIdCounter.cs @@ -0,0 +1,11 @@ +namespace Unity.MLAgents +{ + internal static class EpisodeIdCounter + { + static int s_Counter; + public static int GetEpisodeId() + { + return s_Counter++; + } + } +} diff --git a/com.unity.ml-agents/Runtime/EpisodeIdCounter.cs.meta b/com.unity.ml-agents/Runtime/EpisodeIdCounter.cs.meta new file mode 100644 index 0000000000..c377f5004b --- /dev/null +++ b/com.unity.ml-agents/Runtime/EpisodeIdCounter.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 847786b7bcf9d4817b3f3879d57517c7 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc.meta b/com.unity.ml-agents/Runtime/Grpc.meta new file mode 100644 index 0000000000..f9d48bfc0f --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 418327e202c7464bb6649d025df1b539 +timeCreated: 1569444731 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Grpc/AssemblyInfo.cs b/com.unity.ml-agents/Runtime/Grpc/AssemblyInfo.cs new file mode 100644 index 0000000000..b740e05db8 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/AssemblyInfo.cs @@ -0,0 +1,7 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Unity.ML-Agents")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Runtime.Sensor.Tests")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Runtime.Utils.Tests")] diff --git a/com.unity.ml-agents/Runtime/Grpc/AssemblyInfo.cs.meta b/com.unity.ml-agents/Runtime/Grpc/AssemblyInfo.cs.meta new file mode 100644 index 0000000000..cf7b4f0f10 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/AssemblyInfo.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 54959ce8e2e574f09b91f80a516acee3 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects.meta new file mode 100644 index 0000000000..cef92044c3 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 7ebeef5df83b74a048b7f99681672f3b +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentAction.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentAction.cs new file mode 100644 index 0000000000..3eb0a357a2 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentAction.cs @@ -0,0 +1,242 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/agent_action.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/agent_action.proto + internal static partial class AgentActionReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/agent_action.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static AgentActionReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2Fj", + "dGlvbi5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMijAEKEEFnZW50QWN0", + "aW9uUHJvdG8SIQoZdmVjdG9yX2FjdGlvbnNfZGVwcmVjYXRlZBgBIAMoAhIN", + "CgV2YWx1ZRgEIAEoAhIaChJjb250aW51b3VzX2FjdGlvbnMYBiADKAISGAoQ", + "ZGlzY3JldGVfYWN0aW9ucxgHIAMoBUoECAIQA0oECAMQBEoECAUQBkIlqgIi", + "VW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentActionProto), global::Unity.MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActionsDeprecated", "Value", "ContinuousActions", "DiscreteActions" }, null, null, null) + })); + } + #endregion + + } + #region Messages + internal sealed partial class AgentActionProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AgentActionProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.AgentActionReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AgentActionProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AgentActionProto(AgentActionProto other) : this() { + vectorActionsDeprecated_ = other.vectorActionsDeprecated_.Clone(); + value_ = other.value_; + continuousActions_ = other.continuousActions_.Clone(); + discreteActions_ = other.discreteActions_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AgentActionProto Clone() { + return new AgentActionProto(this); + } + + /// Field number for the "vector_actions_deprecated" field. + public const int VectorActionsDeprecatedFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_vectorActionsDeprecated_codec + = pb::FieldCodec.ForFloat(10); + private readonly pbc::RepeatedField vectorActionsDeprecated_ = new pbc::RepeatedField(); + /// + /// mark as deprecated in communicator v1.3.0 + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField VectorActionsDeprecated { + get { return vectorActionsDeprecated_; } + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 4; + private float value_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Value { + get { return value_; } + set { + value_ = value; + } + } + + /// Field number for the "continuous_actions" field. + public const int ContinuousActionsFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_continuousActions_codec + = pb::FieldCodec.ForFloat(50); + private readonly pbc::RepeatedField continuousActions_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField ContinuousActions { + get { return continuousActions_; } + } + + /// Field number for the "discrete_actions" field. + public const int DiscreteActionsFieldNumber = 7; + private static readonly pb::FieldCodec _repeated_discreteActions_codec + = pb::FieldCodec.ForInt32(58); + private readonly pbc::RepeatedField discreteActions_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField DiscreteActions { + get { return discreteActions_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as AgentActionProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(AgentActionProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!vectorActionsDeprecated_.Equals(other.vectorActionsDeprecated_)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Value, other.Value)) return false; + if(!continuousActions_.Equals(other.continuousActions_)) return false; + if(!discreteActions_.Equals(other.discreteActions_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= vectorActionsDeprecated_.GetHashCode(); + if (Value != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Value); + hash ^= continuousActions_.GetHashCode(); + hash ^= discreteActions_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + vectorActionsDeprecated_.WriteTo(output, _repeated_vectorActionsDeprecated_codec); + if (Value != 0F) { + output.WriteRawTag(37); + output.WriteFloat(Value); + } + continuousActions_.WriteTo(output, _repeated_continuousActions_codec); + discreteActions_.WriteTo(output, _repeated_discreteActions_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += vectorActionsDeprecated_.CalculateSize(_repeated_vectorActionsDeprecated_codec); + if (Value != 0F) { + size += 1 + 4; + } + size += continuousActions_.CalculateSize(_repeated_continuousActions_codec); + size += discreteActions_.CalculateSize(_repeated_discreteActions_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(AgentActionProto other) { + if (other == null) { + return; + } + vectorActionsDeprecated_.Add(other.vectorActionsDeprecated_); + if (other.Value != 0F) { + Value = other.Value; + } + continuousActions_.Add(other.continuousActions_); + discreteActions_.Add(other.discreteActions_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 13: { + vectorActionsDeprecated_.AddEntriesFrom(input, _repeated_vectorActionsDeprecated_codec); + break; + } + case 37: { + Value = input.ReadFloat(); + break; + } + case 50: + case 53: { + continuousActions_.AddEntriesFrom(input, _repeated_continuousActions_codec); + break; + } + case 58: + case 56: { + discreteActions_.AddEntriesFrom(input, _repeated_discreteActions_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentAction.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentAction.cs.meta new file mode 100644 index 0000000000..f47d94375b --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentAction.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: b1fa94db54b734224927bb4b322227cd +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs new file mode 100644 index 0000000000..187f2fdab7 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs @@ -0,0 +1,361 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/agent_info.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/agent_info.proto + internal static partial class AgentInfoReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/agent_info.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static AgentInfoReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjNtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu", + "Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50c19lbnZz", + "L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIvkBCg5B", + "Z2VudEluZm9Qcm90bxIOCgZyZXdhcmQYByABKAISDAoEZG9uZRgIIAEoCBIY", + "ChBtYXhfc3RlcF9yZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlv", + "bl9tYXNrGAsgAygIEjwKDG9ic2VydmF0aW9ucxgNIAMoCzImLmNvbW11bmlj", + "YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SEAoIZ3JvdXBfaWQYDiAB", + "KAUSFAoMZ3JvdXBfcmV3YXJkGA8gASgCSgQIARACSgQIAhADSgQIAxAESgQI", + "BBAFSgQIBRAGSgQIBhAHSgQIDBANQiWqAiJVbml0eS5NTEFnZW50cy5Db21t", + "dW5pY2F0b3JPYmplY3RzYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.ObservationReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "GroupId", "GroupReward" }, null, null, null) + })); + } + #endregion + + } + #region Messages + internal sealed partial class AgentInfoProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AgentInfoProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.AgentInfoReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AgentInfoProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AgentInfoProto(AgentInfoProto other) : this() { + reward_ = other.reward_; + done_ = other.done_; + maxStepReached_ = other.maxStepReached_; + id_ = other.id_; + actionMask_ = other.actionMask_.Clone(); + observations_ = other.observations_.Clone(); + groupId_ = other.groupId_; + groupReward_ = other.groupReward_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AgentInfoProto Clone() { + return new AgentInfoProto(this); + } + + /// Field number for the "reward" field. + public const int RewardFieldNumber = 7; + private float reward_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Reward { + get { return reward_; } + set { + reward_ = value; + } + } + + /// Field number for the "done" field. + public const int DoneFieldNumber = 8; + private bool done_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Done { + get { return done_; } + set { + done_ = value; + } + } + + /// Field number for the "max_step_reached" field. + public const int MaxStepReachedFieldNumber = 9; + private bool maxStepReached_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool MaxStepReached { + get { return maxStepReached_; } + set { + maxStepReached_ = value; + } + } + + /// Field number for the "id" field. + public const int IdFieldNumber = 10; + private int id_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Id { + get { return id_; } + set { + id_ = value; + } + } + + /// Field number for the "action_mask" field. + public const int ActionMaskFieldNumber = 11; + private static readonly pb::FieldCodec _repeated_actionMask_codec + = pb::FieldCodec.ForBool(90); + private readonly pbc::RepeatedField actionMask_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField ActionMask { + get { return actionMask_; } + } + + /// Field number for the "observations" field. + public const int ObservationsFieldNumber = 13; + private static readonly pb::FieldCodec _repeated_observations_codec + = pb::FieldCodec.ForMessage(106, global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser); + private readonly pbc::RepeatedField observations_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Observations { + get { return observations_; } + } + + /// Field number for the "group_id" field. + public const int GroupIdFieldNumber = 14; + private int groupId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int GroupId { + get { return groupId_; } + set { + groupId_ = value; + } + } + + /// Field number for the "group_reward" field. + public const int GroupRewardFieldNumber = 15; + private float groupReward_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float GroupReward { + get { return groupReward_; } + set { + groupReward_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as AgentInfoProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(AgentInfoProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Reward, other.Reward)) return false; + if (Done != other.Done) return false; + if (MaxStepReached != other.MaxStepReached) return false; + if (Id != other.Id) return false; + if(!actionMask_.Equals(other.actionMask_)) return false; + if(!observations_.Equals(other.observations_)) return false; + if (GroupId != other.GroupId) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(GroupReward, other.GroupReward)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Reward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Reward); + if (Done != false) hash ^= Done.GetHashCode(); + if (MaxStepReached != false) hash ^= MaxStepReached.GetHashCode(); + if (Id != 0) hash ^= Id.GetHashCode(); + hash ^= actionMask_.GetHashCode(); + hash ^= observations_.GetHashCode(); + if (GroupId != 0) hash ^= GroupId.GetHashCode(); + if (GroupReward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(GroupReward); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Reward != 0F) { + output.WriteRawTag(61); + output.WriteFloat(Reward); + } + if (Done != false) { + output.WriteRawTag(64); + output.WriteBool(Done); + } + if (MaxStepReached != false) { + output.WriteRawTag(72); + output.WriteBool(MaxStepReached); + } + if (Id != 0) { + output.WriteRawTag(80); + output.WriteInt32(Id); + } + actionMask_.WriteTo(output, _repeated_actionMask_codec); + observations_.WriteTo(output, _repeated_observations_codec); + if (GroupId != 0) { + output.WriteRawTag(112); + output.WriteInt32(GroupId); + } + if (GroupReward != 0F) { + output.WriteRawTag(125); + output.WriteFloat(GroupReward); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Reward != 0F) { + size += 1 + 4; + } + if (Done != false) { + size += 1 + 1; + } + if (MaxStepReached != false) { + size += 1 + 1; + } + if (Id != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id); + } + size += actionMask_.CalculateSize(_repeated_actionMask_codec); + size += observations_.CalculateSize(_repeated_observations_codec); + if (GroupId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(GroupId); + } + if (GroupReward != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(AgentInfoProto other) { + if (other == null) { + return; + } + if (other.Reward != 0F) { + Reward = other.Reward; + } + if (other.Done != false) { + Done = other.Done; + } + if (other.MaxStepReached != false) { + MaxStepReached = other.MaxStepReached; + } + if (other.Id != 0) { + Id = other.Id; + } + actionMask_.Add(other.actionMask_); + observations_.Add(other.observations_); + if (other.GroupId != 0) { + GroupId = other.GroupId; + } + if (other.GroupReward != 0F) { + GroupReward = other.GroupReward; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 61: { + Reward = input.ReadFloat(); + break; + } + case 64: { + Done = input.ReadBool(); + break; + } + case 72: { + MaxStepReached = input.ReadBool(); + break; + } + case 80: { + Id = input.ReadInt32(); + break; + } + case 90: + case 88: { + actionMask_.AddEntriesFrom(input, _repeated_actionMask_codec); + break; + } + case 106: { + observations_.AddEntriesFrom(input, _repeated_observations_codec); + break; + } + case 112: { + GroupId = input.ReadInt32(); + break; + } + case 125: { + GroupReward = input.ReadFloat(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs.meta new file mode 100644 index 0000000000..07ed361456 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: ecaddd3a8141a4854a4d2c7fe8bd6a75 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfoActionPair.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfoActionPair.cs new file mode 100644 index 0000000000..37cd219c73 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfoActionPair.cs @@ -0,0 +1,219 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/agent_info_action_pair.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/agent_info_action_pair.proto + internal static partial class AgentInfoActionPairReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/agent_info_action_pair.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static AgentInfoActionPairReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cj9tbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu", + "Zm9fYWN0aW9uX3BhaXIucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjNt", + "bGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2luZm8u", + "cHJvdG8aNW1sYWdlbnRzX2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvYWdl", + "bnRfYWN0aW9uLnByb3RvIpEBChhBZ2VudEluZm9BY3Rpb25QYWlyUHJvdG8S", + "OAoKYWdlbnRfaW5mbxgBIAEoCzIkLmNvbW11bmljYXRvcl9vYmplY3RzLkFn", + "ZW50SW5mb1Byb3RvEjsKC2FjdGlvbl9pbmZvGAIgASgLMiYuY29tbXVuaWNh", + "dG9yX29iamVjdHMuQWdlbnRBY3Rpb25Qcm90b0IlqgIiVW5pdHkuTUxBZ2Vu", + "dHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.AgentInfoReflection.Descriptor, global::Unity.MLAgents.CommunicatorObjects.AgentActionReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoActionPairProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoActionPairProto.Parser, new[]{ "AgentInfo", "ActionInfo" }, null, null, null) + })); + } + #endregion + + } + #region Messages + internal sealed partial class AgentInfoActionPairProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AgentInfoActionPairProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.AgentInfoActionPairReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AgentInfoActionPairProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AgentInfoActionPairProto(AgentInfoActionPairProto other) : this() { + AgentInfo = other.agentInfo_ != null ? other.AgentInfo.Clone() : null; + ActionInfo = other.actionInfo_ != null ? other.ActionInfo.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AgentInfoActionPairProto Clone() { + return new AgentInfoActionPairProto(this); + } + + /// Field number for the "agent_info" field. + public const int AgentInfoFieldNumber = 1; + private global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto agentInfo_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto AgentInfo { + get { return agentInfo_; } + set { + agentInfo_ = value; + } + } + + /// Field number for the "action_info" field. + public const int ActionInfoFieldNumber = 2; + private global::Unity.MLAgents.CommunicatorObjects.AgentActionProto actionInfo_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.AgentActionProto ActionInfo { + get { return actionInfo_; } + set { + actionInfo_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as AgentInfoActionPairProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(AgentInfoActionPairProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(AgentInfo, other.AgentInfo)) return false; + if (!object.Equals(ActionInfo, other.ActionInfo)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (agentInfo_ != null) hash ^= AgentInfo.GetHashCode(); + if (actionInfo_ != null) hash ^= ActionInfo.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (agentInfo_ != null) { + output.WriteRawTag(10); + output.WriteMessage(AgentInfo); + } + if (actionInfo_ != null) { + output.WriteRawTag(18); + output.WriteMessage(ActionInfo); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (agentInfo_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(AgentInfo); + } + if (actionInfo_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ActionInfo); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(AgentInfoActionPairProto other) { + if (other == null) { + return; + } + if (other.agentInfo_ != null) { + if (agentInfo_ == null) { + agentInfo_ = new global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto(); + } + AgentInfo.MergeFrom(other.AgentInfo); + } + if (other.actionInfo_ != null) { + if (actionInfo_ == null) { + actionInfo_ = new global::Unity.MLAgents.CommunicatorObjects.AgentActionProto(); + } + ActionInfo.MergeFrom(other.ActionInfo); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (agentInfo_ == null) { + agentInfo_ = new global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto(); + } + input.ReadMessage(agentInfo_); + break; + } + case 18: { + if (actionInfo_ == null) { + actionInfo_ = new global::Unity.MLAgents.CommunicatorObjects.AgentActionProto(); + } + input.ReadMessage(actionInfo_); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfoActionPair.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfoActionPair.cs.meta new file mode 100644 index 0000000000..7474dcae69 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfoActionPair.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 29577366657494c678558b0643abcb30 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/BrainParameters.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/BrainParameters.cs new file mode 100644 index 0000000000..65b57f4ea3 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/BrainParameters.cs @@ -0,0 +1,524 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/brain_parameters.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/brain_parameters.proto + internal static partial class BrainParametersReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/brain_parameters.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static BrainParametersReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjltbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2JyYWluX3Bh", + "cmFtZXRlcnMucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjNtbGFnZW50", + "c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3NwYWNlX3R5cGUucHJvdG8i", + "iwEKD0FjdGlvblNwZWNQcm90bxIeChZudW1fY29udGludW91c19hY3Rpb25z", + "GAEgASgFEhwKFG51bV9kaXNjcmV0ZV9hY3Rpb25zGAIgASgFEh0KFWRpc2Ny", + "ZXRlX2JyYW5jaF9zaXplcxgDIAMoBRIbChNhY3Rpb25fZGVzY3JpcHRpb25z", + "GAQgAygJIrYCChRCcmFpblBhcmFtZXRlcnNQcm90bxIlCh12ZWN0b3JfYWN0", + "aW9uX3NpemVfZGVwcmVjYXRlZBgDIAMoBRItCiV2ZWN0b3JfYWN0aW9uX2Rl", + "c2NyaXB0aW9uc19kZXByZWNhdGVkGAUgAygJElEKI3ZlY3Rvcl9hY3Rpb25f", + "c3BhY2VfdHlwZV9kZXByZWNhdGVkGAYgASgOMiQuY29tbXVuaWNhdG9yX29i", + "amVjdHMuU3BhY2VUeXBlUHJvdG8SEgoKYnJhaW5fbmFtZRgHIAEoCRITCgtp", + "c190cmFpbmluZxgIIAEoCBI6CgthY3Rpb25fc3BlYxgJIAEoCzIlLmNvbW11", + "bmljYXRvcl9vYmplY3RzLkFjdGlvblNwZWNQcm90b0oECAEQAkoECAIQA0oE", + "CAQQBUIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IG", + "cHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.SpaceTypeReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto), global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto.Parser, new[]{ "NumContinuousActions", "NumDiscreteActions", "DiscreteBranchSizes", "ActionDescriptions" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.BrainParametersProto), global::Unity.MLAgents.CommunicatorObjects.BrainParametersProto.Parser, new[]{ "VectorActionSizeDeprecated", "VectorActionDescriptionsDeprecated", "VectorActionSpaceTypeDeprecated", "BrainName", "IsTraining", "ActionSpec" }, null, null, null) + })); + } + #endregion + + } + #region Messages + internal sealed partial class ActionSpecProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ActionSpecProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.BrainParametersReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ActionSpecProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ActionSpecProto(ActionSpecProto other) : this() { + numContinuousActions_ = other.numContinuousActions_; + numDiscreteActions_ = other.numDiscreteActions_; + discreteBranchSizes_ = other.discreteBranchSizes_.Clone(); + actionDescriptions_ = other.actionDescriptions_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ActionSpecProto Clone() { + return new ActionSpecProto(this); + } + + /// Field number for the "num_continuous_actions" field. + public const int NumContinuousActionsFieldNumber = 1; + private int numContinuousActions_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumContinuousActions { + get { return numContinuousActions_; } + set { + numContinuousActions_ = value; + } + } + + /// Field number for the "num_discrete_actions" field. + public const int NumDiscreteActionsFieldNumber = 2; + private int numDiscreteActions_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumDiscreteActions { + get { return numDiscreteActions_; } + set { + numDiscreteActions_ = value; + } + } + + /// Field number for the "discrete_branch_sizes" field. + public const int DiscreteBranchSizesFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_discreteBranchSizes_codec + = pb::FieldCodec.ForInt32(26); + private readonly pbc::RepeatedField discreteBranchSizes_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField DiscreteBranchSizes { + get { return discreteBranchSizes_; } + } + + /// Field number for the "action_descriptions" field. + public const int ActionDescriptionsFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_actionDescriptions_codec + = pb::FieldCodec.ForString(34); + private readonly pbc::RepeatedField actionDescriptions_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField ActionDescriptions { + get { return actionDescriptions_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ActionSpecProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ActionSpecProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (NumContinuousActions != other.NumContinuousActions) return false; + if (NumDiscreteActions != other.NumDiscreteActions) return false; + if(!discreteBranchSizes_.Equals(other.discreteBranchSizes_)) return false; + if(!actionDescriptions_.Equals(other.actionDescriptions_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (NumContinuousActions != 0) hash ^= NumContinuousActions.GetHashCode(); + if (NumDiscreteActions != 0) hash ^= NumDiscreteActions.GetHashCode(); + hash ^= discreteBranchSizes_.GetHashCode(); + hash ^= actionDescriptions_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (NumContinuousActions != 0) { + output.WriteRawTag(8); + output.WriteInt32(NumContinuousActions); + } + if (NumDiscreteActions != 0) { + output.WriteRawTag(16); + output.WriteInt32(NumDiscreteActions); + } + discreteBranchSizes_.WriteTo(output, _repeated_discreteBranchSizes_codec); + actionDescriptions_.WriteTo(output, _repeated_actionDescriptions_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (NumContinuousActions != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumContinuousActions); + } + if (NumDiscreteActions != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumDiscreteActions); + } + size += discreteBranchSizes_.CalculateSize(_repeated_discreteBranchSizes_codec); + size += actionDescriptions_.CalculateSize(_repeated_actionDescriptions_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ActionSpecProto other) { + if (other == null) { + return; + } + if (other.NumContinuousActions != 0) { + NumContinuousActions = other.NumContinuousActions; + } + if (other.NumDiscreteActions != 0) { + NumDiscreteActions = other.NumDiscreteActions; + } + discreteBranchSizes_.Add(other.discreteBranchSizes_); + actionDescriptions_.Add(other.actionDescriptions_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + NumContinuousActions = input.ReadInt32(); + break; + } + case 16: { + NumDiscreteActions = input.ReadInt32(); + break; + } + case 26: + case 24: { + discreteBranchSizes_.AddEntriesFrom(input, _repeated_discreteBranchSizes_codec); + break; + } + case 34: { + actionDescriptions_.AddEntriesFrom(input, _repeated_actionDescriptions_codec); + break; + } + } + } + } + + } + + internal sealed partial class BrainParametersProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BrainParametersProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.BrainParametersReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BrainParametersProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BrainParametersProto(BrainParametersProto other) : this() { + vectorActionSizeDeprecated_ = other.vectorActionSizeDeprecated_.Clone(); + vectorActionDescriptionsDeprecated_ = other.vectorActionDescriptionsDeprecated_.Clone(); + vectorActionSpaceTypeDeprecated_ = other.vectorActionSpaceTypeDeprecated_; + brainName_ = other.brainName_; + isTraining_ = other.isTraining_; + ActionSpec = other.actionSpec_ != null ? other.ActionSpec.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BrainParametersProto Clone() { + return new BrainParametersProto(this); + } + + /// Field number for the "vector_action_size_deprecated" field. + public const int VectorActionSizeDeprecatedFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_vectorActionSizeDeprecated_codec + = pb::FieldCodec.ForInt32(26); + private readonly pbc::RepeatedField vectorActionSizeDeprecated_ = new pbc::RepeatedField(); + /// + /// mark as deprecated in communicator v1.3.0 + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField VectorActionSizeDeprecated { + get { return vectorActionSizeDeprecated_; } + } + + /// Field number for the "vector_action_descriptions_deprecated" field. + public const int VectorActionDescriptionsDeprecatedFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_vectorActionDescriptionsDeprecated_codec + = pb::FieldCodec.ForString(42); + private readonly pbc::RepeatedField vectorActionDescriptionsDeprecated_ = new pbc::RepeatedField(); + /// + /// mark as deprecated in communicator v1.3.0 + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField VectorActionDescriptionsDeprecated { + get { return vectorActionDescriptionsDeprecated_; } + } + + /// Field number for the "vector_action_space_type_deprecated" field. + public const int VectorActionSpaceTypeDeprecatedFieldNumber = 6; + private global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto vectorActionSpaceTypeDeprecated_ = 0; + /// + /// mark as deprecated in communicator v1.3.0 + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto VectorActionSpaceTypeDeprecated { + get { return vectorActionSpaceTypeDeprecated_; } + set { + vectorActionSpaceTypeDeprecated_ = value; + } + } + + /// Field number for the "brain_name" field. + public const int BrainNameFieldNumber = 7; + private string brainName_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string BrainName { + get { return brainName_; } + set { + brainName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "is_training" field. + public const int IsTrainingFieldNumber = 8; + private bool isTraining_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool IsTraining { + get { return isTraining_; } + set { + isTraining_ = value; + } + } + + /// Field number for the "action_spec" field. + public const int ActionSpecFieldNumber = 9; + private global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto actionSpec_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto ActionSpec { + get { return actionSpec_; } + set { + actionSpec_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as BrainParametersProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(BrainParametersProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!vectorActionSizeDeprecated_.Equals(other.vectorActionSizeDeprecated_)) return false; + if(!vectorActionDescriptionsDeprecated_.Equals(other.vectorActionDescriptionsDeprecated_)) return false; + if (VectorActionSpaceTypeDeprecated != other.VectorActionSpaceTypeDeprecated) return false; + if (BrainName != other.BrainName) return false; + if (IsTraining != other.IsTraining) return false; + if (!object.Equals(ActionSpec, other.ActionSpec)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= vectorActionSizeDeprecated_.GetHashCode(); + hash ^= vectorActionDescriptionsDeprecated_.GetHashCode(); + if (VectorActionSpaceTypeDeprecated != 0) hash ^= VectorActionSpaceTypeDeprecated.GetHashCode(); + if (BrainName.Length != 0) hash ^= BrainName.GetHashCode(); + if (IsTraining != false) hash ^= IsTraining.GetHashCode(); + if (actionSpec_ != null) hash ^= ActionSpec.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + vectorActionSizeDeprecated_.WriteTo(output, _repeated_vectorActionSizeDeprecated_codec); + vectorActionDescriptionsDeprecated_.WriteTo(output, _repeated_vectorActionDescriptionsDeprecated_codec); + if (VectorActionSpaceTypeDeprecated != 0) { + output.WriteRawTag(48); + output.WriteEnum((int) VectorActionSpaceTypeDeprecated); + } + if (BrainName.Length != 0) { + output.WriteRawTag(58); + output.WriteString(BrainName); + } + if (IsTraining != false) { + output.WriteRawTag(64); + output.WriteBool(IsTraining); + } + if (actionSpec_ != null) { + output.WriteRawTag(74); + output.WriteMessage(ActionSpec); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += vectorActionSizeDeprecated_.CalculateSize(_repeated_vectorActionSizeDeprecated_codec); + size += vectorActionDescriptionsDeprecated_.CalculateSize(_repeated_vectorActionDescriptionsDeprecated_codec); + if (VectorActionSpaceTypeDeprecated != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) VectorActionSpaceTypeDeprecated); + } + if (BrainName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(BrainName); + } + if (IsTraining != false) { + size += 1 + 1; + } + if (actionSpec_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ActionSpec); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(BrainParametersProto other) { + if (other == null) { + return; + } + vectorActionSizeDeprecated_.Add(other.vectorActionSizeDeprecated_); + vectorActionDescriptionsDeprecated_.Add(other.vectorActionDescriptionsDeprecated_); + if (other.VectorActionSpaceTypeDeprecated != 0) { + VectorActionSpaceTypeDeprecated = other.VectorActionSpaceTypeDeprecated; + } + if (other.BrainName.Length != 0) { + BrainName = other.BrainName; + } + if (other.IsTraining != false) { + IsTraining = other.IsTraining; + } + if (other.actionSpec_ != null) { + if (actionSpec_ == null) { + actionSpec_ = new global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto(); + } + ActionSpec.MergeFrom(other.ActionSpec); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 26: + case 24: { + vectorActionSizeDeprecated_.AddEntriesFrom(input, _repeated_vectorActionSizeDeprecated_codec); + break; + } + case 42: { + vectorActionDescriptionsDeprecated_.AddEntriesFrom(input, _repeated_vectorActionDescriptionsDeprecated_codec); + break; + } + case 48: { + vectorActionSpaceTypeDeprecated_ = (global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto) input.ReadEnum(); + break; + } + case 58: { + BrainName = input.ReadString(); + break; + } + case 64: { + IsTraining = input.ReadBool(); + break; + } + case 74: { + if (actionSpec_ == null) { + actionSpec_ = new global::Unity.MLAgents.CommunicatorObjects.ActionSpecProto(); + } + input.ReadMessage(actionSpec_); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/BrainParameters.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/BrainParameters.cs.meta new file mode 100644 index 0000000000..447602fcc2 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/BrainParameters.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 26f9a93df956e4ee88c1cf5f31017f0e +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs new file mode 100644 index 0000000000..ac267f4c2f --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs @@ -0,0 +1,373 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/capabilities.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/capabilities.proto + internal static partial class CapabilitiesReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/capabilities.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static CapabilitiesReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp", + "dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi7AEKGFVuaXR5UkxD", + "YXBhYmlsaXRpZXNQcm90bxIaChJiYXNlUkxDYXBhYmlsaXRpZXMYASABKAgS", + "IwobY29uY2F0ZW5hdGVkUG5nT2JzZXJ2YXRpb25zGAIgASgIEiAKGGNvbXBy", + "ZXNzZWRDaGFubmVsTWFwcGluZxgDIAEoCBIVCg1oeWJyaWRBY3Rpb25zGAQg", + "ASgIEhkKEXRyYWluaW5nQW5hbHl0aWNzGAUgASgIEiEKGXZhcmlhYmxlTGVu", + "Z3RoT2JzZXJ2YXRpb24YBiABKAgSGAoQbXVsdGlBZ2VudEdyb3VwcxgHIAEo", + "CEIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJv", + "dG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions", "TrainingAnalytics", "VariableLengthObservation", "MultiAgentGroups" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// + /// A Capabilities message that will communicate both C# and Python + /// what features are available to both. + /// + internal sealed partial class UnityRLCapabilitiesProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityRLCapabilitiesProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.CapabilitiesReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLCapabilitiesProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLCapabilitiesProto(UnityRLCapabilitiesProto other) : this() { + baseRLCapabilities_ = other.baseRLCapabilities_; + concatenatedPngObservations_ = other.concatenatedPngObservations_; + compressedChannelMapping_ = other.compressedChannelMapping_; + hybridActions_ = other.hybridActions_; + trainingAnalytics_ = other.trainingAnalytics_; + variableLengthObservation_ = other.variableLengthObservation_; + multiAgentGroups_ = other.multiAgentGroups_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLCapabilitiesProto Clone() { + return new UnityRLCapabilitiesProto(this); + } + + /// Field number for the "baseRLCapabilities" field. + public const int BaseRLCapabilitiesFieldNumber = 1; + private bool baseRLCapabilities_; + /// + /// These are the 1.0 capabilities. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool BaseRLCapabilities { + get { return baseRLCapabilities_; } + set { + baseRLCapabilities_ = value; + } + } + + /// Field number for the "concatenatedPngObservations" field. + public const int ConcatenatedPngObservationsFieldNumber = 2; + private bool concatenatedPngObservations_; + /// + /// concatenated PNG files for compressed visual observations with >3 channels. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ConcatenatedPngObservations { + get { return concatenatedPngObservations_; } + set { + concatenatedPngObservations_ = value; + } + } + + /// Field number for the "compressedChannelMapping" field. + public const int CompressedChannelMappingFieldNumber = 3; + private bool compressedChannelMapping_; + /// + /// compression mapping for stacking compressed observations. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool CompressedChannelMapping { + get { return compressedChannelMapping_; } + set { + compressedChannelMapping_ = value; + } + } + + /// Field number for the "hybridActions" field. + public const int HybridActionsFieldNumber = 4; + private bool hybridActions_; + /// + /// support for hybrid action spaces (discrete + continuous) + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool HybridActions { + get { return hybridActions_; } + set { + hybridActions_ = value; + } + } + + /// Field number for the "trainingAnalytics" field. + public const int TrainingAnalyticsFieldNumber = 5; + private bool trainingAnalytics_; + /// + /// support for training analytics + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool TrainingAnalytics { + get { return trainingAnalytics_; } + set { + trainingAnalytics_ = value; + } + } + + /// Field number for the "variableLengthObservation" field. + public const int VariableLengthObservationFieldNumber = 6; + private bool variableLengthObservation_; + /// + /// Support for variable length observations of rank 2 + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool VariableLengthObservation { + get { return variableLengthObservation_; } + set { + variableLengthObservation_ = value; + } + } + + /// Field number for the "multiAgentGroups" field. + public const int MultiAgentGroupsFieldNumber = 7; + private bool multiAgentGroups_; + /// + /// Support for multi agent groups and group rewards + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool MultiAgentGroups { + get { return multiAgentGroups_; } + set { + multiAgentGroups_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as UnityRLCapabilitiesProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(UnityRLCapabilitiesProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (BaseRLCapabilities != other.BaseRLCapabilities) return false; + if (ConcatenatedPngObservations != other.ConcatenatedPngObservations) return false; + if (CompressedChannelMapping != other.CompressedChannelMapping) return false; + if (HybridActions != other.HybridActions) return false; + if (TrainingAnalytics != other.TrainingAnalytics) return false; + if (VariableLengthObservation != other.VariableLengthObservation) return false; + if (MultiAgentGroups != other.MultiAgentGroups) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (BaseRLCapabilities != false) hash ^= BaseRLCapabilities.GetHashCode(); + if (ConcatenatedPngObservations != false) hash ^= ConcatenatedPngObservations.GetHashCode(); + if (CompressedChannelMapping != false) hash ^= CompressedChannelMapping.GetHashCode(); + if (HybridActions != false) hash ^= HybridActions.GetHashCode(); + if (TrainingAnalytics != false) hash ^= TrainingAnalytics.GetHashCode(); + if (VariableLengthObservation != false) hash ^= VariableLengthObservation.GetHashCode(); + if (MultiAgentGroups != false) hash ^= MultiAgentGroups.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (BaseRLCapabilities != false) { + output.WriteRawTag(8); + output.WriteBool(BaseRLCapabilities); + } + if (ConcatenatedPngObservations != false) { + output.WriteRawTag(16); + output.WriteBool(ConcatenatedPngObservations); + } + if (CompressedChannelMapping != false) { + output.WriteRawTag(24); + output.WriteBool(CompressedChannelMapping); + } + if (HybridActions != false) { + output.WriteRawTag(32); + output.WriteBool(HybridActions); + } + if (TrainingAnalytics != false) { + output.WriteRawTag(40); + output.WriteBool(TrainingAnalytics); + } + if (VariableLengthObservation != false) { + output.WriteRawTag(48); + output.WriteBool(VariableLengthObservation); + } + if (MultiAgentGroups != false) { + output.WriteRawTag(56); + output.WriteBool(MultiAgentGroups); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (BaseRLCapabilities != false) { + size += 1 + 1; + } + if (ConcatenatedPngObservations != false) { + size += 1 + 1; + } + if (CompressedChannelMapping != false) { + size += 1 + 1; + } + if (HybridActions != false) { + size += 1 + 1; + } + if (TrainingAnalytics != false) { + size += 1 + 1; + } + if (VariableLengthObservation != false) { + size += 1 + 1; + } + if (MultiAgentGroups != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(UnityRLCapabilitiesProto other) { + if (other == null) { + return; + } + if (other.BaseRLCapabilities != false) { + BaseRLCapabilities = other.BaseRLCapabilities; + } + if (other.ConcatenatedPngObservations != false) { + ConcatenatedPngObservations = other.ConcatenatedPngObservations; + } + if (other.CompressedChannelMapping != false) { + CompressedChannelMapping = other.CompressedChannelMapping; + } + if (other.HybridActions != false) { + HybridActions = other.HybridActions; + } + if (other.TrainingAnalytics != false) { + TrainingAnalytics = other.TrainingAnalytics; + } + if (other.VariableLengthObservation != false) { + VariableLengthObservation = other.VariableLengthObservation; + } + if (other.MultiAgentGroups != false) { + MultiAgentGroups = other.MultiAgentGroups; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + BaseRLCapabilities = input.ReadBool(); + break; + } + case 16: { + ConcatenatedPngObservations = input.ReadBool(); + break; + } + case 24: { + CompressedChannelMapping = input.ReadBool(); + break; + } + case 32: { + HybridActions = input.ReadBool(); + break; + } + case 40: { + TrainingAnalytics = input.ReadBool(); + break; + } + case 48: { + VariableLengthObservation = input.ReadBool(); + break; + } + case 56: { + MultiAgentGroups = input.ReadBool(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs.meta new file mode 100644 index 0000000000..1e65cf6ee3 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e8388443b440343299cab2e88988e14e +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Command.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Command.cs new file mode 100644 index 0000000000..1220f9f9ee --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Command.cs @@ -0,0 +1,49 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/command.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/command.proto + internal static partial class CommandReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/command.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static CommandReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjBtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NvbW1hbmQu", + "cHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzKi0KDENvbW1hbmRQcm90bxII", + "CgRTVEVQEAASCQoFUkVTRVQQARIICgRRVUlUEAJCJaoCIlVuaXR5Lk1MQWdl", + "bnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CommandProto), }, null)); + } + #endregion + + } + #region Enums + internal enum CommandProto { + [pbr::OriginalName("STEP")] Step = 0, + [pbr::OriginalName("RESET")] Reset = 1, + [pbr::OriginalName("QUIT")] Quit = 2, + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Command.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Command.cs.meta new file mode 100644 index 0000000000..f47033a7c1 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Command.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 9be6f5025f61540eabbc831436642adc +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/CustomResetParameters.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/CustomResetParameters.cs new file mode 100644 index 0000000000..45099b04c7 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/CustomResetParameters.cs @@ -0,0 +1,146 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/custom_reset_parameters.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/custom_reset_parameters.proto + internal static partial class CustomResetParametersReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/custom_reset_parameters.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static CustomResetParametersReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CkBtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2N1c3RvbV9y", + "ZXNldF9wYXJhbWV0ZXJzLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyIc", + "ChpDdXN0b21SZXNldFBhcmFtZXRlcnNQcm90b0IlqgIiVW5pdHkuTUxBZ2Vu", + "dHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.CustomResetParametersProto), global::Unity.MLAgents.CommunicatorObjects.CustomResetParametersProto.Parser, null, null, null, null) + })); + } + #endregion + + } + #region Messages + internal sealed partial class CustomResetParametersProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CustomResetParametersProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.CustomResetParametersReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CustomResetParametersProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CustomResetParametersProto(CustomResetParametersProto other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CustomResetParametersProto Clone() { + return new CustomResetParametersProto(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as CustomResetParametersProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(CustomResetParametersProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(CustomResetParametersProto other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/CustomResetParameters.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/CustomResetParameters.cs.meta new file mode 100644 index 0000000000..aa357195f6 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/CustomResetParameters.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 62f03717ee98042bf8990733358f2dbd +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/DemonstrationMeta.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/DemonstrationMeta.cs new file mode 100644 index 0000000000..58f8ad8022 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/DemonstrationMeta.cs @@ -0,0 +1,289 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/demonstration_meta.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/demonstration_meta.proto + internal static partial class DemonstrationMetaReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/demonstration_meta.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static DemonstrationMetaReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjttbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2RlbW9uc3Ry", + "YXRpb25fbWV0YS5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMijQEKFkRl", + "bW9uc3RyYXRpb25NZXRhUHJvdG8SEwoLYXBpX3ZlcnNpb24YASABKAUSGgoS", + "ZGVtb25zdHJhdGlvbl9uYW1lGAIgASgJEhQKDG51bWJlcl9zdGVwcxgDIAEo", + "BRIXCg9udW1iZXJfZXBpc29kZXMYBCABKAUSEwoLbWVhbl9yZXdhcmQYBSAB", + "KAJCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnBy", + "b3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.DemonstrationMetaProto), global::Unity.MLAgents.CommunicatorObjects.DemonstrationMetaProto.Parser, new[]{ "ApiVersion", "DemonstrationName", "NumberSteps", "NumberEpisodes", "MeanReward" }, null, null, null) + })); + } + #endregion + + } + #region Messages + internal sealed partial class DemonstrationMetaProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DemonstrationMetaProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.DemonstrationMetaReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public DemonstrationMetaProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public DemonstrationMetaProto(DemonstrationMetaProto other) : this() { + apiVersion_ = other.apiVersion_; + demonstrationName_ = other.demonstrationName_; + numberSteps_ = other.numberSteps_; + numberEpisodes_ = other.numberEpisodes_; + meanReward_ = other.meanReward_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public DemonstrationMetaProto Clone() { + return new DemonstrationMetaProto(this); + } + + /// Field number for the "api_version" field. + public const int ApiVersionFieldNumber = 1; + private int apiVersion_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int ApiVersion { + get { return apiVersion_; } + set { + apiVersion_ = value; + } + } + + /// Field number for the "demonstration_name" field. + public const int DemonstrationNameFieldNumber = 2; + private string demonstrationName_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string DemonstrationName { + get { return demonstrationName_; } + set { + demonstrationName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "number_steps" field. + public const int NumberStepsFieldNumber = 3; + private int numberSteps_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumberSteps { + get { return numberSteps_; } + set { + numberSteps_ = value; + } + } + + /// Field number for the "number_episodes" field. + public const int NumberEpisodesFieldNumber = 4; + private int numberEpisodes_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumberEpisodes { + get { return numberEpisodes_; } + set { + numberEpisodes_ = value; + } + } + + /// Field number for the "mean_reward" field. + public const int MeanRewardFieldNumber = 5; + private float meanReward_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MeanReward { + get { return meanReward_; } + set { + meanReward_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as DemonstrationMetaProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(DemonstrationMetaProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ApiVersion != other.ApiVersion) return false; + if (DemonstrationName != other.DemonstrationName) return false; + if (NumberSteps != other.NumberSteps) return false; + if (NumberEpisodes != other.NumberEpisodes) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MeanReward, other.MeanReward)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (ApiVersion != 0) hash ^= ApiVersion.GetHashCode(); + if (DemonstrationName.Length != 0) hash ^= DemonstrationName.GetHashCode(); + if (NumberSteps != 0) hash ^= NumberSteps.GetHashCode(); + if (NumberEpisodes != 0) hash ^= NumberEpisodes.GetHashCode(); + if (MeanReward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MeanReward); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (ApiVersion != 0) { + output.WriteRawTag(8); + output.WriteInt32(ApiVersion); + } + if (DemonstrationName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(DemonstrationName); + } + if (NumberSteps != 0) { + output.WriteRawTag(24); + output.WriteInt32(NumberSteps); + } + if (NumberEpisodes != 0) { + output.WriteRawTag(32); + output.WriteInt32(NumberEpisodes); + } + if (MeanReward != 0F) { + output.WriteRawTag(45); + output.WriteFloat(MeanReward); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (ApiVersion != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ApiVersion); + } + if (DemonstrationName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DemonstrationName); + } + if (NumberSteps != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumberSteps); + } + if (NumberEpisodes != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumberEpisodes); + } + if (MeanReward != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(DemonstrationMetaProto other) { + if (other == null) { + return; + } + if (other.ApiVersion != 0) { + ApiVersion = other.ApiVersion; + } + if (other.DemonstrationName.Length != 0) { + DemonstrationName = other.DemonstrationName; + } + if (other.NumberSteps != 0) { + NumberSteps = other.NumberSteps; + } + if (other.NumberEpisodes != 0) { + NumberEpisodes = other.NumberEpisodes; + } + if (other.MeanReward != 0F) { + MeanReward = other.MeanReward; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + ApiVersion = input.ReadInt32(); + break; + } + case 18: { + DemonstrationName = input.ReadString(); + break; + } + case 24: { + NumberSteps = input.ReadInt32(); + break; + } + case 32: { + NumberEpisodes = input.ReadInt32(); + break; + } + case 45: { + MeanReward = input.ReadFloat(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/DemonstrationMeta.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/DemonstrationMeta.cs.meta new file mode 100644 index 0000000000..41176197e9 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/DemonstrationMeta.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 7248e2660150f4a39bb99dfabb9bae7d +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/EngineConfiguration.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/EngineConfiguration.cs new file mode 100644 index 0000000000..6a05c09f28 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/EngineConfiguration.cs @@ -0,0 +1,317 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/engine_configuration.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/engine_configuration.proto + internal static partial class EngineConfigurationReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/engine_configuration.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static EngineConfigurationReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cj1tbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2VuZ2luZV9j", + "b25maWd1cmF0aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyKVAQoY", + "RW5naW5lQ29uZmlndXJhdGlvblByb3RvEg0KBXdpZHRoGAEgASgFEg4KBmhl", + "aWdodBgCIAEoBRIVCg1xdWFsaXR5X2xldmVsGAMgASgFEhIKCnRpbWVfc2Nh", + "bGUYBCABKAISGQoRdGFyZ2V0X2ZyYW1lX3JhdGUYBSABKAUSFAoMc2hvd19t", + "b25pdG9yGAYgASgIQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JP", + "YmplY3RzYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.EngineConfigurationProto), global::Unity.MLAgents.CommunicatorObjects.EngineConfigurationProto.Parser, new[]{ "Width", "Height", "QualityLevel", "TimeScale", "TargetFrameRate", "ShowMonitor" }, null, null, null) + })); + } + #endregion + + } + #region Messages + internal sealed partial class EngineConfigurationProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new EngineConfigurationProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.EngineConfigurationReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EngineConfigurationProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EngineConfigurationProto(EngineConfigurationProto other) : this() { + width_ = other.width_; + height_ = other.height_; + qualityLevel_ = other.qualityLevel_; + timeScale_ = other.timeScale_; + targetFrameRate_ = other.targetFrameRate_; + showMonitor_ = other.showMonitor_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EngineConfigurationProto Clone() { + return new EngineConfigurationProto(this); + } + + /// Field number for the "width" field. + public const int WidthFieldNumber = 1; + private int width_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Width { + get { return width_; } + set { + width_ = value; + } + } + + /// Field number for the "height" field. + public const int HeightFieldNumber = 2; + private int height_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Height { + get { return height_; } + set { + height_ = value; + } + } + + /// Field number for the "quality_level" field. + public const int QualityLevelFieldNumber = 3; + private int qualityLevel_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int QualityLevel { + get { return qualityLevel_; } + set { + qualityLevel_ = value; + } + } + + /// Field number for the "time_scale" field. + public const int TimeScaleFieldNumber = 4; + private float timeScale_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float TimeScale { + get { return timeScale_; } + set { + timeScale_ = value; + } + } + + /// Field number for the "target_frame_rate" field. + public const int TargetFrameRateFieldNumber = 5; + private int targetFrameRate_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int TargetFrameRate { + get { return targetFrameRate_; } + set { + targetFrameRate_ = value; + } + } + + /// Field number for the "show_monitor" field. + public const int ShowMonitorFieldNumber = 6; + private bool showMonitor_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ShowMonitor { + get { return showMonitor_; } + set { + showMonitor_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as EngineConfigurationProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(EngineConfigurationProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Width != other.Width) return false; + if (Height != other.Height) return false; + if (QualityLevel != other.QualityLevel) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(TimeScale, other.TimeScale)) return false; + if (TargetFrameRate != other.TargetFrameRate) return false; + if (ShowMonitor != other.ShowMonitor) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Width != 0) hash ^= Width.GetHashCode(); + if (Height != 0) hash ^= Height.GetHashCode(); + if (QualityLevel != 0) hash ^= QualityLevel.GetHashCode(); + if (TimeScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(TimeScale); + if (TargetFrameRate != 0) hash ^= TargetFrameRate.GetHashCode(); + if (ShowMonitor != false) hash ^= ShowMonitor.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Width != 0) { + output.WriteRawTag(8); + output.WriteInt32(Width); + } + if (Height != 0) { + output.WriteRawTag(16); + output.WriteInt32(Height); + } + if (QualityLevel != 0) { + output.WriteRawTag(24); + output.WriteInt32(QualityLevel); + } + if (TimeScale != 0F) { + output.WriteRawTag(37); + output.WriteFloat(TimeScale); + } + if (TargetFrameRate != 0) { + output.WriteRawTag(40); + output.WriteInt32(TargetFrameRate); + } + if (ShowMonitor != false) { + output.WriteRawTag(48); + output.WriteBool(ShowMonitor); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Width != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Width); + } + if (Height != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Height); + } + if (QualityLevel != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(QualityLevel); + } + if (TimeScale != 0F) { + size += 1 + 4; + } + if (TargetFrameRate != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(TargetFrameRate); + } + if (ShowMonitor != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(EngineConfigurationProto other) { + if (other == null) { + return; + } + if (other.Width != 0) { + Width = other.Width; + } + if (other.Height != 0) { + Height = other.Height; + } + if (other.QualityLevel != 0) { + QualityLevel = other.QualityLevel; + } + if (other.TimeScale != 0F) { + TimeScale = other.TimeScale; + } + if (other.TargetFrameRate != 0) { + TargetFrameRate = other.TargetFrameRate; + } + if (other.ShowMonitor != false) { + ShowMonitor = other.ShowMonitor; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Width = input.ReadInt32(); + break; + } + case 16: { + Height = input.ReadInt32(); + break; + } + case 24: { + QualityLevel = input.ReadInt32(); + break; + } + case 37: { + TimeScale = input.ReadFloat(); + break; + } + case 40: { + TargetFrameRate = input.ReadInt32(); + break; + } + case 48: { + ShowMonitor = input.ReadBool(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/EngineConfiguration.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/EngineConfiguration.cs.meta new file mode 100644 index 0000000000..cb08edae85 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/EngineConfiguration.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 129a5bbec69fc4f42bc70e422660c8f0 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Header.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Header.cs new file mode 100644 index 0000000000..2f38cc8f44 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Header.cs @@ -0,0 +1,202 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/header.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/header.proto + internal static partial class HeaderReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/header.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static HeaderReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Ci9tbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2hlYWRlci5w", + "cm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMiLgoLSGVhZGVyUHJvdG8SDgoG", + "c3RhdHVzGAEgASgFEg8KB21lc3NhZ2UYAiABKAlCJaoCIlVuaXR5Lk1MQWdl", + "bnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.HeaderProto), global::Unity.MLAgents.CommunicatorObjects.HeaderProto.Parser, new[]{ "Status", "Message" }, null, null, null) + })); + } + #endregion + + } + #region Messages + internal sealed partial class HeaderProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HeaderProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.HeaderReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public HeaderProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public HeaderProto(HeaderProto other) : this() { + status_ = other.status_; + message_ = other.message_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public HeaderProto Clone() { + return new HeaderProto(this); + } + + /// Field number for the "status" field. + public const int StatusFieldNumber = 1; + private int status_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Status { + get { return status_; } + set { + status_ = value; + } + } + + /// Field number for the "message" field. + public const int MessageFieldNumber = 2; + private string message_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Message { + get { return message_; } + set { + message_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as HeaderProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(HeaderProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Status != other.Status) return false; + if (Message != other.Message) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Status != 0) hash ^= Status.GetHashCode(); + if (Message.Length != 0) hash ^= Message.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Status != 0) { + output.WriteRawTag(8); + output.WriteInt32(Status); + } + if (Message.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Message); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Status != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Status); + } + if (Message.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Message); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(HeaderProto other) { + if (other == null) { + return; + } + if (other.Status != 0) { + Status = other.Status; + } + if (other.Message.Length != 0) { + Message = other.Message; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Status = input.ReadInt32(); + break; + } + case 18: { + Message = input.ReadString(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Header.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Header.cs.meta new file mode 100644 index 0000000000..3084742c95 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Header.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 870996bd75a1a4fbcbb120b1e1e66c37 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs new file mode 100644 index 0000000000..3e23c8d991 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs @@ -0,0 +1,546 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/observation.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/observation.proto + internal static partial class ObservationReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/observation.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ObservationReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjRtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0", + "aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyKPAwoQT2JzZXJ2YXRp", + "b25Qcm90bxINCgVzaGFwZRgBIAMoBRJEChBjb21wcmVzc2lvbl90eXBlGAIg", + "ASgOMiouY29tbXVuaWNhdG9yX29iamVjdHMuQ29tcHJlc3Npb25UeXBlUHJv", + "dG8SGQoPY29tcHJlc3NlZF9kYXRhGAMgASgMSAASRgoKZmxvYXRfZGF0YRgE", + "IAEoCzIwLmNvbW11bmljYXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8u", + "RmxvYXREYXRhSAASIgoaY29tcHJlc3NlZF9jaGFubmVsX21hcHBpbmcYBSAD", + "KAUSHAoUZGltZW5zaW9uX3Byb3BlcnRpZXMYBiADKAUSRAoQb2JzZXJ2YXRp", + "b25fdHlwZRgHIAEoDjIqLmNvbW11bmljYXRvcl9vYmplY3RzLk9ic2VydmF0", + "aW9uVHlwZVByb3RvEgwKBG5hbWUYCCABKAkaGQoJRmxvYXREYXRhEgwKBGRh", + "dGEYASADKAJCEgoQb2JzZXJ2YXRpb25fZGF0YSopChRDb21wcmVzc2lvblR5", + "cGVQcm90bxIICgROT05FEAASBwoDUE5HEAEqQAoUT2JzZXJ2YXRpb25UeXBl", + "UHJvdG8SCwoHREVGQVVMVBAAEg8KC0dPQUxfU0lHTkFMEAEiBAgCEAIiBAgD", + "EANCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnBy", + "b3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto), }, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping", "DimensionProperties", "ObservationType", "Name" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)}) + })); + } + #endregion + + } + #region Enums + internal enum CompressionTypeProto { + [pbr::OriginalName("NONE")] None = 0, + [pbr::OriginalName("PNG")] Png = 1, + } + + internal enum ObservationTypeProto { + [pbr::OriginalName("DEFAULT")] Default = 0, + [pbr::OriginalName("GOAL_SIGNAL")] GoalSignal = 1, + } + + #endregion + + #region Messages + internal sealed partial class ObservationProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ObservationProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.ObservationReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ObservationProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ObservationProto(ObservationProto other) : this() { + shape_ = other.shape_.Clone(); + compressionType_ = other.compressionType_; + compressedChannelMapping_ = other.compressedChannelMapping_.Clone(); + dimensionProperties_ = other.dimensionProperties_.Clone(); + observationType_ = other.observationType_; + name_ = other.name_; + switch (other.ObservationDataCase) { + case ObservationDataOneofCase.CompressedData: + CompressedData = other.CompressedData; + break; + case ObservationDataOneofCase.FloatData: + FloatData = other.FloatData.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ObservationProto Clone() { + return new ObservationProto(this); + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_shape_codec + = pb::FieldCodec.ForInt32(10); + private readonly pbc::RepeatedField shape_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Shape { + get { return shape_; } + } + + /// Field number for the "compression_type" field. + public const int CompressionTypeFieldNumber = 2; + private global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto compressionType_ = 0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto CompressionType { + get { return compressionType_; } + set { + compressionType_ = value; + } + } + + /// Field number for the "compressed_data" field. + public const int CompressedDataFieldNumber = 3; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pb::ByteString CompressedData { + get { return observationDataCase_ == ObservationDataOneofCase.CompressedData ? (pb::ByteString) observationData_ : pb::ByteString.Empty; } + set { + observationData_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + observationDataCase_ = ObservationDataOneofCase.CompressedData; + } + } + + /// Field number for the "float_data" field. + public const int FloatDataFieldNumber = 4; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData FloatData { + get { return observationDataCase_ == ObservationDataOneofCase.FloatData ? (global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData) observationData_ : null; } + set { + observationData_ = value; + observationDataCase_ = value == null ? ObservationDataOneofCase.None : ObservationDataOneofCase.FloatData; + } + } + + /// Field number for the "compressed_channel_mapping" field. + public const int CompressedChannelMappingFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_compressedChannelMapping_codec + = pb::FieldCodec.ForInt32(42); + private readonly pbc::RepeatedField compressedChannelMapping_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField CompressedChannelMapping { + get { return compressedChannelMapping_; } + } + + /// Field number for the "dimension_properties" field. + public const int DimensionPropertiesFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_dimensionProperties_codec + = pb::FieldCodec.ForInt32(50); + private readonly pbc::RepeatedField dimensionProperties_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField DimensionProperties { + get { return dimensionProperties_; } + } + + /// Field number for the "observation_type" field. + public const int ObservationTypeFieldNumber = 7; + private global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto observationType_ = 0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto ObservationType { + get { return observationType_; } + set { + observationType_ = value; + } + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 8; + private string name_ = ""; + /// + /// Optional name of the observation. + /// This will be set to the ISensor name when writing, + /// and read into the ObservationSpec in the low-level API + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + private object observationData_; + /// Enum of possible cases for the "observation_data" oneof. + public enum ObservationDataOneofCase { + None = 0, + CompressedData = 3, + FloatData = 4, + } + private ObservationDataOneofCase observationDataCase_ = ObservationDataOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ObservationDataOneofCase ObservationDataCase { + get { return observationDataCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearObservationData() { + observationDataCase_ = ObservationDataOneofCase.None; + observationData_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ObservationProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ObservationProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!shape_.Equals(other.shape_)) return false; + if (CompressionType != other.CompressionType) return false; + if (CompressedData != other.CompressedData) return false; + if (!object.Equals(FloatData, other.FloatData)) return false; + if(!compressedChannelMapping_.Equals(other.compressedChannelMapping_)) return false; + if(!dimensionProperties_.Equals(other.dimensionProperties_)) return false; + if (ObservationType != other.ObservationType) return false; + if (Name != other.Name) return false; + if (ObservationDataCase != other.ObservationDataCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= shape_.GetHashCode(); + if (CompressionType != 0) hash ^= CompressionType.GetHashCode(); + if (observationDataCase_ == ObservationDataOneofCase.CompressedData) hash ^= CompressedData.GetHashCode(); + if (observationDataCase_ == ObservationDataOneofCase.FloatData) hash ^= FloatData.GetHashCode(); + hash ^= compressedChannelMapping_.GetHashCode(); + hash ^= dimensionProperties_.GetHashCode(); + if (ObservationType != 0) hash ^= ObservationType.GetHashCode(); + if (Name.Length != 0) hash ^= Name.GetHashCode(); + hash ^= (int) observationDataCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + shape_.WriteTo(output, _repeated_shape_codec); + if (CompressionType != 0) { + output.WriteRawTag(16); + output.WriteEnum((int) CompressionType); + } + if (observationDataCase_ == ObservationDataOneofCase.CompressedData) { + output.WriteRawTag(26); + output.WriteBytes(CompressedData); + } + if (observationDataCase_ == ObservationDataOneofCase.FloatData) { + output.WriteRawTag(34); + output.WriteMessage(FloatData); + } + compressedChannelMapping_.WriteTo(output, _repeated_compressedChannelMapping_codec); + dimensionProperties_.WriteTo(output, _repeated_dimensionProperties_codec); + if (ObservationType != 0) { + output.WriteRawTag(56); + output.WriteEnum((int) ObservationType); + } + if (Name.Length != 0) { + output.WriteRawTag(66); + output.WriteString(Name); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += shape_.CalculateSize(_repeated_shape_codec); + if (CompressionType != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) CompressionType); + } + if (observationDataCase_ == ObservationDataOneofCase.CompressedData) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(CompressedData); + } + if (observationDataCase_ == ObservationDataOneofCase.FloatData) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(FloatData); + } + size += compressedChannelMapping_.CalculateSize(_repeated_compressedChannelMapping_codec); + size += dimensionProperties_.CalculateSize(_repeated_dimensionProperties_codec); + if (ObservationType != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ObservationType); + } + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ObservationProto other) { + if (other == null) { + return; + } + shape_.Add(other.shape_); + if (other.CompressionType != 0) { + CompressionType = other.CompressionType; + } + compressedChannelMapping_.Add(other.compressedChannelMapping_); + dimensionProperties_.Add(other.dimensionProperties_); + if (other.ObservationType != 0) { + ObservationType = other.ObservationType; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + switch (other.ObservationDataCase) { + case ObservationDataOneofCase.CompressedData: + CompressedData = other.CompressedData; + break; + case ObservationDataOneofCase.FloatData: + if (FloatData == null) { + FloatData = new global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData(); + } + FloatData.MergeFrom(other.FloatData); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + shape_.AddEntriesFrom(input, _repeated_shape_codec); + break; + } + case 16: { + compressionType_ = (global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto) input.ReadEnum(); + break; + } + case 26: { + CompressedData = input.ReadBytes(); + break; + } + case 34: { + global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData subBuilder = new global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData(); + if (observationDataCase_ == ObservationDataOneofCase.FloatData) { + subBuilder.MergeFrom(FloatData); + } + input.ReadMessage(subBuilder); + FloatData = subBuilder; + break; + } + case 42: + case 40: { + compressedChannelMapping_.AddEntriesFrom(input, _repeated_compressedChannelMapping_codec); + break; + } + case 50: + case 48: { + dimensionProperties_.AddEntriesFrom(input, _repeated_dimensionProperties_codec); + break; + } + case 56: { + observationType_ = (global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto) input.ReadEnum(); + break; + } + case 66: { + Name = input.ReadString(); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the ObservationProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + internal sealed partial class FloatData : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FloatData()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FloatData() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FloatData(FloatData other) : this() { + data_ = other.data_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FloatData Clone() { + return new FloatData(this); + } + + /// Field number for the "data" field. + public const int DataFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_data_codec + = pb::FieldCodec.ForFloat(10); + private readonly pbc::RepeatedField data_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Data { + get { return data_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as FloatData); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(FloatData other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!data_.Equals(other.data_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= data_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + data_.WriteTo(output, _repeated_data_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += data_.CalculateSize(_repeated_data_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(FloatData other) { + if (other == null) { + return; + } + data_.Add(other.data_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 13: { + data_.AddEntriesFrom(input, _repeated_data_codec); + break; + } + } + } + } + + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs.meta new file mode 100644 index 0000000000..971fead69c --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 9fbba5f80821d4f02b4239a8e16eebfa +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/SpaceType.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/SpaceType.cs new file mode 100644 index 0000000000..d3bf7cf220 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/SpaceType.cs @@ -0,0 +1,48 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/space_type.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/space_type.proto + internal static partial class SpaceTypeReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/space_type.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static SpaceTypeReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjNtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3NwYWNlX3R5", + "cGUucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzKi4KDlNwYWNlVHlwZVBy", + "b3RvEgwKCGRpc2NyZXRlEAASDgoKY29udGludW91cxABQiWqAiJVbml0eS5N", + "TEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.SpaceTypeProto), }, null)); + } + #endregion + + } + #region Enums + internal enum SpaceTypeProto { + [pbr::OriginalName("discrete")] Discrete = 0, + [pbr::OriginalName("continuous")] Continuous = 1, + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/SpaceType.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/SpaceType.cs.meta new file mode 100644 index 0000000000..7b6ada73ea --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/SpaceType.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 3934602aadbe9471ca973685059ef04a +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs new file mode 100644 index 0000000000..042357f280 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs @@ -0,0 +1,907 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/training_analytics.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/training_analytics.proto + internal static partial class TrainingAnalyticsReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/training_analytics.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static TrainingAnalyticsReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjttbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3RyYWluaW5n", + "X2FuYWx5dGljcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi7gEKHlRy", + "YWluaW5nRW52aXJvbm1lbnRJbml0aWFsaXplZBIYChBtbGFnZW50c192ZXJz", + "aW9uGAEgASgJEh0KFW1sYWdlbnRzX2VudnNfdmVyc2lvbhgCIAEoCRIWCg5w", + "eXRob25fdmVyc2lvbhgDIAEoCRIVCg10b3JjaF92ZXJzaW9uGAQgASgJEhkK", + "EXRvcmNoX2RldmljZV90eXBlGAUgASgJEhAKCG51bV9lbnZzGAYgASgFEiIK", + "Gm51bV9lbnZpcm9ubWVudF9wYXJhbWV0ZXJzGAcgASgFEhMKC3J1bl9vcHRp", + "b25zGAggASgJIr0DChtUcmFpbmluZ0JlaGF2aW9ySW5pdGlhbGl6ZWQSFQoN", + "YmVoYXZpb3JfbmFtZRgBIAEoCRIUCgx0cmFpbmVyX3R5cGUYAiABKAkSIAoY", + "ZXh0cmluc2ljX3Jld2FyZF9lbmFibGVkGAMgASgIEhsKE2dhaWxfcmV3YXJk", + "X2VuYWJsZWQYBCABKAgSIAoYY3VyaW9zaXR5X3Jld2FyZF9lbmFibGVkGAUg", + "ASgIEhoKEnJuZF9yZXdhcmRfZW5hYmxlZBgGIAEoCBIiChpiZWhhdmlvcmFs", + "X2Nsb25pbmdfZW5hYmxlZBgHIAEoCBIZChFyZWN1cnJlbnRfZW5hYmxlZBgI", + "IAEoCBIWCg52aXN1YWxfZW5jb2RlchgJIAEoCRIaChJudW1fbmV0d29ya19s", + "YXllcnMYCiABKAUSIAoYbnVtX25ldHdvcmtfaGlkZGVuX3VuaXRzGAsgASgF", + "EhgKEHRyYWluZXJfdGhyZWFkZWQYDCABKAgSGQoRc2VsZl9wbGF5X2VuYWJs", + "ZWQYDSABKAgSGgoSY3VycmljdWx1bV9lbmFibGVkGA4gASgIEg4KBmNvbmZp", + "ZxgPIAEoCUIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0", + "c2IGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized.Parser, new[]{ "MlagentsVersion", "MlagentsEnvsVersion", "PythonVersion", "TorchVersion", "TorchDeviceType", "NumEnvs", "NumEnvironmentParameters", "RunOptions" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized.Parser, new[]{ "BehaviorName", "TrainerType", "ExtrinsicRewardEnabled", "GailRewardEnabled", "CuriosityRewardEnabled", "RndRewardEnabled", "BehavioralCloningEnabled", "RecurrentEnabled", "VisualEncoder", "NumNetworkLayers", "NumNetworkHiddenUnits", "TrainerThreaded", "SelfPlayEnabled", "CurriculumEnabled", "Config" }, null, null, null) + })); + } + #endregion + + } + #region Messages + internal sealed partial class TrainingEnvironmentInitialized : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TrainingEnvironmentInitialized()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.TrainingAnalyticsReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainingEnvironmentInitialized() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainingEnvironmentInitialized(TrainingEnvironmentInitialized other) : this() { + mlagentsVersion_ = other.mlagentsVersion_; + mlagentsEnvsVersion_ = other.mlagentsEnvsVersion_; + pythonVersion_ = other.pythonVersion_; + torchVersion_ = other.torchVersion_; + torchDeviceType_ = other.torchDeviceType_; + numEnvs_ = other.numEnvs_; + numEnvironmentParameters_ = other.numEnvironmentParameters_; + runOptions_ = other.runOptions_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainingEnvironmentInitialized Clone() { + return new TrainingEnvironmentInitialized(this); + } + + /// Field number for the "mlagents_version" field. + public const int MlagentsVersionFieldNumber = 1; + private string mlagentsVersion_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string MlagentsVersion { + get { return mlagentsVersion_; } + set { + mlagentsVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "mlagents_envs_version" field. + public const int MlagentsEnvsVersionFieldNumber = 2; + private string mlagentsEnvsVersion_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string MlagentsEnvsVersion { + get { return mlagentsEnvsVersion_; } + set { + mlagentsEnvsVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "python_version" field. + public const int PythonVersionFieldNumber = 3; + private string pythonVersion_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string PythonVersion { + get { return pythonVersion_; } + set { + pythonVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "torch_version" field. + public const int TorchVersionFieldNumber = 4; + private string torchVersion_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string TorchVersion { + get { return torchVersion_; } + set { + torchVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "torch_device_type" field. + public const int TorchDeviceTypeFieldNumber = 5; + private string torchDeviceType_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string TorchDeviceType { + get { return torchDeviceType_; } + set { + torchDeviceType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "num_envs" field. + public const int NumEnvsFieldNumber = 6; + private int numEnvs_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumEnvs { + get { return numEnvs_; } + set { + numEnvs_ = value; + } + } + + /// Field number for the "num_environment_parameters" field. + public const int NumEnvironmentParametersFieldNumber = 7; + private int numEnvironmentParameters_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumEnvironmentParameters { + get { return numEnvironmentParameters_; } + set { + numEnvironmentParameters_ = value; + } + } + + /// Field number for the "run_options" field. + public const int RunOptionsFieldNumber = 8; + private string runOptions_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string RunOptions { + get { return runOptions_; } + set { + runOptions_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as TrainingEnvironmentInitialized); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(TrainingEnvironmentInitialized other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (MlagentsVersion != other.MlagentsVersion) return false; + if (MlagentsEnvsVersion != other.MlagentsEnvsVersion) return false; + if (PythonVersion != other.PythonVersion) return false; + if (TorchVersion != other.TorchVersion) return false; + if (TorchDeviceType != other.TorchDeviceType) return false; + if (NumEnvs != other.NumEnvs) return false; + if (NumEnvironmentParameters != other.NumEnvironmentParameters) return false; + if (RunOptions != other.RunOptions) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MlagentsVersion.Length != 0) hash ^= MlagentsVersion.GetHashCode(); + if (MlagentsEnvsVersion.Length != 0) hash ^= MlagentsEnvsVersion.GetHashCode(); + if (PythonVersion.Length != 0) hash ^= PythonVersion.GetHashCode(); + if (TorchVersion.Length != 0) hash ^= TorchVersion.GetHashCode(); + if (TorchDeviceType.Length != 0) hash ^= TorchDeviceType.GetHashCode(); + if (NumEnvs != 0) hash ^= NumEnvs.GetHashCode(); + if (NumEnvironmentParameters != 0) hash ^= NumEnvironmentParameters.GetHashCode(); + if (RunOptions.Length != 0) hash ^= RunOptions.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MlagentsVersion.Length != 0) { + output.WriteRawTag(10); + output.WriteString(MlagentsVersion); + } + if (MlagentsEnvsVersion.Length != 0) { + output.WriteRawTag(18); + output.WriteString(MlagentsEnvsVersion); + } + if (PythonVersion.Length != 0) { + output.WriteRawTag(26); + output.WriteString(PythonVersion); + } + if (TorchVersion.Length != 0) { + output.WriteRawTag(34); + output.WriteString(TorchVersion); + } + if (TorchDeviceType.Length != 0) { + output.WriteRawTag(42); + output.WriteString(TorchDeviceType); + } + if (NumEnvs != 0) { + output.WriteRawTag(48); + output.WriteInt32(NumEnvs); + } + if (NumEnvironmentParameters != 0) { + output.WriteRawTag(56); + output.WriteInt32(NumEnvironmentParameters); + } + if (RunOptions.Length != 0) { + output.WriteRawTag(66); + output.WriteString(RunOptions); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MlagentsVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(MlagentsVersion); + } + if (MlagentsEnvsVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(MlagentsEnvsVersion); + } + if (PythonVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PythonVersion); + } + if (TorchVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TorchVersion); + } + if (TorchDeviceType.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TorchDeviceType); + } + if (NumEnvs != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumEnvs); + } + if (NumEnvironmentParameters != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumEnvironmentParameters); + } + if (RunOptions.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(RunOptions); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(TrainingEnvironmentInitialized other) { + if (other == null) { + return; + } + if (other.MlagentsVersion.Length != 0) { + MlagentsVersion = other.MlagentsVersion; + } + if (other.MlagentsEnvsVersion.Length != 0) { + MlagentsEnvsVersion = other.MlagentsEnvsVersion; + } + if (other.PythonVersion.Length != 0) { + PythonVersion = other.PythonVersion; + } + if (other.TorchVersion.Length != 0) { + TorchVersion = other.TorchVersion; + } + if (other.TorchDeviceType.Length != 0) { + TorchDeviceType = other.TorchDeviceType; + } + if (other.NumEnvs != 0) { + NumEnvs = other.NumEnvs; + } + if (other.NumEnvironmentParameters != 0) { + NumEnvironmentParameters = other.NumEnvironmentParameters; + } + if (other.RunOptions.Length != 0) { + RunOptions = other.RunOptions; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + MlagentsVersion = input.ReadString(); + break; + } + case 18: { + MlagentsEnvsVersion = input.ReadString(); + break; + } + case 26: { + PythonVersion = input.ReadString(); + break; + } + case 34: { + TorchVersion = input.ReadString(); + break; + } + case 42: { + TorchDeviceType = input.ReadString(); + break; + } + case 48: { + NumEnvs = input.ReadInt32(); + break; + } + case 56: { + NumEnvironmentParameters = input.ReadInt32(); + break; + } + case 66: { + RunOptions = input.ReadString(); + break; + } + } + } + } + + } + + internal sealed partial class TrainingBehaviorInitialized : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TrainingBehaviorInitialized()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.TrainingAnalyticsReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainingBehaviorInitialized() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainingBehaviorInitialized(TrainingBehaviorInitialized other) : this() { + behaviorName_ = other.behaviorName_; + trainerType_ = other.trainerType_; + extrinsicRewardEnabled_ = other.extrinsicRewardEnabled_; + gailRewardEnabled_ = other.gailRewardEnabled_; + curiosityRewardEnabled_ = other.curiosityRewardEnabled_; + rndRewardEnabled_ = other.rndRewardEnabled_; + behavioralCloningEnabled_ = other.behavioralCloningEnabled_; + recurrentEnabled_ = other.recurrentEnabled_; + visualEncoder_ = other.visualEncoder_; + numNetworkLayers_ = other.numNetworkLayers_; + numNetworkHiddenUnits_ = other.numNetworkHiddenUnits_; + trainerThreaded_ = other.trainerThreaded_; + selfPlayEnabled_ = other.selfPlayEnabled_; + curriculumEnabled_ = other.curriculumEnabled_; + config_ = other.config_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainingBehaviorInitialized Clone() { + return new TrainingBehaviorInitialized(this); + } + + /// Field number for the "behavior_name" field. + public const int BehaviorNameFieldNumber = 1; + private string behaviorName_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string BehaviorName { + get { return behaviorName_; } + set { + behaviorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "trainer_type" field. + public const int TrainerTypeFieldNumber = 2; + private string trainerType_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string TrainerType { + get { return trainerType_; } + set { + trainerType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "extrinsic_reward_enabled" field. + public const int ExtrinsicRewardEnabledFieldNumber = 3; + private bool extrinsicRewardEnabled_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ExtrinsicRewardEnabled { + get { return extrinsicRewardEnabled_; } + set { + extrinsicRewardEnabled_ = value; + } + } + + /// Field number for the "gail_reward_enabled" field. + public const int GailRewardEnabledFieldNumber = 4; + private bool gailRewardEnabled_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool GailRewardEnabled { + get { return gailRewardEnabled_; } + set { + gailRewardEnabled_ = value; + } + } + + /// Field number for the "curiosity_reward_enabled" field. + public const int CuriosityRewardEnabledFieldNumber = 5; + private bool curiosityRewardEnabled_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool CuriosityRewardEnabled { + get { return curiosityRewardEnabled_; } + set { + curiosityRewardEnabled_ = value; + } + } + + /// Field number for the "rnd_reward_enabled" field. + public const int RndRewardEnabledFieldNumber = 6; + private bool rndRewardEnabled_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool RndRewardEnabled { + get { return rndRewardEnabled_; } + set { + rndRewardEnabled_ = value; + } + } + + /// Field number for the "behavioral_cloning_enabled" field. + public const int BehavioralCloningEnabledFieldNumber = 7; + private bool behavioralCloningEnabled_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool BehavioralCloningEnabled { + get { return behavioralCloningEnabled_; } + set { + behavioralCloningEnabled_ = value; + } + } + + /// Field number for the "recurrent_enabled" field. + public const int RecurrentEnabledFieldNumber = 8; + private bool recurrentEnabled_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool RecurrentEnabled { + get { return recurrentEnabled_; } + set { + recurrentEnabled_ = value; + } + } + + /// Field number for the "visual_encoder" field. + public const int VisualEncoderFieldNumber = 9; + private string visualEncoder_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string VisualEncoder { + get { return visualEncoder_; } + set { + visualEncoder_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "num_network_layers" field. + public const int NumNetworkLayersFieldNumber = 10; + private int numNetworkLayers_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumNetworkLayers { + get { return numNetworkLayers_; } + set { + numNetworkLayers_ = value; + } + } + + /// Field number for the "num_network_hidden_units" field. + public const int NumNetworkHiddenUnitsFieldNumber = 11; + private int numNetworkHiddenUnits_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumNetworkHiddenUnits { + get { return numNetworkHiddenUnits_; } + set { + numNetworkHiddenUnits_ = value; + } + } + + /// Field number for the "trainer_threaded" field. + public const int TrainerThreadedFieldNumber = 12; + private bool trainerThreaded_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool TrainerThreaded { + get { return trainerThreaded_; } + set { + trainerThreaded_ = value; + } + } + + /// Field number for the "self_play_enabled" field. + public const int SelfPlayEnabledFieldNumber = 13; + private bool selfPlayEnabled_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool SelfPlayEnabled { + get { return selfPlayEnabled_; } + set { + selfPlayEnabled_ = value; + } + } + + /// Field number for the "curriculum_enabled" field. + public const int CurriculumEnabledFieldNumber = 14; + private bool curriculumEnabled_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool CurriculumEnabled { + get { return curriculumEnabled_; } + set { + curriculumEnabled_ = value; + } + } + + /// Field number for the "config" field. + public const int ConfigFieldNumber = 15; + private string config_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Config { + get { return config_; } + set { + config_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as TrainingBehaviorInitialized); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(TrainingBehaviorInitialized other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (BehaviorName != other.BehaviorName) return false; + if (TrainerType != other.TrainerType) return false; + if (ExtrinsicRewardEnabled != other.ExtrinsicRewardEnabled) return false; + if (GailRewardEnabled != other.GailRewardEnabled) return false; + if (CuriosityRewardEnabled != other.CuriosityRewardEnabled) return false; + if (RndRewardEnabled != other.RndRewardEnabled) return false; + if (BehavioralCloningEnabled != other.BehavioralCloningEnabled) return false; + if (RecurrentEnabled != other.RecurrentEnabled) return false; + if (VisualEncoder != other.VisualEncoder) return false; + if (NumNetworkLayers != other.NumNetworkLayers) return false; + if (NumNetworkHiddenUnits != other.NumNetworkHiddenUnits) return false; + if (TrainerThreaded != other.TrainerThreaded) return false; + if (SelfPlayEnabled != other.SelfPlayEnabled) return false; + if (CurriculumEnabled != other.CurriculumEnabled) return false; + if (Config != other.Config) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (BehaviorName.Length != 0) hash ^= BehaviorName.GetHashCode(); + if (TrainerType.Length != 0) hash ^= TrainerType.GetHashCode(); + if (ExtrinsicRewardEnabled != false) hash ^= ExtrinsicRewardEnabled.GetHashCode(); + if (GailRewardEnabled != false) hash ^= GailRewardEnabled.GetHashCode(); + if (CuriosityRewardEnabled != false) hash ^= CuriosityRewardEnabled.GetHashCode(); + if (RndRewardEnabled != false) hash ^= RndRewardEnabled.GetHashCode(); + if (BehavioralCloningEnabled != false) hash ^= BehavioralCloningEnabled.GetHashCode(); + if (RecurrentEnabled != false) hash ^= RecurrentEnabled.GetHashCode(); + if (VisualEncoder.Length != 0) hash ^= VisualEncoder.GetHashCode(); + if (NumNetworkLayers != 0) hash ^= NumNetworkLayers.GetHashCode(); + if (NumNetworkHiddenUnits != 0) hash ^= NumNetworkHiddenUnits.GetHashCode(); + if (TrainerThreaded != false) hash ^= TrainerThreaded.GetHashCode(); + if (SelfPlayEnabled != false) hash ^= SelfPlayEnabled.GetHashCode(); + if (CurriculumEnabled != false) hash ^= CurriculumEnabled.GetHashCode(); + if (Config.Length != 0) hash ^= Config.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (BehaviorName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(BehaviorName); + } + if (TrainerType.Length != 0) { + output.WriteRawTag(18); + output.WriteString(TrainerType); + } + if (ExtrinsicRewardEnabled != false) { + output.WriteRawTag(24); + output.WriteBool(ExtrinsicRewardEnabled); + } + if (GailRewardEnabled != false) { + output.WriteRawTag(32); + output.WriteBool(GailRewardEnabled); + } + if (CuriosityRewardEnabled != false) { + output.WriteRawTag(40); + output.WriteBool(CuriosityRewardEnabled); + } + if (RndRewardEnabled != false) { + output.WriteRawTag(48); + output.WriteBool(RndRewardEnabled); + } + if (BehavioralCloningEnabled != false) { + output.WriteRawTag(56); + output.WriteBool(BehavioralCloningEnabled); + } + if (RecurrentEnabled != false) { + output.WriteRawTag(64); + output.WriteBool(RecurrentEnabled); + } + if (VisualEncoder.Length != 0) { + output.WriteRawTag(74); + output.WriteString(VisualEncoder); + } + if (NumNetworkLayers != 0) { + output.WriteRawTag(80); + output.WriteInt32(NumNetworkLayers); + } + if (NumNetworkHiddenUnits != 0) { + output.WriteRawTag(88); + output.WriteInt32(NumNetworkHiddenUnits); + } + if (TrainerThreaded != false) { + output.WriteRawTag(96); + output.WriteBool(TrainerThreaded); + } + if (SelfPlayEnabled != false) { + output.WriteRawTag(104); + output.WriteBool(SelfPlayEnabled); + } + if (CurriculumEnabled != false) { + output.WriteRawTag(112); + output.WriteBool(CurriculumEnabled); + } + if (Config.Length != 0) { + output.WriteRawTag(122); + output.WriteString(Config); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (BehaviorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(BehaviorName); + } + if (TrainerType.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TrainerType); + } + if (ExtrinsicRewardEnabled != false) { + size += 1 + 1; + } + if (GailRewardEnabled != false) { + size += 1 + 1; + } + if (CuriosityRewardEnabled != false) { + size += 1 + 1; + } + if (RndRewardEnabled != false) { + size += 1 + 1; + } + if (BehavioralCloningEnabled != false) { + size += 1 + 1; + } + if (RecurrentEnabled != false) { + size += 1 + 1; + } + if (VisualEncoder.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(VisualEncoder); + } + if (NumNetworkLayers != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumNetworkLayers); + } + if (NumNetworkHiddenUnits != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumNetworkHiddenUnits); + } + if (TrainerThreaded != false) { + size += 1 + 1; + } + if (SelfPlayEnabled != false) { + size += 1 + 1; + } + if (CurriculumEnabled != false) { + size += 1 + 1; + } + if (Config.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Config); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(TrainingBehaviorInitialized other) { + if (other == null) { + return; + } + if (other.BehaviorName.Length != 0) { + BehaviorName = other.BehaviorName; + } + if (other.TrainerType.Length != 0) { + TrainerType = other.TrainerType; + } + if (other.ExtrinsicRewardEnabled != false) { + ExtrinsicRewardEnabled = other.ExtrinsicRewardEnabled; + } + if (other.GailRewardEnabled != false) { + GailRewardEnabled = other.GailRewardEnabled; + } + if (other.CuriosityRewardEnabled != false) { + CuriosityRewardEnabled = other.CuriosityRewardEnabled; + } + if (other.RndRewardEnabled != false) { + RndRewardEnabled = other.RndRewardEnabled; + } + if (other.BehavioralCloningEnabled != false) { + BehavioralCloningEnabled = other.BehavioralCloningEnabled; + } + if (other.RecurrentEnabled != false) { + RecurrentEnabled = other.RecurrentEnabled; + } + if (other.VisualEncoder.Length != 0) { + VisualEncoder = other.VisualEncoder; + } + if (other.NumNetworkLayers != 0) { + NumNetworkLayers = other.NumNetworkLayers; + } + if (other.NumNetworkHiddenUnits != 0) { + NumNetworkHiddenUnits = other.NumNetworkHiddenUnits; + } + if (other.TrainerThreaded != false) { + TrainerThreaded = other.TrainerThreaded; + } + if (other.SelfPlayEnabled != false) { + SelfPlayEnabled = other.SelfPlayEnabled; + } + if (other.CurriculumEnabled != false) { + CurriculumEnabled = other.CurriculumEnabled; + } + if (other.Config.Length != 0) { + Config = other.Config; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + BehaviorName = input.ReadString(); + break; + } + case 18: { + TrainerType = input.ReadString(); + break; + } + case 24: { + ExtrinsicRewardEnabled = input.ReadBool(); + break; + } + case 32: { + GailRewardEnabled = input.ReadBool(); + break; + } + case 40: { + CuriosityRewardEnabled = input.ReadBool(); + break; + } + case 48: { + RndRewardEnabled = input.ReadBool(); + break; + } + case 56: { + BehavioralCloningEnabled = input.ReadBool(); + break; + } + case 64: { + RecurrentEnabled = input.ReadBool(); + break; + } + case 74: { + VisualEncoder = input.ReadString(); + break; + } + case 80: { + NumNetworkLayers = input.ReadInt32(); + break; + } + case 88: { + NumNetworkHiddenUnits = input.ReadInt32(); + break; + } + case 96: { + TrainerThreaded = input.ReadBool(); + break; + } + case 104: { + SelfPlayEnabled = input.ReadBool(); + break; + } + case 112: { + CurriculumEnabled = input.ReadBool(); + break; + } + case 122: { + Config = input.ReadString(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs.meta new file mode 100644 index 0000000000..8e9d358feb --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 9e6ac06a3931742d798cf922de6b99f0 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityInput.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityInput.cs new file mode 100644 index 0000000000..9497aa0ad8 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityInput.cs @@ -0,0 +1,220 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/unity_input.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/unity_input.proto + internal static partial class UnityInputReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/unity_input.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static UnityInputReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjRtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X2lu", + "cHV0LnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cxo3bWxhZ2VudHNfZW52", + "cy9jb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9ybF9pbnB1dC5wcm90bxpG", + "bWxhZ2VudHNfZW52cy9jb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9ybF9p", + "bml0aWFsaXphdGlvbl9pbnB1dC5wcm90byKkAQoPVW5pdHlJbnB1dFByb3Rv", + "EjkKCHJsX2lucHV0GAEgASgLMicuY29tbXVuaWNhdG9yX29iamVjdHMuVW5p", + "dHlSTElucHV0UHJvdG8SVgoXcmxfaW5pdGlhbGl6YXRpb25faW5wdXQYAiAB", + "KAsyNS5jb21tdW5pY2F0b3Jfb2JqZWN0cy5Vbml0eVJMSW5pdGlhbGl6YXRp", + "b25JbnB1dFByb3RvQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JP", + "YmplY3RzYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.UnityRlInputReflection.Descriptor, global::Unity.MLAgents.CommunicatorObjects.UnityRlInitializationInputReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityInputProto), global::Unity.MLAgents.CommunicatorObjects.UnityInputProto.Parser, new[]{ "RlInput", "RlInitializationInput" }, null, null, null) + })); + } + #endregion + + } + #region Messages + internal sealed partial class UnityInputProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityInputProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.UnityInputReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityInputProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityInputProto(UnityInputProto other) : this() { + RlInput = other.rlInput_ != null ? other.RlInput.Clone() : null; + RlInitializationInput = other.rlInitializationInput_ != null ? other.RlInitializationInput.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityInputProto Clone() { + return new UnityInputProto(this); + } + + /// Field number for the "rl_input" field. + public const int RlInputFieldNumber = 1; + private global::Unity.MLAgents.CommunicatorObjects.UnityRLInputProto rlInput_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.UnityRLInputProto RlInput { + get { return rlInput_; } + set { + rlInput_ = value; + } + } + + /// Field number for the "rl_initialization_input" field. + public const int RlInitializationInputFieldNumber = 2; + private global::Unity.MLAgents.CommunicatorObjects.UnityRLInitializationInputProto rlInitializationInput_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.UnityRLInitializationInputProto RlInitializationInput { + get { return rlInitializationInput_; } + set { + rlInitializationInput_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as UnityInputProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(UnityInputProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(RlInput, other.RlInput)) return false; + if (!object.Equals(RlInitializationInput, other.RlInitializationInput)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (rlInput_ != null) hash ^= RlInput.GetHashCode(); + if (rlInitializationInput_ != null) hash ^= RlInitializationInput.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (rlInput_ != null) { + output.WriteRawTag(10); + output.WriteMessage(RlInput); + } + if (rlInitializationInput_ != null) { + output.WriteRawTag(18); + output.WriteMessage(RlInitializationInput); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (rlInput_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RlInput); + } + if (rlInitializationInput_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RlInitializationInput); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(UnityInputProto other) { + if (other == null) { + return; + } + if (other.rlInput_ != null) { + if (rlInput_ == null) { + rlInput_ = new global::Unity.MLAgents.CommunicatorObjects.UnityRLInputProto(); + } + RlInput.MergeFrom(other.RlInput); + } + if (other.rlInitializationInput_ != null) { + if (rlInitializationInput_ == null) { + rlInitializationInput_ = new global::Unity.MLAgents.CommunicatorObjects.UnityRLInitializationInputProto(); + } + RlInitializationInput.MergeFrom(other.RlInitializationInput); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (rlInput_ == null) { + rlInput_ = new global::Unity.MLAgents.CommunicatorObjects.UnityRLInputProto(); + } + input.ReadMessage(rlInput_); + break; + } + case 18: { + if (rlInitializationInput_ == null) { + rlInitializationInput_ = new global::Unity.MLAgents.CommunicatorObjects.UnityRLInitializationInputProto(); + } + input.ReadMessage(rlInitializationInput_); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityInput.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityInput.cs.meta new file mode 100644 index 0000000000..32f1aa8334 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityInput.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 13de5026cc0834f558fe971eb93c850e +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityMessage.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityMessage.cs new file mode 100644 index 0000000000..b98264dc44 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityMessage.cs @@ -0,0 +1,255 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/unity_message.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/unity_message.proto + internal static partial class UnityMessageReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/unity_message.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static UnityMessageReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjZtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X21l", + "c3NhZ2UucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjVtbGFnZW50c19l", + "bnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X291dHB1dC5wcm90bxo0", + "bWxhZ2VudHNfZW52cy9jb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9pbnB1", + "dC5wcm90bxovbWxhZ2VudHNfZW52cy9jb21tdW5pY2F0b3Jfb2JqZWN0cy9o", + "ZWFkZXIucHJvdG8iwAEKEVVuaXR5TWVzc2FnZVByb3RvEjEKBmhlYWRlchgB", + "IAEoCzIhLmNvbW11bmljYXRvcl9vYmplY3RzLkhlYWRlclByb3RvEjwKDHVu", + "aXR5X291dHB1dBgCIAEoCzImLmNvbW11bmljYXRvcl9vYmplY3RzLlVuaXR5", + "T3V0cHV0UHJvdG8SOgoLdW5pdHlfaW5wdXQYAyABKAsyJS5jb21tdW5pY2F0", + "b3Jfb2JqZWN0cy5Vbml0eUlucHV0UHJvdG9CJaoCIlVuaXR5Lk1MQWdlbnRz", + "LkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.UnityOutputReflection.Descriptor, global::Unity.MLAgents.CommunicatorObjects.UnityInputReflection.Descriptor, global::Unity.MLAgents.CommunicatorObjects.HeaderReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityMessageProto), global::Unity.MLAgents.CommunicatorObjects.UnityMessageProto.Parser, new[]{ "Header", "UnityOutput", "UnityInput" }, null, null, null) + })); + } + #endregion + + } + #region Messages + internal sealed partial class UnityMessageProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityMessageProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.UnityMessageReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityMessageProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityMessageProto(UnityMessageProto other) : this() { + Header = other.header_ != null ? other.Header.Clone() : null; + UnityOutput = other.unityOutput_ != null ? other.UnityOutput.Clone() : null; + UnityInput = other.unityInput_ != null ? other.UnityInput.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityMessageProto Clone() { + return new UnityMessageProto(this); + } + + /// Field number for the "header" field. + public const int HeaderFieldNumber = 1; + private global::Unity.MLAgents.CommunicatorObjects.HeaderProto header_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.HeaderProto Header { + get { return header_; } + set { + header_ = value; + } + } + + /// Field number for the "unity_output" field. + public const int UnityOutputFieldNumber = 2; + private global::Unity.MLAgents.CommunicatorObjects.UnityOutputProto unityOutput_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.UnityOutputProto UnityOutput { + get { return unityOutput_; } + set { + unityOutput_ = value; + } + } + + /// Field number for the "unity_input" field. + public const int UnityInputFieldNumber = 3; + private global::Unity.MLAgents.CommunicatorObjects.UnityInputProto unityInput_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.UnityInputProto UnityInput { + get { return unityInput_; } + set { + unityInput_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as UnityMessageProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(UnityMessageProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Header, other.Header)) return false; + if (!object.Equals(UnityOutput, other.UnityOutput)) return false; + if (!object.Equals(UnityInput, other.UnityInput)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (header_ != null) hash ^= Header.GetHashCode(); + if (unityOutput_ != null) hash ^= UnityOutput.GetHashCode(); + if (unityInput_ != null) hash ^= UnityInput.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (header_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Header); + } + if (unityOutput_ != null) { + output.WriteRawTag(18); + output.WriteMessage(UnityOutput); + } + if (unityInput_ != null) { + output.WriteRawTag(26); + output.WriteMessage(UnityInput); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (header_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Header); + } + if (unityOutput_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(UnityOutput); + } + if (unityInput_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(UnityInput); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(UnityMessageProto other) { + if (other == null) { + return; + } + if (other.header_ != null) { + if (header_ == null) { + header_ = new global::Unity.MLAgents.CommunicatorObjects.HeaderProto(); + } + Header.MergeFrom(other.Header); + } + if (other.unityOutput_ != null) { + if (unityOutput_ == null) { + unityOutput_ = new global::Unity.MLAgents.CommunicatorObjects.UnityOutputProto(); + } + UnityOutput.MergeFrom(other.UnityOutput); + } + if (other.unityInput_ != null) { + if (unityInput_ == null) { + unityInput_ = new global::Unity.MLAgents.CommunicatorObjects.UnityInputProto(); + } + UnityInput.MergeFrom(other.UnityInput); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (header_ == null) { + header_ = new global::Unity.MLAgents.CommunicatorObjects.HeaderProto(); + } + input.ReadMessage(header_); + break; + } + case 18: { + if (unityOutput_ == null) { + unityOutput_ = new global::Unity.MLAgents.CommunicatorObjects.UnityOutputProto(); + } + input.ReadMessage(unityOutput_); + break; + } + case 26: { + if (unityInput_ == null) { + unityInput_ = new global::Unity.MLAgents.CommunicatorObjects.UnityInputProto(); + } + input.ReadMessage(unityInput_); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityMessage.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityMessage.cs.meta new file mode 100644 index 0000000000..fe03de4e4f --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityMessage.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e2189c32296994576b0ef0aaa2b78142 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityOutput.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityOutput.cs new file mode 100644 index 0000000000..efb255d251 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityOutput.cs @@ -0,0 +1,220 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/unity_output.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/unity_output.proto + internal static partial class UnityOutputReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/unity_output.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static UnityOutputReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X291", + "dHB1dC5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMaOG1sYWdlbnRzX2Vu", + "dnMvY29tbXVuaWNhdG9yX29iamVjdHMvdW5pdHlfcmxfb3V0cHV0LnByb3Rv", + "GkdtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X3Js", + "X2luaXRpYWxpemF0aW9uX291dHB1dC5wcm90byKpAQoQVW5pdHlPdXRwdXRQ", + "cm90bxI7CglybF9vdXRwdXQYASABKAsyKC5jb21tdW5pY2F0b3Jfb2JqZWN0", + "cy5Vbml0eVJMT3V0cHV0UHJvdG8SWAoYcmxfaW5pdGlhbGl6YXRpb25fb3V0", + "cHV0GAIgASgLMjYuY29tbXVuaWNhdG9yX29iamVjdHMuVW5pdHlSTEluaXRp", + "YWxpemF0aW9uT3V0cHV0UHJvdG9CJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11", + "bmljYXRvck9iamVjdHNiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.UnityRlOutputReflection.Descriptor, global::Unity.MLAgents.CommunicatorObjects.UnityRlInitializationOutputReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityOutputProto), global::Unity.MLAgents.CommunicatorObjects.UnityOutputProto.Parser, new[]{ "RlOutput", "RlInitializationOutput" }, null, null, null) + })); + } + #endregion + + } + #region Messages + internal sealed partial class UnityOutputProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityOutputProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.UnityOutputReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityOutputProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityOutputProto(UnityOutputProto other) : this() { + RlOutput = other.rlOutput_ != null ? other.RlOutput.Clone() : null; + RlInitializationOutput = other.rlInitializationOutput_ != null ? other.RlInitializationOutput.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityOutputProto Clone() { + return new UnityOutputProto(this); + } + + /// Field number for the "rl_output" field. + public const int RlOutputFieldNumber = 1; + private global::Unity.MLAgents.CommunicatorObjects.UnityRLOutputProto rlOutput_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.UnityRLOutputProto RlOutput { + get { return rlOutput_; } + set { + rlOutput_ = value; + } + } + + /// Field number for the "rl_initialization_output" field. + public const int RlInitializationOutputFieldNumber = 2; + private global::Unity.MLAgents.CommunicatorObjects.UnityRLInitializationOutputProto rlInitializationOutput_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.UnityRLInitializationOutputProto RlInitializationOutput { + get { return rlInitializationOutput_; } + set { + rlInitializationOutput_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as UnityOutputProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(UnityOutputProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(RlOutput, other.RlOutput)) return false; + if (!object.Equals(RlInitializationOutput, other.RlInitializationOutput)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (rlOutput_ != null) hash ^= RlOutput.GetHashCode(); + if (rlInitializationOutput_ != null) hash ^= RlInitializationOutput.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (rlOutput_ != null) { + output.WriteRawTag(10); + output.WriteMessage(RlOutput); + } + if (rlInitializationOutput_ != null) { + output.WriteRawTag(18); + output.WriteMessage(RlInitializationOutput); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (rlOutput_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RlOutput); + } + if (rlInitializationOutput_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RlInitializationOutput); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(UnityOutputProto other) { + if (other == null) { + return; + } + if (other.rlOutput_ != null) { + if (rlOutput_ == null) { + rlOutput_ = new global::Unity.MLAgents.CommunicatorObjects.UnityRLOutputProto(); + } + RlOutput.MergeFrom(other.RlOutput); + } + if (other.rlInitializationOutput_ != null) { + if (rlInitializationOutput_ == null) { + rlInitializationOutput_ = new global::Unity.MLAgents.CommunicatorObjects.UnityRLInitializationOutputProto(); + } + RlInitializationOutput.MergeFrom(other.RlInitializationOutput); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (rlOutput_ == null) { + rlOutput_ = new global::Unity.MLAgents.CommunicatorObjects.UnityRLOutputProto(); + } + input.ReadMessage(rlOutput_); + break; + } + case 18: { + if (rlInitializationOutput_ == null) { + rlInitializationOutput_ = new global::Unity.MLAgents.CommunicatorObjects.UnityRLInitializationOutputProto(); + } + input.ReadMessage(rlInitializationOutput_); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityOutput.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityOutput.cs.meta new file mode 100644 index 0000000000..e1ae734459 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityOutput.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e1c19e75c7657497fbc05cfa40dd6783 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationInput.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationInput.cs new file mode 100644 index 0000000000..1b83a17e68 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationInput.cs @@ -0,0 +1,312 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/unity_rl_initialization_input.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/unity_rl_initialization_input.proto + internal static partial class UnityRlInitializationInputReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/unity_rl_initialization_input.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static UnityRlInitializationInputReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CkZtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X3Js", + "X2luaXRpYWxpemF0aW9uX2lucHV0LnByb3RvEhRjb21tdW5pY2F0b3Jfb2Jq", + "ZWN0cxo1bWxhZ2VudHNfZW52cy9jb21tdW5pY2F0b3Jfb2JqZWN0cy9jYXBh", + "YmlsaXRpZXMucHJvdG8iwAEKH1VuaXR5UkxJbml0aWFsaXphdGlvbklucHV0", + "UHJvdG8SDAoEc2VlZBgBIAEoBRIdChVjb21tdW5pY2F0aW9uX3ZlcnNpb24Y", + "AiABKAkSFwoPcGFja2FnZV92ZXJzaW9uGAMgASgJEkQKDGNhcGFiaWxpdGll", + "cxgEIAEoCzIuLmNvbW11bmljYXRvcl9vYmplY3RzLlVuaXR5UkxDYXBhYmls", + "aXRpZXNQcm90bxIRCgludW1fYXJlYXMYBSABKAVCJaoCIlVuaXR5Lk1MQWdl", + "bnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.CapabilitiesReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLInitializationInputProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLInitializationInputProto.Parser, new[]{ "Seed", "CommunicationVersion", "PackageVersion", "Capabilities", "NumAreas" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// The initializaiton message - this is typically sent from the Python trainer to the C# environment. + /// + internal sealed partial class UnityRLInitializationInputProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityRLInitializationInputProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.UnityRlInitializationInputReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInitializationInputProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInitializationInputProto(UnityRLInitializationInputProto other) : this() { + seed_ = other.seed_; + communicationVersion_ = other.communicationVersion_; + packageVersion_ = other.packageVersion_; + Capabilities = other.capabilities_ != null ? other.Capabilities.Clone() : null; + numAreas_ = other.numAreas_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInitializationInputProto Clone() { + return new UnityRLInitializationInputProto(this); + } + + /// Field number for the "seed" field. + public const int SeedFieldNumber = 1; + private int seed_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Seed { + get { return seed_; } + set { + seed_ = value; + } + } + + /// Field number for the "communication_version" field. + public const int CommunicationVersionFieldNumber = 2; + private string communicationVersion_ = ""; + /// + /// Communication protocol version that the initiating side (typically the Python trainer) is using. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string CommunicationVersion { + get { return communicationVersion_; } + set { + communicationVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "package_version" field. + public const int PackageVersionFieldNumber = 3; + private string packageVersion_ = ""; + /// + /// Package/library version that the initiating side (typically the Python trainer) is using. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string PackageVersion { + get { return packageVersion_; } + set { + packageVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "capabilities" field. + public const int CapabilitiesFieldNumber = 4; + private global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto capabilities_; + /// + /// The RL Capabilities of the Python trainer. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto Capabilities { + get { return capabilities_; } + set { + capabilities_ = value; + } + } + + /// Field number for the "num_areas" field. + public const int NumAreasFieldNumber = 5; + private int numAreas_; + /// + /// The number of training areas to instantiate + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumAreas { + get { return numAreas_; } + set { + numAreas_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as UnityRLInitializationInputProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(UnityRLInitializationInputProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Seed != other.Seed) return false; + if (CommunicationVersion != other.CommunicationVersion) return false; + if (PackageVersion != other.PackageVersion) return false; + if (!object.Equals(Capabilities, other.Capabilities)) return false; + if (NumAreas != other.NumAreas) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Seed != 0) hash ^= Seed.GetHashCode(); + if (CommunicationVersion.Length != 0) hash ^= CommunicationVersion.GetHashCode(); + if (PackageVersion.Length != 0) hash ^= PackageVersion.GetHashCode(); + if (capabilities_ != null) hash ^= Capabilities.GetHashCode(); + if (NumAreas != 0) hash ^= NumAreas.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Seed != 0) { + output.WriteRawTag(8); + output.WriteInt32(Seed); + } + if (CommunicationVersion.Length != 0) { + output.WriteRawTag(18); + output.WriteString(CommunicationVersion); + } + if (PackageVersion.Length != 0) { + output.WriteRawTag(26); + output.WriteString(PackageVersion); + } + if (capabilities_ != null) { + output.WriteRawTag(34); + output.WriteMessage(Capabilities); + } + if (NumAreas != 0) { + output.WriteRawTag(40); + output.WriteInt32(NumAreas); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Seed != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Seed); + } + if (CommunicationVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(CommunicationVersion); + } + if (PackageVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PackageVersion); + } + if (capabilities_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Capabilities); + } + if (NumAreas != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumAreas); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(UnityRLInitializationInputProto other) { + if (other == null) { + return; + } + if (other.Seed != 0) { + Seed = other.Seed; + } + if (other.CommunicationVersion.Length != 0) { + CommunicationVersion = other.CommunicationVersion; + } + if (other.PackageVersion.Length != 0) { + PackageVersion = other.PackageVersion; + } + if (other.capabilities_ != null) { + if (capabilities_ == null) { + capabilities_ = new global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto(); + } + Capabilities.MergeFrom(other.Capabilities); + } + if (other.NumAreas != 0) { + NumAreas = other.NumAreas; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Seed = input.ReadInt32(); + break; + } + case 18: { + CommunicationVersion = input.ReadString(); + break; + } + case 26: { + PackageVersion = input.ReadString(); + break; + } + case 34: { + if (capabilities_ == null) { + capabilities_ = new global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto(); + } + input.ReadMessage(capabilities_); + break; + } + case 40: { + NumAreas = input.ReadInt32(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationInput.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationInput.cs.meta new file mode 100644 index 0000000000..c0c9fc2fd4 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationInput.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e1542ad34ffb34317b74b239135d0477 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationOutput.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationOutput.cs new file mode 100644 index 0000000000..1e073a5965 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationOutput.cs @@ -0,0 +1,332 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/unity_rl_initialization_output.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/unity_rl_initialization_output.proto + internal static partial class UnityRlInitializationOutputReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/unity_rl_initialization_output.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static UnityRlInitializationOutputReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CkdtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X3Js", + "X2luaXRpYWxpemF0aW9uX291dHB1dC5wcm90bxIUY29tbXVuaWNhdG9yX29i", + "amVjdHMaNW1sYWdlbnRzX2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvY2Fw", + "YWJpbGl0aWVzLnByb3RvGjltbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9v", + "YmplY3RzL2JyYWluX3BhcmFtZXRlcnMucHJvdG8ijAIKIFVuaXR5UkxJbml0", + "aWFsaXphdGlvbk91dHB1dFByb3RvEgwKBG5hbWUYASABKAkSHQoVY29tbXVu", + "aWNhdGlvbl92ZXJzaW9uGAIgASgJEhAKCGxvZ19wYXRoGAMgASgJEkQKEGJy", + "YWluX3BhcmFtZXRlcnMYBSADKAsyKi5jb21tdW5pY2F0b3Jfb2JqZWN0cy5C", + "cmFpblBhcmFtZXRlcnNQcm90bxIXCg9wYWNrYWdlX3ZlcnNpb24YByABKAkS", + "RAoMY2FwYWJpbGl0aWVzGAggASgLMi4uY29tbXVuaWNhdG9yX29iamVjdHMu", + "VW5pdHlSTENhcGFiaWxpdGllc1Byb3RvSgQIBhAHQiWqAiJVbml0eS5NTEFn", + "ZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.CapabilitiesReflection.Descriptor, global::Unity.MLAgents.CommunicatorObjects.BrainParametersReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLInitializationOutputProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLInitializationOutputProto.Parser, new[]{ "Name", "CommunicationVersion", "LogPath", "BrainParameters", "PackageVersion", "Capabilities" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// The request message containing the academy's parameters. + /// + internal sealed partial class UnityRLInitializationOutputProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityRLInitializationOutputProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.UnityRlInitializationOutputReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInitializationOutputProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInitializationOutputProto(UnityRLInitializationOutputProto other) : this() { + name_ = other.name_; + communicationVersion_ = other.communicationVersion_; + logPath_ = other.logPath_; + brainParameters_ = other.brainParameters_.Clone(); + packageVersion_ = other.packageVersion_; + Capabilities = other.capabilities_ != null ? other.Capabilities.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInitializationOutputProto Clone() { + return new UnityRLInitializationOutputProto(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "communication_version" field. + public const int CommunicationVersionFieldNumber = 2; + private string communicationVersion_ = ""; + /// + /// Communication protocol version that the responding side (typically the C# environment) is using. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string CommunicationVersion { + get { return communicationVersion_; } + set { + communicationVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "log_path" field. + public const int LogPathFieldNumber = 3; + private string logPath_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string LogPath { + get { return logPath_; } + set { + logPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "brain_parameters" field. + public const int BrainParametersFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_brainParameters_codec + = pb::FieldCodec.ForMessage(42, global::Unity.MLAgents.CommunicatorObjects.BrainParametersProto.Parser); + private readonly pbc::RepeatedField brainParameters_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField BrainParameters { + get { return brainParameters_; } + } + + /// Field number for the "package_version" field. + public const int PackageVersionFieldNumber = 7; + private string packageVersion_ = ""; + /// + /// Package/library version that the responding side (typically the C# environment) is using. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string PackageVersion { + get { return packageVersion_; } + set { + packageVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "capabilities" field. + public const int CapabilitiesFieldNumber = 8; + private global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto capabilities_; + /// + /// The RL Capabilities of the C# package. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto Capabilities { + get { return capabilities_; } + set { + capabilities_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as UnityRLInitializationOutputProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(UnityRLInitializationOutputProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (CommunicationVersion != other.CommunicationVersion) return false; + if (LogPath != other.LogPath) return false; + if(!brainParameters_.Equals(other.brainParameters_)) return false; + if (PackageVersion != other.PackageVersion) return false; + if (!object.Equals(Capabilities, other.Capabilities)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (CommunicationVersion.Length != 0) hash ^= CommunicationVersion.GetHashCode(); + if (LogPath.Length != 0) hash ^= LogPath.GetHashCode(); + hash ^= brainParameters_.GetHashCode(); + if (PackageVersion.Length != 0) hash ^= PackageVersion.GetHashCode(); + if (capabilities_ != null) hash ^= Capabilities.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (CommunicationVersion.Length != 0) { + output.WriteRawTag(18); + output.WriteString(CommunicationVersion); + } + if (LogPath.Length != 0) { + output.WriteRawTag(26); + output.WriteString(LogPath); + } + brainParameters_.WriteTo(output, _repeated_brainParameters_codec); + if (PackageVersion.Length != 0) { + output.WriteRawTag(58); + output.WriteString(PackageVersion); + } + if (capabilities_ != null) { + output.WriteRawTag(66); + output.WriteMessage(Capabilities); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (CommunicationVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(CommunicationVersion); + } + if (LogPath.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(LogPath); + } + size += brainParameters_.CalculateSize(_repeated_brainParameters_codec); + if (PackageVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PackageVersion); + } + if (capabilities_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Capabilities); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(UnityRLInitializationOutputProto other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.CommunicationVersion.Length != 0) { + CommunicationVersion = other.CommunicationVersion; + } + if (other.LogPath.Length != 0) { + LogPath = other.LogPath; + } + brainParameters_.Add(other.brainParameters_); + if (other.PackageVersion.Length != 0) { + PackageVersion = other.PackageVersion; + } + if (other.capabilities_ != null) { + if (capabilities_ == null) { + capabilities_ = new global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto(); + } + Capabilities.MergeFrom(other.Capabilities); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + CommunicationVersion = input.ReadString(); + break; + } + case 26: { + LogPath = input.ReadString(); + break; + } + case 42: { + brainParameters_.AddEntriesFrom(input, _repeated_brainParameters_codec); + break; + } + case 58: { + PackageVersion = input.ReadString(); + break; + } + case 66: { + if (capabilities_ == null) { + capabilities_ = new global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto(); + } + input.ReadMessage(capabilities_); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationOutput.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationOutput.cs.meta new file mode 100644 index 0000000000..bbc4dba7c4 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationOutput.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e0bcb88495d5d48229140a2080dfd297 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInput.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInput.cs new file mode 100644 index 0000000000..e8553e8b8e --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInput.cs @@ -0,0 +1,361 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/unity_rl_input.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/unity_rl_input.proto + internal static partial class UnityRlInputReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/unity_rl_input.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static UnityRlInputReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjdtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X3Js", + "X2lucHV0LnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cxo1bWxhZ2VudHNf", + "ZW52cy9jb21tdW5pY2F0b3Jfb2JqZWN0cy9hZ2VudF9hY3Rpb24ucHJvdG8a", + "MG1sYWdlbnRzX2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvY29tbWFuZC5w", + "cm90byL+AgoRVW5pdHlSTElucHV0UHJvdG8SUAoNYWdlbnRfYWN0aW9ucxgB", + "IAMoCzI5LmNvbW11bmljYXRvcl9vYmplY3RzLlVuaXR5UkxJbnB1dFByb3Rv", + "LkFnZW50QWN0aW9uc0VudHJ5EjMKB2NvbW1hbmQYBCABKA4yIi5jb21tdW5p", + "Y2F0b3Jfb2JqZWN0cy5Db21tYW5kUHJvdG8SFAoMc2lkZV9jaGFubmVsGAUg", + "ASgMGk0KFExpc3RBZ2VudEFjdGlvblByb3RvEjUKBXZhbHVlGAEgAygLMiYu", + "Y29tbXVuaWNhdG9yX29iamVjdHMuQWdlbnRBY3Rpb25Qcm90bxpxChFBZ2Vu", + "dEFjdGlvbnNFbnRyeRILCgNrZXkYASABKAkSSwoFdmFsdWUYAiABKAsyPC5j", + "b21tdW5pY2F0b3Jfb2JqZWN0cy5Vbml0eVJMSW5wdXRQcm90by5MaXN0QWdl", + "bnRBY3Rpb25Qcm90bzoCOAFKBAgCEANKBAgDEARCJaoCIlVuaXR5Lk1MQWdl", + "bnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.AgentActionReflection.Descriptor, global::Unity.MLAgents.CommunicatorObjects.CommandReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLInputProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLInputProto.Parser, new[]{ "AgentActions", "Command", "SideChannel" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLInputProto.Types.ListAgentActionProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLInputProto.Types.ListAgentActionProto.Parser, new[]{ "Value" }, null, null, null), + null, }) + })); + } + #endregion + + } + #region Messages + internal sealed partial class UnityRLInputProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityRLInputProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.UnityRlInputReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInputProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInputProto(UnityRLInputProto other) : this() { + agentActions_ = other.agentActions_.Clone(); + command_ = other.command_; + sideChannel_ = other.sideChannel_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInputProto Clone() { + return new UnityRLInputProto(this); + } + + /// Field number for the "agent_actions" field. + public const int AgentActionsFieldNumber = 1; + private static readonly pbc::MapField.Codec _map_agentActions_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::Unity.MLAgents.CommunicatorObjects.UnityRLInputProto.Types.ListAgentActionProto.Parser), 10); + private readonly pbc::MapField agentActions_ = new pbc::MapField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField AgentActions { + get { return agentActions_; } + } + + /// Field number for the "command" field. + public const int CommandFieldNumber = 4; + private global::Unity.MLAgents.CommunicatorObjects.CommandProto command_ = 0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Unity.MLAgents.CommunicatorObjects.CommandProto Command { + get { return command_; } + set { + command_ = value; + } + } + + /// Field number for the "side_channel" field. + public const int SideChannelFieldNumber = 5; + private pb::ByteString sideChannel_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pb::ByteString SideChannel { + get { return sideChannel_; } + set { + sideChannel_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as UnityRLInputProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(UnityRLInputProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!AgentActions.Equals(other.AgentActions)) return false; + if (Command != other.Command) return false; + if (SideChannel != other.SideChannel) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= AgentActions.GetHashCode(); + if (Command != 0) hash ^= Command.GetHashCode(); + if (SideChannel.Length != 0) hash ^= SideChannel.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + agentActions_.WriteTo(output, _map_agentActions_codec); + if (Command != 0) { + output.WriteRawTag(32); + output.WriteEnum((int) Command); + } + if (SideChannel.Length != 0) { + output.WriteRawTag(42); + output.WriteBytes(SideChannel); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += agentActions_.CalculateSize(_map_agentActions_codec); + if (Command != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Command); + } + if (SideChannel.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(SideChannel); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(UnityRLInputProto other) { + if (other == null) { + return; + } + agentActions_.Add(other.agentActions_); + if (other.Command != 0) { + Command = other.Command; + } + if (other.SideChannel.Length != 0) { + SideChannel = other.SideChannel; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + agentActions_.AddEntriesFrom(input, _map_agentActions_codec); + break; + } + case 32: { + command_ = (global::Unity.MLAgents.CommunicatorObjects.CommandProto) input.ReadEnum(); + break; + } + case 42: { + SideChannel = input.ReadBytes(); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the UnityRLInputProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + internal sealed partial class ListAgentActionProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ListAgentActionProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.UnityRLInputProto.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ListAgentActionProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ListAgentActionProto(ListAgentActionProto other) : this() { + value_ = other.value_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ListAgentActionProto Clone() { + return new ListAgentActionProto(this); + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_value_codec + = pb::FieldCodec.ForMessage(10, global::Unity.MLAgents.CommunicatorObjects.AgentActionProto.Parser); + private readonly pbc::RepeatedField value_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Value { + get { return value_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ListAgentActionProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ListAgentActionProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!value_.Equals(other.value_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= value_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + value_.WriteTo(output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += value_.CalculateSize(_repeated_value_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ListAgentActionProto other) { + if (other == null) { + return; + } + value_.Add(other.value_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + value_.AddEntriesFrom(input, _repeated_value_codec); + break; + } + } + } + } + + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInput.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInput.cs.meta new file mode 100644 index 0000000000..c2e8eb03e9 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInput.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: c9d247f0bc49d468da0f9f0cc6484d34 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlOutput.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlOutput.cs new file mode 100644 index 0000000000..0971d50299 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlOutput.cs @@ -0,0 +1,331 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/unity_rl_output.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/unity_rl_output.proto + internal static partial class UnityRlOutputReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/unity_rl_output.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static UnityRlOutputReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjhtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X3Js", + "X291dHB1dC5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMaM21sYWdlbnRz", + "X2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvYWdlbnRfaW5mby5wcm90byK5", + "AgoSVW5pdHlSTE91dHB1dFByb3RvEkwKCmFnZW50SW5mb3MYAiADKAsyOC5j", + "b21tdW5pY2F0b3Jfb2JqZWN0cy5Vbml0eVJMT3V0cHV0UHJvdG8uQWdlbnRJ", + "bmZvc0VudHJ5EhQKDHNpZGVfY2hhbm5lbBgDIAEoDBpJChJMaXN0QWdlbnRJ", + "bmZvUHJvdG8SMwoFdmFsdWUYASADKAsyJC5jb21tdW5pY2F0b3Jfb2JqZWN0", + "cy5BZ2VudEluZm9Qcm90bxpuCg9BZ2VudEluZm9zRW50cnkSCwoDa2V5GAEg", + "ASgJEkoKBXZhbHVlGAIgASgLMjsuY29tbXVuaWNhdG9yX29iamVjdHMuVW5p", + "dHlSTE91dHB1dFByb3RvLkxpc3RBZ2VudEluZm9Qcm90bzoCOAFKBAgBEAJC", + "JaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3Rv", + "Mw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.AgentInfoReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLOutputProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLOutputProto.Parser, new[]{ "AgentInfos", "SideChannel" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto.Parser, new[]{ "Value" }, null, null, null), + null, }) + })); + } + #endregion + + } + #region Messages + internal sealed partial class UnityRLOutputProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityRLOutputProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.UnityRlOutputReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLOutputProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLOutputProto(UnityRLOutputProto other) : this() { + agentInfos_ = other.agentInfos_.Clone(); + sideChannel_ = other.sideChannel_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLOutputProto Clone() { + return new UnityRLOutputProto(this); + } + + /// Field number for the "agentInfos" field. + public const int AgentInfosFieldNumber = 2; + private static readonly pbc::MapField.Codec _map_agentInfos_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::Unity.MLAgents.CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto.Parser), 18); + private readonly pbc::MapField agentInfos_ = new pbc::MapField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField AgentInfos { + get { return agentInfos_; } + } + + /// Field number for the "side_channel" field. + public const int SideChannelFieldNumber = 3; + private pb::ByteString sideChannel_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pb::ByteString SideChannel { + get { return sideChannel_; } + set { + sideChannel_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as UnityRLOutputProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(UnityRLOutputProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!AgentInfos.Equals(other.AgentInfos)) return false; + if (SideChannel != other.SideChannel) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= AgentInfos.GetHashCode(); + if (SideChannel.Length != 0) hash ^= SideChannel.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + agentInfos_.WriteTo(output, _map_agentInfos_codec); + if (SideChannel.Length != 0) { + output.WriteRawTag(26); + output.WriteBytes(SideChannel); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += agentInfos_.CalculateSize(_map_agentInfos_codec); + if (SideChannel.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(SideChannel); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(UnityRLOutputProto other) { + if (other == null) { + return; + } + agentInfos_.Add(other.agentInfos_); + if (other.SideChannel.Length != 0) { + SideChannel = other.SideChannel; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 18: { + agentInfos_.AddEntriesFrom(input, _map_agentInfos_codec); + break; + } + case 26: { + SideChannel = input.ReadBytes(); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the UnityRLOutputProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + internal sealed partial class ListAgentInfoProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ListAgentInfoProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.UnityRLOutputProto.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ListAgentInfoProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ListAgentInfoProto(ListAgentInfoProto other) : this() { + value_ = other.value_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ListAgentInfoProto Clone() { + return new ListAgentInfoProto(this); + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_value_codec + = pb::FieldCodec.ForMessage(10, global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser); + private readonly pbc::RepeatedField value_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Value { + get { return value_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ListAgentInfoProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ListAgentInfoProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!value_.Equals(other.value_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= value_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + value_.WriteTo(output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += value_.CalculateSize(_repeated_value_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ListAgentInfoProto other) { + if (other == null) { + return; + } + value_.Add(other.value_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + value_.AddEntriesFrom(input, _repeated_value_codec); + break; + } + } + } + } + + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlOutput.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlOutput.cs.meta new file mode 100644 index 0000000000..d2607bf220 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlOutput.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 7b039d6d52b5142a78431d1758f5bf53 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityToExternal.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityToExternal.cs new file mode 100644 index 0000000000..ddb47dea96 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityToExternal.cs @@ -0,0 +1,43 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/unity_to_external.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/unity_to_external.proto + public static partial class UnityToExternalReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/unity_to_external.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static UnityToExternalReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjptbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X3Rv", + "X2V4dGVybmFsLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cxo2bWxhZ2Vu", + "dHNfZW52cy9jb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9tZXNzYWdlLnBy", + "b3RvMnYKFFVuaXR5VG9FeHRlcm5hbFByb3RvEl4KCEV4Y2hhbmdlEicuY29t", + "bXVuaWNhdG9yX29iamVjdHMuVW5pdHlNZXNzYWdlUHJvdG8aJy5jb21tdW5p", + "Y2F0b3Jfb2JqZWN0cy5Vbml0eU1lc3NhZ2VQcm90byIAQiWqAiJVbml0eS5N", + "TEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.UnityMessageReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null)); + } + #endregion + + } +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityToExternal.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityToExternal.cs.meta new file mode 100644 index 0000000000..e8ad13fd02 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityToExternal.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: c03819ddc4c30416ab6ecc83c9cee562 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityToExternalGrpc.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityToExternalGrpc.cs new file mode 100644 index 0000000000..273c57cb1d --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityToExternalGrpc.cs @@ -0,0 +1,135 @@ +#if UNITY_EDITOR || UNITY_STANDALONE +#define MLA_SUPPORTED_TRAINING_PLATFORM +#endif +#if MLA_SUPPORTED_TRAINING_PLATFORM +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/unity_to_external.proto +// +#pragma warning disable 0414, 1591 +#region Designer generated code + +using grpc = global::Grpc.Core; + +namespace Unity.MLAgents.CommunicatorObjects { + internal static partial class UnityToExternalProto + { + static readonly string __ServiceName = "communicator_objects.UnityToExternalProto"; + + static readonly grpc::Marshaller __Marshaller_communicator_objects_UnityMessageProto = grpc::Marshallers.Create((arg) => global::Google.Protobuf.MessageExtensions.ToByteArray(arg), global::Unity.MLAgents.CommunicatorObjects.UnityMessageProto.Parser.ParseFrom); + + static readonly grpc::Method __Method_Exchange = new grpc::Method( + grpc::MethodType.Unary, + __ServiceName, + "Exchange", + __Marshaller_communicator_objects_UnityMessageProto, + __Marshaller_communicator_objects_UnityMessageProto); + + /// Service descriptor + public static global::Google.Protobuf.Reflection.ServiceDescriptor Descriptor + { + get { return global::Unity.MLAgents.CommunicatorObjects.UnityToExternalReflection.Descriptor.Services[0]; } + } + + /// Base class for server-side implementations of UnityToExternalProto + public abstract partial class UnityToExternalProtoBase + { + /// + /// Sends the academy parameters + /// + /// The request received from the client. + /// The context of the server-side call handler being invoked. + /// The response to send back to the client (wrapped by a task). + public virtual global::System.Threading.Tasks.Task Exchange(global::Unity.MLAgents.CommunicatorObjects.UnityMessageProto request, grpc::ServerCallContext context) + { + throw new grpc::RpcException(new grpc::Status(grpc::StatusCode.Unimplemented, "")); + } + + } + + /// Client for UnityToExternalProto + public partial class UnityToExternalProtoClient : grpc::ClientBase + { + /// Creates a new client for UnityToExternalProto + /// The channel to use to make remote calls. + public UnityToExternalProtoClient(grpc::Channel channel) : base(channel) + { + } + /// Creates a new client for UnityToExternalProto that uses a custom CallInvoker. + /// The callInvoker to use to make remote calls. + public UnityToExternalProtoClient(grpc::CallInvoker callInvoker) : base(callInvoker) + { + } + /// Protected parameterless constructor to allow creation of test doubles. + protected UnityToExternalProtoClient() : base() + { + } + /// Protected constructor to allow creation of configured clients. + /// The client configuration. + protected UnityToExternalProtoClient(ClientBaseConfiguration configuration) : base(configuration) + { + } + + /// + /// Sends the academy parameters + /// + /// The request to send to the server. + /// The initial metadata to send with the call. This parameter is optional. + /// An optional deadline for the call. The call will be cancelled if deadline is hit. + /// An optional token for canceling the call. + /// The response received from the server. + public virtual global::Unity.MLAgents.CommunicatorObjects.UnityMessageProto Exchange(global::Unity.MLAgents.CommunicatorObjects.UnityMessageProto request, grpc::Metadata headers = null, global::System.DateTime? deadline = null, global::System.Threading.CancellationToken cancellationToken = default(global::System.Threading.CancellationToken)) + { + return Exchange(request, new grpc::CallOptions(headers, deadline, cancellationToken)); + } + /// + /// Sends the academy parameters + /// + /// The request to send to the server. + /// The options for the call. + /// The response received from the server. + public virtual global::Unity.MLAgents.CommunicatorObjects.UnityMessageProto Exchange(global::Unity.MLAgents.CommunicatorObjects.UnityMessageProto request, grpc::CallOptions options) + { + return CallInvoker.BlockingUnaryCall(__Method_Exchange, null, options, request); + } + /// + /// Sends the academy parameters + /// + /// The request to send to the server. + /// The initial metadata to send with the call. This parameter is optional. + /// An optional deadline for the call. The call will be cancelled if deadline is hit. + /// An optional token for canceling the call. + /// The call object. + public virtual grpc::AsyncUnaryCall ExchangeAsync(global::Unity.MLAgents.CommunicatorObjects.UnityMessageProto request, grpc::Metadata headers = null, global::System.DateTime? deadline = null, global::System.Threading.CancellationToken cancellationToken = default(global::System.Threading.CancellationToken)) + { + return ExchangeAsync(request, new grpc::CallOptions(headers, deadline, cancellationToken)); + } + /// + /// Sends the academy parameters + /// + /// The request to send to the server. + /// The options for the call. + /// The call object. + public virtual grpc::AsyncUnaryCall ExchangeAsync(global::Unity.MLAgents.CommunicatorObjects.UnityMessageProto request, grpc::CallOptions options) + { + return CallInvoker.AsyncUnaryCall(__Method_Exchange, null, options, request); + } + /// Creates a new instance of client from given ClientBaseConfiguration. + protected override UnityToExternalProtoClient NewInstance(ClientBaseConfiguration configuration) + { + return new UnityToExternalProtoClient(configuration); + } + } + + /// Creates service definition that can be registered with a server + /// An object implementing the server-side handling logic. + public static grpc::ServerServiceDefinition BindService(UnityToExternalProtoBase serviceImpl) + { + return grpc::ServerServiceDefinition.CreateBuilder() + .AddMethod(__Method_Exchange, serviceImpl.Exchange).Build(); + } + + } +} +#endregion +#endif diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityToExternalGrpc.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityToExternalGrpc.cs.meta new file mode 100644 index 0000000000..620a3f1d44 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityToExternalGrpc.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 6c0f560328e7343499ad203c75c11741 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/Unity.ML-Agents.CommunicatorObjects.asmdef b/com.unity.ml-agents/Runtime/Grpc/Unity.ML-Agents.CommunicatorObjects.asmdef new file mode 100755 index 0000000000..469aab355b --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/Unity.ML-Agents.CommunicatorObjects.asmdef @@ -0,0 +1,18 @@ +{ + "name": "Unity.ML-Agents.CommunicatorObjects", + "rootNamespace": "", + "references": [], + "includePlatforms": [], + "excludePlatforms": [], + "allowUnsafeCode": false, + "overrideReferences": false, + "precompiledReferences": [ + "System.IO.Abstractions.dll", + "Google.Protobuf.dll", + "Grpc.Core.dll" + ], + "autoReferenced": true, + "defineConstraints": [], + "versionDefines": [], + "noEngineReferences": false +} \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Grpc/Unity.ML-Agents.CommunicatorObjects.asmdef.meta b/com.unity.ml-agents/Runtime/Grpc/Unity.ML-Agents.CommunicatorObjects.asmdef.meta new file mode 100644 index 0000000000..1bf58ee5ea --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/Unity.ML-Agents.CommunicatorObjects.asmdef.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: 61c5b659adf544b4baf3eef86248e13a +AssemblyDefinitionImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs b/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs new file mode 100644 index 0000000000..0dcdd0a5af --- /dev/null +++ b/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs @@ -0,0 +1,28 @@ +namespace Unity.MLAgents +{ + /// + /// MultiAgentGroup interface for grouping agents to support multi-agent training. + /// + public interface IMultiAgentGroup + { + /// + /// Get the ID of MultiAgentGroup. + /// + /// + /// MultiAgentGroup ID. + /// + int GetId(); + + /// + /// Register agent to the MultiAgentGroup. + /// + /// The Agent to register. + void RegisterAgent(Agent agent); + + /// + /// Unregister agent from the MultiAgentGroup. + /// + /// The Agent to unregister. + void UnregisterAgent(Agent agent); + } +} diff --git a/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs.meta b/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs.meta new file mode 100644 index 0000000000..b9171ab040 --- /dev/null +++ b/com.unity.ml-agents/Runtime/IMultiAgentGroup.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 3744ac27d956e43e1a39c7ba2550ab82 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Inference.meta b/com.unity.ml-agents/Runtime/Inference.meta new file mode 100644 index 0000000000..cb7450c10e --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: ccb5b186c34bc48d8bd81e9d9bd5cd95 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs b/com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs new file mode 100644 index 0000000000..e80ccd9fdd --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs @@ -0,0 +1,207 @@ +using System.Collections.Generic; +using System.Linq; +using Unity.MLAgents.Inference.Utils; +using Unity.MLAgents.Actuators; +using Unity.Barracuda; +using UnityEngine; + +namespace Unity.MLAgents.Inference +{ + /// + /// The Applier for the Continuous Action output tensor. Tensor is assumed to contain the + /// continuous action data of the agents in the batch. + /// + internal class ContinuousActionOutputApplier : TensorApplier.IApplier + { + readonly ActionSpec m_ActionSpec; + + public ContinuousActionOutputApplier(ActionSpec actionSpec) + { + m_ActionSpec = actionSpec; + } + + public void Apply(TensorProxy tensorProxy, IList actionIds, Dictionary lastActions) + { + var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1]; + var agentIndex = 0; + for (var i = 0; i < actionIds.Count; i++) + { + var agentId = actionIds[i]; + if (lastActions.ContainsKey(agentId)) + { + var actionBuffer = lastActions[agentId]; + if (actionBuffer.IsEmpty()) + { + actionBuffer = new ActionBuffers(m_ActionSpec); + lastActions[agentId] = actionBuffer; + } + var continuousBuffer = actionBuffer.ContinuousActions; + for (var j = 0; j < actionSize; j++) + { + continuousBuffer[j] = tensorProxy.data[agentIndex, j]; + } + } + agentIndex++; + } + } + } + + /// + /// The Applier for the Discrete Action output tensor. + /// + internal class DiscreteActionOutputApplier : TensorApplier.IApplier + { + readonly ActionSpec m_ActionSpec; + + + public DiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator) + { + m_ActionSpec = actionSpec; + } + + public void Apply(TensorProxy tensorProxy, IList actionIds, Dictionary lastActions) + { + var agentIndex = 0; + var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1]; + for (var i = 0; i < actionIds.Count; i++) + { + var agentId = actionIds[i]; + if (lastActions.ContainsKey(agentId)) + { + var actionBuffer = lastActions[agentId]; + if (actionBuffer.IsEmpty()) + { + actionBuffer = new ActionBuffers(m_ActionSpec); + lastActions[agentId] = actionBuffer; + } + var discreteBuffer = actionBuffer.DiscreteActions; + for (var j = 0; j < actionSize; j++) + { + discreteBuffer[j] = (int)tensorProxy.data[agentIndex, j]; + } + } + agentIndex++; + } + } + } + + + /// + /// The Applier for the Discrete Action output tensor. Uses multinomial to sample discrete + /// actions from the logits contained in the tensor. + /// + internal class LegacyDiscreteActionOutputApplier : TensorApplier.IApplier + { + readonly int[] m_ActionSize; + readonly Multinomial m_Multinomial; + readonly ActionSpec m_ActionSpec; + readonly int[] m_StartActionIndices; + readonly float[] m_CdfBuffer; + + + public LegacyDiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator) + { + m_ActionSize = actionSpec.BranchSizes; + m_Multinomial = new Multinomial(seed); + m_ActionSpec = actionSpec; + m_StartActionIndices = Utilities.CumSum(m_ActionSize); + + // Scratch space for computing the cumulative distribution function. + // In order to reuse it, make it the size of the largest branch. + var largestBranch = Mathf.Max(m_ActionSize); + m_CdfBuffer = new float[largestBranch]; + } + + public void Apply(TensorProxy tensorProxy, IList actionIds, Dictionary lastActions) + { + var agentIndex = 0; + for (var i = 0; i < actionIds.Count; i++) + { + var agentId = actionIds[i]; + if (lastActions.ContainsKey(agentId)) + { + var actionBuffer = lastActions[agentId]; + if (actionBuffer.IsEmpty()) + { + actionBuffer = new ActionBuffers(m_ActionSpec); + lastActions[agentId] = actionBuffer; + } + var discreteBuffer = actionBuffer.DiscreteActions; + for (var j = 0; j < m_ActionSize.Length; j++) + { + ComputeCdf(tensorProxy, agentIndex, m_StartActionIndices[j], m_ActionSize[j]); + discreteBuffer[j] = m_Multinomial.Sample(m_CdfBuffer, m_ActionSize[j]); + } + } + agentIndex++; + } + } + + /// + /// Compute the cumulative distribution function for a given agent's action + /// given the log-probabilities. + /// The results are stored in m_CdfBuffer, which is the size of the largest action's number of branches. + /// + /// + /// Index of the agent being considered + /// Offset into the tensor's channel. + /// + internal void ComputeCdf(TensorProxy logProbs, int batch, int channelOffset, int branchSize) + { + // Find the class maximum + var maxProb = float.NegativeInfinity; + for (var cls = 0; cls < branchSize; ++cls) + { + maxProb = Mathf.Max(logProbs.data[batch, cls + channelOffset], maxProb); + } + + // Sum the log probabilities and compute CDF + var sumProb = 0.0f; + for (var cls = 0; cls < branchSize; ++cls) + { + sumProb += Mathf.Exp(logProbs.data[batch, cls + channelOffset] - maxProb); + m_CdfBuffer[cls] = sumProb; + } + } + } + + /// + /// The Applier for the Memory output tensor. Tensor is assumed to contain the new + /// memory data of the agents in the batch. + /// + internal class MemoryOutputApplier : TensorApplier.IApplier + { + Dictionary> m_Memories; + + public MemoryOutputApplier( + Dictionary> memories) + { + m_Memories = memories; + } + + public void Apply(TensorProxy tensorProxy, IList actionIds, Dictionary lastActions) + { + var agentIndex = 0; + var memorySize = tensorProxy.data.width; + for (var i = 0; i < actionIds.Count; i++) + { + var agentId = actionIds[i]; + List memory; + if (!m_Memories.TryGetValue(agentId, out memory) + || memory.Count < memorySize) + { + memory = new List(); + memory.AddRange(Enumerable.Repeat(0f, memorySize)); + } + + for (var j = 0; j < memorySize; j++) + { + memory[j] = tensorProxy.data[agentIndex, 0, j, 0]; + } + + m_Memories[agentId] = memory; + agentIndex++; + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs.meta b/com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs.meta new file mode 100644 index 0000000000..b6ecb20fa6 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 99d5dc2d52e442d1a1f466a246cfb28d +timeCreated: 1539118675 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs b/com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs new file mode 100644 index 0000000000..5e7338c057 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs @@ -0,0 +1,449 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Unity.Barracuda; +using FailedCheck = Unity.MLAgents.Inference.BarracudaModelParamLoader.FailedCheck; + +namespace Unity.MLAgents.Inference +{ + /// + /// Barracuda Model extension methods. + /// + internal static class BarracudaModelExtensions + { + /// + /// Get array of the input tensor names of the model. + /// + /// + /// The Barracuda engine model for loading static parameters. + /// + /// Array of the input tensor names of the model + public static string[] GetInputNames(this Model model) + { + var names = new List(); + + if (model == null) + return names.ToArray(); + + foreach (var input in model.inputs) + { + names.Add(input.name); + } + + foreach (var mem in model.memories) + { + names.Add(mem.input); + } + + names.Sort(StringComparer.InvariantCulture); + + return names.ToArray(); + } + + /// + /// Get the version of the model. + /// + /// + /// The Barracuda engine model for loading static parameters. + /// + /// The api version of the model + public static int GetVersion(this Model model) + { + return (int)model.GetTensorByName(TensorNames.VersionNumber)[0]; + } + + /// + /// Generates the Tensor inputs that are expected to be present in the Model. + /// + /// + /// The Barracuda engine model for loading static parameters. + /// + /// TensorProxy IEnumerable with the expected Tensor inputs. + public static IReadOnlyList GetInputTensors(this Model model) + { + var tensors = new List(); + + if (model == null) + return tensors; + + foreach (var input in model.inputs) + { + tensors.Add(new TensorProxy + { + name = input.name, + valueType = TensorProxy.TensorType.FloatingPoint, + data = null, + shape = input.shape.Select(i => (long)i).ToArray() + }); + } + + tensors.Sort((el1, el2) => string.Compare(el1.name, el2.name, StringComparison.InvariantCulture)); + + return tensors; + } + + /// + /// Get number of visual observation inputs to the model. + /// + /// + /// The Barracuda engine model for loading static parameters. + /// + /// Number of visual observation inputs to the model + public static int GetNumVisualInputs(this Model model) + { + var count = 0; + if (model == null) + return count; + + foreach (var input in model.inputs) + { + if (input.name.StartsWith(TensorNames.VisualObservationPlaceholderPrefix)) + { + count++; + } + } + + return count; + } + + /// + /// Get array of the output tensor names of the model. + /// + /// + /// The Barracuda engine model for loading static parameters. + /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. + /// Array of the output tensor names of the model + public static string[] GetOutputNames(this Model model, bool deterministicInference = false) + { + var names = new List(); + + if (model == null) + { + return names.ToArray(); + } + + if (model.HasContinuousOutputs(deterministicInference)) + { + names.Add(model.ContinuousOutputName(deterministicInference)); + } + if (model.HasDiscreteOutputs(deterministicInference)) + { + names.Add(model.DiscreteOutputName(deterministicInference)); + } + + var modelVersion = model.GetVersion(); + var memory = (int)model.GetTensorByName(TensorNames.MemorySize)[0]; + if (memory > 0) + { + names.Add(TensorNames.RecurrentOutput); + } + + names.Sort(StringComparer.InvariantCulture); + + return names.ToArray(); + } + + /// + /// Check if the model has continuous action outputs. + /// + /// + /// The Barracuda engine model for loading static parameters. + /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. + /// True if the model has continuous action outputs. + public static bool HasContinuousOutputs(this Model model, bool deterministicInference = false) + { + if (model == null) + return false; + if (!model.SupportsContinuousAndDiscrete()) + { + return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0; + } + else + { + bool hasStochasticOutput = !deterministicInference && + model.outputs.Contains(TensorNames.ContinuousActionOutput); + bool hasDeterministicOutput = deterministicInference && + model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput); + + return (hasStochasticOutput || hasDeterministicOutput) && + (int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0; + } + } + + /// + /// Continuous action output size of the model. + /// + /// + /// The Barracuda engine model for loading static parameters. + /// + /// Size of continuous action output. + public static int ContinuousOutputSize(this Model model) + { + if (model == null) + return 0; + if (!model.SupportsContinuousAndDiscrete()) + { + return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ? + (int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0] : 0; + } + else + { + var continuousOutputShape = model.GetTensorByName(TensorNames.ContinuousActionOutputShape); + return continuousOutputShape == null ? 0 : (int)continuousOutputShape[0]; + } + } + + /// + /// Continuous action output tensor name of the model. + /// + /// + /// The Barracuda engine model for loading static parameters. + /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. + /// Tensor name of continuous action output. + public static string ContinuousOutputName(this Model model, bool deterministicInference = false) + { + if (model == null) + return null; + if (!model.SupportsContinuousAndDiscrete()) + { + return TensorNames.ActionOutputDeprecated; + } + else + { + return deterministicInference ? TensorNames.DeterministicContinuousActionOutput : TensorNames.ContinuousActionOutput; + } + } + + /// + /// Check if the model has discrete action outputs. + /// + /// + /// The Barracuda engine model for loading static parameters. + /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. + /// True if the model has discrete action outputs. + public static bool HasDiscreteOutputs(this Model model, bool deterministicInference = false) + { + if (model == null) + return false; + if (!model.SupportsContinuousAndDiscrete()) + { + return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] == 0; + } + else + { + bool hasStochasticOutput = !deterministicInference && + model.outputs.Contains(TensorNames.DiscreteActionOutput); + bool hasDeterministicOutput = deterministicInference && + model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput); + return (hasStochasticOutput || hasDeterministicOutput) && + model.DiscreteOutputSize() > 0; + } + } + + /// + /// Discrete action output size of the model. This is equal to the sum of the branch sizes. + /// This method gets the tensor representing the list of branch size and returns the + /// sum of all the elements in the Tensor. + /// - In version 1.X this tensor contains a single number, the sum of all branch + /// size values. + /// - In version 2.X this tensor contains a 1D Tensor with each element corresponding + /// to a branch size. + /// Since this method does the sum of all elements in the tensor, the output + /// will be the same on both 1.X and 2.X. + /// + /// + /// The Barracuda engine model for loading static parameters. + /// + /// Size of discrete action output. + public static int DiscreteOutputSize(this Model model) + { + if (model == null) + return 0; + if (!model.SupportsContinuousAndDiscrete()) + { + return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ? + 0 : (int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0]; + } + else + { + var discreteOutputShape = model.GetTensorByName(TensorNames.DiscreteActionOutputShape); + if (discreteOutputShape == null) + { + return 0; + } + else + { + int result = 0; + for (int i = 0; i < discreteOutputShape.length; i++) + { + result += (int)discreteOutputShape[i]; + } + return result; + } + } + } + + /// + /// Discrete action output tensor name of the model. + /// + /// + /// The Barracuda engine model for loading static parameters. + /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. + /// Tensor name of discrete action output. + public static string DiscreteOutputName(this Model model, bool deterministicInference = false) + { + if (model == null) + return null; + if (!model.SupportsContinuousAndDiscrete()) + { + return TensorNames.ActionOutputDeprecated; + } + else + { + return deterministicInference ? TensorNames.DeterministicDiscreteActionOutput : TensorNames.DiscreteActionOutput; + } + } + + /// + /// Check if the model supports both continuous and discrete actions. + /// If not, the model should be handled differently and use the deprecated fields. + /// + /// + /// The Barracuda engine model for loading static parameters. + /// + /// True if the model supports both continuous and discrete actions. + public static bool SupportsContinuousAndDiscrete(this Model model) + { + return model == null || + model.outputs.Contains(TensorNames.ContinuousActionOutput) || + model.outputs.Contains(TensorNames.DiscreteActionOutput); + } + + /// + /// Check if the model contains all the expected input/output tensors. + /// + /// + /// The Barracuda engine model for loading static parameters. + /// + /// Output list of failure messages + /// Inference only: set to true if the action selection from model should be + /// deterministic. + /// True if the model contains all the expected tensors. + /// TODO: add checks for deterministic actions + public static bool CheckExpectedTensors(this Model model, List failedModelChecks, bool deterministicInference = false) + { + // Check the presence of model version + var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber); + if (modelApiVersionTensor == null) + { + failedModelChecks.Add( + FailedCheck.Warning($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.") + ); + return false; + } + + // Check the presence of memory size + var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize); + if (memorySizeTensor == null) + { + failedModelChecks.Add( + FailedCheck.Warning($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.") + ); + return false; + } + + // Check the presence of action output tensor + if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) && + !model.outputs.Contains(TensorNames.ContinuousActionOutput) && + !model.outputs.Contains(TensorNames.DiscreteActionOutput) && + !model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput) && + !model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput)) + { + failedModelChecks.Add( + FailedCheck.Warning("The model does not contain any Action Output Node.") + ); + return false; + } + + // Check the presence of action output shape tensor + if (!model.SupportsContinuousAndDiscrete()) + { + if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null) + { + failedModelChecks.Add( + FailedCheck.Warning("The model does not contain any Action Output Shape Node.") + ); + return false; + } + if (model.GetTensorByName(TensorNames.IsContinuousControlDeprecated) == null) + { + failedModelChecks.Add( + FailedCheck.Warning($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was " + + "not found in the model file. " + + "This is only required for model that uses a deprecated model format.") + ); + return false; + } + } + else + { + if (model.outputs.Contains(TensorNames.ContinuousActionOutput)) + { + if (model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null) + { + failedModelChecks.Add( + FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.") + ); + return false; + } + + else if (!model.HasContinuousOutputs(deterministicInference)) + { + var actionType = deterministicInference ? "deterministic" : "stochastic"; + var actionName = deterministicInference ? "Deterministic" : ""; + failedModelChecks.Add( + FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Continuous Action Output Tensor. Uncheck `Deterministic inference` flag..") + ); + return false; + } + } + + if (model.outputs.Contains(TensorNames.DiscreteActionOutput)) + { + if (model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null) + { + failedModelChecks.Add( + FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.") + ); + return false; + } + else if (!model.HasDiscreteOutputs(deterministicInference)) + { + var actionType = deterministicInference ? "deterministic" : "stochastic"; + var actionName = deterministicInference ? "Deterministic" : ""; + failedModelChecks.Add( + FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Discrete Action Output Tensor. Uncheck `Deterministic inference` flag.") + ); + return false; + } + + } + + + + + } + return true; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs.meta b/com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs.meta new file mode 100644 index 0000000000..d43474fb46 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 1193c3bef93464baca0d8ba2d6ce1754 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs new file mode 100644 index 0000000000..724bf2750b --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs @@ -0,0 +1,913 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Unity.Barracuda; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Policies; + +namespace Unity.MLAgents.Inference +{ + /// + /// Prepares the Tensors for the Learning Brain and exposes a list of failed checks if Model + /// and BrainParameters are incompatible. + /// + internal class BarracudaModelParamLoader + { + + internal enum ModelApiVersion + { + /// + /// ML-Agents model version for versions 1.x.y + /// The observations are split between vector and visual observations + /// There are legacy action outputs for discrete and continuous actions + /// LSTM inputs and outputs are handled by Barracuda + /// + MLAgents1_0 = 2, + + /// + /// All observations are treated the same and named obs_{i} with i being + /// the sensor index + /// Legacy "action" output is no longer present + /// LSTM inputs and outputs are treated like regular inputs and outputs + /// and no longer managed by Barracuda + /// + MLAgents2_0 = 3, + MinSupportedVersion = MLAgents1_0, + MaxSupportedVersion = MLAgents2_0 + } + + internal class FailedCheck + { + public enum CheckTypeEnum + { + Info = 0, + Warning = 1, + Error = 2 + } + public CheckTypeEnum CheckType; + public string Message; + public static FailedCheck Info(string message) + { + return new FailedCheck { CheckType = CheckTypeEnum.Info, Message = message }; + } + public static FailedCheck Warning(string message) + { + return new FailedCheck { CheckType = CheckTypeEnum.Warning, Message = message }; + } + public static FailedCheck Error(string message) + { + return new FailedCheck { CheckType = CheckTypeEnum.Error, Message = message }; + } + } + + /// + /// Checks that a model has the appropriate version. + /// + /// + /// The Barracuda engine model for loading static parameters + /// + /// A FailedCheck containing the error message if the version of the model does not mach, else null + public static FailedCheck CheckModelVersion(Model model) + { + var modelApiVersion = model.GetVersion(); + if (modelApiVersion < (int)ModelApiVersion.MinSupportedVersion) + { + return FailedCheck.Error( + "Model was trained with a older version of the trainer than is supported. " + + "Either retrain with an newer trainer, or use an older version of com.unity.ml-agents.\n" + + $"Model version: {modelApiVersion} Minimum supported version: {(int)ModelApiVersion.MinSupportedVersion}" + ); + } + + if (modelApiVersion > (int)ModelApiVersion.MaxSupportedVersion) + { + return FailedCheck.Error( + "Model was trained with a newer version of the trainer than is supported. " + + "Either retrain with an older trainer, or update to a newer version of com.unity.ml-agents.\n" + + $"Model version: {modelApiVersion} Maximum supported version: {(int)ModelApiVersion.MaxSupportedVersion}" + ); + } + + var memorySize = (int)model.GetTensorByName(TensorNames.MemorySize)[0]; + + if (modelApiVersion == (int)ModelApiVersion.MLAgents1_0 && memorySize > 0) + { + // This block is to make sure that models that are trained with MLAgents version 1.x and have + // an LSTM (i.e. use the barracuda _c and _h inputs and outputs) will not work with MLAgents version + // 2.x. This is because Barracuda version 2.x will eventually drop support for the _c and _h inputs + // and only ML-Agents 2.x models will be compatible. + return FailedCheck.Error( + "Models from com.unity.ml-agents 1.x that use recurrent neural networks are not supported in newer versions. " + + "Either retrain with an newer trainer, or use an older version of com.unity.ml-agents.\n" + ); + } + return null; + + } + + + + /// + /// Factory for the ModelParamLoader : Creates a ModelParamLoader and runs the checks + /// on it. + /// + /// + /// The Barracuda engine model for loading static parameters + /// + /// + /// The BrainParameters that are used verify the compatibility with the InferenceEngine + /// + /// Attached sensor components + /// Attached actuator components + /// Sum of the sizes of all ObservableAttributes. + /// BehaviorType or the Agent to check. + /// Inference only: set to true if the action selection from model should be + /// deterministic. + /// A IEnumerable of the checks that failed + public static IEnumerable CheckModel( + Model model, + BrainParameters brainParameters, + ISensor[] sensors, + ActuatorComponent[] actuatorComponents, + int observableAttributeTotalSize = 0, + BehaviorType behaviorType = BehaviorType.Default, + bool deterministicInference = false + ) + { + List failedModelChecks = new List(); + if (model == null) + { + var errorMsg = "There is no model for this Brain; cannot run inference. "; + if (behaviorType == BehaviorType.InferenceOnly) + { + errorMsg += "Either assign a model, or change to a different Behavior Type."; + } + else + { + errorMsg += "(But can still train)"; + } + failedModelChecks.Add(FailedCheck.Info(errorMsg)); + return failedModelChecks; + } + + var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks, deterministicInference); + if (!hasExpectedTensors) + { + return failedModelChecks; + } + + var modelApiVersion = model.GetVersion(); + var versionCheck = CheckModelVersion(model); + if (versionCheck != null) + { + failedModelChecks.Add(versionCheck); + } + + var memorySize = (int)model.GetTensorByName(TensorNames.MemorySize)[0]; + if (memorySize == -1) + { + failedModelChecks.Add(FailedCheck.Warning($"Missing node in the model provided : {TensorNames.MemorySize}" + )); + return failedModelChecks; + } + + if (modelApiVersion == (int)ModelApiVersion.MLAgents1_0) + { + failedModelChecks.AddRange( + CheckInputTensorPresenceLegacy(model, brainParameters, memorySize, sensors) + ); + failedModelChecks.AddRange( + CheckInputTensorShapeLegacy(model, brainParameters, sensors, observableAttributeTotalSize) + ); + } + else if (modelApiVersion == (int)ModelApiVersion.MLAgents2_0) + { + failedModelChecks.AddRange( + CheckInputTensorPresence(model, brainParameters, memorySize, sensors, deterministicInference) + ); + failedModelChecks.AddRange( + CheckInputTensorShape(model, brainParameters, sensors, observableAttributeTotalSize) + ); + } + + + + failedModelChecks.AddRange( + CheckOutputTensorShape(model, brainParameters, actuatorComponents) + ); + + failedModelChecks.AddRange( + CheckOutputTensorPresence(model, memorySize, deterministicInference) + ); + return failedModelChecks; + } + + /// + /// Generates failed checks that correspond to inputs expected by the model that are not + /// present in the BrainParameters. Tests the models created with the API of version 1.X + /// + /// + /// The Barracuda engine model for loading static parameters + /// + /// + /// The BrainParameters that are used verify the compatibility with the InferenceEngine + /// + /// + /// The memory size that the model is expecting. + /// + /// Array of attached sensor components + /// + /// A IEnumerable of the checks that failed + /// + static IEnumerable CheckInputTensorPresenceLegacy( + Model model, + BrainParameters brainParameters, + int memory, + ISensor[] sensors + ) + { + var failedModelChecks = new List(); + var tensorsNames = model.GetInputNames(); + + // If there is no Vector Observation Input but the Brain Parameters expect one. + if ((brainParameters.VectorObservationSize != 0) && + (!tensorsNames.Contains(TensorNames.VectorObservationPlaceholder))) + { + failedModelChecks.Add( + FailedCheck.Warning("The model does not contain a Vector Observation Placeholder Input. " + + "You must set the Vector Observation Space Size to 0.") + ); + } + + // If there are not enough Visual Observation Input compared to what the + // sensors expect. + var visObsIndex = 0; + for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++) + { + var sensor = sensors[sensorIndex]; + if (sensor.GetObservationSpec().Shape.Length == 3) + { + if (!tensorsNames.Contains( + TensorNames.GetVisualObservationName(visObsIndex))) + { + failedModelChecks.Add( + FailedCheck.Warning("The model does not contain a Visual Observation Placeholder Input " + + $"for sensor component {visObsIndex} ({sensor.GetType().Name}).") + ); + } + visObsIndex++; + } + if (sensor.GetObservationSpec().Shape.Length == 2) + { + if (!tensorsNames.Contains( + TensorNames.GetObservationName(sensorIndex))) + { + failedModelChecks.Add( + FailedCheck.Warning("The model does not contain an Observation Placeholder Input " + + $"for sensor component {sensorIndex} ({sensor.GetType().Name}).") + ); + } + } + + } + + var expectedVisualObs = model.GetNumVisualInputs(); + // Check if there's not enough visual sensors (too many would be handled above) + if (expectedVisualObs > visObsIndex) + { + failedModelChecks.Add( + FailedCheck.Warning($"The model expects {expectedVisualObs} visual inputs," + + $" but only found {visObsIndex} visual sensors.") + ); + } + + // If the model has a non-negative memory size but requires a recurrent input + if (memory > 0) + { + if (!tensorsNames.Any(x => x.EndsWith("_h")) || + !tensorsNames.Any(x => x.EndsWith("_c"))) + { + failedModelChecks.Add( + FailedCheck.Warning("The model does not contain a Recurrent Input Node but has memory_size.") + ); + } + } + + // If the model uses discrete control but does not have an input for action masks + if (model.HasDiscreteOutputs()) + { + if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder)) + { + failedModelChecks.Add( + FailedCheck.Warning("The model does not contain an Action Mask but is using Discrete Control.") + ); + } + } + return failedModelChecks; + } + + /// + /// Generates failed checks that correspond to inputs expected by the model that are not + /// present in the BrainParameters. + /// + /// + /// The Barracuda engine model for loading static parameters + /// + /// + /// The BrainParameters that are used verify the compatibility with the InferenceEngine + /// + /// + /// The memory size that the model is expecting. + /// + /// Array of attached sensor components + /// Inference only: set to true if the action selection from model should be + /// Deterministic. + /// + /// A IEnumerable of the checks that failed + /// + static IEnumerable CheckInputTensorPresence( + Model model, + BrainParameters brainParameters, + int memory, + ISensor[] sensors, + bool deterministicInference = false + ) + { + var failedModelChecks = new List(); + var tensorsNames = model.GetInputNames(); + for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++) + { + if (!tensorsNames.Contains( + TensorNames.GetObservationName(sensorIndex))) + { + var sensor = sensors[sensorIndex]; + failedModelChecks.Add( + FailedCheck.Warning("The model does not contain an Observation Placeholder Input " + + $"for sensor component {sensorIndex} ({sensor.GetType().Name}).") + ); + } + } + + // If the model has a non-negative memory size but requires a recurrent input + if (memory > 0) + { + var modelVersion = model.GetVersion(); + if (!tensorsNames.Any(x => x == TensorNames.RecurrentInPlaceholder)) + { + failedModelChecks.Add( + FailedCheck.Warning("The model does not contain a Recurrent Input Node but has memory_size.") + ); + } + } + + // If the model uses discrete control but does not have an input for action masks + if (model.HasDiscreteOutputs(deterministicInference)) + { + if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder)) + { + failedModelChecks.Add( + FailedCheck.Warning("The model does not contain an Action Mask but is using Discrete Control.") + ); + } + } + return failedModelChecks; + } + + /// + /// Generates failed checks that correspond to outputs expected by the model that are not + /// present in the BrainParameters. + /// + /// + /// The Barracuda engine model for loading static parameters + /// + /// The memory size that the model is expecting/ + /// Inference only: set to true if the action selection from model should be + /// deterministic. + /// + /// A IEnumerable of the checks that failed + /// + static IEnumerable CheckOutputTensorPresence(Model model, int memory, bool deterministicInference = false) + { + var failedModelChecks = new List(); + + // If there is no Recurrent Output but the model is Recurrent. + if (memory > 0) + { + var allOutputs = model.GetOutputNames(deterministicInference).ToList(); + if (!allOutputs.Any(x => x == TensorNames.RecurrentOutput)) + { + failedModelChecks.Add( + FailedCheck.Warning("The model does not contain a Recurrent Output Node but has memory_size.") + ); + } + + } + return failedModelChecks; + } + + /// + /// Checks that the shape of the visual observation input placeholder is the same as the corresponding sensor. + /// + /// The tensor that is expected by the model + /// The sensor that produces the visual observation. + /// + /// If the Check failed, returns a string containing information about why the + /// check failed. If the check passed, returns null. + /// + static FailedCheck CheckVisualObsShape( + TensorProxy tensorProxy, ISensor sensor) + { + var shape = sensor.GetObservationSpec().Shape; + var heightBp = shape[0]; + var widthBp = shape[1]; + var pixelBp = shape[2]; + var heightT = tensorProxy.Height; + var widthT = tensorProxy.Width; + var pixelT = tensorProxy.Channels; + if ((widthBp != widthT) || (heightBp != heightT) || (pixelBp != pixelT)) + { + return FailedCheck.Warning($"The visual Observation of the model does not match. " + + $"Received TensorProxy of shape [?x{widthBp}x{heightBp}x{pixelBp}] but " + + $"was expecting [?x{widthT}x{heightT}x{pixelT}] for the {sensor.GetName()} Sensor." + ); + } + return null; + } + + /// + /// Checks that the shape of the rank 2 observation input placeholder is the same as the corresponding sensor. + /// + /// The tensor that is expected by the model + /// The sensor that produces the visual observation. + /// + /// If the Check failed, returns a string containing information about why the + /// check failed. If the check passed, returns null. + /// + static FailedCheck CheckRankTwoObsShape( + TensorProxy tensorProxy, ISensor sensor) + { + var shape = sensor.GetObservationSpec().Shape; + var dim1Bp = shape[0]; + var dim2Bp = shape[1]; + var dim1T = tensorProxy.Channels; + var dim2T = tensorProxy.Width; + var dim3T = tensorProxy.Height; + if ((dim1Bp != dim1T) || (dim2Bp != dim2T)) + { + var proxyDimStr = $"[?x{dim1T}x{dim2T}]"; + if (dim3T > 1) + { + proxyDimStr = $"[?x{dim3T}x{dim2T}x{dim1T}]"; + } + return FailedCheck.Warning($"An Observation of the model does not match. " + + $"Received TensorProxy of shape [?x{dim1Bp}x{dim2Bp}] but " + + $"was expecting {proxyDimStr} for the {sensor.GetName()} Sensor." + ); + } + return null; + } + + /// + /// Checks that the shape of the rank 2 observation input placeholder is the same as the corresponding sensor. + /// + /// The tensor that is expected by the model + /// The sensor that produces the visual observation. + /// + /// If the Check failed, returns a string containing information about why the + /// check failed. If the check passed, returns null. + /// + static FailedCheck CheckRankOneObsShape( + TensorProxy tensorProxy, ISensor sensor) + { + var shape = sensor.GetObservationSpec().Shape; + var dim1Bp = shape[0]; + var dim1T = tensorProxy.Channels; + var dim2T = tensorProxy.Width; + var dim3T = tensorProxy.Height; + if ((dim1Bp != dim1T)) + { + var proxyDimStr = $"[?x{dim1T}]"; + if (dim2T > 1) + { + proxyDimStr = $"[?x{dim1T}x{dim2T}]"; + } + if (dim3T > 1) + { + proxyDimStr = $"[?x{dim3T}x{dim2T}x{dim1T}]"; + } + return FailedCheck.Warning($"An Observation of the model does not match. " + + $"Received TensorProxy of shape [?x{dim1Bp}] but " + + $"was expecting {proxyDimStr} for the {sensor.GetName()} Sensor." + ); + } + return null; + } + + /// + /// Generates failed checks that correspond to inputs shapes incompatibilities between + /// the model and the BrainParameters. Tests the models created with the API of version 1.X + /// + /// + /// The Barracuda engine model for loading static parameters + /// + /// + /// The BrainParameters that are used verify the compatibility with the InferenceEngine + /// + /// Attached sensors + /// Sum of the sizes of all ObservableAttributes. + /// A IEnumerable of the checks that failed + static IEnumerable CheckInputTensorShapeLegacy( + Model model, BrainParameters brainParameters, ISensor[] sensors, + int observableAttributeTotalSize) + { + var failedModelChecks = new List(); + var tensorTester = + new Dictionary>() + { + {TensorNames.VectorObservationPlaceholder, CheckVectorObsShapeLegacy}, + {TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape}, + {TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs, i) => null)}, + {TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs, i) => null)}, + {TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs, i) => null)}, + {TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs, i) => null)}, + }; + + foreach (var mem in model.memories) + { + tensorTester[mem.input] = ((bp, tensor, scs, i) => null); + } + + var visObsIndex = 0; + for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++) + { + var sens = sensors[sensorIndex]; + if (sens.GetObservationSpec().Shape.Length == 3) + { + + tensorTester[TensorNames.GetVisualObservationName(visObsIndex)] = + (bp, tensor, scs, i) => CheckVisualObsShape(tensor, sens); + visObsIndex++; + } + if (sens.GetObservationSpec().Shape.Length == 2) + { + tensorTester[TensorNames.GetObservationName(sensorIndex)] = + (bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens); + } + } + + // If the model expects an input but it is not in this list + foreach (var tensor in model.GetInputTensors()) + { + if (!tensorTester.ContainsKey(tensor.name)) + { + if (!tensor.name.Contains("visual_observation")) + { + failedModelChecks.Add( + FailedCheck.Warning("Model contains an unexpected input named : " + tensor.name) + ); + } + } + else + { + var tester = tensorTester[tensor.name]; + var error = tester.Invoke(brainParameters, tensor, sensors, observableAttributeTotalSize); + if (error != null) + { + failedModelChecks.Add(error); + } + } + } + return failedModelChecks; + } + + /// + /// Checks that the shape of the Vector Observation input placeholder is the same in the + /// model and in the Brain Parameters. Tests the models created with the API of version 1.X + /// + /// + /// The BrainParameters that are used verify the compatibility with the InferenceEngine + /// + /// The tensor that is expected by the model + /// Array of attached sensor components + /// Sum of the sizes of all ObservableAttributes. + /// + /// If the Check failed, returns a string containing information about why the + /// check failed. If the check passed, returns null. + /// + static FailedCheck CheckVectorObsShapeLegacy( + BrainParameters brainParameters, TensorProxy tensorProxy, ISensor[] sensors, + int observableAttributeTotalSize) + { + var vecObsSizeBp = brainParameters.VectorObservationSize; + var numStackedVector = brainParameters.NumStackedVectorObservations; + var totalVecObsSizeT = tensorProxy.shape[tensorProxy.shape.Length - 1]; + + var totalVectorSensorSize = 0; + foreach (var sens in sensors) + { + if ((sens.GetObservationSpec().Shape.Length == 1)) + { + totalVectorSensorSize += sens.GetObservationSpec().Shape[0]; + } + } + + if (totalVectorSensorSize != totalVecObsSizeT) + { + var sensorSizes = ""; + foreach (var sensorComp in sensors) + { + if (sensorComp.GetObservationSpec().Shape.Length == 1) + { + var vecSize = sensorComp.GetObservationSpec().Shape[0]; + if (sensorSizes.Length == 0) + { + sensorSizes = $"[{vecSize}"; + } + else + { + sensorSizes += $", {vecSize}"; + } + } + } + + sensorSizes += "]"; + return FailedCheck.Warning( + $"Vector Observation Size of the model does not match. Was expecting {totalVecObsSizeT} " + + $"but received: \n" + + $"Vector observations: {vecObsSizeBp} x {numStackedVector}\n" + + $"Total [Observable] attributes: {observableAttributeTotalSize}\n" + + $"Sensor sizes: {sensorSizes}." + ); + } + return null; + } + + + /// + /// Generates failed checks that correspond to inputs shapes incompatibilities between + /// the model and the BrainParameters. + /// + /// + /// The Barracuda engine model for loading static parameters + /// + /// + /// The BrainParameters that are used verify the compatibility with the InferenceEngine + /// + /// Attached sensors + /// Sum of the sizes of all ObservableAttributes. + /// A IEnumerable of the checks that failed + static IEnumerable CheckInputTensorShape( + Model model, BrainParameters brainParameters, ISensor[] sensors, + int observableAttributeTotalSize) + { + var failedModelChecks = new List(); + var tensorTester = + new Dictionary>() + { + {TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape}, + {TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs, i) => null)}, + {TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs, i) => null)}, + {TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs, i) => null)}, + {TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs, i) => null)}, + }; + + foreach (var mem in model.memories) + { + tensorTester[mem.input] = ((bp, tensor, scs, i) => null); + } + + for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++) + { + var sens = sensors[sensorIndex]; + if (sens.GetObservationSpec().Rank == 3) + { + tensorTester[TensorNames.GetObservationName(sensorIndex)] = + (bp, tensor, scs, i) => CheckVisualObsShape(tensor, sens); + } + if (sens.GetObservationSpec().Rank == 2) + { + tensorTester[TensorNames.GetObservationName(sensorIndex)] = + (bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens); + } + if (sens.GetObservationSpec().Rank == 1) + { + tensorTester[TensorNames.GetObservationName(sensorIndex)] = + (bp, tensor, scs, i) => CheckRankOneObsShape(tensor, sens); + } + + } + + // If the model expects an input but it is not in this list + foreach (var tensor in model.GetInputTensors()) + { + if (!tensorTester.ContainsKey(tensor.name)) + { + failedModelChecks.Add(FailedCheck.Warning("Model contains an unexpected input named : " + tensor.name + )); + } + else + { + var tester = tensorTester[tensor.name]; + var error = tester.Invoke(brainParameters, tensor, sensors, observableAttributeTotalSize); + if (error != null) + { + failedModelChecks.Add(error); + } + } + } + return failedModelChecks; + } + + /// + /// Checks that the shape of the Previous Vector Action input placeholder is the same in the + /// model and in the Brain Parameters. + /// + /// + /// The BrainParameters that are used verify the compatibility with the InferenceEngine + /// + /// The tensor that is expected by the model + /// Array of attached sensor components (unused). + /// Sum of the sizes of all ObservableAttributes (unused). + /// If the Check failed, returns a string containing information about why the + /// check failed. If the check passed, returns null. + static FailedCheck CheckPreviousActionShape( + BrainParameters brainParameters, TensorProxy tensorProxy, + ISensor[] sensors, int observableAttributeTotalSize) + { + var numberActionsBp = brainParameters.ActionSpec.NumDiscreteActions; + var numberActionsT = tensorProxy.shape[tensorProxy.shape.Length - 1]; + if (numberActionsBp != numberActionsT) + { + return FailedCheck.Warning("Previous Action Size of the model does not match. " + + $"Received {numberActionsBp} but was expecting {numberActionsT}." + ); + } + return null; + } + + /// + /// Generates failed checks that correspond to output shapes incompatibilities between + /// the model and the BrainParameters. + /// + /// + /// The Barracuda engine model for loading static parameters + /// + /// + /// The BrainParameters that are used verify the compatibility with the InferenceEngine + /// + /// Array of attached actuator components. + /// + /// A IEnumerable of error messages corresponding to the incompatible shapes between model + /// and BrainParameters. + /// + static IEnumerable CheckOutputTensorShape( + Model model, + BrainParameters brainParameters, + ActuatorComponent[] actuatorComponents) + { + var failedModelChecks = new List(); + + // If the model expects an output but it is not in this list + var modelContinuousActionSize = model.ContinuousOutputSize(); + var continuousError = CheckContinuousActionOutputShape(brainParameters, actuatorComponents, modelContinuousActionSize); + if (continuousError != null) + { + failedModelChecks.Add(continuousError); + } + FailedCheck discreteError = null; + var modelApiVersion = model.GetVersion(); + if (modelApiVersion == (int)ModelApiVersion.MLAgents1_0) + { + var modelSumDiscreteBranchSizes = model.DiscreteOutputSize(); + discreteError = CheckDiscreteActionOutputShapeLegacy(brainParameters, actuatorComponents, modelSumDiscreteBranchSizes); + } + if (modelApiVersion == (int)ModelApiVersion.MLAgents2_0) + { + var modelDiscreteBranches = model.GetTensorByName(TensorNames.DiscreteActionOutputShape); + discreteError = CheckDiscreteActionOutputShape(brainParameters, actuatorComponents, modelDiscreteBranches); + } + + if (discreteError != null) + { + failedModelChecks.Add(discreteError); + } + return failedModelChecks; + } + + /// + /// Checks that the shape of the discrete action output is the same in the + /// model and in the Brain Parameters. + /// + /// + /// The BrainParameters that are used verify the compatibility with the InferenceEngine + /// + /// Array of attached actuator components. + /// The Tensor of branch sizes. + /// + /// + /// If the Check failed, returns a string containing information about why the + /// check failed. If the check passed, returns null. + /// + static FailedCheck CheckDiscreteActionOutputShape( + BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, Tensor modelDiscreteBranches) + { + + var discreteActionBranches = brainParameters.ActionSpec.BranchSizes.ToList(); + foreach (var actuatorComponent in actuatorComponents) + { + var actionSpec = actuatorComponent.ActionSpec; + discreteActionBranches.AddRange(actionSpec.BranchSizes); + } + + int modelDiscreteBranchesLength = modelDiscreteBranches?.length ?? 0; + if (modelDiscreteBranchesLength != discreteActionBranches.Count) + { + return FailedCheck.Warning("Discrete Action Size of the model does not match. The BrainParameters expect " + + $"{discreteActionBranches.Count} branches but the model contains {modelDiscreteBranchesLength}." + ); + } + + for (int i = 0; i < modelDiscreteBranchesLength; i++) + { + if (modelDiscreteBranches != null && modelDiscreteBranches[i] != discreteActionBranches[i]) + { + return FailedCheck.Warning($"The number of Discrete Actions of branch {i} does not match. " + + $"Was expecting {discreteActionBranches[i]} but the model contains {modelDiscreteBranches[i]} " + ); + } + } + return null; + } + + /// + /// Checks that the shape of the discrete action output is the same in the + /// model and in the Brain Parameters. Tests the models created with the API of version 1.X + /// + /// + /// The BrainParameters that are used verify the compatibility with the InferenceEngine + /// + /// Array of attached actuator components. + /// + /// The size of the discrete action output that is expected by the model. + /// + /// + /// If the Check failed, returns a string containing information about why the + /// check failed. If the check passed, returns null. + /// + static FailedCheck CheckDiscreteActionOutputShapeLegacy( + BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, int modelSumDiscreteBranchSizes) + { + // TODO: check each branch size instead of sum of branch sizes + var sumOfDiscreteBranchSizes = brainParameters.ActionSpec.SumOfDiscreteBranchSizes; + + foreach (var actuatorComponent in actuatorComponents) + { + var actionSpec = actuatorComponent.ActionSpec; + sumOfDiscreteBranchSizes += actionSpec.SumOfDiscreteBranchSizes; + } + + if (modelSumDiscreteBranchSizes != sumOfDiscreteBranchSizes) + { + return FailedCheck.Warning("Discrete Action Size of the model does not match. The BrainParameters expect " + + $"{sumOfDiscreteBranchSizes} but the model contains {modelSumDiscreteBranchSizes}." + ); + } + return null; + } + + /// + /// Checks that the shape of the continuous action output is the same in the + /// model and in the Brain Parameters. + /// + /// + /// The BrainParameters that are used verify the compatibility with the InferenceEngine + /// + /// Array of attached actuator components. + /// + /// The size of the continuous action output that is expected by the model. + /// + /// If the Check failed, returns a string containing information about why the + /// check failed. If the check passed, returns null. + static FailedCheck CheckContinuousActionOutputShape( + BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, int modelContinuousActionSize) + { + var numContinuousActions = brainParameters.ActionSpec.NumContinuousActions; + + foreach (var actuatorComponent in actuatorComponents) + { + var actionSpec = actuatorComponent.ActionSpec; + numContinuousActions += actionSpec.NumContinuousActions; + } + + if (modelContinuousActionSize != numContinuousActions) + { + return FailedCheck.Warning( + "Continuous Action Size of the model does not match. The BrainParameters and ActuatorComponents expect " + + $"{numContinuousActions} but the model contains {modelContinuousActionSize}." + ); + } + return null; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs.meta b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs.meta new file mode 100644 index 0000000000..2029d8acbb --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 399c5e92395a1484cb2808ac397745e1 +timeCreated: 1539197357 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs b/com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs new file mode 100644 index 0000000000..68a997f2bb --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs @@ -0,0 +1,284 @@ +using System.Collections.Generic; +using System; +using Unity.Barracuda; +using Unity.MLAgents.Inference.Utils; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Inference +{ + /// + /// Reshapes a Tensor so that its first dimension becomes equal to the current batch size + /// and initializes its content to be zeros. Will only work on 2-dimensional tensors. + /// The second dimension of the Tensor will not be modified. + /// + internal class BiDimensionalOutputGenerator : TensorGenerator.IGenerator + { + readonly ITensorAllocator m_Allocator; + + public BiDimensionalOutputGenerator(ITensorAllocator allocator) + { + m_Allocator = allocator; + } + + public void Generate(TensorProxy tensorProxy, int batchSize, IList infos) + { + TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator); + } + } + + /// + /// Generates the Tensor corresponding to the BatchSize input : Will be a one dimensional + /// integer array of size 1 containing the batch size. + /// + internal class BatchSizeGenerator : TensorGenerator.IGenerator + { + readonly ITensorAllocator m_Allocator; + + public BatchSizeGenerator(ITensorAllocator allocator) + { + m_Allocator = allocator; + } + + public void Generate(TensorProxy tensorProxy, int batchSize, IList infos) + { + tensorProxy.data?.Dispose(); + tensorProxy.data = m_Allocator.Alloc(new TensorShape(1, 1)); + tensorProxy.data[0] = batchSize; + } + } + + /// + /// Generates the Tensor corresponding to the SequenceLength input : Will be a one + /// dimensional integer array of size 1 containing 1. + /// Note : the sequence length is always one since recurrent networks only predict for + /// one step at the time. + /// + internal class SequenceLengthGenerator : TensorGenerator.IGenerator + { + readonly ITensorAllocator m_Allocator; + + public SequenceLengthGenerator(ITensorAllocator allocator) + { + m_Allocator = allocator; + } + + public void Generate(TensorProxy tensorProxy, int batchSize, IList infos) + { + tensorProxy.shape = new long[0]; + tensorProxy.data?.Dispose(); + tensorProxy.data = m_Allocator.Alloc(new TensorShape(1, 1)); + tensorProxy.data[0] = 1; + } + } + + /// + /// Generates the Tensor corresponding to the Recurrent input : Will be a two + /// dimensional float array of dimension [batchSize x memorySize]. + /// It will use the Memory data contained in the agentInfo to fill the data + /// of the tensor. + /// + internal class RecurrentInputGenerator : TensorGenerator.IGenerator + { + readonly ITensorAllocator m_Allocator; + Dictionary> m_Memories; + + public RecurrentInputGenerator( + ITensorAllocator allocator, + Dictionary> memories) + { + m_Allocator = allocator; + m_Memories = memories; + } + + public void Generate( + TensorProxy tensorProxy, int batchSize, IList infos) + { + TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator); + + var memorySize = tensorProxy.data.width; + + var agentIndex = 0; + for (var infoIndex = 0; infoIndex < infos.Count; infoIndex++) + { + var infoSensorPair = infos[infoIndex]; + var info = infoSensorPair.agentInfo; + List memory; + + if (info.done) + { + m_Memories.Remove(info.episodeId); + } + if (!m_Memories.TryGetValue(info.episodeId, out memory)) + { + for (var j = 0; j < memorySize; j++) + { + tensorProxy.data[agentIndex, 0, j, 0] = 0; + } + agentIndex++; + continue; + } + for (var j = 0; j < Math.Min(memorySize, memory.Count); j++) + { + if (j >= memory.Count) + { + break; + } + tensorProxy.data[agentIndex, 0, j, 0] = memory[j]; + } + agentIndex++; + } + } + } + + /// + /// Generates the Tensor corresponding to the Previous Action input : Will be a two + /// dimensional integer array of dimension [batchSize x actionSize]. + /// It will use the previous action data contained in the agentInfo to fill the data + /// of the tensor. + /// + internal class PreviousActionInputGenerator : TensorGenerator.IGenerator + { + readonly ITensorAllocator m_Allocator; + + public PreviousActionInputGenerator(ITensorAllocator allocator) + { + m_Allocator = allocator; + } + + public void Generate(TensorProxy tensorProxy, int batchSize, IList infos) + { + TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator); + + var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1]; + var agentIndex = 0; + for (var infoIndex = 0; infoIndex < infos.Count; infoIndex++) + { + var infoSensorPair = infos[infoIndex]; + var info = infoSensorPair.agentInfo; + var pastAction = info.storedActions.DiscreteActions; + if (!pastAction.IsEmpty()) + { + for (var j = 0; j < actionSize; j++) + { + tensorProxy.data[agentIndex, j] = pastAction[j]; + } + } + + agentIndex++; + } + } + } + + /// + /// Generates the Tensor corresponding to the Action Mask input : Will be a two + /// dimensional float array of dimension [batchSize x numActionLogits]. + /// It will use the Action Mask data contained in the agentInfo to fill the data + /// of the tensor. + /// + internal class ActionMaskInputGenerator : TensorGenerator.IGenerator + { + readonly ITensorAllocator m_Allocator; + + public ActionMaskInputGenerator(ITensorAllocator allocator) + { + m_Allocator = allocator; + } + + public void Generate(TensorProxy tensorProxy, int batchSize, IList infos) + { + TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator); + + var maskSize = tensorProxy.shape[tensorProxy.shape.Length - 1]; + var agentIndex = 0; + for (var infoIndex = 0; infoIndex < infos.Count; infoIndex++) + { + var infoSensorPair = infos[infoIndex]; + var agentInfo = infoSensorPair.agentInfo; + var maskList = agentInfo.discreteActionMasks; + for (var j = 0; j < maskSize; j++) + { + var isUnmasked = (maskList != null && maskList[j]) ? 0.0f : 1.0f; + tensorProxy.data[agentIndex, j] = isUnmasked; + } + agentIndex++; + } + } + } + + /// + /// Generates the Tensor corresponding to the Epsilon input : Will be a two + /// dimensional float array of dimension [batchSize x actionSize]. + /// It will use the generate random input data from a normal Distribution. + /// + internal class RandomNormalInputGenerator : TensorGenerator.IGenerator + { + readonly RandomNormal m_RandomNormal; + readonly ITensorAllocator m_Allocator; + + public RandomNormalInputGenerator(int seed, ITensorAllocator allocator) + { + m_RandomNormal = new RandomNormal(seed); + m_Allocator = allocator; + } + + public void Generate(TensorProxy tensorProxy, int batchSize, IList infos) + { + TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator); + TensorUtils.FillTensorWithRandomNormal(tensorProxy, m_RandomNormal); + } + } + + /// + /// Generates the Tensor corresponding to the Observation input : Will be a multi + /// dimensional float array. + /// It will use the Observation data contained in the sensors to fill the data + /// of the tensor. + /// + internal class ObservationGenerator : TensorGenerator.IGenerator + { + readonly ITensorAllocator m_Allocator; + List m_SensorIndices = new List(); + ObservationWriter m_ObservationWriter = new ObservationWriter(); + + public ObservationGenerator(ITensorAllocator allocator) + { + m_Allocator = allocator; + } + + public void AddSensorIndex(int sensorIndex) + { + m_SensorIndices.Add(sensorIndex); + } + + public void Generate(TensorProxy tensorProxy, int batchSize, IList infos) + { + TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator); + var agentIndex = 0; + for (var infoIndex = 0; infoIndex < infos.Count; infoIndex++) + { + var info = infos[infoIndex]; + if (info.agentInfo.done) + { + // If the agent is done, we might have a stale reference to the sensors + // e.g. a dependent object might have been disposed. + // To avoid this, just fill observation with zeroes instead of calling sensor.Write. + TensorUtils.FillTensorBatch(tensorProxy, agentIndex, 0.0f); + } + else + { + var tensorOffset = 0; + // Write each sensor consecutively to the tensor + for (var sensorIndexIndex = 0; sensorIndexIndex < m_SensorIndices.Count; sensorIndexIndex++) + { + var sensorIndex = m_SensorIndices[sensorIndexIndex]; + var sensor = info.sensors[sensorIndex]; + m_ObservationWriter.SetTarget(tensorProxy, agentIndex, tensorOffset); + var numWritten = sensor.Write(m_ObservationWriter); + tensorOffset += numWritten; + } + } + agentIndex++; + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs.meta b/com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs.meta new file mode 100644 index 0000000000..1f628e0e51 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: c57a4989c7e54b93ab56293698d7d237 +timeCreated: 1539109542 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Inference/ModelRunner.cs b/com.unity.ml-agents/Runtime/Inference/ModelRunner.cs new file mode 100644 index 0000000000..1ded24d115 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/ModelRunner.cs @@ -0,0 +1,253 @@ +using System.Collections.Generic; +using Unity.Barracuda; +using UnityEngine.Profiling; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Policies; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Inference +{ + internal struct AgentInfoSensorsPair + { + public AgentInfo agentInfo; + public List sensors; + } + + internal class ModelRunner + { + List m_Infos = new List(); + Dictionary m_LastActionsReceived = new Dictionary(); + List m_OrderedAgentsRequestingDecisions = new List(); + + ITensorAllocator m_TensorAllocator; + TensorGenerator m_TensorGenerator; + TensorApplier m_TensorApplier; + + NNModel m_Model; + string m_ModelName; + InferenceDevice m_InferenceDevice; + IWorker m_Engine; + bool m_Verbose = false; + bool m_DeterministicInference; + string[] m_OutputNames; + IReadOnlyList m_InferenceInputs; + List m_InferenceOutputs; + Dictionary m_InputsByName; + Dictionary> m_Memories = new Dictionary>(); + + SensorShapeValidator m_SensorShapeValidator = new SensorShapeValidator(); + + bool m_ObservationsInitialized; + + /// + /// Initializes the Brain with the Model that it will use when selecting actions for + /// the agents + /// + /// The Barracuda model to load + /// Description of the actions for the Agent. + /// Inference execution device. CPU is the fastest + /// option for most of ML Agents models. + /// The seed that will be used to initialize the RandomNormal + /// and Multinomial objects used when running inference. + /// Inference only: set to true if the action selection from model should be + /// deterministic. + /// Throws an error when the model is null + /// + public ModelRunner( + NNModel model, + ActionSpec actionSpec, + InferenceDevice inferenceDevice, + int seed = 0, + bool deterministicInference = false) + { + Model barracudaModel; + m_Model = model; + m_ModelName = model?.name; + m_InferenceDevice = inferenceDevice; + m_DeterministicInference = deterministicInference; + m_TensorAllocator = new TensorCachingAllocator(); + if (model != null) + { +#if BARRACUDA_VERBOSE + m_Verbose = true; +#endif + + D.logEnabled = m_Verbose; + + barracudaModel = ModelLoader.Load(model); + + var failedCheck = BarracudaModelParamLoader.CheckModelVersion( + barracudaModel + ); + if (failedCheck != null) + { + if (failedCheck.CheckType == BarracudaModelParamLoader.FailedCheck.CheckTypeEnum.Error) + { + throw new UnityAgentsException(failedCheck.Message); + } + } + + WorkerFactory.Type executionDevice; + switch (inferenceDevice) + { + case InferenceDevice.CPU: + executionDevice = WorkerFactory.Type.CSharp; + break; + case InferenceDevice.GPU: + executionDevice = WorkerFactory.Type.ComputePrecompiled; + break; + case InferenceDevice.Burst: + executionDevice = WorkerFactory.Type.CSharpBurst; + break; + case InferenceDevice.Default: // fallthrough + default: + executionDevice = WorkerFactory.Type.CSharpBurst; + break; + } + m_Engine = WorkerFactory.CreateWorker(executionDevice, barracudaModel, m_Verbose); + } + else + { + barracudaModel = null; + m_Engine = null; + } + + m_InferenceInputs = barracudaModel.GetInputTensors(); + m_OutputNames = barracudaModel.GetOutputNames(m_DeterministicInference); + + m_TensorGenerator = new TensorGenerator( + seed, m_TensorAllocator, m_Memories, barracudaModel, m_DeterministicInference); + m_TensorApplier = new TensorApplier( + actionSpec, seed, m_TensorAllocator, m_Memories, barracudaModel, m_DeterministicInference); + m_InputsByName = new Dictionary(); + m_InferenceOutputs = new List(); + } + + public InferenceDevice InferenceDevice + { + get { return m_InferenceDevice; } + } + + public NNModel Model + { + get { return m_Model; } + } + + void PrepareBarracudaInputs(IReadOnlyList infInputs) + { + m_InputsByName.Clear(); + for (var i = 0; i < infInputs.Count; i++) + { + var inp = infInputs[i]; + m_InputsByName[inp.name] = inp.data; + } + } + + public void Dispose() + { + if (m_Engine != null) + m_Engine.Dispose(); + m_TensorAllocator?.Reset(false); + } + + void FetchBarracudaOutputs(string[] names) + { + m_InferenceOutputs.Clear(); + foreach (var n in names) + { + var output = m_Engine.PeekOutput(n); + m_InferenceOutputs.Add(TensorUtils.TensorProxyFromBarracuda(output, n)); + } + } + + public void PutObservations(AgentInfo info, List sensors) + { +#if DEBUG + m_SensorShapeValidator.ValidateSensors(sensors); +#endif + m_Infos.Add(new AgentInfoSensorsPair + { + agentInfo = info, + sensors = sensors + }); + + // We add the episodeId to this list to maintain the order in which the decisions were requested + m_OrderedAgentsRequestingDecisions.Add(info.episodeId); + + if (!m_LastActionsReceived.ContainsKey(info.episodeId)) + { + m_LastActionsReceived[info.episodeId] = ActionBuffers.Empty; + } + if (info.done) + { + // If the agent is done, we remove the key from the last action dictionary since no action + // should be taken. + m_LastActionsReceived.Remove(info.episodeId); + } + } + + public void DecideBatch() + { + var currentBatchSize = m_Infos.Count; + if (currentBatchSize == 0) + { + return; + } + if (!m_ObservationsInitialized) + { + // Just grab the first agent in the collection (any will suffice, really). + // We check for an empty Collection above, so this will always return successfully. + var firstInfo = m_Infos[0]; + m_TensorGenerator.InitializeObservations(firstInfo.sensors, m_TensorAllocator); + m_ObservationsInitialized = true; + } + + Profiler.BeginSample("ModelRunner.DecideAction"); + Profiler.BeginSample(m_ModelName); + + Profiler.BeginSample($"GenerateTensors"); + // Prepare the input tensors to be feed into the engine + m_TensorGenerator.GenerateTensors(m_InferenceInputs, currentBatchSize, m_Infos); + Profiler.EndSample(); + + Profiler.BeginSample($"PrepareBarracudaInputs"); + PrepareBarracudaInputs(m_InferenceInputs); + Profiler.EndSample(); + + // Execute the Model + Profiler.BeginSample($"ExecuteGraph"); + m_Engine.Execute(m_InputsByName); + Profiler.EndSample(); + + Profiler.BeginSample($"FetchBarracudaOutputs"); + FetchBarracudaOutputs(m_OutputNames); + Profiler.EndSample(); + + Profiler.BeginSample($"ApplyTensors"); + // Update the outputs + m_TensorApplier.ApplyTensors(m_InferenceOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived); + Profiler.EndSample(); + + Profiler.EndSample(); // end name + Profiler.EndSample(); // end ModelRunner.DecideAction + + m_Infos.Clear(); + + m_OrderedAgentsRequestingDecisions.Clear(); + } + + public bool HasModel(NNModel other, InferenceDevice otherInferenceDevice) + { + return m_Model == other && m_InferenceDevice == otherInferenceDevice; + } + + public ActionBuffers GetAction(int agentId) + { + if (m_LastActionsReceived.ContainsKey(agentId)) + { + return m_LastActionsReceived[agentId]; + } + return ActionBuffers.Empty; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Inference/ModelRunner.cs.meta b/com.unity.ml-agents/Runtime/Inference/ModelRunner.cs.meta new file mode 100644 index 0000000000..e4e8e67539 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/ModelRunner.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 8f3f4b630ca3f4a4ba74922ec8249046 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Inference/TensorApplier.cs b/com.unity.ml-agents/Runtime/Inference/TensorApplier.cs new file mode 100644 index 0000000000..a03b3d927e --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/TensorApplier.cs @@ -0,0 +1,112 @@ +using System.Collections.Generic; +using Unity.Barracuda; +using Unity.MLAgents.Actuators; + + +namespace Unity.MLAgents.Inference +{ + /// + /// Mapping between the output tensor names and the method that will use the + /// output tensors and the Agents present in the batch to update their action, memories and + /// value estimates. + /// A TensorApplier implements a Dictionary of strings (node names) to an Action. + /// This action takes as input the tensor and the Dictionary of Agent to AgentInfo for + /// the current batch. + /// + internal class TensorApplier + { + /// + /// A tensor Applier's Execute method takes a tensor and a Dictionary of Agent to AgentInfo. + /// Uses the data contained inside the tensor to modify the state of the Agent. The Tensors + /// are assumed to have the batch size on the first dimension and the agents to be ordered + /// the same way in the dictionary and in the tensor. + /// + public interface IApplier + { + /// + /// Applies the values in the Tensor to the Agents present in the agentInfos + /// + /// + /// The Tensor containing the data to be applied to the Agents + /// + /// List of Agents Ids that will be updated using the tensor's data + /// Dictionary of AgentId to Actions to be updated + void Apply(TensorProxy tensorProxy, IList actionIds, Dictionary lastActions); + } + + readonly Dictionary m_Dict = new Dictionary(); + + /// + /// Returns a new TensorAppliers object. + /// + /// Description of the actions for the Agent. + /// The seed the Appliers will be initialized with. + /// Tensor allocator + /// Dictionary of AgentInfo.id to memory used to pass to the inference model. + /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. + public TensorApplier( + ActionSpec actionSpec, + int seed, + ITensorAllocator allocator, + Dictionary> memories, + object barracudaModel = null, + bool deterministicInference = false) + { + // If model is null, no inference to run and exception is thrown before reaching here. + if (barracudaModel == null) + { + return; + } + + var model = (Model)barracudaModel; + if (!model.SupportsContinuousAndDiscrete()) + { + actionSpec.CheckAllContinuousOrDiscrete(); + } + if (actionSpec.NumContinuousActions > 0) + { + var tensorName = model.ContinuousOutputName(deterministicInference); + m_Dict[tensorName] = new ContinuousActionOutputApplier(actionSpec); + } + var modelVersion = model.GetVersion(); + if (actionSpec.NumDiscreteActions > 0) + { + var tensorName = model.DiscreteOutputName(deterministicInference); + if (modelVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0) + { + m_Dict[tensorName] = new LegacyDiscreteActionOutputApplier(actionSpec, seed, allocator); + } + if (modelVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0) + { + m_Dict[tensorName] = new DiscreteActionOutputApplier(actionSpec, seed, allocator); + } + } + m_Dict[TensorNames.RecurrentOutput] = new MemoryOutputApplier(memories); + } + + /// + /// Updates the state of the agents based on the data present in the tensor. + /// + /// Enumerable of tensors containing the data. + /// List of Agents Ids that will be updated using the tensor's data + /// Dictionary of AgentId to Actions to be updated + /// One of the tensor does not have an + /// associated applier. + public void ApplyTensors( + IReadOnlyList tensors, IList actionIds, Dictionary lastActions) + { + for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++) + { + var tensor = tensors[tensorIndex]; + if (!m_Dict.ContainsKey(tensor.name)) + { + throw new UnityAgentsException( + $"Unknown tensorProxy expected as output : {tensor.name}"); + } + m_Dict[tensor.name].Apply(tensor, actionIds, lastActions); + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Inference/TensorApplier.cs.meta b/com.unity.ml-agents/Runtime/Inference/TensorApplier.cs.meta new file mode 100644 index 0000000000..d95eb26a15 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/TensorApplier.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: d1bef4f4ae72645108f16614355473e8 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs b/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs new file mode 100644 index 0000000000..39bed85792 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs @@ -0,0 +1,180 @@ +using System.Collections.Generic; +using Unity.Barracuda; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Inference +{ + /// + /// Mapping between Tensor names and generators. + /// A TensorGenerator implements a Dictionary of strings (node names) to an Action. + /// The Action take as argument the tensor, the current batch size and a Dictionary of + /// Agent to AgentInfo corresponding to the current batch. + /// Each Generator reshapes and fills the data of the tensor based of the data of the batch. + /// When the TensorProxy is an Input to the model, the shape of the Tensor will be modified + /// depending on the current batch size and the data of the Tensor will be filled using the + /// Dictionary of Agent to AgentInfo. + /// When the TensorProxy is an Output of the model, only the shape of the Tensor will be + /// modified using the current batch size. The data will be pre-filled with zeros. + /// + internal class TensorGenerator + { + public interface IGenerator + { + /// + /// Modifies the data inside a Tensor according to the information contained in the + /// AgentInfos contained in the current batch. + /// + /// The tensor the data and shape will be modified. + /// The number of agents present in the current batch. + /// + /// List of AgentInfos containing the information that will be used to populate + /// the tensor's data. + /// + void Generate( + TensorProxy tensorProxy, int batchSize, IList infos); + } + + readonly Dictionary m_Dict = new Dictionary(); + int m_ApiVersion; + + /// + /// Returns a new TensorGenerators object. + /// + /// The seed the Generators will be initialized with. + /// Tensor allocator. + /// Dictionary of AgentInfo.id to memory for use in the inference model. + /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. + public TensorGenerator( + int seed, + ITensorAllocator allocator, + Dictionary> memories, + object barracudaModel = null, + bool deterministicInference = false) + { + // If model is null, no inference to run and exception is thrown before reaching here. + if (barracudaModel == null) + { + return; + } + var model = (Model)barracudaModel; + + m_ApiVersion = model.GetVersion(); + + // Generator for Inputs + m_Dict[TensorNames.BatchSizePlaceholder] = + new BatchSizeGenerator(allocator); + m_Dict[TensorNames.SequenceLengthPlaceholder] = + new SequenceLengthGenerator(allocator); + m_Dict[TensorNames.RecurrentInPlaceholder] = + new RecurrentInputGenerator(allocator, memories); + + m_Dict[TensorNames.PreviousActionPlaceholder] = + new PreviousActionInputGenerator(allocator); + m_Dict[TensorNames.ActionMaskPlaceholder] = + new ActionMaskInputGenerator(allocator); + m_Dict[TensorNames.RandomNormalEpsilonPlaceholder] = + new RandomNormalInputGenerator(seed, allocator); + + + // Generators for Outputs + if (model.HasContinuousOutputs(deterministicInference)) + { + m_Dict[model.ContinuousOutputName(deterministicInference)] = new BiDimensionalOutputGenerator(allocator); + } + if (model.HasDiscreteOutputs(deterministicInference)) + { + m_Dict[model.DiscreteOutputName(deterministicInference)] = new BiDimensionalOutputGenerator(allocator); + } + m_Dict[TensorNames.RecurrentOutput] = new BiDimensionalOutputGenerator(allocator); + m_Dict[TensorNames.ValueEstimateOutput] = new BiDimensionalOutputGenerator(allocator); + } + + public void InitializeObservations(List sensors, ITensorAllocator allocator) + { + if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0) + { + // Loop through the sensors on a representative agent. + // All vector observations use a shared ObservationGenerator since they are concatenated. + // All other observations use a unique ObservationInputGenerator + var visIndex = 0; + ObservationGenerator vecObsGen = null; + for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++) + { + var sensor = sensors[sensorIndex]; + var rank = sensor.GetObservationSpec().Rank; + ObservationGenerator obsGen = null; + string obsGenName = null; + switch (rank) + { + case 1: + if (vecObsGen == null) + { + vecObsGen = new ObservationGenerator(allocator); + } + obsGen = vecObsGen; + obsGenName = TensorNames.VectorObservationPlaceholder; + break; + case 2: + // If the tensor is of rank 2, we use the index of the sensor + // to create the name + obsGen = new ObservationGenerator(allocator); + obsGenName = TensorNames.GetObservationName(sensorIndex); + break; + case 3: + // If the tensor is of rank 3, we use the "visual observation + // index", which only counts the rank 3 sensors + obsGen = new ObservationGenerator(allocator); + obsGenName = TensorNames.GetVisualObservationName(visIndex); + visIndex++; + break; + default: + throw new UnityAgentsException( + $"Sensor {sensor.GetName()} have an invalid rank {rank}"); + } + obsGen.AddSensorIndex(sensorIndex); + m_Dict[obsGenName] = obsGen; + } + } + + if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0) + { + for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++) + { + var obsGen = new ObservationGenerator(allocator); + var obsGenName = TensorNames.GetObservationName(sensorIndex); + obsGen.AddSensorIndex(sensorIndex); + m_Dict[obsGenName] = obsGen; + + } + } + } + + /// + /// Populates the data of the tensor inputs given the data contained in the current batch + /// of agents. + /// + /// Enumerable of tensors that will be modified. + /// The number of agents present in the current batch + /// + /// List of AgentsInfos and Sensors that contains the + /// data that will be used to modify the tensors + /// One of the tensor does not have an + /// associated generator. + public void GenerateTensors( + IReadOnlyList tensors, int currentBatchSize, IList infos) + { + for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++) + { + var tensor = tensors[tensorIndex]; + if (!m_Dict.ContainsKey(tensor.name)) + { + throw new UnityAgentsException( + $"Unknown tensorProxy expected as input : {tensor.name}"); + } + m_Dict[tensor.name].Generate(tensor, currentBatchSize, infos); + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs.meta b/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs.meta new file mode 100644 index 0000000000..0bfd9673a9 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 6a24e86bc77c4a5088a5fd04d6d30e81 +timeCreated: 1537484304 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Inference/TensorNames.cs b/com.unity.ml-agents/Runtime/Inference/TensorNames.cs new file mode 100644 index 0000000000..48ae04b5f6 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/TensorNames.cs @@ -0,0 +1,50 @@ +namespace Unity.MLAgents.Inference +{ + /// + /// Contains the names of the input and output tensors for the Inference Brain. + /// + internal static class TensorNames + { + public const string BatchSizePlaceholder = "batch_size"; + public const string SequenceLengthPlaceholder = "sequence_length"; + public const string VectorObservationPlaceholder = "vector_observation"; + public const string RecurrentInPlaceholder = "recurrent_in"; + public const string VisualObservationPlaceholderPrefix = "visual_observation_"; + public const string ObservationPlaceholderPrefix = "obs_"; + public const string PreviousActionPlaceholder = "prev_action"; + public const string ActionMaskPlaceholder = "action_masks"; + public const string RandomNormalEpsilonPlaceholder = "epsilon"; + + public const string ValueEstimateOutput = "value_estimate"; + public const string RecurrentOutput = "recurrent_out"; + public const string MemorySize = "memory_size"; + public const string VersionNumber = "version_number"; + public const string ContinuousActionOutputShape = "continuous_action_output_shape"; + public const string DiscreteActionOutputShape = "discrete_action_output_shape"; + public const string ContinuousActionOutput = "continuous_actions"; + public const string DiscreteActionOutput = "discrete_actions"; + public const string DeterministicContinuousActionOutput = "deterministic_continuous_actions"; + public const string DeterministicDiscreteActionOutput = "deterministic_discrete_actions"; + + // Deprecated TensorNames entries for backward compatibility + public const string IsContinuousControlDeprecated = "is_continuous_control"; + public const string ActionOutputDeprecated = "action"; + public const string ActionOutputShapeDeprecated = "action_output_shape"; + + /// + /// Returns the name of the visual observation with a given index + /// + public static string GetVisualObservationName(int index) + { + return VisualObservationPlaceholderPrefix + index; + } + + /// + /// Returns the name of the observation with a given index + /// + public static string GetObservationName(int index) + { + return ObservationPlaceholderPrefix + index; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Inference/TensorNames.cs.meta b/com.unity.ml-agents/Runtime/Inference/TensorNames.cs.meta new file mode 100644 index 0000000000..5c8c6b80b1 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/TensorNames.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: b28a46ea97c2445794d29d5a8a718a4a +timeCreated: 1538158527 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Inference/TensorProxy.cs b/com.unity.ml-agents/Runtime/Inference/TensorProxy.cs new file mode 100644 index 0000000000..b36b765ca8 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/TensorProxy.cs @@ -0,0 +1,161 @@ +using System; +using System.Collections.Generic; +using Unity.Barracuda; +using Unity.MLAgents.Inference.Utils; + +namespace Unity.MLAgents.Inference +{ + /// + /// Tensor - A class to encapsulate a Tensor used for inference. + /// + /// This class contains the Array that holds the data array, the shapes, type and the + /// placeholder in the execution graph. All the fields are editable in the inspector, + /// allowing the user to specify everything but the data in a graphical way. + /// + [Serializable] + internal class TensorProxy + { + public enum TensorType + { + Integer, + FloatingPoint + }; + + static readonly Dictionary k_TypeMap = + new Dictionary() + { + {TensorType.FloatingPoint, typeof(float)}, + {TensorType.Integer, typeof(int)} + }; + + public string name; + public TensorType valueType; + + // Since Type is not serializable, we use the DisplayType for the Inspector + public Type DataType => k_TypeMap[valueType]; + public long[] shape; + public Tensor data; + + public long Height + { + get { return shape.Length == 4 ? shape[1] : shape[5]; } + } + + public long Width + { + get { return shape.Length == 4 ? shape[2] : shape[6]; } + } + + public long Channels + { + get { return shape.Length == 4 ? shape[3] : shape[7]; } + } + } + + internal static class TensorUtils + { + public static void ResizeTensor(TensorProxy tensor, int batch, ITensorAllocator allocator) + { + if (tensor.shape[0] == batch && + tensor.data != null && tensor.data.batch == batch) + { + return; + } + + tensor.data?.Dispose(); + tensor.shape[0] = batch; + + if (tensor.shape.Length == 4 || tensor.shape.Length == 8) + { + tensor.data = allocator.Alloc( + new TensorShape( + batch, + (int)tensor.Height, + (int)tensor.Width, + (int)tensor.Channels)); + } + else + { + tensor.data = allocator.Alloc( + new TensorShape( + batch, + (int)tensor.shape[tensor.shape.Length - 1])); + } + } + + internal static long[] TensorShapeFromBarracuda(TensorShape src) + { + if (src.height == 1 && src.width == 1) + { + return new long[] { src.batch, src.channels }; + } + + return new long[] { src.batch, src.height, src.width, src.channels }; + } + + public static TensorProxy TensorProxyFromBarracuda(Tensor src, string nameOverride = null) + { + var shape = TensorShapeFromBarracuda(src.shape); + return new TensorProxy + { + name = nameOverride ?? src.name, + valueType = TensorProxy.TensorType.FloatingPoint, + shape = shape, + data = src + }; + } + + /// + /// Fill a specific batch of a TensorProxy with a given value + /// + /// + /// The batch index to fill. + /// + public static void FillTensorBatch(TensorProxy tensorProxy, int batch, float fillValue) + { + var height = tensorProxy.data.height; + var width = tensorProxy.data.width; + var channels = tensorProxy.data.channels; + for (var h = 0; h < height; h++) + { + for (var w = 0; w < width; w++) + { + for (var c = 0; c < channels; c++) + { + tensorProxy.data[batch, h, w, c] = fillValue; + } + } + } + } + + /// + /// Fill a pre-allocated Tensor with random numbers + /// + /// The pre-allocated Tensor to fill + /// RandomNormal object used to populate tensor + /// + /// Throws when trying to fill a Tensor of type other than float + /// + /// + /// Throws when the Tensor is not allocated + /// + public static void FillTensorWithRandomNormal( + TensorProxy tensorProxy, RandomNormal randomNormal) + { + if (tensorProxy.DataType != typeof(float)) + { + throw new NotImplementedException("Only float data types are currently supported"); + } + + if (tensorProxy.data == null) + { + throw new ArgumentNullException(); + } + + for (var i = 0; i < tensorProxy.data.length; i++) + { + tensorProxy.data[i] = (float)randomNormal.NextDouble(); + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Inference/TensorProxy.cs.meta b/com.unity.ml-agents/Runtime/Inference/TensorProxy.cs.meta new file mode 100644 index 0000000000..5e0dd9d08d --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/TensorProxy.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 926149e757bc849689e00e12d8c6fbdb +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Inference/Utils.meta b/com.unity.ml-agents/Runtime/Inference/Utils.meta new file mode 100644 index 0000000000..5431b3d646 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/Utils.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 7872e22895343467b9fe96d336a7edba +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Inference/Utils/Multinomial.cs b/com.unity.ml-agents/Runtime/Inference/Utils/Multinomial.cs new file mode 100644 index 0000000000..41603dd3ba --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/Utils/Multinomial.cs @@ -0,0 +1,59 @@ +namespace Unity.MLAgents.Inference.Utils +{ + /// + /// Multinomial - Draws samples from a multinomial distribution given a (potentially unscaled) + /// cumulative mass function (CMF). This means that the CMF need not "end" with probability + /// mass of 1.0. For instance: [0.1, 0.2, 0.5] is a valid (unscaled). What is important is + /// that it is a cumulative function, not a probability function. In other words, + /// entry[i] = P(x \le i), NOT P(i - 1 \le x \lt i). + /// (\le stands for less than or equal to while \lt is strictly less than). + /// + internal class Multinomial + { + readonly System.Random m_Random; + + /// + /// Constructor. + /// + /// + /// Seed for the random number generator used in the sampling process. + /// + public Multinomial(int seed) + { + m_Random = new System.Random(seed); + } + + /// + /// Samples from the Multinomial distribution defined by the provided cumulative + /// mass function. + /// + /// + /// Cumulative mass function, which may be unscaled. The entries in this array need + /// to be monotonic (always increasing). If the CMF is scaled, then the last entry in + /// the array will be 1.0. + /// + /// The number of possible branches, i.e. the effective size of the cmf array. + /// A sampled index from the CMF ranging from 0 to branchSize-1. + public int Sample(float[] cmf, int branchSize) + { + var p = (float)m_Random.NextDouble() * cmf[branchSize - 1]; + var cls = 0; + while (cmf[cls] < p) + { + ++cls; + } + + return cls; + } + + /// + /// Samples from the Multinomial distribution defined by the provided cumulative + /// mass function. + /// + /// A sampled index from the CMF ranging from 0 to cmf.Length-1. + public int Sample(float[] cmf) + { + return Sample(cmf, cmf.Length); + } + } +} diff --git a/com.unity.ml-agents/Runtime/Inference/Utils/Multinomial.cs.meta b/com.unity.ml-agents/Runtime/Inference/Utils/Multinomial.cs.meta new file mode 100644 index 0000000000..2467eb50b3 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/Utils/Multinomial.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 5c9e297dad748408db9e5ce26b940fe3 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Inference/Utils/RandomNormal.cs b/com.unity.ml-agents/Runtime/Inference/Utils/RandomNormal.cs new file mode 100644 index 0000000000..4b0e1e7e2f --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/Utils/RandomNormal.cs @@ -0,0 +1,56 @@ +using System; + +namespace Unity.MLAgents.Inference.Utils +{ + /// + /// RandomNormal - A random number generator that produces normally distributed random + /// numbers using the Marsaglia polar method: + /// https://en.wikipedia.org/wiki/Marsaglia_polar_method + /// TODO: worth overriding System.Random instead of aggregating? + /// + internal class RandomNormal + { + readonly double m_Mean; + readonly double m_Stddev; + readonly Random m_Random; + + public RandomNormal(int seed, float mean = 0.0f, float stddev = 1.0f) + { + m_Mean = mean; + m_Stddev = stddev; + m_Random = new Random(seed); + } + + // Each iteration produces two numbers. Hold one here for next call + bool m_HasSpare; + double m_SpareUnscaled; + + /// + /// Return the next random double number. + /// + /// Next random double number. + public double NextDouble() + { + if (m_HasSpare) + { + m_HasSpare = false; + return m_SpareUnscaled * m_Stddev + m_Mean; + } + + double u, v, s; + do + { + u = m_Random.NextDouble() * 2.0 - 1.0; + v = m_Random.NextDouble() * 2.0 - 1.0; + s = u * u + v * v; + } + while (s >= 1.0 || Math.Abs(s) < double.Epsilon); + + s = Math.Sqrt(-2.0 * Math.Log(s) / s); + m_SpareUnscaled = u * s; + m_HasSpare = true; + + return v * s * m_Stddev + m_Mean; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Inference/Utils/RandomNormal.cs.meta b/com.unity.ml-agents/Runtime/Inference/Utils/RandomNormal.cs.meta new file mode 100644 index 0000000000..3cd152462d --- /dev/null +++ b/com.unity.ml-agents/Runtime/Inference/Utils/RandomNormal.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: df8528cf20f0e4c64a4a7596eccc1631 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/InplaceArray.cs b/com.unity.ml-agents/Runtime/InplaceArray.cs new file mode 100644 index 0000000000..91208136d8 --- /dev/null +++ b/com.unity.ml-agents/Runtime/InplaceArray.cs @@ -0,0 +1,240 @@ +using System; +using System.Collections.Generic; + +namespace Unity.MLAgents +{ + /// + /// An array-like object that stores up to four elements. + /// This is a value type that does not allocate any additional memory. + /// + /// + /// This does not implement any interfaces such as IList, in order to avoid any accidental boxing allocations. + /// + /// + public struct InplaceArray : IEquatable> where T : struct + { + private const int k_MaxLength = 4; + private readonly int m_Length; + + private T m_Elem0; + private T m_Elem1; + private T m_Elem2; + private T m_Elem3; + + /// + /// Create a length-1 array. + /// + /// + public InplaceArray(T elem0) + { + m_Length = 1; + m_Elem0 = elem0; + m_Elem1 = new T(); + m_Elem2 = new T(); + m_Elem3 = new T(); + } + + /// + /// Create a length-2 array. + /// + /// + /// + public InplaceArray(T elem0, T elem1) + { + m_Length = 2; + m_Elem0 = elem0; + m_Elem1 = elem1; + m_Elem2 = new T(); + m_Elem3 = new T(); + } + + /// + /// Create a length-3 array. + /// + /// + /// + /// + public InplaceArray(T elem0, T elem1, T elem2) + { + m_Length = 3; + m_Elem0 = elem0; + m_Elem1 = elem1; + m_Elem2 = elem2; + m_Elem3 = new T(); + } + + /// + /// Create a length-3 array. + /// + /// + /// + /// + /// + public InplaceArray(T elem0, T elem1, T elem2, T elem3) + { + m_Length = 4; + m_Elem0 = elem0; + m_Elem1 = elem1; + m_Elem2 = elem2; + m_Elem3 = elem3; + } + + /// + /// Construct an InplaceArray from an IList (e.g. Array or List). + /// The source must be non-empty and have at most 4 elements. + /// + /// + /// + /// + public static InplaceArray FromList(IList elems) + { + switch (elems.Count) + { + case 1: + return new InplaceArray(elems[0]); + case 2: + return new InplaceArray(elems[0], elems[1]); + case 3: + return new InplaceArray(elems[0], elems[1], elems[2]); + case 4: + return new InplaceArray(elems[0], elems[1], elems[2], elems[3]); + default: + throw new ArgumentOutOfRangeException(); + } + } + + /// + /// Per-element access. + /// + /// + /// + public T this[int index] + { + get + { + if (index >= Length) + { + throw new IndexOutOfRangeException(); + } + + switch (index) + { + case 0: + return m_Elem0; + case 1: + return m_Elem1; + case 2: + return m_Elem2; + case 3: + return m_Elem3; + default: + throw new IndexOutOfRangeException(); + } + } + + set + { + if (index >= Length) + { + throw new IndexOutOfRangeException(); + } + + switch (index) + { + case 0: + m_Elem0 = value; + break; + case 1: + m_Elem1 = value; + break; + case 2: + m_Elem2 = value; + break; + case 3: + m_Elem3 = value; + break; + default: + throw new IndexOutOfRangeException(); + } + } + } + + /// + /// The length of the array. + /// + public int Length + { + get => m_Length; + } + + /// + /// Returns a string representation of the array's elements. + /// + /// + /// + public override string ToString() + { + switch (m_Length) + { + case 1: + return $"[{m_Elem0}]"; + case 2: + return $"[{m_Elem0}, {m_Elem1}]"; + case 3: + return $"[{m_Elem0}, {m_Elem1}, {m_Elem2}]"; + case 4: + return $"[{m_Elem0}, {m_Elem1}, {m_Elem2}, {m_Elem3}]"; + default: + throw new IndexOutOfRangeException(); + } + } + + /// + /// Check that the arrays have the same length and have all equal values. + /// + /// + /// + /// Whether the arrays are equivalent. + public static bool operator ==(InplaceArray lhs, InplaceArray rhs) + { + return lhs.Equals(rhs); + } + + /// + /// Check that the arrays are not equivalent. + /// + /// + /// + /// Whether the arrays are not equivalent + public static bool operator !=(InplaceArray lhs, InplaceArray rhs) => !lhs.Equals(rhs); + + /// + /// Check that the arrays are equivalent. + /// + /// + /// Whether the arrays are not equivalent + public override bool Equals(object other) => other is InplaceArray other1 && this.Equals(other1); + + /// + /// Check that the arrays are equivalent. + /// + /// + /// Whether the arrays are not equivalent + public bool Equals(InplaceArray other) + { + // See https://montemagno.com/optimizing-c-struct-equality-with-iequatable/ + var thisTuple = (m_Elem0, m_Elem1, m_Elem2, m_Elem3, Length); + var otherTuple = (other.m_Elem0, other.m_Elem1, other.m_Elem2, other.m_Elem3, other.Length); + return thisTuple.Equals(otherTuple); + } + + /// + /// Get a hashcode for the array. + /// + /// + public override int GetHashCode() + { + return (m_Elem0, m_Elem1, m_Elem2, m_Elem3, Length).GetHashCode(); + } + } +} diff --git a/com.unity.ml-agents/Runtime/InplaceArray.cs.meta b/com.unity.ml-agents/Runtime/InplaceArray.cs.meta new file mode 100644 index 0000000000..3e4ab0c928 --- /dev/null +++ b/com.unity.ml-agents/Runtime/InplaceArray.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: c1a80abee18a41c8aee89aeb33f5985d +timeCreated: 1615506199 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Integrations.meta b/com.unity.ml-agents/Runtime/Integrations.meta new file mode 100644 index 0000000000..f218be2521 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Integrations.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: f8f4fd0bc35f4e8f9867228591f663e3 +timeCreated: 1618359419 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Integrations/Match3.meta b/com.unity.ml-agents/Runtime/Integrations/Match3.meta new file mode 100644 index 0000000000..27b09a9eff --- /dev/null +++ b/com.unity.ml-agents/Runtime/Integrations/Match3.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 569f8fa2b7dd477c9b71f09e9d633832 +timeCreated: 1600465975 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Integrations/Match3/AbstractBoard.cs b/com.unity.ml-agents/Runtime/Integrations/Match3/AbstractBoard.cs new file mode 100644 index 0000000000..d2f218a3f7 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Integrations/Match3/AbstractBoard.cs @@ -0,0 +1,312 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using UnityEngine; +using Debug = UnityEngine.Debug; + +namespace Unity.MLAgents.Integrations.Match3 +{ + /// + /// Representation of the AbstractBoard dimensions, and number of cell and special types. + /// + public struct BoardSize + { + /// + /// Number of rows on the board + /// + public int Rows; + + /// + /// Number of columns on the board + /// + public int Columns; + + /// + /// Maximum number of different types of cells (colors, pieces, etc). + /// + public int NumCellTypes; + + /// + /// Maximum number of special types. This can be zero, in which case + /// all cells of the same type are assumed to be equivalent. + /// + public int NumSpecialTypes; + + /// + /// Check that all fields of the left-hand BoardSize are less than or equal to the field of the right-hand BoardSize + /// + /// + /// + /// True if all fields are less than or equal. + public static bool operator <=(BoardSize lhs, BoardSize rhs) + { + return lhs.Rows <= rhs.Rows && lhs.Columns <= rhs.Columns && lhs.NumCellTypes <= rhs.NumCellTypes && + lhs.NumSpecialTypes <= rhs.NumSpecialTypes; + } + + /// + /// Check that all fields of the left-hand BoardSize are greater than or equal to the field of the right-hand BoardSize + /// + /// + /// + /// True if all fields are greater than or equal. + public static bool operator >=(BoardSize lhs, BoardSize rhs) + { + return lhs.Rows >= rhs.Rows && lhs.Columns >= rhs.Columns && lhs.NumCellTypes >= rhs.NumCellTypes && + lhs.NumSpecialTypes >= rhs.NumSpecialTypes; + } + + /// + /// Return a string representation of the BoardSize. + /// + /// + public override string ToString() + { + return + $"Rows: {Rows}, Columns: {Columns}, NumCellTypes: {NumCellTypes}, NumSpecialTypes: {NumSpecialTypes}"; + } + } + + /// + /// An adapter between ML Agents and a Match-3 game. + /// + public abstract class AbstractBoard : MonoBehaviour + { + /// + /// Return the maximum size of the board. This is used to determine the size of observations and actions, + /// so the returned values must not change. + /// + /// + public abstract BoardSize GetMaxBoardSize(); + + /// + /// Return the current size of the board. The values must less than or equal to the values returned from + /// . + /// By default, this will return ; if your board doesn't change size, you don't need to + /// override it. + /// + /// + public virtual BoardSize GetCurrentBoardSize() + { + return GetMaxBoardSize(); + } + + /// + /// Returns the "color" of the piece at the given row and column. + /// This should be between 0 and BoardSize.NumCellTypes-1 (inclusive). + /// The actual order of the values doesn't matter. + /// + /// + /// + /// + public abstract int GetCellType(int row, int col); + + /// + /// Returns the special type of the piece at the given row and column. + /// This should be between 0 and BoardSize.NumSpecialTypes (inclusive). + /// The actual order of the values doesn't matter. + /// + /// + /// + /// + public abstract int GetSpecialType(int row, int col); + + /// + /// Check whether the particular Move is valid for the game. + /// The actual results will depend on the rules of the game, but we provide + /// that handles basic match3 rules with no special or immovable pieces. + /// + /// + /// Moves that would go outside of are filtered out before they are + /// passed to IsMoveValid(). + /// + /// The move to check. + /// + public abstract bool IsMoveValid(Move m); + + /// + /// Instruct the game to make the given . Returns true if the move was made. + /// Note that during training, a move that was marked as invalid may occasionally still be + /// requested. If this happens, it is safe to do nothing and request another move. + /// + /// The move to carry out. + /// + public abstract bool MakeMove(Move m); + + /// + /// Return the total number of moves possible for the board. + /// + /// + public int NumMoves() + { + return Move.NumPotentialMoves(GetMaxBoardSize()); + } + + /// + /// An optional callback for when the all moves are invalid. Ideally, the game state should + /// be changed before this happens, but this is a way to get notified if not. + /// + public Action OnNoValidMovesAction; + + /// + /// Iterate through all moves on the board. + /// + /// + public IEnumerable AllMoves() + { + var maxBoardSize = GetMaxBoardSize(); + var currentBoardSize = GetCurrentBoardSize(); + + var currentMove = Move.FromMoveIndex(0, maxBoardSize); + for (var i = 0; i < NumMoves(); i++) + { + if (currentMove.InRangeForBoard(currentBoardSize)) + { + yield return currentMove; + } + currentMove.Next(maxBoardSize); + } + } + + /// + /// Iterate through all valid moves on the board. + /// + /// + public IEnumerable ValidMoves() + { + var maxBoardSize = GetMaxBoardSize(); + var currentBoardSize = GetCurrentBoardSize(); + + var currentMove = Move.FromMoveIndex(0, maxBoardSize); + for (var i = 0; i < NumMoves(); i++) + { + if (currentMove.InRangeForBoard(currentBoardSize) && IsMoveValid(currentMove)) + { + yield return currentMove; + } + currentMove.Next(maxBoardSize); + } + } + + /// + /// Returns true if swapping the cells specified by the move would result in + /// 3 or more cells of the same type in a row. This assumes that all pieces are allowed + /// to be moved; to add extra logic, incorporate it into your method. + /// + /// + /// + public bool SimpleIsMoveValid(Move move) + { + using (TimerStack.Instance.Scoped("SimpleIsMoveValid")) + { + var moveVal = GetCellType(move.Row, move.Column); + var (otherRow, otherCol) = move.OtherCell(); + var oppositeVal = GetCellType(otherRow, otherCol); + + // Simple check - if the values are the same, don't match + // This might not be valid for all games + { + if (moveVal == oppositeVal) + { + return false; + } + } + + bool moveMatches = CheckHalfMove(otherRow, otherCol, moveVal, move.Direction); + if (moveMatches) + { + // early out + return true; + } + + bool otherMatches = CheckHalfMove(move.Row, move.Column, oppositeVal, move.OtherDirection()); + return otherMatches; + } + } + + /// + /// Check if one of the cells that is swapped during a move matches 3 or more. + /// Since these checks are similar for each cell, we consider the move as two "half moves". + /// + /// + /// + /// + /// + /// + bool CheckHalfMove(int newRow, int newCol, int newValue, Direction incomingDirection) + { + var currentBoardSize = GetCurrentBoardSize(); + int matchedLeft = 0, matchedRight = 0, matchedUp = 0, matchedDown = 0; + + if (incomingDirection != Direction.Right) + { + for (var c = newCol - 1; c >= 0; c--) + { + if (GetCellType(newRow, c) == newValue) + matchedLeft++; + else + break; + } + } + + if (incomingDirection != Direction.Left) + { + for (var c = newCol + 1; c < currentBoardSize.Columns; c++) + { + if (GetCellType(newRow, c) == newValue) + matchedRight++; + else + break; + } + } + + if (incomingDirection != Direction.Down) + { + for (var r = newRow + 1; r < currentBoardSize.Rows; r++) + { + if (GetCellType(r, newCol) == newValue) + matchedUp++; + else + break; + } + } + + if (incomingDirection != Direction.Up) + { + for (var r = newRow - 1; r >= 0; r--) + { + if (GetCellType(r, newCol) == newValue) + matchedDown++; + else + break; + } + } + + if ((matchedUp + matchedDown >= 2) || (matchedLeft + matchedRight >= 2)) + { + return true; + } + + return false; + } + + /// + /// Make sure that the current BoardSize isn't larger than the original value of . + /// If it is, log a warning. + /// + /// + [Conditional("DEBUG")] + internal void CheckBoardSizes(BoardSize originalMaxBoardSize) + { + var currentBoardSize = GetCurrentBoardSize(); + if (!(currentBoardSize <= originalMaxBoardSize)) + { + Debug.LogWarning( + "Current BoardSize is larger than maximum board size was on initialization. This may cause unexpected results.\n" + + $"Original GetMaxBoardSize() result: {originalMaxBoardSize}\n" + + $"GetCurrentBoardSize() result: {currentBoardSize}" + ); + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Integrations/Match3/AbstractBoard.cs.meta b/com.unity.ml-agents/Runtime/Integrations/Match3/AbstractBoard.cs.meta new file mode 100644 index 0000000000..42019368a1 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Integrations/Match3/AbstractBoard.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 6222defa70dc4c08aaeafd0be4e821d2 +timeCreated: 1600466051 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Integrations/Match3/Match3Actuator.cs b/com.unity.ml-agents/Runtime/Integrations/Match3/Match3Actuator.cs new file mode 100644 index 0000000000..9bd60bd571 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Integrations/Match3/Match3Actuator.cs @@ -0,0 +1,188 @@ +using Unity.MLAgents.Actuators; +using Debug = UnityEngine.Debug; + + +namespace Unity.MLAgents.Integrations.Match3 +{ + /// + /// Actuator for a Match3 game. It translates valid moves (defined by AbstractBoard.IsMoveValid()) + /// in action masks, and applies the action to the board via AbstractBoard.MakeMove(). + /// + public class Match3Actuator : IActuator, IBuiltInActuator + { + AbstractBoard m_Board; + System.Random m_Random; + ActionSpec m_ActionSpec; + bool m_ForceHeuristic; + BoardSize m_MaxBoardSize; + + /// + /// Create a Match3Actuator. + /// + /// + /// Whether the inference action should be ignored and the Agent's Heuristic + /// should be called. This should only be used for generating comparison stats of the Heuristic. + /// The seed used to initialize . + /// + public Match3Actuator(AbstractBoard board, + bool forceHeuristic, + int seed, + string name) + { + m_Board = board; + m_MaxBoardSize = m_Board.GetMaxBoardSize(); + Name = name; + + m_ForceHeuristic = forceHeuristic; + + var numMoves = Move.NumPotentialMoves(m_MaxBoardSize); + m_ActionSpec = ActionSpec.MakeDiscrete(numMoves); + m_Random = new System.Random(seed); + } + + /// + public ActionSpec ActionSpec => m_ActionSpec; + + /// + public void OnActionReceived(ActionBuffers actions) + { + m_Board.CheckBoardSizes(m_MaxBoardSize); + if (m_ForceHeuristic) + { + Heuristic(actions); + } + var moveIndex = actions.DiscreteActions[0]; + + Move move = Move.FromMoveIndex(moveIndex, m_MaxBoardSize); + m_Board.MakeMove(move); + } + + /// + public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) + { + var currentBoardSize = m_Board.GetCurrentBoardSize(); + m_Board.CheckBoardSizes(m_MaxBoardSize); + const int branch = 0; + bool foundValidMove = false; + using (TimerStack.Instance.Scoped("WriteDiscreteActionMask")) + { + var numMoves = m_Board.NumMoves(); + + var currentMove = Move.FromMoveIndex(0, m_MaxBoardSize); + for (var i = 0; i < numMoves; i++) + { + // Check that the move is allowed for the current boardSize (e.g. it won't move a piece out of + // bounds), and that it's allowed by the game itself. + if (currentMove.InRangeForBoard(currentBoardSize) && m_Board.IsMoveValid(currentMove)) + { + foundValidMove = true; + } + else + { + actionMask.SetActionEnabled(branch, i, false); + } + currentMove.Next(m_MaxBoardSize); + } + + if (!foundValidMove) + { + // If all the moves are invalid and we mask all the actions out, this will cause an assert + // later on in IDiscreteActionMask. Instead, fire a callback to the user if they provided one, + // (or log a warning if not) and leave the last action unmasked. This isn't great, but + // an invalid move should be easier to handle than an exception.. + if (m_Board.OnNoValidMovesAction != null) + { + m_Board.OnNoValidMovesAction(); + } + else + { + Debug.LogWarning( + "No valid moves are available. The last action will be left unmasked, so " + + "an invalid move will be passed to AbstractBoard.MakeMove()." + ); + } + actionMask.SetActionEnabled(branch, numMoves - 1, true); + } + } + } + + /// + public string Name { get; } + + /// + public void ResetData() + { + } + + /// + public BuiltInActuatorType GetBuiltInActuatorType() + { + return BuiltInActuatorType.Match3Actuator; + } + + /// + public void Heuristic(in ActionBuffers actionsOut) + { + var discreteActions = actionsOut.DiscreteActions; + discreteActions[0] = GreedyMove(); + } + + /// + /// Returns a valid move that gives the highest value for EvalMovePoints(). If multiple moves have the same + /// value, one of them will be chosen with uniform probability. + /// + /// + /// By default, EvalMovePoints() returns 1, so all valid moves are equally likely. Inherit from this class and + /// override EvalMovePoints() to use your game's scoring as a better estimate. + /// + /// + internal int GreedyMove() + { + var bestMoveIndex = 0; + var bestMovePoints = -1; + var numMovesAtCurrentScore = 0; + + foreach (var move in m_Board.ValidMoves()) + { + var movePoints = EvalMovePoints(move); + if (movePoints < bestMovePoints) + { + // Worse, skip + continue; + } + + if (movePoints > bestMovePoints) + { + // Better, keep + bestMovePoints = movePoints; + bestMoveIndex = move.MoveIndex; + numMovesAtCurrentScore = 1; + } + else + { + // Tied for best - use reservoir sampling to make sure we select from equal moves uniformly. + // See https://en.wikipedia.org/wiki/Reservoir_sampling#Simple_algorithm + numMovesAtCurrentScore++; + var randVal = m_Random.Next(0, numMovesAtCurrentScore); + if (randVal == 0) + { + // Keep the new one + bestMoveIndex = move.MoveIndex; + } + } + } + + return bestMoveIndex; + } + + /// + /// Method to be overridden when evaluating how many points a specific move will generate. + /// + /// The move to evaluate. + /// The number of points the move generates. + protected virtual int EvalMovePoints(Move move) + { + return 1; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Integrations/Match3/Match3Actuator.cs.meta b/com.unity.ml-agents/Runtime/Integrations/Match3/Match3Actuator.cs.meta new file mode 100644 index 0000000000..4052f5e51f --- /dev/null +++ b/com.unity.ml-agents/Runtime/Integrations/Match3/Match3Actuator.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 9083fa4c35dc499aa5a86d8e7447c7cf +timeCreated: 1600906373 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Integrations/Match3/Match3ActuatorComponent.cs b/com.unity.ml-agents/Runtime/Integrations/Match3/Match3ActuatorComponent.cs new file mode 100644 index 0000000000..8dfa7e496e --- /dev/null +++ b/com.unity.ml-agents/Runtime/Integrations/Match3/Match3ActuatorComponent.cs @@ -0,0 +1,81 @@ +using System; +using Unity.MLAgents.Actuators; +using UnityEngine; +using UnityEngine.Serialization; + +namespace Unity.MLAgents.Integrations.Match3 +{ + /// + /// Actuator component for a Match3 game. Generates a Match3Actuator at runtime. + /// + [AddComponentMenu("ML Agents/Match 3 Actuator", (int)MenuGroup.Actuators)] + public class Match3ActuatorComponent : ActuatorComponent + { + [HideInInspector, SerializeField, FormerlySerializedAs("ActuatorName")] + string m_ActuatorName = "Match3 Actuator"; + + /// + /// Name of the generated Match3Actuator object. + /// Note that changing this at runtime does not affect how the Agent sorts the actuators. + /// + public string ActuatorName + { + get => m_ActuatorName; + set => m_ActuatorName = value; + } + + [HideInInspector, SerializeField, FormerlySerializedAs("RandomSeed")] + int m_RandomSeed = -1; + + /// + /// A random seed used in the actuator's heuristic, if needed. + /// + public int RandomSeed + { + get => m_RandomSeed; + set => m_RandomSeed = value; + } + + [HideInInspector, SerializeField, FormerlySerializedAs("ForceHeuristic")] + [Tooltip("Force using the Agent's Heuristic() method to decide the action. This should only be used in testing.")] + bool m_ForceHeuristic; + + /// + /// Force using the Agent's Heuristic() method to decide the action. This should only be used in testing. + /// + public bool ForceHeuristic + { + get => m_ForceHeuristic; + set => m_ForceHeuristic = value; + } + + /// + public override IActuator[] CreateActuators() + { + var board = GetComponent(); + if (!board) + { + return Array.Empty(); + } + + var seed = m_RandomSeed == -1 ? gameObject.GetInstanceID() : m_RandomSeed + 1; + return new IActuator[] { new Match3Actuator(board, m_ForceHeuristic, seed, m_ActuatorName) }; + } + + /// + public override ActionSpec ActionSpec + { + get + { + var board = GetComponent(); + if (board == null) + { + return ActionSpec.MakeContinuous(0); + } + + var numMoves = Move.NumPotentialMoves(board.GetMaxBoardSize()); + return ActionSpec.MakeDiscrete(numMoves); + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Integrations/Match3/Match3ActuatorComponent.cs.meta b/com.unity.ml-agents/Runtime/Integrations/Match3/Match3ActuatorComponent.cs.meta new file mode 100644 index 0000000000..c592f8be7a --- /dev/null +++ b/com.unity.ml-agents/Runtime/Integrations/Match3/Match3ActuatorComponent.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 08e4b0da54cb4d56bfcbae22dd49ab8d +timeCreated: 1600906388 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Integrations/Match3/Match3Sensor.cs b/com.unity.ml-agents/Runtime/Integrations/Match3/Match3Sensor.cs new file mode 100644 index 0000000000..4252bd9b50 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Integrations/Match3/Match3Sensor.cs @@ -0,0 +1,350 @@ +using System; +using System.Collections.Generic; +using Unity.MLAgents.Sensors; +using UnityEngine; + +namespace Unity.MLAgents.Integrations.Match3 +{ + /// + /// Delegate that provides integer values at a given (x,y) coordinate. + /// + /// + /// + public delegate int GridValueProvider(int x, int y); + + /// + /// Type of observations to generate. + /// + /// + public enum Match3ObservationType + { + /// + /// Generate a one-hot encoding of the cell type for each cell on the board. If there are special types, + /// these will also be one-hot encoded. + /// + Vector, + + /// + /// Generate a one-hot encoding of the cell type for each cell on the board, but arranged as + /// a Rows x Columns visual observation. If there are special types, these will also be one-hot encoded. + /// + UncompressedVisual, + + /// + /// Generate a one-hot encoding of the cell type for each cell on the board, but arranged as + /// a Rows x Columns visual observation. If there are special types, these will also be one-hot encoded. + /// During training, these will be sent as a concatenated series of PNG images, with 3 channels per image. + /// + CompressedVisual + } + + /// + /// Sensor for Match3 games. Can generate either vector, compressed visual, + /// or uncompressed visual observations. Uses a GridValueProvider to determine the observation values. + /// + public class Match3Sensor : ISensor, IBuiltInSensor, IDisposable + { + Match3ObservationType m_ObservationType; + ObservationSpec m_ObservationSpec; + string m_Name; + + AbstractBoard m_Board; + BoardSize m_MaxBoardSize; + GridValueProvider m_GridValues; + int m_OneHotSize; + + Texture2D m_ObservationTexture; + OneHotToTextureUtil m_TextureUtil; + + /// + /// Create a sensor for the GridValueProvider with the specified observation type. + /// + /// + /// Use Match3Sensor.CellTypeSensor() or Match3Sensor.SpecialTypeSensor() instead of calling + /// the constructor directly. + /// + /// The abstract board. + /// The GridValueProvider, should be either board.GetCellType or board.GetSpecialType. + /// The number of possible values that the GridValueProvider can return. + /// Whether to produce vector or visual observations + /// Name of the sensor. + public Match3Sensor(AbstractBoard board, GridValueProvider gvp, int oneHotSize, Match3ObservationType obsType, string name) + { + var maxBoardSize = board.GetMaxBoardSize(); + m_Name = name; + m_MaxBoardSize = maxBoardSize; + m_GridValues = gvp; + m_OneHotSize = oneHotSize; + m_Board = board; + + m_ObservationType = obsType; + m_ObservationSpec = obsType == Match3ObservationType.Vector + ? ObservationSpec.Vector(maxBoardSize.Rows * maxBoardSize.Columns * oneHotSize) + : ObservationSpec.Visual(maxBoardSize.Rows, maxBoardSize.Columns, oneHotSize); + } + + /// + /// Create a sensor that encodes the board cells as observations. + /// + /// The abstract board. + /// Whether to produce vector or visual observations + /// Name of the sensor. + /// + public static Match3Sensor CellTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name) + { + var maxBoardSize = board.GetMaxBoardSize(); + return new Match3Sensor(board, board.GetCellType, maxBoardSize.NumCellTypes, obsType, name); + } + + /// + /// Create a sensor that encodes the cell special types as observations. Returns null if the board's + /// NumSpecialTypes is 0 (indicating the sensor isn't needed). + /// + /// The abstract board. + /// Whether to produce vector or visual observations + /// Name of the sensor. + /// + public static Match3Sensor SpecialTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name) + { + var maxBoardSize = board.GetMaxBoardSize(); + if (maxBoardSize.NumSpecialTypes == 0) + { + return null; + } + var specialSize = maxBoardSize.NumSpecialTypes + 1; + return new Match3Sensor(board, board.GetSpecialType, specialSize, obsType, name); + } + + /// + public ObservationSpec GetObservationSpec() + { + return m_ObservationSpec; + } + + /// + public int Write(ObservationWriter writer) + { + m_Board.CheckBoardSizes(m_MaxBoardSize); + var currentBoardSize = m_Board.GetCurrentBoardSize(); + + int offset = 0; + var isVisual = m_ObservationType != Match3ObservationType.Vector; + + // This is equivalent to + // for (var r = 0; r < m_MaxBoardSize.Rows; r++) + // for (var c = 0; c < m_MaxBoardSize.Columns; c++) + // if (r < currentBoardSize.Rows && c < currentBoardSize.Columns) + // WriteOneHot + // else + // WriteZero + // but rearranged to avoid the branching. + + for (var r = 0; r < currentBoardSize.Rows; r++) + { + for (var c = 0; c < currentBoardSize.Columns; c++) + { + var val = m_GridValues(r, c); + writer.WriteOneHot(offset, r, c, val, m_OneHotSize, isVisual); + offset += m_OneHotSize; + } + + for (var c = currentBoardSize.Columns; c < m_MaxBoardSize.Columns; c++) + { + writer.WriteZero(offset, r, c, m_OneHotSize, isVisual); + offset += m_OneHotSize; + } + } + + for (var r = currentBoardSize.Rows; r < m_MaxBoardSize.Columns; r++) + { + for (var c = 0; c < m_MaxBoardSize.Columns; c++) + { + writer.WriteZero(offset, r, c, m_OneHotSize, isVisual); + offset += m_OneHotSize; + } + } + + return offset; + } + + /// + public byte[] GetCompressedObservation() + { + m_Board.CheckBoardSizes(m_MaxBoardSize); + var height = m_MaxBoardSize.Rows; + var width = m_MaxBoardSize.Columns; + if (ReferenceEquals(null, m_ObservationTexture)) + { + m_ObservationTexture = new Texture2D(width, height, TextureFormat.RGB24, false); + } + + if (ReferenceEquals(null, m_TextureUtil)) + { + m_TextureUtil = new OneHotToTextureUtil(height, width); + } + var bytesOut = new List(); + var currentBoardSize = m_Board.GetCurrentBoardSize(); + + // Encode the cell types or special types as batches of PNGs + // This is potentially wasteful, e.g. if there are 4 cell types and 1 special type, we could + // fit in in 2 images, but we'll use 3 total (2 PNGs for the 4 cell type channels, and 1 for + // the special types). + var numCellImages = (m_OneHotSize + 2) / 3; + for (var i = 0; i < numCellImages; i++) + { + m_TextureUtil.EncodeToTexture( + m_GridValues, + m_ObservationTexture, + 3 * i, + currentBoardSize.Rows, + currentBoardSize.Columns + ); + bytesOut.AddRange(m_ObservationTexture.EncodeToPNG()); + } + + return bytesOut.ToArray(); + } + + /// + public void Update() + { + } + + /// + public void Reset() + { + } + + internal SensorCompressionType GetCompressionType() + { + return m_ObservationType == Match3ObservationType.CompressedVisual ? + SensorCompressionType.PNG : + SensorCompressionType.None; + } + + /// + public CompressionSpec GetCompressionSpec() + { + return new CompressionSpec(GetCompressionType()); + } + + /// + public string GetName() + { + return m_Name; + } + + /// + public BuiltInSensorType GetBuiltInSensorType() + { + return BuiltInSensorType.Match3Sensor; + } + + /// + /// Clean up the owned Texture2D. + /// + public void Dispose() + { + if (!ReferenceEquals(null, m_ObservationTexture)) + { + Utilities.DestroyTexture(m_ObservationTexture); + m_ObservationTexture = null; + } + } + } + + /// + /// Utility class for converting a 2D array of ints representing a one-hot encoding into + /// a texture, suitable for conversion to PNGs for observations. + /// Works by encoding 3 values at a time as pixels in the texture, thus it should be + /// called (maxValue + 2) / 3 times, increasing the channelOffset by 3 each time. + /// + internal class OneHotToTextureUtil + { + Color[] m_Colors; + int m_MaxHeight; + int m_MaxWidth; + private static Color[] s_OneHotColors = { Color.red, Color.green, Color.blue }; + + public OneHotToTextureUtil(int maxHeight, int maxWidth) + { + m_Colors = new Color[maxHeight * maxWidth]; + m_MaxHeight = maxHeight; + m_MaxWidth = maxWidth; + } + + public void EncodeToTexture( + GridValueProvider gridValueProvider, + Texture2D texture, + int channelOffset, + int currentHeight, + int currentWidth + ) + { + var i = 0; + // There's an implicit flip converting to PNG from texture, so make sure we + // counteract that when forming the texture by iterating through h in reverse. + for (var h = m_MaxHeight - 1; h >= 0; h--) + { + for (var w = 0; w < m_MaxWidth; w++) + { + var colorVal = Color.black; + if (h < currentHeight && w < currentWidth) + { + int oneHotValue = gridValueProvider(h, w); + if (oneHotValue >= channelOffset && oneHotValue < channelOffset + 3) + { + colorVal = s_OneHotColors[oneHotValue - channelOffset]; + } + } + m_Colors[i++] = colorVal; + } + } + texture.SetPixels(m_Colors); + } + } + + /// + /// Utility methods for writing one-hot observations. + /// + internal static class ObservationWriterMatch3Extensions + { + public static void WriteOneHot(this ObservationWriter writer, int offset, int row, int col, int value, int oneHotSize, bool isVisual) + { + if (isVisual) + { + for (var i = 0; i < oneHotSize; i++) + { + writer[row, col, i] = (i == value) ? 1.0f : 0.0f; + } + } + else + { + for (var i = 0; i < oneHotSize; i++) + { + writer[offset] = (i == value) ? 1.0f : 0.0f; + offset++; + } + } + } + + public static void WriteZero(this ObservationWriter writer, int offset, int row, int col, int oneHotSize, bool isVisual) + { + if (isVisual) + { + for (var i = 0; i < oneHotSize; i++) + { + writer[row, col, i] = 0.0f; + } + } + else + { + for (var i = 0; i < oneHotSize; i++) + { + writer[offset] = 0.0f; + offset++; + } + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Integrations/Match3/Match3Sensor.cs.meta b/com.unity.ml-agents/Runtime/Integrations/Match3/Match3Sensor.cs.meta new file mode 100644 index 0000000000..b440cac0fc --- /dev/null +++ b/com.unity.ml-agents/Runtime/Integrations/Match3/Match3Sensor.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 795ad5f211e344e5bf3049abd9499721 +timeCreated: 1600906663 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Integrations/Match3/Match3SensorComponent.cs b/com.unity.ml-agents/Runtime/Integrations/Match3/Match3SensorComponent.cs new file mode 100644 index 0000000000..8afd4d0edc --- /dev/null +++ b/com.unity.ml-agents/Runtime/Integrations/Match3/Match3SensorComponent.cs @@ -0,0 +1,77 @@ +using System; +using Unity.MLAgents.Sensors; +using UnityEngine; +using UnityEngine.Serialization; + +namespace Unity.MLAgents.Integrations.Match3 +{ + /// + /// Sensor component for a Match3 game. + /// + [AddComponentMenu("ML Agents/Match 3 Sensor", (int)MenuGroup.Sensors)] + public class Match3SensorComponent : SensorComponent, IDisposable + { + [HideInInspector, SerializeField, FormerlySerializedAs("SensorName")] + string m_SensorName = "Match3 Sensor"; + + /// + /// Name of the generated Match3Sensor object. + /// Note that changing this at runtime does not affect how the Agent sorts the sensors. + /// + public string SensorName + { + get => m_SensorName; + set => m_SensorName = value; + } + + [HideInInspector, SerializeField, FormerlySerializedAs("ObservationType")] + Match3ObservationType m_ObservationType = Match3ObservationType.Vector; + + /// + /// Type of observation to generate. + /// + public Match3ObservationType ObservationType + { + get => m_ObservationType; + set => m_ObservationType = value; + } + + private ISensor[] m_Sensors; + + /// + public override ISensor[] CreateSensors() + { + // Clean up any existing sensors + Dispose(); + + var board = GetComponent(); + if (!board) + { + return Array.Empty(); + } + var cellSensor = Match3Sensor.CellTypeSensor(board, m_ObservationType, m_SensorName + " (cells)"); + // This can be null if BoardSize.NumSpecialTypes is 0 + var specialSensor = Match3Sensor.SpecialTypeSensor(board, m_ObservationType, m_SensorName + " (special)"); + m_Sensors = specialSensor != null + ? new ISensor[] { cellSensor, specialSensor } + : new ISensor[] { cellSensor }; + return m_Sensors; + } + + /// + /// Clean up the sensors created by CreateSensors(). + /// + public void Dispose() + { + if (m_Sensors != null) + { + for (var i = 0; i < m_Sensors.Length; i++) + { + ((Match3Sensor)m_Sensors[i]).Dispose(); + } + + m_Sensors = null; + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Integrations/Match3/Match3SensorComponent.cs.meta b/com.unity.ml-agents/Runtime/Integrations/Match3/Match3SensorComponent.cs.meta new file mode 100644 index 0000000000..d2d2713eef --- /dev/null +++ b/com.unity.ml-agents/Runtime/Integrations/Match3/Match3SensorComponent.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 530d2f105aa145bd8a00e021bdd925fd +timeCreated: 1600906676 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Integrations/Match3/Move.cs b/com.unity.ml-agents/Runtime/Integrations/Match3/Move.cs new file mode 100644 index 0000000000..20bf0809bc --- /dev/null +++ b/com.unity.ml-agents/Runtime/Integrations/Match3/Move.cs @@ -0,0 +1,278 @@ +using System; +using UnityEngine; + +namespace Unity.MLAgents.Integrations.Match3 +{ + /// + /// Directions for a Move. + /// + public enum Direction + { + /// + /// Move up (increasing row direction). + /// + Up, + + /// + /// Move down (decreasing row direction). + /// + Down, // -row direction + + /// + /// Move left (decreasing column direction). + /// + Left, // -column direction + + /// + /// Move right (increasing column direction). + /// + Right, // +column direction + } + + /// + /// Struct that encapsulates a swap of adjacent cells. + /// A Move can be constructed from either a starting row, column, and direction, + /// or from a "move index" between 0 and NumPotentialMoves()-1. + /// Moves are enumerated as the internal edges of the game grid. + /// Left/right moves come first. There are (maxCols - 1) * maxRows of these. + /// Up/down moves are next. There are (maxRows - 1) * maxCols of these. + /// + public struct Move + { + /// + /// Index of the move, from 0 to NumPotentialMoves-1. + /// + public int MoveIndex; + + /// + /// Row of the cell that will be moved. + /// + public int Row; + + /// + /// Column of the cell that will be moved. + /// + public int Column; + + /// + /// Direction that the cell will be moved. + /// + public Direction Direction; + + /// + /// Construct a Move from its move index and the board size. + /// This is useful for iterating through all the Moves on a board, or constructing + /// the Move corresponding to an Agent decision. + /// + /// Must be between 0 and NumPotentialMoves(maxRows, maxCols). + /// + /// + /// + public static Move FromMoveIndex(int moveIndex, BoardSize maxBoardSize) + { + var maxRows = maxBoardSize.Rows; + var maxCols = maxBoardSize.Columns; + + if (moveIndex < 0 || moveIndex >= NumPotentialMoves(maxBoardSize)) + { + throw new ArgumentOutOfRangeException("moveIndex"); + } + Direction dir; + int row, col; + if (moveIndex < (maxCols - 1) * maxRows) + { + dir = Direction.Right; + col = moveIndex % (maxCols - 1); + row = moveIndex / (maxCols - 1); + } + else + { + dir = Direction.Up; + var offset = moveIndex - (maxCols - 1) * maxRows; + col = offset % maxCols; + row = offset / maxCols; + } + return new Move + { + MoveIndex = moveIndex, + Direction = dir, + Row = row, + Column = col + }; + } + + /// + /// Increment the Move to the next MoveIndex, and update the Row, Column, and Direction accordingly. + /// + /// + public void Next(BoardSize maxBoardSize) + { + var maxRows = maxBoardSize.Rows; + var maxCols = maxBoardSize.Columns; + + var switchoverIndex = (maxCols - 1) * maxRows; + + MoveIndex++; + if (MoveIndex < switchoverIndex) + { + Column++; + if (Column == maxCols - 1) + { + Row++; + Column = 0; + } + } + else if (MoveIndex == switchoverIndex) + { + // switch from moving right to moving up + Row = 0; + Column = 0; + Direction = Direction.Up; + } + else + { + Column++; + if (Column == maxCols) + { + Row++; + Column = 0; + } + } + } + + /// + /// Construct a Move from the row, column, direction, and board size. + /// + /// + /// + /// + /// + /// + public static Move FromPositionAndDirection(int row, int col, Direction dir, BoardSize maxBoardSize) + { + + // Check for out-of-bounds + if (row < 0 || row >= maxBoardSize.Rows) + { + throw new IndexOutOfRangeException($"row was {row}, but must be between 0 and {maxBoardSize.Rows - 1}."); + } + + if (col < 0 || col >= maxBoardSize.Columns) + { + throw new IndexOutOfRangeException($"col was {col}, but must be between 0 and {maxBoardSize.Columns - 1}."); + } + + // Check moves that would go out of bounds e.g. col == 0 and dir == Left + if ( + row == 0 && dir == Direction.Down || + row == maxBoardSize.Rows - 1 && dir == Direction.Up || + col == 0 && dir == Direction.Left || + col == maxBoardSize.Columns - 1 && dir == Direction.Right + ) + { + throw new IndexOutOfRangeException($"Cannot move cell at row={row} col={col} in Direction={dir}"); + } + + // Normalize - only consider Right and Up + if (dir == Direction.Left) + { + dir = Direction.Right; + col = col - 1; + } + else if (dir == Direction.Down) + { + dir = Direction.Up; + row = row - 1; + } + + int moveIndex; + if (dir == Direction.Right) + { + moveIndex = col + row * (maxBoardSize.Columns - 1); + } + else + { + var offset = (maxBoardSize.Columns - 1) * maxBoardSize.Rows; + moveIndex = offset + col + row * maxBoardSize.Columns; + } + + return new Move + { + Row = row, + Column = col, + Direction = dir, + MoveIndex = moveIndex, + }; + } + + /// + /// Check if the move is valid for the given board size. + /// This will be passed the return value from AbstractBoard.GetCurrentBoardSize(). + /// + /// + /// + public bool InRangeForBoard(BoardSize boardSize) + { + var (otherRow, otherCol) = OtherCell(); + // Get the maximum row and column this move would affect. + var maxMoveRow = Mathf.Max(Row, otherRow); + var maxMoveCol = Mathf.Max(Column, otherCol); + return maxMoveRow < boardSize.Rows && maxMoveCol < boardSize.Columns; + } + + /// + /// Get the other row and column that correspond to this move. + /// + /// + /// + public (int Row, int Column) OtherCell() + { + switch (Direction) + { + case Direction.Up: + return (Row + 1, Column); + case Direction.Down: + return (Row - 1, Column); + case Direction.Left: + return (Row, Column - 1); + case Direction.Right: + return (Row, Column + 1); + default: + throw new ArgumentOutOfRangeException(); + } + } + + /// + /// Get the opposite direction of this move. + /// + /// + /// + public Direction OtherDirection() + { + switch (Direction) + { + case Direction.Up: + return Direction.Down; + case Direction.Down: + return Direction.Up; + case Direction.Left: + return Direction.Right; + case Direction.Right: + return Direction.Left; + default: + throw new ArgumentOutOfRangeException(); + } + } + + /// + /// Return the number of potential moves for a board of the given size. + /// This is equivalent to the number of internal edges in the board. + /// + /// + /// + public static int NumPotentialMoves(BoardSize maxBoardSize) + { + return maxBoardSize.Rows * (maxBoardSize.Columns - 1) + (maxBoardSize.Rows - 1) * (maxBoardSize.Columns); + } + } +} diff --git a/com.unity.ml-agents/Runtime/Integrations/Match3/Move.cs.meta b/com.unity.ml-agents/Runtime/Integrations/Match3/Move.cs.meta new file mode 100644 index 0000000000..1457c24b13 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Integrations/Match3/Move.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 41d6d7b9e07c4ef1ae075c74a906906b +timeCreated: 1600466100 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/MLAgentsSettings.cs b/com.unity.ml-agents/Runtime/MLAgentsSettings.cs new file mode 100644 index 0000000000..a86cb3635c --- /dev/null +++ b/com.unity.ml-agents/Runtime/MLAgentsSettings.cs @@ -0,0 +1,41 @@ +using UnityEngine; +using System.Runtime.CompilerServices; + + +[assembly: InternalsVisibleTo("Unity.ML-Agents.DevTests.Editor")] +namespace Unity.MLAgents +{ + internal class MLAgentsSettings : ScriptableObject + { + [SerializeField] + private bool m_ConnectTrainer = true; + [SerializeField] + private int m_EditorPort = 5004; + + public bool ConnectTrainer + { + get { return m_ConnectTrainer; } + set + { + m_ConnectTrainer = value; + OnChange(); + } + } + + public int EditorPort + { + get { return m_EditorPort; } + set + { + m_EditorPort = value; + OnChange(); + } + } + + internal void OnChange() + { + if (MLAgentsSettingsManager.Settings == this) + MLAgentsSettingsManager.ApplySettings(); + } + } +} diff --git a/com.unity.ml-agents/Runtime/MLAgentsSettings.cs.meta b/com.unity.ml-agents/Runtime/MLAgentsSettings.cs.meta new file mode 100644 index 0000000000..90c1507a50 --- /dev/null +++ b/com.unity.ml-agents/Runtime/MLAgentsSettings.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 71515ce028aaa4b4cb6bee13e96ef6f3 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/MLAgentsSettingsManager.cs b/com.unity.ml-agents/Runtime/MLAgentsSettingsManager.cs new file mode 100644 index 0000000000..2c54593497 --- /dev/null +++ b/com.unity.ml-agents/Runtime/MLAgentsSettingsManager.cs @@ -0,0 +1,92 @@ +using System; +using UnityEngine; +#if UNITY_EDITOR +using UnityEditor; +#else +using System.Linq; +#endif + +namespace Unity.MLAgents +{ +#if UNITY_EDITOR + [InitializeOnLoad] +#endif + internal static class MLAgentsSettingsManager + { + internal static event Action OnSettingsChange; + internal const string EditorBuildSettingsConfigKey = "com.unity.ml-agents.settings"; + private static MLAgentsSettings s_Settings; + + + // setter will trigger callback for refreshing editor UI if using editor + public static MLAgentsSettings Settings + { + get + { + if (s_Settings == null) + { + Initialize(); + } + return s_Settings; + } + set + { + Debug.Assert(value != null); +#if UNITY_EDITOR + if (!string.IsNullOrEmpty(AssetDatabase.GetAssetPath(value))) + { + EditorBuildSettings.AddConfigObject(EditorBuildSettingsConfigKey, value, true); + } +#endif + s_Settings = value; + ApplySettings(); + } + } + + static MLAgentsSettingsManager() + { + Initialize(); + } + + static void Initialize() + { +#if UNITY_EDITOR + InitializeInEditor(); +#else + InitializeInPlayer(); +#endif + } + +#if UNITY_EDITOR + internal static void InitializeInEditor() + { + var settings = ScriptableObject.CreateInstance(); + if (EditorBuildSettings.TryGetConfigObject(EditorBuildSettingsConfigKey, + out MLAgentsSettings settingsAsset)) + { + if (settingsAsset != null) + { + settings = settingsAsset; + } + } + Settings = settings; + } +#else + internal static void InitializeInPlayer() + { + Settings = Resources.FindObjectsOfTypeAll().FirstOrDefault() ?? ScriptableObject.CreateInstance(); + } +#endif + + internal static void ApplySettings() + { + OnSettingsChange?.Invoke(); + } + + internal static void Destroy() + { + s_Settings = null; + OnSettingsChange = null; + } + } +} diff --git a/com.unity.ml-agents/Runtime/MLAgentsSettingsManager.cs.meta b/com.unity.ml-agents/Runtime/MLAgentsSettingsManager.cs.meta new file mode 100644 index 0000000000..9a6f0c3a12 --- /dev/null +++ b/com.unity.ml-agents/Runtime/MLAgentsSettingsManager.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: be40451993af54e3c84c7113140fdf2c +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs b/com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs new file mode 100644 index 0000000000..47aa61299e --- /dev/null +++ b/com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs @@ -0,0 +1,13 @@ +using System.Threading; + +namespace Unity.MLAgents +{ + internal static class MultiAgentGroupIdCounter + { + static int s_Counter; + public static int GetGroupId() + { + return Interlocked.Increment(ref s_Counter); + } + } +} diff --git a/com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta b/com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta new file mode 100644 index 0000000000..b4298cdc95 --- /dev/null +++ b/com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 5661ffdb6c7704e84bc785572dcd5bd1 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Policies.meta b/com.unity.ml-agents/Runtime/Policies.meta new file mode 100644 index 0000000000..6353357f7b --- /dev/null +++ b/com.unity.ml-agents/Runtime/Policies.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 3c3d1b36de8f74c9e8ab29c8f23f58ab +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs b/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs new file mode 100644 index 0000000000..96a15b50d8 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs @@ -0,0 +1,144 @@ +using Unity.Barracuda; +using System.Collections.Generic; +using System.Diagnostics; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Inference; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Policies +{ + /// + /// Where to perform inference. + /// + public enum InferenceDevice + { + /// + /// Default inference. This is currently the same as Burst, but may change in the future. + /// + Default = 0, + + /// + /// GPU inference. Corresponds to WorkerFactory.Type.ComputePrecompiled in Barracuda. + /// + GPU = 1, + + /// + /// CPU inference using Burst. Corresponds to WorkerFactory.Type.CSharpBurst in Barracuda. + /// + Burst = 2, + + /// + /// CPU inference. Corresponds to in WorkerFactory.Type.CSharp Barracuda. + /// Burst is recommended instead; this is kept for legacy compatibility. + /// + CPU = 3, + } + + /// + /// The Barracuda Policy uses a Barracuda Model to make decisions at + /// every step. It uses a ModelRunner that is shared across all + /// Barracuda Policies that use the same model and inference devices. + /// + internal class BarracudaPolicy : IPolicy + { + protected ModelRunner m_ModelRunner; + ActionBuffers m_LastActionBuffer; + + int m_AgentId; + /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. + /// + bool m_DeterministicInference; + + /// + /// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors. + /// + List m_SensorShapes; + ActionSpec m_ActionSpec; + + private string m_BehaviorName; + + /// + /// List of actuators, only used for analytics + /// + private IList m_Actuators; + + /// + /// Whether or not we've tried to send analytics for this model. We only ever try to send once per policy, + /// and do additional deduplication in the analytics code. + /// + private bool m_AnalyticsSent; + + /// + /// Instantiate a BarracudaPolicy with the necessary objects for it to run. + /// + /// The action spec of the behavior. + /// The actuators used for this behavior. + /// The Neural Network to use. + /// Which device Barracuda will run on. + /// The name of the behavior. + /// Inference only: set to true if the action selection from model should be + /// deterministic. + public BarracudaPolicy( + ActionSpec actionSpec, + IList actuators, + NNModel model, + InferenceDevice inferenceDevice, + string behaviorName, + bool deterministicInference = false + ) + { + var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, actionSpec, inferenceDevice, deterministicInference); + m_ModelRunner = modelRunner; + m_BehaviorName = behaviorName; + m_ActionSpec = actionSpec; + m_Actuators = actuators; + m_DeterministicInference = deterministicInference; + } + + /// + public void RequestDecision(AgentInfo info, List sensors) + { + SendAnalytics(sensors); + m_AgentId = info.episodeId; + m_ModelRunner?.PutObservations(info, sensors); + } + + [Conditional("MLA_UNITY_ANALYTICS_MODULE")] + void SendAnalytics(IList sensors) + { + if (!m_AnalyticsSent) + { + m_AnalyticsSent = true; + Analytics.InferenceAnalytics.InferenceModelSet( + m_ModelRunner.Model, + m_BehaviorName, + m_ModelRunner.InferenceDevice, + sensors, + m_ActionSpec, + m_Actuators + ); + } + } + + /// + public ref readonly ActionBuffers DecideAction() + { + if (m_ModelRunner == null) + { + m_LastActionBuffer = ActionBuffers.Empty; + } + else + { + m_ModelRunner?.DecideBatch(); + m_LastActionBuffer = m_ModelRunner.GetAction(m_AgentId); + } + return ref m_LastActionBuffer; + } + + public void Dispose() + { + } + } +} diff --git a/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs.meta b/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs.meta new file mode 100644 index 0000000000..014a05302d --- /dev/null +++ b/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 8eb047b11855142d2be2cc458bef3264 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs new file mode 100644 index 0000000000..b0d369b910 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs @@ -0,0 +1,294 @@ +using Unity.Barracuda; +using System; +using UnityEngine; +using UnityEngine.Serialization; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Sensors.Reflection; + +namespace Unity.MLAgents.Policies +{ + /// + /// Defines what type of behavior the Agent will be using + /// + [Serializable] + public enum BehaviorType + { + /// + /// The Agent will use the remote process for decision making. + /// if unavailable, will use inference and if no model is provided, will use + /// the heuristic. + /// + Default, + + /// + /// The Agent will always use its heuristic + /// + HeuristicOnly, + + /// + /// The Agent will always use inference with the provided + /// neural network model. + /// + InferenceOnly + } + + /// + /// Options for controlling how the Agent class is searched for s. + /// + public enum ObservableAttributeOptions + { + /// + /// All ObservableAttributes on the Agent will be ignored. This is the + /// default behavior. If there are no ObservableAttributes on the + /// Agent, this will result in the fastest initialization time. + /// + Ignore, + + /// + /// Only members on the declared class will be examined; members that are + /// inherited are ignored. This is a reasonable tradeoff between + /// performance and flexibility. + /// + /// This corresponds to setting the + /// [BindingFlags.DeclaredOnly](https://docs.microsoft.com/en-us/dotnet/api/system.reflection.bindingflags?view=netcore-3.1) + /// when examining the fields and properties of the Agent class instance. + /// + ExcludeInherited, + + /// + /// All members on the class will be examined. This can lead to slower + /// startup times. + /// + ExamineAll + } + + /// + /// A component for setting an instance's behavior and + /// brain properties. + /// + /// At runtime, this component generates the agent's policy objects + /// according to the settings you specified in the Editor. + [AddComponentMenu("ML Agents/Behavior Parameters", (int)MenuGroup.Default)] + public class BehaviorParameters : MonoBehaviour + { + [HideInInspector, SerializeField] + BrainParameters m_BrainParameters = new BrainParameters(); + + /// + /// Delegate for receiving events about Policy Updates. + /// + /// Whether or not the current policy is running in heuristic mode. + public delegate void PolicyUpdated(bool isInHeuristicMode); + + /// + /// Event that fires when an Agent's policy is updated. + /// + internal event PolicyUpdated OnPolicyUpdated; + + /// + /// The associated for this behavior. + /// + public BrainParameters BrainParameters + { + get { return m_BrainParameters; } + internal set { m_BrainParameters = value; } + } + + [HideInInspector, SerializeField] + NNModel m_Model; + + /// + /// The neural network model used when in inference mode. + /// This should not be set at runtime; use + /// to set it instead. + /// + public NNModel Model + { + get { return m_Model; } + set { m_Model = value; UpdateAgentPolicy(); } + } + + [HideInInspector, SerializeField] + InferenceDevice m_InferenceDevice = InferenceDevice.Default; + + /// + /// How inference is performed for this Agent's model. + /// This should not be set at runtime; use + /// to set it instead. + /// + public InferenceDevice InferenceDevice + { + get { return m_InferenceDevice; } + set { m_InferenceDevice = value; UpdateAgentPolicy(); } + } + + [HideInInspector, SerializeField] + BehaviorType m_BehaviorType; + + /// + /// The BehaviorType for the Agent. + /// + public BehaviorType BehaviorType + { + get { return m_BehaviorType; } + set { m_BehaviorType = value; UpdateAgentPolicy(); } + } + + [HideInInspector, SerializeField] + string m_BehaviorName = "My Behavior"; + + /// + /// The name of this behavior, which is used as a base name. See + /// for the full name. + /// This should not be set at runtime; use + /// to set it instead. + /// + public string BehaviorName + { + get { return m_BehaviorName; } + set { m_BehaviorName = value; UpdateAgentPolicy(); } + } + + /// + /// The team ID for this behavior. + /// + [HideInInspector, SerializeField, FormerlySerializedAs("m_TeamID")] + public int TeamId; + // TODO properties here instead of Agent + + [FormerlySerializedAs("m_useChildSensors")] + [HideInInspector] + [SerializeField] + [Tooltip("Use all Sensor components attached to child GameObjects of this Agent.")] + bool m_UseChildSensors = true; + + [HideInInspector] + [SerializeField] + [Tooltip("Use all Actuator components attached to child GameObjects of this Agent.")] + bool m_UseChildActuators = true; + + /// + /// Whether or not to use all the sensor components attached to child GameObjects of the agent. + /// Note that changing this after the Agent has been initialized will not have any effect. + /// + public bool UseChildSensors + { + get { return m_UseChildSensors; } + set { m_UseChildSensors = value; } + } + + [HideInInspector] + [SerializeField] + [Tooltip("Set action selection to deterministic, Only applies to inference from within unity.")] + private bool m_DeterministicInference = false; + + /// + /// Whether to select actions deterministically during inference from the provided neural network. + /// + public bool DeterministicInference + { + get { return m_DeterministicInference; } + set { m_DeterministicInference = value; } + } + + /// + /// Whether or not to use all the actuator components attached to child GameObjects of the agent. + /// Note that changing this after the Agent has been initialized will not have any effect. + /// + public bool UseChildActuators + { + get { return m_UseChildActuators; } + set { m_UseChildActuators = value; } + } + + [HideInInspector, SerializeField] + ObservableAttributeOptions m_ObservableAttributeHandling = ObservableAttributeOptions.Ignore; + + /// + /// Determines how the Agent class is searched for s. + /// + public ObservableAttributeOptions ObservableAttributeHandling + { + get { return m_ObservableAttributeHandling; } + set { m_ObservableAttributeHandling = value; } + } + + /// + /// Returns the behavior name, concatenated with any other metadata (i.e. team id). + /// + public string FullyQualifiedBehaviorName + { + get { return m_BehaviorName + "?team=" + TeamId; } + } + + void Awake() + { + OnPolicyUpdated += mode => { }; + } + + internal IPolicy GeneratePolicy(ActionSpec actionSpec, ActuatorManager actuatorManager) + { + switch (m_BehaviorType) + { + case BehaviorType.HeuristicOnly: + return new HeuristicPolicy(actuatorManager, actionSpec); + case BehaviorType.InferenceOnly: + { + if (m_Model == null) + { + var behaviorType = BehaviorType.InferenceOnly.ToString(); + throw new UnityAgentsException( + $"Can't use Behavior Type {behaviorType} without a model. " + + "Either assign a model, or change to a different Behavior Type." + ); + } + return new BarracudaPolicy(actionSpec, actuatorManager, m_Model, m_InferenceDevice, m_BehaviorName, m_DeterministicInference); + } + case BehaviorType.Default: + if (Academy.Instance.IsCommunicatorOn) + { + return new RemotePolicy(actionSpec, actuatorManager, FullyQualifiedBehaviorName); + } + if (m_Model != null) + { + return new BarracudaPolicy(actionSpec, actuatorManager, m_Model, m_InferenceDevice, m_BehaviorName, m_DeterministicInference); + } + else + { + return new HeuristicPolicy(actuatorManager, actionSpec); + } + default: + return new HeuristicPolicy(actuatorManager, actionSpec); + } + } + + /// + /// Query the behavior parameters in order to see if the Agent is running in Heuristic Mode. + /// + /// true if the Agent is running in Heuristic mode. + public bool IsInHeuristicMode() + { + if (BehaviorType == BehaviorType.HeuristicOnly) + { + return true; + } + + return BehaviorType == BehaviorType.Default && + ReferenceEquals(Model, null) && + (!Academy.IsInitialized || + Academy.IsInitialized && + !Academy.Instance.IsCommunicatorOn); + } + + internal void UpdateAgentPolicy() + { + var agent = GetComponent(); + if (agent == null) + { + return; + } + agent.ReloadPolicy(); + OnPolicyUpdated?.Invoke(IsInHeuristicMode()); + } + } +} diff --git a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs.meta b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs.meta new file mode 100644 index 0000000000..507c417e97 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 5d1c4e0b1822b495aa52bc52839ecb30 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Policies/BrainParameters.cs b/com.unity.ml-agents/Runtime/Policies/BrainParameters.cs new file mode 100644 index 0000000000..882521a892 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Policies/BrainParameters.cs @@ -0,0 +1,192 @@ +using System; +using UnityEngine; +using UnityEngine.Serialization; +using Unity.MLAgents.Actuators; + +namespace Unity.MLAgents.Policies +{ + /// + /// This is deprecated. Agents can now use both continuous and discrete actions together. + /// + [Obsolete("Continuous and discrete actions on the same Agent are now supported; see ActionSpec.")] + internal enum SpaceType + { + /// + /// Discrete action space: a fixed number of options are available. + /// + Discrete, + + /// + /// Continuous action space: each action can take on a float value. + /// + Continuous + } + + /// + /// Holds information about the brain. It defines what are the inputs and outputs of the + /// decision process. + /// + /// + /// Set brain parameters for an instance using the + /// component attached to the agent's [GameObject]. + /// + /// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html + /// + [Serializable] + public class BrainParameters : ISerializationCallbackReceiver + { + /// + /// The number of the observations that are added in + /// + /// + /// + /// The length of the vector containing observation values. + /// + [FormerlySerializedAs("vectorObservationSize")] + public int VectorObservationSize = 1; + + /// + /// Stacking refers to concatenating the observations across multiple frames. This field + /// indicates the number of frames to concatenate across. + /// + [FormerlySerializedAs("numStackedVectorObservations")] + [Range(1, 50)] public int NumStackedVectorObservations = 1; + + [SerializeField] + internal ActionSpec m_ActionSpec = new ActionSpec(0, null); + + /// + /// The specification of the Actions for the BrainParameters. + /// + public ActionSpec ActionSpec + { + get { return m_ActionSpec; } + set + { + m_ActionSpec.NumContinuousActions = value.NumContinuousActions; + m_ActionSpec.BranchSizes = value.BranchSizes; + SyncDeprecatedActionFields(); + } + } + + /// + /// (Deprecated) The number of possible actions. + /// + /// The size specified is interpreted differently depending on whether + /// the agent uses the continuous or the discrete actions. + /// + /// For the continuous actions: the length of the float vector that represents + /// the action. + /// For the discrete actions: the number of branches. + /// + [Obsolete("VectorActionSize has been deprecated, please use ActionSpec instead.")] + [SerializeField] + [FormerlySerializedAs("vectorActionSize")] + internal int[] VectorActionSize = new[] { 1 }; + + /// + /// The list of strings describing what the actions correspond to. + /// + [FormerlySerializedAs("vectorActionDescriptions")] + public string[] VectorActionDescriptions; + + /// + /// (Deprecated) Defines if the action is discrete or continuous. + /// + [Obsolete("VectorActionSpaceType has been deprecated, please use ActionSpec instead.")] + [SerializeField] + [FormerlySerializedAs("vectorActionSpaceType")] + internal SpaceType VectorActionSpaceType = SpaceType.Discrete; + + [SerializeField] + [HideInInspector] + internal bool hasUpgradedBrainParametersWithActionSpec; + + /// + /// Deep clones the BrainParameter object. + /// + /// A new BrainParameter object with the same values as the original. + public BrainParameters Clone() + { + // Disable deprecation warnings so we can read/write the old fields. +#pragma warning disable CS0618 + return new BrainParameters + { + VectorObservationSize = VectorObservationSize, + NumStackedVectorObservations = NumStackedVectorObservations, + VectorActionDescriptions = (string[])VectorActionDescriptions.Clone(), + ActionSpec = new ActionSpec(ActionSpec.NumContinuousActions, ActionSpec.BranchSizes), + VectorActionSize = (int[])VectorActionSize.Clone(), + VectorActionSpaceType = VectorActionSpaceType, + }; +#pragma warning restore CS0618 + } + + /// + /// Propagate ActionSpec fields from deprecated fields + /// + private void UpdateToActionSpec() + { + // Disable deprecation warnings so we can read the old fields. +#pragma warning disable CS0618 + if (!hasUpgradedBrainParametersWithActionSpec + && m_ActionSpec.NumContinuousActions == 0 + && m_ActionSpec.NumDiscreteActions == 0) + { + if (VectorActionSpaceType == SpaceType.Continuous) + { + m_ActionSpec.NumContinuousActions = VectorActionSize[0]; + } + if (VectorActionSpaceType == SpaceType.Discrete) + { + m_ActionSpec.BranchSizes = (int[])VectorActionSize.Clone(); + } + } + hasUpgradedBrainParametersWithActionSpec = true; +#pragma warning restore CS0618 + } + + /// + /// Sync values in ActionSpec fields to deprecated fields + /// + private void SyncDeprecatedActionFields() + { + // Disable deprecation warnings so we can read the old fields. +#pragma warning disable CS0618 + + if (m_ActionSpec.NumContinuousActions == 0) + { + VectorActionSize = (int[])ActionSpec.BranchSizes.Clone(); + VectorActionSpaceType = SpaceType.Discrete; + } + else if (m_ActionSpec.NumDiscreteActions == 0) + { + VectorActionSize = new[] { m_ActionSpec.NumContinuousActions }; + VectorActionSpaceType = SpaceType.Continuous; + } + else + { + VectorActionSize = null; + } +#pragma warning restore CS0618 + } + + /// + /// Called by Unity immediately before serializing this object. + /// + public void OnBeforeSerialize() + { + UpdateToActionSpec(); + SyncDeprecatedActionFields(); + } + + /// + /// Called by Unity immediately after deserializing this object. + /// + public void OnAfterDeserialize() + { + UpdateToActionSpec(); + SyncDeprecatedActionFields(); + } + } +} diff --git a/com.unity.ml-agents/Runtime/Policies/BrainParameters.cs.meta b/com.unity.ml-agents/Runtime/Policies/BrainParameters.cs.meta new file mode 100644 index 0000000000..248b4d0f6d --- /dev/null +++ b/com.unity.ml-agents/Runtime/Policies/BrainParameters.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 6108a41e9be04c238d7babaed4476134 +timeCreated: 1538758934 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs new file mode 100644 index 0000000000..8e5333874a --- /dev/null +++ b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs @@ -0,0 +1,142 @@ +using System.Collections.Generic; +using System; +using System.Collections; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Policies +{ + /// + /// The Heuristic Policy uses a hard-coded Heuristic method + /// to take decisions each time the RequestDecision method is + /// called. + /// + internal class HeuristicPolicy : IPolicy + { + ActuatorManager m_ActuatorManager; + ActionBuffers m_ActionBuffers; + bool m_Done; + bool m_DecisionRequested; + + ObservationWriter m_ObservationWriter = new ObservationWriter(); + NullList m_NullList = new NullList(); + + + public HeuristicPolicy(ActuatorManager actuatorManager, ActionSpec actionSpec) + { + m_ActuatorManager = actuatorManager; + var numContinuousActions = actionSpec.NumContinuousActions; + var numDiscreteActions = actionSpec.NumDiscreteActions; + var continuousDecision = new ActionSegment(new float[numContinuousActions], 0, numContinuousActions); + var discreteDecision = new ActionSegment(new int[numDiscreteActions], 0, numDiscreteActions); + m_ActionBuffers = new ActionBuffers(continuousDecision, discreteDecision); + } + + /// + public void RequestDecision(AgentInfo info, List sensors) + { + StepSensors(sensors); + m_Done = info.done; + m_DecisionRequested = true; + } + + /// + public ref readonly ActionBuffers DecideAction() + { + if (!m_Done && m_DecisionRequested) + { + m_ActionBuffers.Clear(); + m_ActuatorManager.ApplyHeuristic(m_ActionBuffers); + } + m_DecisionRequested = false; + return ref m_ActionBuffers; + } + + public void Dispose() + { + } + + /// + /// Trivial implementation of the IList interface that does nothing. + /// This is only used for "writing" observations that we will discard. + /// + internal class NullList : IList + { + public IEnumerator GetEnumerator() + { + throw new NotImplementedException(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public void Add(float item) + { + } + + public void Clear() + { + } + + public bool Contains(float item) + { + return false; + } + + public void CopyTo(float[] array, int arrayIndex) + { + throw new NotImplementedException(); + } + + public bool Remove(float item) + { + return false; + } + + public int Count { get; } + public bool IsReadOnly { get; } + public int IndexOf(float item) + { + return -1; + } + + public void Insert(int index, float item) + { + } + + public void RemoveAt(int index) + { + } + + public float this[int index] + { + get { return 0.0f; } + set { } + } + } + + /// + /// Run ISensor.Write or ISensor.GetCompressedObservation for each sensor + /// The output is currently unused, but this makes the sensor usage consistent + /// between training and inference. + /// + /// + void StepSensors(List sensors) + { + foreach (var sensor in sensors) + { + if (sensor.GetCompressionSpec().SensorCompressionType == SensorCompressionType.None) + { + m_ObservationWriter.SetTarget(m_NullList, sensor.GetObservationSpec(), 0); + sensor.Write(m_ObservationWriter); + } + else + { + sensor.GetCompressedObservation(); + } + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs.meta b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs.meta new file mode 100644 index 0000000000..ae074f5727 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 8a55e3cea7fd643109e42f5f4c9a1425 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Policies/IPolicy.cs b/com.unity.ml-agents/Runtime/Policies/IPolicy.cs new file mode 100644 index 0000000000..4079a1f25a --- /dev/null +++ b/com.unity.ml-agents/Runtime/Policies/IPolicy.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Policies +{ + /// + /// IPolicy is connected to a single Agent. Each time the agent needs + /// a decision, it will request a decision to the Policy. The decision + /// will not be taken immediately but will be taken before or when + /// DecideAction is called. + /// + internal interface IPolicy : IDisposable + { + /// + /// Signals the Brain that the Agent needs a Decision. The Policy + /// will make the decision at a later time to allow possible + /// batching of requests. + /// + /// + /// + void RequestDecision(AgentInfo info, List sensors); + + /// + /// Signals the Policy that if the Decision has not been taken yet, + /// it must be taken now. The Brain is expected to update the actions + /// of the Agents at this point the latest. + /// + ref readonly ActionBuffers DecideAction(); + } +} diff --git a/com.unity.ml-agents/Runtime/Policies/IPolicy.cs.meta b/com.unity.ml-agents/Runtime/Policies/IPolicy.cs.meta new file mode 100644 index 0000000000..f43c4ddc8b --- /dev/null +++ b/com.unity.ml-agents/Runtime/Policies/IPolicy.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 56e263dd566be41d6b81d0b46895a0dd +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs b/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs new file mode 100644 index 0000000000..faa8a37e60 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs @@ -0,0 +1,77 @@ +using System.Collections.Generic; +using System.Diagnostics; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Analytics; + + +namespace Unity.MLAgents.Policies +{ + /// + /// The Remote Policy only works when training. + /// When training your Agents, the RemotePolicy will be controlled by Python. + /// + internal class RemotePolicy : IPolicy + { + int m_AgentId; + string m_FullyQualifiedBehaviorName; + ActionSpec m_ActionSpec; + ActionBuffers m_LastActionBuffer; + bool m_AnalyticsSent; + + internal ICommunicator m_Communicator; + + /// + /// List of actuators, only used for analytics + /// + private IList m_Actuators; + + public RemotePolicy( + ActionSpec actionSpec, + IList actuators, + string fullyQualifiedBehaviorName) + { + m_FullyQualifiedBehaviorName = fullyQualifiedBehaviorName; + m_Communicator = Academy.Instance.Communicator; + m_Communicator?.SubscribeBrain(m_FullyQualifiedBehaviorName, actionSpec); + m_ActionSpec = actionSpec; + m_Actuators = actuators; + } + + /// + public void RequestDecision(AgentInfo info, List sensors) + { + SendAnalytics(sensors); + m_AgentId = info.episodeId; + m_Communicator?.PutObservations(m_FullyQualifiedBehaviorName, info, sensors); + } + + [Conditional("MLA_UNITY_ANALYTICS_MODULE")] + void SendAnalytics(IList sensors) + { + if (!m_AnalyticsSent) + { + m_AnalyticsSent = true; + TrainingAnalytics.RemotePolicyInitialized( + m_FullyQualifiedBehaviorName, + sensors, + m_ActionSpec, + m_Actuators + ); + } + } + + /// + public ref readonly ActionBuffers DecideAction() + { + m_Communicator?.DecideBatch(); + var actions = m_Communicator?.GetActions(m_FullyQualifiedBehaviorName, m_AgentId); + m_LastActionBuffer = actions == null ? ActionBuffers.Empty : (ActionBuffers)actions; + return ref m_LastActionBuffer; + } + + public void Dispose() + { + } + } +} diff --git a/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs.meta b/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs.meta new file mode 100644 index 0000000000..08996fa8a2 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 2f1ffc0e0bec14a1eaca4c709b3ba230 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/RecursionChecker.cs b/com.unity.ml-agents/Runtime/RecursionChecker.cs new file mode 100644 index 0000000000..ac411fa837 --- /dev/null +++ b/com.unity.ml-agents/Runtime/RecursionChecker.cs @@ -0,0 +1,35 @@ +using System; + +namespace Unity.MLAgents +{ + internal class RecursionChecker : IDisposable + { + private bool m_IsRunning; + private string m_MethodName; + + public RecursionChecker(string methodName) + { + m_MethodName = methodName; + } + + public IDisposable Start() + { + if (m_IsRunning) + { + throw new UnityAgentsException( + $"{m_MethodName} called recursively. " + + "This might happen if you call EnvironmentStep() or EndEpisode() from custom " + + "code such as CollectObservations() or OnActionReceived()." + ); + } + m_IsRunning = true; + return this; + } + + public void Dispose() + { + // Reset the flag when we're done (or if an exception occurred). + m_IsRunning = false; + } + } +} diff --git a/com.unity.ml-agents/Runtime/RecursionChecker.cs.meta b/com.unity.ml-agents/Runtime/RecursionChecker.cs.meta new file mode 100644 index 0000000000..4b2363f809 --- /dev/null +++ b/com.unity.ml-agents/Runtime/RecursionChecker.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 49ebd06532b24078a6edda823aeff5d2 +timeCreated: 1602731302 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Sampler.cs b/com.unity.ml-agents/Runtime/Sampler.cs new file mode 100644 index 0000000000..5135f368d5 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sampler.cs @@ -0,0 +1,69 @@ +using System; +using System.Collections.Generic; +using Unity.MLAgents.Inference.Utils; +using Random = System.Random; + +namespace Unity.MLAgents +{ + + /// + /// Takes a list of floats that encode a sampling distribution and returns the sampling function. + /// + internal static class SamplerFactory + { + + public static Func CreateUniformSampler(float min, float max, int seed) + { + Random distr = new Random(seed); + return () => min + (float)distr.NextDouble() * (max - min); + } + + public static Func CreateGaussianSampler(float mean, float stddev, int seed) + { + RandomNormal distr = new RandomNormal(seed, mean, stddev); + return () => (float)distr.NextDouble(); + } + + public static Func CreateMultiRangeUniformSampler(IList intervals, int seed) + { + //RNG + Random distr = new Random(seed); + // Will be used to normalize intervalFuncs + float sumIntervalSizes = 0; + //The number of intervals + int numIntervals = (intervals.Count / 2); + // List that will store interval lengths + float[] intervalSizes = new float[numIntervals]; + // List that will store uniform distributions + IList> intervalFuncs = new Func[numIntervals]; + // Collect all intervals and store as uniform distrus + // Collect all interval sizes + for (int i = 0; i < numIntervals; i++) + { + var min = intervals[2 * i]; + var max = intervals[2 * i + 1]; + var intervalSize = max - min; + sumIntervalSizes += intervalSize; + intervalSizes[i] = intervalSize; + intervalFuncs[i] = () => min + (float)distr.NextDouble() * intervalSize; + } + // Normalize interval lengths + for (int i = 0; i < numIntervals; i++) + { + intervalSizes[i] = intervalSizes[i] / sumIntervalSizes; + } + // Build cmf for intervals + for (int i = 1; i < numIntervals; i++) + { + intervalSizes[i] += intervalSizes[i - 1]; + } + Multinomial intervalDistr = new Multinomial(seed + 1); + float MultiRange() + { + int sampledInterval = intervalDistr.Sample(intervalSizes); + return intervalFuncs[sampledInterval].Invoke(); + } + return MultiRange; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sampler.cs.meta b/com.unity.ml-agents/Runtime/Sampler.cs.meta new file mode 100644 index 0000000000..950e28c5b6 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sampler.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 39ce0ea5a8b2e47f696f6efc807029f6 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/SensorHelper.cs b/com.unity.ml-agents/Runtime/SensorHelper.cs new file mode 100644 index 0000000000..45cf0fe935 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SensorHelper.cs @@ -0,0 +1,129 @@ +using Unity.Barracuda; + +namespace Unity.MLAgents.Sensors +{ + /// + /// Utility methods related to implementations. + /// + public static class SensorHelper + { + /// + /// Generates the observations for the provided sensor, and returns true if they equal the + /// expected values. If they are unequal, errorMessage is also set. + /// This should not generally be used in production code. It is only intended for + /// simplifying unit tests. + /// + /// + /// + /// + /// + public static bool CompareObservation(ISensor sensor, float[] expected, out string errorMessage) + { + var numExpected = expected.Length; + const float fill = -1337f; + var output = new float[numExpected]; + for (var i = 0; i < numExpected; i++) + { + output[i] = fill; + } + + if (numExpected > 0) + { + if (fill != output[0]) + { + errorMessage = "Error setting output buffer."; + return false; + } + } + + ObservationWriter writer = new ObservationWriter(); + writer.SetTarget(output, sensor.GetObservationSpec(), 0); + + // Make sure ObservationWriter didn't touch anything + if (numExpected > 0) + { + if (fill != output[0]) + { + errorMessage = "ObservationWriter.SetTarget modified a buffer it shouldn't have."; + return false; + } + } + + sensor.Write(writer); + for (var i = 0; i < output.Length; i++) + { + if (expected[i] != output[i]) + { + errorMessage = $"Expected and actual differed in position {i}. Expected: {expected[i]} Actual: {output[i]} "; + return false; + } + } + + errorMessage = null; + return true; + } + + /// + /// Generates the observations for the provided sensor, and returns true if they equal the + /// expected values. If they are unequal, errorMessage is also set. + /// This should not generally be used in production code. It is only intended for + /// simplifying unit tests. + /// + /// + /// + /// + /// + public static bool CompareObservation(ISensor sensor, float[,,] expected, out string errorMessage) + { + var tensorShape = new TensorShape(0, expected.GetLength(0), expected.GetLength(1), expected.GetLength(2)); + var numExpected = tensorShape.height * tensorShape.width * tensorShape.channels; + const float fill = -1337f; + var output = new float[numExpected]; + for (var i = 0; i < numExpected; i++) + { + output[i] = fill; + } + + if (numExpected > 0) + { + if (fill != output[0]) + { + errorMessage = "Error setting output buffer."; + return false; + } + } + + ObservationWriter writer = new ObservationWriter(); + writer.SetTarget(output, sensor.GetObservationSpec(), 0); + + // Make sure ObservationWriter didn't touch anything + if (numExpected > 0) + { + if (fill != output[0]) + { + errorMessage = "ObservationWriter.SetTarget modified a buffer it shouldn't have."; + return false; + } + } + + sensor.Write(writer); + for (var h = 0; h < tensorShape.height; h++) + { + for (var w = 0; w < tensorShape.width; w++) + { + for (var c = 0; c < tensorShape.channels; c++) + { + if (expected[h, w, c] != output[tensorShape.Index(0, h, w, c)]) + { + errorMessage = $"Expected and actual differed in position [{h}, {w}, {c}]. " + + $"Expected: {expected[h, w, c]} Actual: {output[tensorShape.Index(0, h, w, c)]} "; + return false; + } + } + } + } + errorMessage = null; + return true; + } + } +} diff --git a/com.unity.ml-agents/Runtime/SensorHelper.cs.meta b/com.unity.ml-agents/Runtime/SensorHelper.cs.meta new file mode 100644 index 0000000000..c331abd0b6 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SensorHelper.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 7c1189c0af42c46f7b533350d49ad3e7 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors.meta b/com.unity.ml-agents/Runtime/Sensors.meta new file mode 100644 index 0000000000..06dbaab148 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 9c79ae05164e94259bd28ad71dbd3afa +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/BoxOverlapChecker.cs b/com.unity.ml-agents/Runtime/Sensors/BoxOverlapChecker.cs new file mode 100644 index 0000000000..1c290630cc --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/BoxOverlapChecker.cs @@ -0,0 +1,263 @@ +using System; +using UnityEngine; + +namespace Unity.MLAgents.Sensors +{ + /// + /// The grid perception strategy that uses box overlap to detect objects. + /// + internal class BoxOverlapChecker : IGridPerception + { + Vector3 m_CellScale; + Vector3Int m_GridSize; + bool m_RotateWithAgent; + LayerMask m_ColliderMask; + GameObject m_CenterObject; + GameObject m_AgentGameObject; + string[] m_DetectableTags; + int m_InitialColliderBufferSize; + int m_MaxColliderBufferSize; + + int m_NumCells; + Vector3 m_HalfCellScale; + Vector3 m_CellCenterOffset; + Vector3[] m_CellLocalPositions; + +#if MLA_UNITY_PHYSICS_MODULE + Collider[] m_ColliderBuffer; + + public event Action GridOverlapDetectedAll; + public event Action GridOverlapDetectedClosest; + public event Action GridOverlapDetectedDebug; +#endif + + public BoxOverlapChecker( + Vector3 cellScale, + Vector3Int gridSize, + bool rotateWithAgent, + LayerMask colliderMask, + GameObject centerObject, + GameObject agentGameObject, + string[] detectableTags, + int initialColliderBufferSize, + int maxColliderBufferSize) + { + m_CellScale = cellScale; + m_GridSize = gridSize; + m_RotateWithAgent = rotateWithAgent; + m_ColliderMask = colliderMask; + m_CenterObject = centerObject; + m_AgentGameObject = agentGameObject; + m_DetectableTags = detectableTags; + m_InitialColliderBufferSize = initialColliderBufferSize; + m_MaxColliderBufferSize = maxColliderBufferSize; + + m_NumCells = gridSize.x * gridSize.z; + m_HalfCellScale = new Vector3(cellScale.x / 2f, cellScale.y, cellScale.z / 2f); + m_CellCenterOffset = new Vector3((gridSize.x - 1f) / 2, 0, (gridSize.z - 1f) / 2); +#if MLA_UNITY_PHYSICS_MODULE + m_ColliderBuffer = new Collider[Math.Min(m_MaxColliderBufferSize, m_InitialColliderBufferSize)]; +#endif + + InitCellLocalPositions(); + } + + public bool RotateWithAgent + { + get { return m_RotateWithAgent; } + set { m_RotateWithAgent = value; } + } + + public LayerMask ColliderMask + { + get { return m_ColliderMask; } + set { m_ColliderMask = value; } + } + + /// + /// Initializes the local location of the cells + /// + void InitCellLocalPositions() + { + m_CellLocalPositions = new Vector3[m_NumCells]; + + for (int i = 0; i < m_NumCells; i++) + { + m_CellLocalPositions[i] = GetCellLocalPosition(i); + } + } + + public Vector3 GetCellLocalPosition(int cellIndex) + { + float x = (cellIndex / m_GridSize.z - m_CellCenterOffset.x) * m_CellScale.x; + float z = (cellIndex % m_GridSize.z - m_CellCenterOffset.z) * m_CellScale.z; + return new Vector3(x, 0, z); + } + + public Vector3 GetCellGlobalPosition(int cellIndex) + { + if (m_RotateWithAgent) + { + return m_CenterObject.transform.TransformPoint(m_CellLocalPositions[cellIndex]); + } + else + { + return m_CellLocalPositions[cellIndex] + m_CenterObject.transform.position; + } + } + + public Quaternion GetGridRotation() + { + return m_RotateWithAgent ? m_CenterObject.transform.rotation : Quaternion.identity; + } + + public void Perceive() + { +#if MLA_UNITY_PHYSICS_MODULE + for (var cellIndex = 0; cellIndex < m_NumCells; cellIndex++) + { + var cellCenter = GetCellGlobalPosition(cellIndex); + var numFound = BufferResizingOverlapBoxNonAlloc(cellCenter, m_HalfCellScale, GetGridRotation()); + + if (GridOverlapDetectedAll != null) + { + ParseCollidersAll(m_ColliderBuffer, numFound, cellIndex, cellCenter, GridOverlapDetectedAll); + } + if (GridOverlapDetectedClosest != null) + { + ParseCollidersClosest(m_ColliderBuffer, numFound, cellIndex, cellCenter, GridOverlapDetectedClosest); + } + } +#endif + } + + public void UpdateGizmo() + { +#if MLA_UNITY_PHYSICS_MODULE + for (var cellIndex = 0; cellIndex < m_NumCells; cellIndex++) + { + var cellCenter = GetCellGlobalPosition(cellIndex); + var numFound = BufferResizingOverlapBoxNonAlloc(cellCenter, m_HalfCellScale, GetGridRotation()); + + ParseCollidersClosest(m_ColliderBuffer, numFound, cellIndex, cellCenter, GridOverlapDetectedDebug); + } +#endif + } + +#if MLA_UNITY_PHYSICS_MODULE + /// + /// This method attempts to perform the Physics.OverlapBoxNonAlloc and will double the size of the Collider buffer + /// if the number of Colliders in the buffer after the call is equal to the length of the buffer. + /// + /// + /// + /// + /// + int BufferResizingOverlapBoxNonAlloc(Vector3 cellCenter, Vector3 halfCellScale, Quaternion rotation) + { + int numFound; + // Since we can only get a fixed number of results, requery + // until we're sure we can hold them all (or until we hit the max size). + while (true) + { + numFound = Physics.OverlapBoxNonAlloc(cellCenter, halfCellScale, m_ColliderBuffer, rotation, m_ColliderMask); + if (numFound == m_ColliderBuffer.Length && m_ColliderBuffer.Length < m_MaxColliderBufferSize) + { + m_ColliderBuffer = new Collider[Math.Min(m_MaxColliderBufferSize, m_ColliderBuffer.Length * 2)]; + m_InitialColliderBufferSize = m_ColliderBuffer.Length; + } + else + { + break; + } + } + return numFound; + } + + /// + /// Parses the array of colliders found within a cell. Finds the closest gameobject to the agent root reference within the cell + /// + void ParseCollidersClosest(Collider[] foundColliders, int numFound, int cellIndex, Vector3 cellCenter, Action detectedAction) + { + GameObject closestColliderGo = null; + var minDistanceSquared = float.MaxValue; + + for (var i = 0; i < numFound; i++) + { + var currentColliderGo = foundColliders[i].gameObject; + + // Continue if the current collider go is the root reference + if (ReferenceEquals(currentColliderGo, m_AgentGameObject)) + { + continue; + } + + var closestColliderPoint = foundColliders[i].ClosestPointOnBounds(cellCenter); + var currentDistanceSquared = (closestColliderPoint - m_CenterObject.transform.position).sqrMagnitude; + + if (currentDistanceSquared >= minDistanceSquared) + { + continue; + } + + // Checks if our colliders contain a detectable object + var index = -1; + for (var ii = 0; ii < m_DetectableTags.Length; ii++) + { + if (currentColliderGo.CompareTag(m_DetectableTags[ii])) + { + index = ii; + break; + } + } + if (index > -1 && currentDistanceSquared < minDistanceSquared) + { + minDistanceSquared = currentDistanceSquared; + closestColliderGo = currentColliderGo; + } + } + + if (!ReferenceEquals(closestColliderGo, null)) + { + detectedAction.Invoke(closestColliderGo, cellIndex); + } + } + + /// + /// Parses all colliders in the array of colliders found within a cell. + /// + void ParseCollidersAll(Collider[] foundColliders, int numFound, int cellIndex, Vector3 cellCenter, Action detectedAction) + { + for (int i = 0; i < numFound; i++) + { + var currentColliderGo = foundColliders[i].gameObject; + if (!ReferenceEquals(currentColliderGo, m_AgentGameObject)) + { + detectedAction.Invoke(currentColliderGo, cellIndex); + } + } + } +#endif + + public void RegisterSensor(GridSensorBase sensor) + { +#if MLA_UNITY_PHYSICS_MODULE + if (sensor.GetProcessCollidersMethod() == ProcessCollidersMethod.ProcessAllColliders) + { + GridOverlapDetectedAll += sensor.ProcessDetectedObject; + } + else + { + GridOverlapDetectedClosest += sensor.ProcessDetectedObject; + } +#endif + } + + public void RegisterDebugSensor(GridSensorBase debugSensor) + { +#if MLA_UNITY_PHYSICS_MODULE + GridOverlapDetectedDebug += debugSensor.ProcessDetectedObject; +#endif + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/BoxOverlapChecker.cs.meta b/com.unity.ml-agents/Runtime/Sensors/BoxOverlapChecker.cs.meta new file mode 100644 index 0000000000..1d20815c0e --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/BoxOverlapChecker.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e039296229a084578823e21bce9cf834 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs b/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs new file mode 100644 index 0000000000..09db85d224 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs @@ -0,0 +1,114 @@ +using System; + +namespace Unity.MLAgents.Sensors +{ + /// + /// A Sensor that allows to observe a variable number of entities. + /// + public class BufferSensor : ISensor, IBuiltInSensor + { + private string m_Name; + private int m_MaxNumObs; + private int m_ObsSize; + float[] m_ObservationBuffer; + int m_CurrentNumObservables; + ObservationSpec m_ObservationSpec; + + + /// + /// Creates the BufferSensor. + /// + /// The maximum number of observations to be appended to this BufferSensor. + /// The size of each observation appended to the BufferSensor. + /// The name of the sensor. + public BufferSensor(int maxNumberObs, int obsSize, string name) + { + m_Name = name; + m_MaxNumObs = maxNumberObs; + m_ObsSize = obsSize; + m_ObservationBuffer = new float[m_ObsSize * m_MaxNumObs]; + m_CurrentNumObservables = 0; + m_ObservationSpec = ObservationSpec.VariableLength(m_MaxNumObs, m_ObsSize); + } + + /// + public ObservationSpec GetObservationSpec() + { + return m_ObservationSpec; + } + + /// + /// Appends an observation to the buffer. If the buffer is full (maximum number + /// of observation is reached) the observation will be ignored. the length of + /// the provided observation array must be equal to the observation size of + /// the buffer sensor. + /// + /// The float array observation + public void AppendObservation(float[] obs) + { + if (obs.Length != m_ObsSize) + { + throw new UnityAgentsException( + "The BufferSensor was expecting an observation of size " + + $"{m_ObsSize} but received {obs.Length} observations instead." + ); + } + if (m_CurrentNumObservables >= m_MaxNumObs) + { + return; + } + for (int i = 0; i < obs.Length; i++) + { + m_ObservationBuffer[m_CurrentNumObservables * m_ObsSize + i] = obs[i]; + } + m_CurrentNumObservables++; + } + + /// + public int Write(ObservationWriter writer) + { + for (int i = 0; i < m_ObsSize * m_MaxNumObs; i++) + { + writer[i] = m_ObservationBuffer[i]; + } + return m_ObsSize * m_MaxNumObs; + } + + /// + public virtual byte[] GetCompressedObservation() + { + return null; + } + + /// + public void Update() + { + Reset(); + } + + /// + public void Reset() + { + m_CurrentNumObservables = 0; + Array.Clear(m_ObservationBuffer, 0, m_ObservationBuffer.Length); + } + + /// + public CompressionSpec GetCompressionSpec() + { + return CompressionSpec.Default(); + } + + /// + public string GetName() + { + return m_Name; + } + + /// + public BuiltInSensorType GetBuiltInSensorType() + { + return BuiltInSensorType.BufferSensor; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs.meta new file mode 100644 index 0000000000..327456d2ee --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 034f05c858e684e5498d9a548c9d1fc5 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs new file mode 100644 index 0000000000..b825b4a974 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs @@ -0,0 +1,70 @@ +using UnityEngine; + +namespace Unity.MLAgents.Sensors +{ + + /// + /// A SensorComponent that creates a . + /// + [AddComponentMenu("ML Agents/Buffer Sensor", (int)MenuGroup.Sensors)] + public class BufferSensorComponent : SensorComponent + { + + /// + /// Name of the generated object. + /// Note that changing this at runtime does not affect how the Agent sorts the sensors. + /// + public string SensorName + { + get { return m_SensorName; } + set { m_SensorName = value; } + } + [HideInInspector, SerializeField] + private string m_SensorName = "BufferSensor"; + + /// + /// This is how many floats each entities will be represented with. This number + /// is fixed and all entities must have the same representation. + /// + public int ObservableSize + { + get { return m_ObservableSize; } + set { m_ObservableSize = value; } + } + [HideInInspector, SerializeField] + private int m_ObservableSize; + + /// + /// This is the maximum number of entities the `BufferSensor` will be able to + /// collect. + /// + public int MaxNumObservables + { + get { return m_MaxNumObservables; } + set { m_MaxNumObservables = value; } + } + [HideInInspector, SerializeField] + private int m_MaxNumObservables; + + private BufferSensor m_Sensor; + + /// + public override ISensor[] CreateSensors() + { + m_Sensor = new BufferSensor(MaxNumObservables, ObservableSize, m_SensorName); + return new ISensor[] { m_Sensor }; + } + + /// + /// Appends an observation to the buffer. If the buffer is full (maximum number + /// of observation is reached) the observation will be ignored. the length of + /// the provided observation array must be equal to the observation size of + /// the buffer sensor. + /// + /// The float array observation + public void AppendObservation(float[] obs) + { + m_Sensor.AppendObservation(obs); + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs.meta b/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs.meta new file mode 100644 index 0000000000..69bee5ca08 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: dd8012d5925524537b27131fef517017 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs new file mode 100644 index 0000000000..ddf3d0000b --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs @@ -0,0 +1,184 @@ +using System; +using UnityEngine; +using UnityEngine.Rendering; + +namespace Unity.MLAgents.Sensors +{ + /// + /// A sensor that wraps a Camera object to generate visual observations for an agent. + /// + public class CameraSensor : ISensor, IBuiltInSensor, IDisposable + { + Camera m_Camera; + int m_Width; + int m_Height; + bool m_Grayscale; + string m_Name; + private ObservationSpec m_ObservationSpec; + SensorCompressionType m_CompressionType; + Texture2D m_Texture; + + /// + /// The Camera used for rendering the sensor observations. + /// + public Camera Camera + { + get { return m_Camera; } + set { m_Camera = value; } + } + + /// + /// The compression type used by the sensor. + /// + public SensorCompressionType CompressionType + { + get { return m_CompressionType; } + set { m_CompressionType = value; } + } + + /// + /// Creates and returns the camera sensor. + /// + /// Camera object to capture images from. + /// The width of the generated visual observation. + /// The height of the generated visual observation. + /// Whether to convert the generated image to grayscale or keep color. + /// The name of the camera sensor. + /// The compression to apply to the generated image. + /// The type of observation. + public CameraSensor( + Camera camera, int width, int height, bool grayscale, string name, SensorCompressionType compression, ObservationType observationType = ObservationType.Default) + { + m_Camera = camera; + m_Width = width; + m_Height = height; + m_Grayscale = grayscale; + m_Name = name; + var channels = grayscale ? 1 : 3; + m_ObservationSpec = ObservationSpec.Visual(height, width, channels, observationType); + m_CompressionType = compression; + m_Texture = new Texture2D(width, height, TextureFormat.RGB24, false); + } + + /// + /// Accessor for the name of the sensor. + /// + /// Sensor name. + public string GetName() + { + return m_Name; + } + + /// + /// Returns a description of the observations that will be generated by the sensor. + /// The shape will be h x w x 1 for grayscale and h x w x 3 for color. + /// The dimensions have translational equivariance along width and height, + /// and no property along the channels dimension. + /// + /// + public ObservationSpec GetObservationSpec() + { + return m_ObservationSpec; + } + + /// + /// Generates a compressed image. This can be valuable in speeding-up training. + /// + /// Compressed image. + public byte[] GetCompressedObservation() + { + using (TimerStack.Instance.Scoped("CameraSensor.GetCompressedObservation")) + { + ObservationToTexture(m_Camera, m_Texture, m_Width, m_Height); + // TODO support more types here, e.g. JPG + var compressed = m_Texture.EncodeToPNG(); + return compressed; + } + } + + /// + /// Writes out the generated, uncompressed image to the provided . + /// + /// Where the observation is written to. + /// + public int Write(ObservationWriter writer) + { + using (TimerStack.Instance.Scoped("CameraSensor.WriteToTensor")) + { + ObservationToTexture(m_Camera, m_Texture, m_Width, m_Height); + var numWritten = writer.WriteTexture(m_Texture, m_Grayscale); + return numWritten; + } + } + + /// + public void Update() { } + + /// + public void Reset() { } + + /// + public CompressionSpec GetCompressionSpec() + { + return new CompressionSpec(m_CompressionType); + } + + /// + /// Renders a Camera instance to a 2D texture at the corresponding resolution. + /// + /// Camera. + /// Texture2D to render to. + /// Width of resulting 2D texture. + /// Height of resulting 2D texture. + public static void ObservationToTexture(Camera obsCamera, Texture2D texture2D, int width, int height) + { + if (SystemInfo.graphicsDeviceType == GraphicsDeviceType.Null) + { + Debug.LogError("GraphicsDeviceType is Null. This will likely crash when trying to render."); + } + + var oldRec = obsCamera.rect; + obsCamera.rect = new Rect(0f, 0f, 1f, 1f); + var depth = 24; + var format = RenderTextureFormat.Default; + var readWrite = RenderTextureReadWrite.Default; + + var tempRt = + RenderTexture.GetTemporary(width, height, depth, format, readWrite); + + var prevActiveRt = RenderTexture.active; + var prevCameraRt = obsCamera.targetTexture; + + // render to offscreen texture (readonly from CPU side) + RenderTexture.active = tempRt; + obsCamera.targetTexture = tempRt; + + obsCamera.Render(); + + texture2D.ReadPixels(new Rect(0, 0, texture2D.width, texture2D.height), 0, 0); + + obsCamera.targetTexture = prevCameraRt; + obsCamera.rect = oldRec; + RenderTexture.active = prevActiveRt; + RenderTexture.ReleaseTemporary(tempRt); + } + + /// + public BuiltInSensorType GetBuiltInSensorType() + { + return BuiltInSensorType.CameraSensor; + } + + /// + /// Clean up the owned Texture2D. + /// + public void Dispose() + { + if (!ReferenceEquals(null, m_Texture)) + { + Utilities.DestroyTexture(m_Texture); + m_Texture = null; + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs.meta new file mode 100644 index 0000000000..1a0314b8f7 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 41cb6bf4b09974bf583f5b2fef0c08a7 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs new file mode 100644 index 0000000000..f6b53f087e --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs @@ -0,0 +1,179 @@ +using System; +using UnityEngine; +using UnityEngine.Serialization; + +namespace Unity.MLAgents.Sensors +{ + /// + /// A SensorComponent that creates a . + /// + [AddComponentMenu("ML Agents/Camera Sensor", (int)MenuGroup.Sensors)] + public class CameraSensorComponent : SensorComponent, IDisposable + { + [HideInInspector, SerializeField, FormerlySerializedAs("camera")] + Camera m_Camera; + + CameraSensor m_Sensor; + + /// + /// Camera object that provides the data to the sensor. + /// + public Camera Camera + { + get { return m_Camera; } + set { m_Camera = value; UpdateSensor(); } + } + + [HideInInspector, SerializeField, FormerlySerializedAs("sensorName")] + string m_SensorName = "CameraSensor"; + + /// + /// Name of the generated object. + /// Note that changing this at runtime does not affect how the Agent sorts the sensors. + /// + public string SensorName + { + get { return m_SensorName; } + set { m_SensorName = value; } + } + + [HideInInspector, SerializeField, FormerlySerializedAs("width")] + int m_Width = 84; + + /// + /// Width of the generated observation. + /// Note that changing this after the sensor is created has no effect. + /// + public int Width + { + get { return m_Width; } + set { m_Width = value; } + } + + [HideInInspector, SerializeField, FormerlySerializedAs("height")] + int m_Height = 84; + + /// + /// Height of the generated observation. + /// Note that changing this after the sensor is created has no effect. + /// + public int Height + { + get { return m_Height; } + set { m_Height = value; } + } + + [HideInInspector, SerializeField, FormerlySerializedAs("grayscale")] + bool m_Grayscale; + + /// + /// Whether to generate grayscale images or color. + /// Note that changing this after the sensor is created has no effect. + /// + public bool Grayscale + { + get { return m_Grayscale; } + set { m_Grayscale = value; } + } + + [HideInInspector, SerializeField] + ObservationType m_ObservationType; + + /// + /// The type of the observation. + /// + public ObservationType ObservationType + { + get { return m_ObservationType; } + set { m_ObservationType = value; UpdateSensor(); } + } + + [HideInInspector, SerializeField] + bool m_RuntimeCameraEnable; + + + /// + /// Controls the whether the camera sensor's attached camera + /// is enabled during runtime. Overrides the camera object enabled status. + /// Disabled for improved performance. Disabled by default. + /// + public bool RuntimeCameraEnable + { + get { return m_RuntimeCameraEnable; } + set { m_RuntimeCameraEnable = value; UpdateSensor(); } + } + + [HideInInspector, SerializeField] + [Range(1, 50)] + [Tooltip("Number of camera frames that will be stacked before being fed to the neural network.")] + int m_ObservationStacks = 1; + + [HideInInspector, SerializeField, FormerlySerializedAs("compression")] + SensorCompressionType m_Compression = SensorCompressionType.PNG; + + /// + /// The compression type to use for the sensor. + /// + public SensorCompressionType CompressionType + { + get { return m_Compression; } + set { m_Compression = value; UpdateSensor(); } + } + + /// + /// Whether to stack previous observations. Using 1 means no previous observations. + /// Note that changing this after the sensor is created has no effect. + /// + public int ObservationStacks + { + get { return m_ObservationStacks; } + set { m_ObservationStacks = value; } + } + + void Start() + { + UpdateSensor(); + } + + /// + /// Creates the + /// + /// The created object for this component. + public override ISensor[] CreateSensors() + { + Dispose(); + m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression, m_ObservationType); + + if (ObservationStacks != 1) + { + return new ISensor[] { new StackingSensor(m_Sensor, ObservationStacks) }; + } + return new ISensor[] { m_Sensor }; + } + + /// + /// Update fields that are safe to change on the Sensor at runtime. + /// + internal void UpdateSensor() + { + if (m_Sensor != null) + { + m_Sensor.Camera = m_Camera; + m_Sensor.CompressionType = m_Compression; + m_Sensor.Camera.enabled = m_RuntimeCameraEnable; + } + } + + /// + /// Clean up the sensor created by CreateSensors(). + /// + public void Dispose() + { + if (!ReferenceEquals(m_Sensor, null)) + { + m_Sensor.Dispose(); + m_Sensor = null; + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs.meta b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs.meta new file mode 100644 index 0000000000..307dc64952 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 282f342c2ab144bf38be65d4d0c4e07d +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs b/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs new file mode 100644 index 0000000000..76e283a14a --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs @@ -0,0 +1,113 @@ +using System.Linq; +namespace Unity.MLAgents.Sensors +{ + /// + /// The compression setting for visual/camera observations. + /// + public enum SensorCompressionType + { + /// + /// No compression. Data is preserved as float arrays. + /// + None, + + /// + /// PNG format. Data will be stored in binary format. + /// + PNG + } + + /// + /// A description of the compression used for observations. + /// + /// + /// Most ISensor implementations can't take advantage of compression, + /// and should return CompressionSpec.Default() from their ISensor.GetCompressionSpec() methods. + /// Visual observations, or mulitdimensional categorical observations (for example, image segmentation + /// or the piece types in a match-3 game board) can use PNG compression reduce the amount of + /// data transferred between Unity and the trainer. + /// + public struct CompressionSpec + { + internal SensorCompressionType m_SensorCompressionType; + + /// + /// The compression type that the sensor will use for its observations. + /// + public SensorCompressionType SensorCompressionType + { + get => m_SensorCompressionType; + } + + internal int[] m_CompressedChannelMapping; + + /// + /// The mapping of the channels in compressed data to the actual channel after decompression. + /// + /// + /// The mapping is a list of integer index with the same length as + /// the number of output observation layers (channels), including padding if there's any. + /// Each index indicates the actual channel the layer will go into. + /// Layers with the same index will be averaged, and layers with negative index will be dropped. + /// For example, mapping for CameraSensor using grayscale and stacking of two: [0, 0, 0, 1, 1, 1] + /// Mapping for GridSensor of 4 channels and stacking of two: [0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1] + /// + public int[] CompressedChannelMapping + { + get => m_CompressedChannelMapping; + } + + /// + /// Return a CompressionSpec indicating possible compression. + /// + /// The compression type to use. + /// Optional mapping mapping of the channels in compressed data to the + /// actual channel after decompression. + public CompressionSpec(SensorCompressionType sensorCompressionType, int[] compressedChannelMapping = null) + { + m_SensorCompressionType = sensorCompressionType; + m_CompressedChannelMapping = compressedChannelMapping; + } + + /// + /// Return a CompressionSpec indicating no compression. This is recommended for most sensors. + /// + /// + public static CompressionSpec Default() + { + return new CompressionSpec + { + m_SensorCompressionType = SensorCompressionType.None, + m_CompressedChannelMapping = null + }; + } + + /// + /// Return whether the compressed channel mapping is "trivial"; if so it doesn't need to be sent to the + /// trainer. + /// + /// + internal bool IsTrivialMapping() + { + var mapping = CompressedChannelMapping; + if (mapping == null) + { + return true; + } + // check if mapping equals zero mapping + if (mapping.Length == 3 && mapping.All(m => m == 0)) + { + return true; + } + // check if mapping equals identity mapping + for (var i = 0; i < mapping.Length; i++) + { + if (mapping[i] != i) + { + return false; + } + } + return true; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs.meta b/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs.meta new file mode 100644 index 0000000000..3bbac496d7 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 0ddff1d1b7ad4170acb1a10272d4a8c2 +timeCreated: 1616006929 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Sensors/GridSensorBase.cs b/com.unity.ml-agents/Runtime/Sensors/GridSensorBase.cs new file mode 100644 index 0000000000..a37861e271 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/GridSensorBase.cs @@ -0,0 +1,348 @@ +using System; +using System.Collections.Generic; +using UnityEngine; +using UnityEngine.Profiling; + +namespace Unity.MLAgents.Sensors +{ + /// + /// The way the GridSensor process detected colliders in a cell. + /// + public enum ProcessCollidersMethod + { + /// + /// Get data from all colliders detected in a cell + /// + ProcessAllColliders, + + /// + /// Get data from the collider closest to the agent + /// + ProcessClosestColliders + } + + /// + /// Grid-based sensor. + /// + public class GridSensorBase : ISensor, IBuiltInSensor, IDisposable + { + string m_Name; + Vector3 m_CellScale; + Vector3Int m_GridSize; + string[] m_DetectableTags; + SensorCompressionType m_CompressionType; + ObservationSpec m_ObservationSpec; + internal IGridPerception m_GridPerception; + + // Buffers + float[] m_PerceptionBuffer; + Color[] m_PerceptionColors; + Texture2D m_PerceptionTexture; + float[] m_CellDataBuffer; + + // Utility Constants Calculated on Init + int m_NumCells; + int m_CellObservationSize; + Vector3 m_CellCenterOffset; + + + /// + /// Create a GridSensorBase with the specified configuration. + /// + /// The sensor name + /// The scale of each cell in the grid + /// Number of cells on each side of the grid + /// Tags to be detected by the sensor + /// Compression type + public GridSensorBase( + string name, + Vector3 cellScale, + Vector3Int gridSize, + string[] detectableTags, + SensorCompressionType compression + ) + { + m_Name = name; + m_CellScale = cellScale; + m_GridSize = gridSize; + m_DetectableTags = detectableTags; + CompressionType = compression; + + if (m_GridSize.y != 1) + { + throw new UnityAgentsException("GridSensor only supports 2D grids."); + } + + m_NumCells = m_GridSize.x * m_GridSize.z; + m_CellObservationSize = GetCellObservationSize(); + m_ObservationSpec = ObservationSpec.Visual(m_GridSize.x, m_GridSize.z, m_CellObservationSize); + m_PerceptionTexture = new Texture2D(m_GridSize.x, m_GridSize.z, TextureFormat.RGB24, false); + + ResetPerceptionBuffer(); + } + + /// + /// The compression type used by the sensor. + /// + public SensorCompressionType CompressionType + { + get { return m_CompressionType; } + set + { + if (!IsDataNormalized() && value == SensorCompressionType.PNG) + { + Debug.LogWarning($"Compression type {value} is only supported with normalized data. " + + "The sensor will not compress the data."); + return; + } + m_CompressionType = value; + } + } + + internal float[] PerceptionBuffer + { + get { return m_PerceptionBuffer; } + } + + /// + /// The tags which the sensor dectects. + /// + protected string[] DetectableTags + { + get { return m_DetectableTags; } + } + + /// + public void Reset() { } + + /// + /// Clears the perception buffer before loading in new data. + /// + public void ResetPerceptionBuffer() + { + if (m_PerceptionBuffer != null) + { + Array.Clear(m_PerceptionBuffer, 0, m_PerceptionBuffer.Length); + Array.Clear(m_CellDataBuffer, 0, m_CellDataBuffer.Length); + } + else + { + m_PerceptionBuffer = new float[m_CellObservationSize * m_NumCells]; + m_CellDataBuffer = new float[m_CellObservationSize]; + m_PerceptionColors = new Color[m_NumCells]; + } + } + + /// + public string GetName() + { + return m_Name; + } + + /// + public CompressionSpec GetCompressionSpec() + { + return new CompressionSpec(CompressionType); + } + + /// + public BuiltInSensorType GetBuiltInSensorType() + { + return BuiltInSensorType.GridSensor; + } + + /// + public byte[] GetCompressedObservation() + { + using (TimerStack.Instance.Scoped("GridSensor.GetCompressedObservation")) + { + var allBytes = new List(); + var numImages = (m_CellObservationSize + 2) / 3; + for (int i = 0; i < numImages; i++) + { + var channelIndex = 3 * i; + GridValuesToTexture(channelIndex, Math.Min(3, m_CellObservationSize - channelIndex)); + allBytes.AddRange(m_PerceptionTexture.EncodeToPNG()); + } + + return allBytes.ToArray(); + } + } + + /// + /// Convert observation values to texture for PNG compression. + /// + void GridValuesToTexture(int channelIndex, int numChannelsToAdd) + { + for (int i = 0; i < m_NumCells; i++) + { + for (int j = 0; j < numChannelsToAdd; j++) + { + m_PerceptionColors[i][j] = m_PerceptionBuffer[i * m_CellObservationSize + channelIndex + j]; + } + } + m_PerceptionTexture.SetPixels(m_PerceptionColors); + } + + /// + /// Get the observation values of the detected game object. + /// Default is to record the detected tag index. + /// + /// This method can be overridden to encode the observation differently or get custom data from the object. + /// When overriding this method, and + /// might also need to change accordingly. + /// + /// The game object that was detected within a certain cell + /// The index of the detectedObject's tag in the DetectableObjects list + /// The buffer to write the observation values. + /// The buffer size is configured by . + /// + /// + /// Here is an example of overriding GetObjectData to get the velocity of a potential Rigidbody: + /// + /// protected override void GetObjectData(GameObject detectedObject, int tagIndex, float[] dataBuffer) + /// { + /// if (tagIndex == Array.IndexOf(DetectableTags, "RigidBodyObject")) + /// { + /// Rigidbody rigidbody = detectedObject.GetComponent<Rigidbody>(); + /// dataBuffer[0] = rigidbody.velocity.x; + /// dataBuffer[1] = rigidbody.velocity.y; + /// dataBuffer[2] = rigidbody.velocity.z; + /// } + /// } + /// + /// + protected virtual void GetObjectData(GameObject detectedObject, int tagIndex, float[] dataBuffer) + { + dataBuffer[0] = tagIndex + 1; + } + + /// + /// Get the observation size for each cell. This will be the size of dataBuffer for . + /// If overriding , override this method as well to the custom observation size. + /// + /// The observation size of each cell. + protected virtual int GetCellObservationSize() + { + return 1; + } + + /// + /// Whether the data is normalized within [0, 1]. The sensor can only use PNG compression if the data is normailzed. + /// If overriding , override this method as well according to the custom observation values. + /// + /// Bool value indicating whether data is normalized. + protected virtual bool IsDataNormalized() + { + return false; + } + + /// + /// Whether to process all detected colliders in a cell. Default to false and only use the one closest to the agent. + /// If overriding , consider override this method when needed. + /// + /// Bool value indicating whether to process all detected colliders in a cell. + protected internal virtual ProcessCollidersMethod GetProcessCollidersMethod() + { + return ProcessCollidersMethod.ProcessClosestColliders; + } + + /// + /// If using PNG compression, check if the values are normalized. + /// + void ValidateValues(float[] dataValues, GameObject detectedObject) + { + if (m_CompressionType != SensorCompressionType.PNG) + { + return; + } + + for (int j = 0; j < dataValues.Length; j++) + { + if (dataValues[j] < 0 || dataValues[j] > 1) + throw new UnityAgentsException($"When using compression type {m_CompressionType} the data value has to be normalized between 0-1. " + + $"Received value[{dataValues[j]}] for {detectedObject.name}"); + } + } + + /// + /// Collect data from the detected object if a detectable tag is matched. + /// + internal void ProcessDetectedObject(GameObject detectedObject, int cellIndex) + { + Profiler.BeginSample("GridSensor.ProcessDetectedObject"); + for (var i = 0; i < m_DetectableTags.Length; i++) + { + if (!ReferenceEquals(detectedObject, null) && detectedObject.CompareTag(m_DetectableTags[i])) + { + if (GetProcessCollidersMethod() == ProcessCollidersMethod.ProcessAllColliders) + { + Array.Copy(m_PerceptionBuffer, cellIndex * m_CellObservationSize, m_CellDataBuffer, 0, m_CellObservationSize); + } + else + { + Array.Clear(m_CellDataBuffer, 0, m_CellDataBuffer.Length); + } + + GetObjectData(detectedObject, i, m_CellDataBuffer); + ValidateValues(m_CellDataBuffer, detectedObject); + Array.Copy(m_CellDataBuffer, 0, m_PerceptionBuffer, cellIndex * m_CellObservationSize, m_CellObservationSize); + break; + } + } + Profiler.EndSample(); + } + + /// + public void Update() + { + ResetPerceptionBuffer(); + using (TimerStack.Instance.Scoped("GridSensor.Update")) + { + if (m_GridPerception != null) + { + m_GridPerception.Perceive(); + } + } + } + + /// + public ObservationSpec GetObservationSpec() + { + return m_ObservationSpec; + } + + /// + public int Write(ObservationWriter writer) + { + using (TimerStack.Instance.Scoped("GridSensor.Write")) + { + int index = 0; + for (var h = m_GridSize.z - 1; h >= 0; h--) + { + for (var w = 0; w < m_GridSize.x; w++) + { + for (var d = 0; d < m_CellObservationSize; d++) + { + writer[h, w, d] = m_PerceptionBuffer[index]; + index++; + } + } + } + return index; + } + } + + /// + /// Clean up the internal objects. + /// + public void Dispose() + { + if (!ReferenceEquals(null, m_PerceptionTexture)) + { + Utilities.DestroyTexture(m_PerceptionTexture); + m_PerceptionTexture = null; + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/GridSensorBase.cs.meta b/com.unity.ml-agents/Runtime/Sensors/GridSensorBase.cs.meta new file mode 100644 index 0000000000..623d17202d --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/GridSensorBase.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 2454efb6b02aa414dae2cb8573e87682 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs new file mode 100644 index 0000000000..f381dbd039 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs @@ -0,0 +1,310 @@ +using System.Collections.Generic; +using System.Linq; +using UnityEngine; + +namespace Unity.MLAgents.Sensors +{ + /// + /// A SensorComponent that creates a . + /// + [AddComponentMenu("ML Agents/Grid Sensor", (int)MenuGroup.Sensors)] + public class GridSensorComponent : SensorComponent + { + // dummy sensor only used for debug gizmo + GridSensorBase m_DebugSensor; + List m_Sensors; + internal IGridPerception m_GridPerception; + + [HideInInspector, SerializeField] + protected internal string m_SensorName = "GridSensor"; + /// + /// Name of the generated object. + /// Note that changing this at runtime does not affect how the Agent sorts the sensors. + /// + public string SensorName + { + get { return m_SensorName; } + set { m_SensorName = value; } + } + + [HideInInspector, SerializeField] + internal Vector3 m_CellScale = new Vector3(1f, 0.01f, 1f); + + /// + /// The scale of each grid cell. + /// Note that changing this after the sensor is created has no effect. + /// + public Vector3 CellScale + { + get { return m_CellScale; } + set { m_CellScale = value; } + } + + [HideInInspector, SerializeField] + internal Vector3Int m_GridSize = new Vector3Int(16, 1, 16); + /// + /// The number of grid on each side. + /// Note that changing this after the sensor is created has no effect. + /// + public Vector3Int GridSize + { + get { return m_GridSize; } + set + { + if (value.y != 1) + { + m_GridSize = new Vector3Int(value.x, 1, value.z); + } + else + { + m_GridSize = value; + } + } + } + + [HideInInspector, SerializeField] + internal bool m_RotateWithAgent = true; + /// + /// Rotate the grid based on the direction the agent is facing. + /// + public bool RotateWithAgent + { + get { return m_RotateWithAgent; } + set { m_RotateWithAgent = value; } + } + + [HideInInspector, SerializeField] + internal GameObject m_AgentGameObject; + /// + /// The reference of the root of the agent. This is used to disambiguate objects with + /// the same tag as the agent. Defaults to current GameObject. + /// + public GameObject AgentGameObject + { + get { return (m_AgentGameObject == null ? gameObject : m_AgentGameObject); } + set { m_AgentGameObject = value; } + } + + [HideInInspector, SerializeField] + internal string[] m_DetectableTags; + /// + /// List of tags that are detected. + /// Note that changing this after the sensor is created has no effect. + /// + public string[] DetectableTags + { + get { return m_DetectableTags; } + set { m_DetectableTags = value; } + } + + [HideInInspector, SerializeField] + internal LayerMask m_ColliderMask; + /// + /// The layer mask. + /// + public LayerMask ColliderMask + { + get { return m_ColliderMask; } + set { m_ColliderMask = value; } + } + + [HideInInspector, SerializeField] + internal int m_MaxColliderBufferSize = 500; + /// + /// The absolute max size of the Collider buffer used in the non-allocating Physics calls. In other words + /// the Collider buffer will never grow beyond this number even if there are more Colliders in the Grid Cell. + /// Note that changing this after the sensor is created has no effect. + /// + public int MaxColliderBufferSize + { + get { return m_MaxColliderBufferSize; } + set { m_MaxColliderBufferSize = value; } + } + + [HideInInspector, SerializeField] + internal int m_InitialColliderBufferSize = 4; + /// + /// The Estimated Max Number of Colliders to expect per cell. This number is used to + /// pre-allocate an array of Colliders in order to take advantage of the OverlapBoxNonAlloc + /// Physics API. If the number of colliders found is >= InitialColliderBufferSize the array + /// will be resized to double its current size. The hard coded absolute size is 500. + /// Note that changing this after the sensor is created has no effect. + /// + public int InitialColliderBufferSize + { + get { return m_InitialColliderBufferSize; } + set { m_InitialColliderBufferSize = value; } + } + + [HideInInspector, SerializeField] + internal Color[] m_DebugColors; + /// + /// Array of Colors used for the grid gizmos. + /// + public Color[] DebugColors + { + get { return m_DebugColors; } + set { m_DebugColors = value; } + } + + [HideInInspector, SerializeField] + internal float m_GizmoYOffset = 0f; + /// + /// The height of the gizmos grid. + /// + public float GizmoYOffset + { + get { return m_GizmoYOffset; } + set { m_GizmoYOffset = value; } + } + + [HideInInspector, SerializeField] + internal bool m_ShowGizmos = false; + /// + /// Whether to show gizmos or not. + /// + public bool ShowGizmos + { + get { return m_ShowGizmos; } + set { m_ShowGizmos = value; } + } + + [HideInInspector, SerializeField] + internal SensorCompressionType m_CompressionType = SensorCompressionType.PNG; + /// + /// The compression type to use for the sensor. + /// + public SensorCompressionType CompressionType + { + get { return m_CompressionType; } + set { m_CompressionType = value; UpdateSensor(); } + } + + [HideInInspector, SerializeField] + [Range(1, 50)] + [Tooltip("Number of frames of observations that will be stacked before being fed to the neural network.")] + internal int m_ObservationStacks = 1; + /// + /// Whether to stack previous observations. Using 1 means no previous observations. + /// Note that changing this after the sensor is created has no effect. + /// + public int ObservationStacks + { + get { return m_ObservationStacks; } + set { m_ObservationStacks = value; } + } + + /// + public override ISensor[] CreateSensors() + { + m_GridPerception = new BoxOverlapChecker( + m_CellScale, + m_GridSize, + m_RotateWithAgent, + m_ColliderMask, + gameObject, + AgentGameObject, + m_DetectableTags, + m_InitialColliderBufferSize, + m_MaxColliderBufferSize + ); + + // debug data is positive int value and will trigger data validation exception if SensorCompressionType is not None. + m_DebugSensor = new GridSensorBase("DebugGridSensor", m_CellScale, m_GridSize, m_DetectableTags, SensorCompressionType.None); + m_GridPerception.RegisterDebugSensor(m_DebugSensor); + + m_Sensors = GetGridSensors().ToList(); + if (m_Sensors == null || m_Sensors.Count < 1) + { + throw new UnityAgentsException("GridSensorComponent received no sensors. Specify at least one observation type (OneHot/Counting) to use grid sensors." + + "If you're overriding GridSensorComponent.GetGridSensors(), return at least one grid sensor."); + } + + // Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once + m_Sensors[0].m_GridPerception = m_GridPerception; + foreach (var sensor in m_Sensors) + { + m_GridPerception.RegisterSensor(sensor); + } + + if (ObservationStacks != 1) + { + var sensors = new ISensor[m_Sensors.Count]; + for (var i = 0; i < m_Sensors.Count; i++) + { + sensors[i] = new StackingSensor(m_Sensors[i], ObservationStacks); + } + return sensors; + } + else + { + return m_Sensors.ToArray(); + } + } + + /// + /// Get an array of GridSensors to be added in this component. + /// Override this method and return custom GridSensor implementations. + /// + /// Array of grid sensors to be added to the component. + protected virtual GridSensorBase[] GetGridSensors() + { + List sensorList = new List(); + var sensor = new OneHotGridSensor(m_SensorName + "-OneHot", m_CellScale, m_GridSize, m_DetectableTags, m_CompressionType); + sensorList.Add(sensor); + return sensorList.ToArray(); + } + + /// + /// Update fields that are safe to change on the Sensor at runtime. + /// + internal void UpdateSensor() + { + if (m_Sensors != null) + { + m_GridPerception.RotateWithAgent = m_RotateWithAgent; + m_GridPerception.ColliderMask = m_ColliderMask; + foreach (var sensor in m_Sensors) + { + sensor.CompressionType = m_CompressionType; + } + } + } + + void OnDrawGizmos() + { + if (m_ShowGizmos) + { + if (m_GridPerception == null || m_DebugSensor == null) + { + return; + } + + m_DebugSensor.ResetPerceptionBuffer(); + m_GridPerception.UpdateGizmo(); + var cellColors = m_DebugSensor.PerceptionBuffer; + var rotation = m_GridPerception.GetGridRotation(); + + var scale = new Vector3(m_CellScale.x, 1, m_CellScale.z); + var gizmoYOffset = new Vector3(0, m_GizmoYOffset, 0); + var oldGizmoMatrix = Gizmos.matrix; + for (var i = 0; i < m_DebugSensor.PerceptionBuffer.Length; i++) + { + var cellPosition = m_GridPerception.GetCellGlobalPosition(i); + var cubeTransform = Matrix4x4.TRS(cellPosition + gizmoYOffset, rotation, scale); + Gizmos.matrix = oldGizmoMatrix * cubeTransform; + var colorIndex = cellColors[i] - 1; + var debugRayColor = Color.white; + if (colorIndex > -1 && m_DebugColors.Length > colorIndex) + { + debugRayColor = m_DebugColors[(int)colorIndex]; + } + Gizmos.color = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f); + Gizmos.DrawCube(Vector3.zero, Vector3.one); + } + + Gizmos.matrix = oldGizmoMatrix; + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs.meta b/com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs.meta new file mode 100644 index 0000000000..8090977a3a --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 2a501962d056745d1a30e99146ee39fe +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/IBuiltInSensor.cs b/com.unity.ml-agents/Runtime/Sensors/IBuiltInSensor.cs new file mode 100644 index 0000000000..d69164d8c9 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/IBuiltInSensor.cs @@ -0,0 +1,69 @@ +namespace Unity.MLAgents.Sensors +{ + /// + /// Identifiers for "built in" sensor types. + /// These are only used for analytics, and should not be used for any runtime decisions. + /// + /// NOTE: Do not renumber these, since the values are used for analytics. Renaming is allowed though. + /// + public enum BuiltInSensorType + { + /// + /// Default Sensor type if it cannot be determined. + /// + Unknown = 0, + /// + /// The Vector sensor used by the agent. + /// + VectorSensor = 1, + /// + /// The Stacking Sensor type. NOTE: StackingSensor actually returns the wrapped sensor's type. + /// + StackingSensor = 2, + /// + /// The RayPerception Sensor types, both 3D and 2D. + /// + RayPerceptionSensor = 3, + /// + /// The observable attribute sensor type. + /// + ReflectionSensor = 4, + /// + /// Sensors that use the Camera for observations. + /// + CameraSensor = 5, + /// + /// Sensors that use RenderTextures for observations. + /// + RenderTextureSensor = 6, + /// + /// Sensors that use buffers or tensors for observations. + /// + BufferSensor = 7, + /// + /// The sensors that observe properties of rigid bodies. + /// + PhysicsBodySensor = 8, + /// + /// The sensors that observe Match 3 boards. + /// + Match3Sensor = 9, + /// + /// Sensors that break down the world into a grid of colliders to observe an area at a pre-defined granularity. + /// + GridSensor = 10 + } + + /// + /// Interface for sensors that are provided as part of ML-Agents. + /// User-implemented sensors don't need to use this interface. + /// + internal interface IBuiltInSensor + { + /// + /// Return the corresponding BuiltInSensorType for the sensor. + /// + /// A BuiltInSensorType corresponding to the sensor. + BuiltInSensorType GetBuiltInSensorType(); + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/IBuiltInSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/IBuiltInSensor.cs.meta new file mode 100644 index 0000000000..93dd08d1f1 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/IBuiltInSensor.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: c0c4a98bf1c941b381917cb65209beee +timeCreated: 1611096525 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Sensors/IGridPerception.cs b/com.unity.ml-agents/Runtime/Sensors/IGridPerception.cs new file mode 100644 index 0000000000..bbb981efa5 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/IGridPerception.cs @@ -0,0 +1,62 @@ +using System; +using UnityEngine; + +namespace Unity.MLAgents.Sensors +{ + /// + /// An interface for GridSensor perception that defines the grid cells and collider detecting strategies. + /// + internal interface IGridPerception + { + bool RotateWithAgent + { + get; + set; + } + + LayerMask ColliderMask + { + get; + set; + } + + /// Converts the index of the cell to the 3D point (y is zero) relative to grid center + /// Vector3 of the position of the center of the cell relative to grid center + /// The index of the cell + Vector3 GetCellLocalPosition(int cellIndex); + + /// + /// Converts the index of the cell to the 3D point (y is zero) in world space + /// based on the result from GetCellLocalPosition() + /// + /// Vector3 of the position of the center of the cell in world space + /// The index of the cell + Vector3 GetCellGlobalPosition(int cellIndex); + + Quaternion GetGridRotation(); + + /// + /// Perceive the latest grid status. Detect colliders for each cell, parse the collider arrays, + /// then trigger registered sensors to encode and update with the new grid status. + /// + void Perceive(); + + /// + /// Same as Perceive(), but only load data for debug gizmo. + /// + void UpdateGizmo(); + + /// + /// Register a sensor to this GridPerception to receive the grid perception results. + /// When the GridPerception perceive a new observation, registered sensors will be triggered + /// to encode the new observation and update its data. + /// + void RegisterSensor(GridSensorBase sensor); + + /// + /// Register an internal debug sensor. + /// Debug sensors will only be triggered when drawing debug gizmos. + /// + void RegisterDebugSensor(GridSensorBase debugSensor); + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/IGridPerception.cs.meta b/com.unity.ml-agents/Runtime/Sensors/IGridPerception.cs.meta new file mode 100644 index 0000000000..b08ed6449d --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/IGridPerception.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 87820d9eb927c4fa483dff9289d983f1 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs new file mode 100644 index 0000000000..ed93910fa0 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs @@ -0,0 +1,139 @@ +using System; +using System.Collections.Generic; + +namespace Unity.MLAgents.Sensors +{ + /// + /// The Dimension property flags of the observations + /// + [Flags] + public enum DimensionProperty + { + /// + /// No properties specified. + /// + Unspecified = 0, + + /// + /// No Property of the observation in that dimension. Observation can be processed with + /// fully connected networks. + /// + None = 1, + + /// + /// Means it is suitable to do a convolution in this dimension. + /// + TranslationalEquivariance = 2, + + /// + /// Means that there can be a variable number of observations in this dimension. + /// The observations are unordered. + /// + VariableSize = 4, + } + + /// + /// The ObservationType enum of the Sensor. + /// + public enum ObservationType + { + /// + /// Collected observations are generic. + /// + Default = 0, + + /// + /// Collected observations contain goal information. + /// + GoalSignal = 1, + } + + /// + /// Sensor interface for generating observations. + /// + public interface ISensor + { + /// + /// Returns a description of the observations that will be generated by the sensor. + /// See for more details, and helper methods to create one. + /// + /// An object describing the observation. + ObservationSpec GetObservationSpec(); + + /// + /// Write the observation data directly to the . + /// Note that this (and ) may + /// be called multiple times per agent step, so should not mutate any internal state. + /// + /// Where the observations will be written to. + /// The number of elements written. + int Write(ObservationWriter writer); + + /// + /// Return a compressed representation of the observation. For small observations, + /// this should generally not be implemented. However, compressing large observations + /// (such as visual results) can significantly improve model training time. + /// + /// Compressed observation. + byte[] GetCompressedObservation(); + + /// + /// Update any internal state of the sensor. This is called once per each agent step. + /// + void Update(); + + /// + /// Resets the internal state of the sensor. This is called at the end of an Agent's episode. + /// Most implementations can leave this empty. + /// + void Reset(); + + /// + /// Return information on the compression type being used. If no compression is used, return + /// . + /// + /// An object describing the compression used by the sensor. + CompressionSpec GetCompressionSpec(); + + /// + /// Get the name of the sensor. This is used to ensure deterministic sorting of the sensors + /// on an Agent, so the naming must be consistent across all sensors and agents. + /// + /// The name of the sensor. + string GetName(); + } + + + /// + /// Helper methods to be shared by all classes that implement . + /// + public static class SensorExtensions + { + /// + /// Get the total number of elements in the ISensor's observation (i.e. the product of the + /// shape elements). + /// + /// + /// + public static int ObservationSize(this ISensor sensor) + { + var obsSpec = sensor.GetObservationSpec(); + var count = 1; + for (var i = 0; i < obsSpec.Rank; i++) + { + count *= obsSpec.Shape[i]; + } + + return count; + } + } + + internal static class SensorUtils + { + internal static void SortSensors(List sensors) + { + // Use InvariantCulture to ensure consistent sorting between different culture settings. + sensors.Sort((x, y) => string.Compare(x.GetName(), y.GetName(), StringComparison.InvariantCulture)); + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs.meta new file mode 100644 index 0000000000..d8ceedec70 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 4bb5e09a94c6d4cd9a46c60b084e4952 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs new file mode 100644 index 0000000000..59c7f2eb4e --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs @@ -0,0 +1,138 @@ +namespace Unity.MLAgents.Sensors +{ + /// + /// A description of the observations that an ISensor produces. + /// This includes the size of the observation, the properties of each dimension, and how the observation + /// should be used for training. + /// + public struct ObservationSpec + { + internal readonly InplaceArray m_Shape; + + /// + /// The size of the observations that will be generated. + /// For example, a sensor that observes the velocity of a rigid body (in 3D) would use [3]. + /// A sensor that returns an RGB image would use [Height, Width, 3]. + /// + public InplaceArray Shape + { + get => m_Shape; + } + + internal readonly InplaceArray m_DimensionProperties; + + /// + /// The properties of each dimensions of the observation. + /// The length of the array must be equal to the rank of the observation tensor. + /// + /// + /// It is generally recommended to use default values provided by helper functions, + /// as not all combinations of DimensionProperty may be supported by the trainer. + /// + public InplaceArray DimensionProperties + { + get => m_DimensionProperties; + } + + internal ObservationType m_ObservationType; + + /// + /// The type of the observation, e.g. whether they are generic or + /// help determine the goal for the Agent. + /// + public ObservationType ObservationType + { + get => m_ObservationType; + } + + /// + /// The number of dimensions of the observation. + /// + public int Rank + { + get { return Shape.Length; } + } + + /// + /// Construct an ObservationSpec for 1-D observations of the requested length. + /// + /// + /// + /// + public static ObservationSpec Vector(int length, ObservationType obsType = ObservationType.Default) + { + return new ObservationSpec( + new InplaceArray(length), + new InplaceArray(DimensionProperty.None), + obsType + ); + } + + /// + /// Construct an ObservationSpec for variable-length observations. + /// + /// + /// + /// + public static ObservationSpec VariableLength(int obsSize, int maxNumObs) + { + var dimProps = new InplaceArray( + DimensionProperty.VariableSize, + DimensionProperty.None + ); + return new ObservationSpec( + new InplaceArray(obsSize, maxNumObs), + dimProps + ); + } + + /// + /// Construct an ObservationSpec for visual-like observations, e.g. observations + /// with a height, width, and possible multiple channels. + /// + /// + /// + /// + /// + /// + public static ObservationSpec Visual(int height, int width, int channels, ObservationType obsType = ObservationType.Default) + { + var dimProps = new InplaceArray( + DimensionProperty.TranslationalEquivariance, + DimensionProperty.TranslationalEquivariance, + DimensionProperty.None + ); + return new ObservationSpec( + new InplaceArray(height, width, channels), + dimProps, + obsType + ); + } + + /// + /// Create a general ObservationSpec from the shape, dimension properties, and observation type. + /// + /// + /// Note that not all combinations of DimensionProperty may be supported by the trainer. + /// shape and dimensionProperties must have the same size. + /// + /// + /// + /// + /// + public ObservationSpec( + InplaceArray shape, + InplaceArray dimensionProperties, + ObservationType observationType = ObservationType.Default + ) + { + if (shape.Length != dimensionProperties.Length) + { + throw new UnityAgentsException("shape and dimensionProperties must have the same length."); + } + m_Shape = shape; + m_DimensionProperties = dimensionProperties; + m_ObservationType = observationType; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs.meta b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs.meta new file mode 100644 index 0000000000..691fdf6172 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: cc1734d60fd5485ead94247cb206aa35 +timeCreated: 1615412644 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs b/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs new file mode 100644 index 0000000000..02fbae8997 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs @@ -0,0 +1,321 @@ +using System; +using System.Collections.Generic; +using Unity.Barracuda; +using Unity.MLAgents.Inference; +using UnityEngine; + +namespace Unity.MLAgents.Sensors +{ + /// + /// Allows sensors to write to both TensorProxy and float arrays/lists. + /// + public class ObservationWriter + { + IList m_Data; + int m_Offset; + + TensorProxy m_Proxy; + int m_Batch; + + TensorShape m_TensorShape; + + public ObservationWriter() { } + + /// + /// Set the writer to write to an IList at the given channelOffset. + /// + /// Float array or list that will be written to. + /// ObservationSpec of the observation to be written + /// Offset from the start of the float data to write to. + internal void SetTarget(IList data, ObservationSpec observationSpec, int offset) + { + SetTarget(data, observationSpec.Shape, offset); + } + + /// + /// Set the writer to write to an IList at the given channelOffset. + /// + /// Float array or list that will be written to. + /// Shape of the observations to be written. + /// Offset from the start of the float data to write to. + internal void SetTarget(IList data, InplaceArray shape, int offset) + { + m_Data = data; + m_Offset = offset; + m_Proxy = null; + m_Batch = 0; + + if (shape.Length == 1) + { + m_TensorShape = new TensorShape(m_Batch, shape[0]); + } + else if (shape.Length == 2) + { + m_TensorShape = new TensorShape(new[] { m_Batch, 1, shape[0], shape[1] }); + } + else + { + m_TensorShape = new TensorShape(m_Batch, shape[0], shape[1], shape[2]); + } + } + + /// + /// Set the writer to write to a TensorProxy at the given batch and channel offset. + /// + /// Tensor proxy that will be written to. + /// Batch index in the tensor proxy (i.e. the index of the Agent). + /// Offset from the start of the channel to write to. + internal void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset) + { + m_Proxy = tensorProxy; + m_Batch = batchIndex; + m_Offset = channelOffset; + m_Data = null; + m_TensorShape = m_Proxy.data.shape; + } + + /// + /// 1D write access at a specified index. Use AddList if possible instead. + /// + /// Index to write to. + public float this[int index] + { + set + { + if (m_Data != null) + { + m_Data[index + m_Offset] = value; + } + else + { + m_Proxy.data[m_Batch, index + m_Offset] = value; + } + } + } + + /// + /// 3D write access at the specified height, width, and channel. + /// + /// + /// + /// + public float this[int h, int w, int ch] + { + set + { + if (m_Data != null) + { + if (h < 0 || h >= m_TensorShape.height) + { + throw new IndexOutOfRangeException($"height value {h} must be in range [0, {m_TensorShape.height - 1}]"); + } + if (w < 0 || w >= m_TensorShape.width) + { + throw new IndexOutOfRangeException($"width value {w} must be in range [0, {m_TensorShape.width - 1}]"); + } + if (ch < 0 || ch >= m_TensorShape.channels) + { + throw new IndexOutOfRangeException($"channel value {ch} must be in range [0, {m_TensorShape.channels - 1}]"); + } + + var index = m_TensorShape.Index(m_Batch, h, w, ch + m_Offset); + m_Data[index] = value; + } + else + { + m_Proxy.data[m_Batch, h, w, ch + m_Offset] = value; + } + } + } + + /// + /// Write the list of floats. + /// + /// The actual list of floats to write. + /// Optional write offset to start writing from. + public void AddList(IList data, int writeOffset = 0) + { + if (m_Data != null) + { + for (var index = 0; index < data.Count; index++) + { + var val = data[index]; + m_Data[index + m_Offset + writeOffset] = val; + } + } + else + { + for (var index = 0; index < data.Count; index++) + { + var val = data[index]; + m_Proxy.data[m_Batch, index + m_Offset + writeOffset] = val; + } + } + } + + /// + /// Write the Vector3 components. + /// + /// The Vector3 to be written. + /// Optional write offset. + public void Add(Vector3 vec, int writeOffset = 0) + { + if (m_Data != null) + { + m_Data[m_Offset + writeOffset + 0] = vec.x; + m_Data[m_Offset + writeOffset + 1] = vec.y; + m_Data[m_Offset + writeOffset + 2] = vec.z; + } + else + { + m_Proxy.data[m_Batch, m_Offset + writeOffset + 0] = vec.x; + m_Proxy.data[m_Batch, m_Offset + writeOffset + 1] = vec.y; + m_Proxy.data[m_Batch, m_Offset + writeOffset + 2] = vec.z; + } + } + + /// + /// Write the Vector4 components. + /// + /// The Vector4 to be written. + /// Optional write offset. + public void Add(Vector4 vec, int writeOffset = 0) + { + if (m_Data != null) + { + m_Data[m_Offset + writeOffset + 0] = vec.x; + m_Data[m_Offset + writeOffset + 1] = vec.y; + m_Data[m_Offset + writeOffset + 2] = vec.z; + m_Data[m_Offset + writeOffset + 3] = vec.w; + } + else + { + m_Proxy.data[m_Batch, m_Offset + writeOffset + 0] = vec.x; + m_Proxy.data[m_Batch, m_Offset + writeOffset + 1] = vec.y; + m_Proxy.data[m_Batch, m_Offset + writeOffset + 2] = vec.z; + m_Proxy.data[m_Batch, m_Offset + writeOffset + 3] = vec.w; + } + } + + /// + /// Write the Quaternion components. + /// + /// The Quaternion to be written. + /// Optional write offset. + + public void Add(Quaternion quat, int writeOffset = 0) + { + if (m_Data != null) + { + m_Data[m_Offset + writeOffset + 0] = quat.x; + m_Data[m_Offset + writeOffset + 1] = quat.y; + m_Data[m_Offset + writeOffset + 2] = quat.z; + m_Data[m_Offset + writeOffset + 3] = quat.w; + } + else + { + m_Proxy.data[m_Batch, m_Offset + writeOffset + 0] = quat.x; + m_Proxy.data[m_Batch, m_Offset + writeOffset + 1] = quat.y; + m_Proxy.data[m_Batch, m_Offset + writeOffset + 2] = quat.z; + m_Proxy.data[m_Batch, m_Offset + writeOffset + 3] = quat.w; + } + } + } + + /// + /// Provides extension methods for the ObservationWriter. + /// + public static class ObservationWriterExtension + { + /// + /// Writes a Texture2D into a ObservationWriter. + /// + /// + /// Writer to fill with Texture data. + /// + /// + /// The texture to be put into the tensor. + /// + /// + /// If set to true the textures will be converted to grayscale before + /// being stored in the tensor. + /// + /// The number of floats written + public static int WriteTexture( + this ObservationWriter obsWriter, + Texture2D texture, + bool grayScale) + { + if (texture.format == TextureFormat.RGB24) + { + return obsWriter.WriteTextureRGB24(texture, grayScale); + } + var width = texture.width; + var height = texture.height; + + var texturePixels = texture.GetPixels32(); + // During training, we convert from Texture to PNG before sending to the trainer, which has the + // effect of flipping the image. We need another flip here at inference time to match this. + for (var h = height - 1; h >= 0; h--) + { + for (var w = 0; w < width; w++) + { + var currentPixel = texturePixels[(height - h - 1) * width + w]; + + if (grayScale) + { + obsWriter[h, w, 0] = + (currentPixel.r + currentPixel.g + currentPixel.b) / 3f / 255.0f; + } + else + { + // For Color32, the r, g and b values are between 0 and 255. + obsWriter[h, w, 0] = currentPixel.r / 255.0f; + obsWriter[h, w, 1] = currentPixel.g / 255.0f; + obsWriter[h, w, 2] = currentPixel.b / 255.0f; + } + } + } + + return height * width * (grayScale ? 1 : 3); + } + + internal static int WriteTextureRGB24( + this ObservationWriter obsWriter, + Texture2D texture, + bool grayScale + ) + { + var width = texture.width; + var height = texture.height; + + var rawBytes = texture.GetRawTextureData(); + // During training, we convert from Texture to PNG before sending to the trainer, which has the + // effect of flipping the image. We need another flip here at inference time to match this. + for (var h = height - 1; h >= 0; h--) + { + for (var w = 0; w < width; w++) + { + var offset = (height - h - 1) * width + w; + var r = rawBytes[3 * offset]; + var g = rawBytes[3 * offset + 1]; + var b = rawBytes[3 * offset + 2]; + + if (grayScale) + { + obsWriter[h, w, 0] = (r + g + b) / 3f / 255.0f; + } + else + { + // For Color32, the r, g and b values are between 0 and 255. + obsWriter[h, w, 0] = r / 255.0f; + obsWriter[h, w, 1] = g / 255.0f; + obsWriter[h, w, 2] = b / 255.0f; + } + } + } + + return height * width * (grayScale ? 1 : 3); + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs.meta b/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs.meta new file mode 100644 index 0000000000..62fc3b1aba --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 86bad2e6dded4a62853752a1713981f2 +timeCreated: 1572540197 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Sensors/OneHotGridSensor.cs b/com.unity.ml-agents/Runtime/Sensors/OneHotGridSensor.cs new file mode 100644 index 0000000000..648c702d80 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/OneHotGridSensor.cs @@ -0,0 +1,59 @@ +using UnityEngine; + +namespace Unity.MLAgents.Sensors +{ + /// + /// Grid-based sensor with one-hot observations. + /// + public class OneHotGridSensor : GridSensorBase + { + /// + /// Create a OneHotGridSensor with the specified configuration. + /// + /// The sensor name + /// The scale of each cell in the grid + /// Number of cells on each side of the grid + /// Tags to be detected by the sensor + /// Compression type + public OneHotGridSensor( + string name, + Vector3 cellScale, + Vector3Int gridSize, + string[] detectableTags, + SensorCompressionType compression + ) : base(name, cellScale, gridSize, detectableTags, compression) + { + } + + /// + protected override int GetCellObservationSize() + { + return DetectableTags == null ? 0 : DetectableTags.Length; + } + + /// + protected override bool IsDataNormalized() + { + return true; + } + + /// + protected internal override ProcessCollidersMethod GetProcessCollidersMethod() + { + return ProcessCollidersMethod.ProcessClosestColliders; + } + + /// + /// Get the one-hot representation of the detected game object's tag. + /// + /// The game object that was detected within a certain cell + /// The index of the detectedObject's tag in the DetectableObjects list + /// The buffer to write the observation values. + /// The buffer size is configured by . + /// + protected override void GetObjectData(GameObject detectedObject, int tagIndex, float[] dataBuffer) + { + dataBuffer[tagIndex] = 1; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/OneHotGridSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/OneHotGridSensor.cs.meta new file mode 100644 index 0000000000..c21f87cba4 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/OneHotGridSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 914f5ab90be9e411d83642035abebc2c +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs new file mode 100644 index 0000000000..4cb67ec709 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs @@ -0,0 +1,519 @@ +using System; +using System.Collections.Generic; +using UnityEngine; + +namespace Unity.MLAgents.Sensors +{ + /// + /// Determines which dimensions the sensor will perform the casts in. + /// + public enum RayPerceptionCastType + { + /// + /// Cast in 2 dimensions, using Physics2D.CircleCast or Physics2D.RayCast. + /// + Cast2D, + + /// + /// Cast in 3 dimensions, using Physics.SphereCast or Physics.RayCast. + /// + Cast3D, + } + + /// + /// Contains the elements that define a ray perception sensor. + /// + public struct RayPerceptionInput + { + /// + /// Length of the rays to cast. This will be scaled up or down based on the scale of the transform. + /// + public float RayLength; + + /// + /// List of tags which correspond to object types agent can see. + /// + public IReadOnlyList DetectableTags; + + /// + /// List of angles (in degrees) used to define the rays. + /// 90 degrees is considered "forward" relative to the game object. + /// + public IReadOnlyList Angles; + + /// + /// Starting height offset of ray from center of agent + /// + public float StartOffset; + + /// + /// Ending height offset of ray from center of agent. + /// + public float EndOffset; + + /// + /// Radius of the sphere to use for spherecasting. + /// If 0 or less, rays are used instead - this may be faster, especially for complex environments. + /// + public float CastRadius; + + /// + /// Transform of the GameObject. + /// + public Transform Transform; + + /// + /// Whether to perform the casts in 2D or 3D. + /// + public RayPerceptionCastType CastType; + + /// + /// Filtering options for the casts. + /// + public int LayerMask; + + /// + /// Returns the expected number of floats in the output. + /// + /// + public int OutputSize() + { + return ((DetectableTags?.Count ?? 0) + 2) * (Angles?.Count ?? 0); + } + + /// + /// Get the cast start and end points for the given ray index/ + /// + /// + /// A tuple of the start and end positions in world space. + public (Vector3 StartPositionWorld, Vector3 EndPositionWorld) RayExtents(int rayIndex) + { + var angle = Angles[rayIndex]; + Vector3 startPositionLocal, endPositionLocal; + if (CastType == RayPerceptionCastType.Cast3D) + { + startPositionLocal = new Vector3(0, StartOffset, 0); + endPositionLocal = PolarToCartesian3D(RayLength, angle); + endPositionLocal.y += EndOffset; + } + else + { + // Vector2s here get converted to Vector3s (and back to Vector2s for casting) + startPositionLocal = new Vector2(); + endPositionLocal = PolarToCartesian2D(RayLength, angle); + } + + var startPositionWorld = Transform.TransformPoint(startPositionLocal); + var endPositionWorld = Transform.TransformPoint(endPositionLocal); + + return (StartPositionWorld: startPositionWorld, EndPositionWorld: endPositionWorld); + } + + /// + /// Converts polar coordinate to cartesian coordinate. + /// + static internal Vector3 PolarToCartesian3D(float radius, float angleDegrees) + { + var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees); + var z = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees); + return new Vector3(x, 0f, z); + } + + /// + /// Converts polar coordinate to cartesian coordinate. + /// + static internal Vector2 PolarToCartesian2D(float radius, float angleDegrees) + { + var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees); + var y = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees); + return new Vector2(x, y); + } + } + + /// + /// Contains the data generated/produced from a ray perception sensor. + /// + public class RayPerceptionOutput + { + /// + /// Contains the data generated from a single ray of a ray perception sensor. + /// + public struct RayOutput + { + /// + /// Whether or not the ray hit anything. + /// + public bool HasHit; + + /// + /// Whether or not the ray hit an object whose tag is in the input's DetectableTags list. + /// + public bool HitTaggedObject; + + /// + /// The index of the hit object's tag in the DetectableTags list, or -1 if there was no hit, or the + /// hit object has a different tag. + /// + public int HitTagIndex; + + /// + /// Normalized distance to the hit object. + /// + public float HitFraction; + + /// + /// The hit GameObject (or null if there was no hit). + /// + public GameObject HitGameObject; + + /// + /// Start position of the ray in world space. + /// + public Vector3 StartPositionWorld; + + /// + /// End position of the ray in world space. + /// + public Vector3 EndPositionWorld; + + /// + /// The scaled length of the ray. + /// + /// + /// If there is non-(1,1,1) scale, |EndPositionWorld - StartPositionWorld| will be different from + /// the input rayLength. + /// + public float ScaledRayLength + { + get + { + var rayDirection = EndPositionWorld - StartPositionWorld; + return rayDirection.magnitude; + } + } + + /// + /// The scaled size of the cast. + /// + /// + /// If there is non-(1,1,1) scale, the cast radius will be also be scaled. + /// + public float ScaledCastRadius; + + /// + /// Writes the ray output information to a subset of the float array. Each element in the rayAngles array + /// determines a sublist of data to the observation. The sublist contains the observation data for a single cast. + /// The list is composed of the following: + /// 1. A one-hot encoding for detectable tags. For example, if DetectableTags.Length = n, the + /// first n elements of the sublist will be a one-hot encoding of the detectableTag that was hit, or + /// all zeroes otherwise. + /// 2. The 'numDetectableTags' element of the sublist will be 1 if the ray missed everything, or 0 if it hit + /// something (detectable or not). + /// 3. The 'numDetectableTags+1' element of the sublist will contain the normalized distance to the object + /// hit, or 1.0 if nothing was hit. + /// + /// + /// + /// Output buffer. The size must be equal to (numDetectableTags+2) * RayOutputs.Length + public void ToFloatArray(int numDetectableTags, int rayIndex, float[] buffer) + { + var bufferOffset = (numDetectableTags + 2) * rayIndex; + if (HitTaggedObject) + { + buffer[bufferOffset + HitTagIndex] = 1f; + } + buffer[bufferOffset + numDetectableTags] = HasHit ? 0f : 1f; + buffer[bufferOffset + numDetectableTags + 1] = HitFraction; + } + } + + /// + /// RayOutput for each ray that was cast. + /// + public RayOutput[] RayOutputs; + } + + /// + /// A sensor implementation that supports ray cast-based observations. + /// + public class RayPerceptionSensor : ISensor, IBuiltInSensor + { + float[] m_Observations; + ObservationSpec m_ObservationSpec; + string m_Name; + + RayPerceptionInput m_RayPerceptionInput; + RayPerceptionOutput m_RayPerceptionOutput; + + /// + /// Time.frameCount at the last time Update() was called. This is only used for display in gizmos. + /// + int m_DebugLastFrameCount; + + internal int DebugLastFrameCount + { + get { return m_DebugLastFrameCount; } + } + + /// + /// Creates the RayPerceptionSensor. + /// + /// The name of the sensor. + /// The inputs for the sensor. + public RayPerceptionSensor(string name, RayPerceptionInput rayInput) + { + m_Name = name; + m_RayPerceptionInput = rayInput; + + SetNumObservations(rayInput.OutputSize()); + + m_DebugLastFrameCount = Time.frameCount; + m_RayPerceptionOutput = new RayPerceptionOutput(); + } + + /// + /// The most recent raycast results. + /// + public RayPerceptionOutput RayPerceptionOutput + { + get { return m_RayPerceptionOutput; } + } + + void SetNumObservations(int numObservations) + { + m_ObservationSpec = ObservationSpec.Vector(numObservations); + m_Observations = new float[numObservations]; + } + + internal void SetRayPerceptionInput(RayPerceptionInput rayInput) + { + // Note that change the number of rays or tags doesn't directly call this, + // but changing them and then changing another field will. + if (m_RayPerceptionInput.OutputSize() != rayInput.OutputSize()) + { + Debug.Log( + "Changing the number of tags or rays at runtime is not " + + "supported and may cause errors in training or inference." + ); + // Changing the shape will probably break things downstream, but we can at least + // keep this consistent. + SetNumObservations(rayInput.OutputSize()); + } + m_RayPerceptionInput = rayInput; + } + + /// + /// Computes the ray perception observations and saves them to the provided + /// . + /// + /// Where the ray perception observations are written to. + /// + public int Write(ObservationWriter writer) + { + using (TimerStack.Instance.Scoped("RayPerceptionSensor.Perceive")) + { + Array.Clear(m_Observations, 0, m_Observations.Length); + var numRays = m_RayPerceptionInput.Angles.Count; + var numDetectableTags = m_RayPerceptionInput.DetectableTags.Count; + + // For each ray, write the information to the observation buffer + for (var rayIndex = 0; rayIndex < numRays; rayIndex++) + { + m_RayPerceptionOutput.RayOutputs?[rayIndex].ToFloatArray(numDetectableTags, rayIndex, m_Observations); + } + + // Finally, add the observations to the ObservationWriter + writer.AddList(m_Observations); + } + return m_Observations.Length; + } + + /// + public void Update() + { + m_DebugLastFrameCount = Time.frameCount; + var numRays = m_RayPerceptionInput.Angles.Count; + + if (m_RayPerceptionOutput.RayOutputs == null || m_RayPerceptionOutput.RayOutputs.Length != numRays) + { + m_RayPerceptionOutput.RayOutputs = new RayPerceptionOutput.RayOutput[numRays]; + } + + // For each ray, do the casting and save the results. + for (var rayIndex = 0; rayIndex < numRays; rayIndex++) + { + m_RayPerceptionOutput.RayOutputs[rayIndex] = PerceiveSingleRay(m_RayPerceptionInput, rayIndex); + } + } + + /// + public void Reset() { } + + /// + public ObservationSpec GetObservationSpec() + { + return m_ObservationSpec; + } + + /// + public string GetName() + { + return m_Name; + } + + /// + public virtual byte[] GetCompressedObservation() + { + return null; + } + + /// + public CompressionSpec GetCompressionSpec() + { + return CompressionSpec.Default(); + } + + /// + public BuiltInSensorType GetBuiltInSensorType() + { + return BuiltInSensorType.RayPerceptionSensor; + } + + /// + /// Evaluates the raycasts to be used as part of an observation of an agent. + /// + /// Input defining the rays that will be cast. + /// Output struct containing the raycast results. + public static RayPerceptionOutput Perceive(RayPerceptionInput input) + { + RayPerceptionOutput output = new RayPerceptionOutput(); + output.RayOutputs = new RayPerceptionOutput.RayOutput[input.Angles.Count]; + + for (var rayIndex = 0; rayIndex < input.Angles.Count; rayIndex++) + { + output.RayOutputs[rayIndex] = PerceiveSingleRay(input, rayIndex); + } + + return output; + } + + /// + /// Evaluate the raycast results of a single ray from the RayPerceptionInput. + /// + /// + /// + /// + internal static RayPerceptionOutput.RayOutput PerceiveSingleRay( + RayPerceptionInput input, + int rayIndex + ) + { + var unscaledRayLength = input.RayLength; + var unscaledCastRadius = input.CastRadius; + + var extents = input.RayExtents(rayIndex); + var startPositionWorld = extents.StartPositionWorld; + var endPositionWorld = extents.EndPositionWorld; + + var rayDirection = endPositionWorld - startPositionWorld; + // If there is non-unity scale, |rayDirection| will be different from rayLength. + // We want to use this transformed ray length for determining cast length, hit fraction etc. + // We also it to scale up or down the sphere or circle radii + var scaledRayLength = rayDirection.magnitude; + // Avoid 0/0 if unscaledRayLength is 0 + var scaledCastRadius = unscaledRayLength > 0 ? + unscaledCastRadius * scaledRayLength / unscaledRayLength : + unscaledCastRadius; + + // Do the cast and assign the hit information for each detectable tag. + var castHit = false; + var hitFraction = 1.0f; + GameObject hitObject = null; + + if (input.CastType == RayPerceptionCastType.Cast3D) + { +#if MLA_UNITY_PHYSICS_MODULE + RaycastHit rayHit; + if (scaledCastRadius > 0f) + { + castHit = Physics.SphereCast(startPositionWorld, scaledCastRadius, rayDirection, out rayHit, + scaledRayLength, input.LayerMask); + } + else + { + castHit = Physics.Raycast(startPositionWorld, rayDirection, out rayHit, + scaledRayLength, input.LayerMask); + } + + // If scaledRayLength is 0, we still could have a hit with sphere casts (maybe?). + // To avoid 0/0, set the fraction to 0. + hitFraction = castHit ? (scaledRayLength > 0 ? rayHit.distance / scaledRayLength : 0.0f) : 1.0f; + hitObject = castHit ? rayHit.collider.gameObject : null; +#endif + } + else + { +#if MLA_UNITY_PHYSICS2D_MODULE + RaycastHit2D rayHit; + if (scaledCastRadius > 0f) + { + rayHit = Physics2D.CircleCast(startPositionWorld, scaledCastRadius, rayDirection, + scaledRayLength, input.LayerMask); + } + else + { + rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, scaledRayLength, input.LayerMask); + } + + castHit = rayHit; + hitFraction = castHit ? rayHit.fraction : 1.0f; + hitObject = castHit ? rayHit.collider.gameObject : null; +#endif + } + + var rayOutput = new RayPerceptionOutput.RayOutput + { + HasHit = castHit, + HitFraction = hitFraction, + HitTaggedObject = false, + HitTagIndex = -1, + HitGameObject = hitObject, + StartPositionWorld = startPositionWorld, + EndPositionWorld = endPositionWorld, + ScaledCastRadius = scaledCastRadius + }; + + if (castHit) + { + // Find the index of the tag of the object that was hit. + var numTags = input.DetectableTags?.Count ?? 0; + for (var i = 0; i < numTags; i++) + { + var tagsEqual = false; + try + { + var tag = input.DetectableTags[i]; + if (!string.IsNullOrEmpty(tag)) + { + tagsEqual = hitObject.CompareTag(tag); + } + } + catch (UnityException) + { + // If the tag is null, empty, or not a valid tag, just ignore it. + } + + if (tagsEqual) + { + rayOutput.HitTaggedObject = true; + rayOutput.HitTagIndex = i; + break; + } + } + } + + + return rayOutput; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs.meta new file mode 100644 index 0000000000..4c7247977c --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 71417cdf8dd542e19ec22822b001b884 +timeCreated: 1573089052 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponent2D.cs b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponent2D.cs new file mode 100644 index 0000000000..3abee3d46b --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponent2D.cs @@ -0,0 +1,17 @@ +using UnityEngine; + +namespace Unity.MLAgents.Sensors +{ + /// + /// A component for 2D Ray Perception. + /// + [AddComponentMenu("ML Agents/Ray Perception Sensor 2D", (int)MenuGroup.Sensors)] + public class RayPerceptionSensorComponent2D : RayPerceptionSensorComponentBase + { + /// + public override RayPerceptionCastType GetCastType() + { + return RayPerceptionCastType.Cast2D; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponent2D.cs.meta b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponent2D.cs.meta new file mode 100644 index 0000000000..947a0904d3 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponent2D.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: f67c7e722ba14acd9153bb4488bff6e4 +timeCreated: 1573769662 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponent3D.cs b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponent3D.cs new file mode 100644 index 0000000000..34bde48140 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponent3D.cs @@ -0,0 +1,58 @@ +using UnityEngine; +using UnityEngine.Serialization; + +namespace Unity.MLAgents.Sensors +{ + /// + /// A component for 3D Ray Perception. + /// + [AddComponentMenu("ML Agents/Ray Perception Sensor 3D", (int)MenuGroup.Sensors)] + public class RayPerceptionSensorComponent3D : RayPerceptionSensorComponentBase + { + [HideInInspector, SerializeField, FormerlySerializedAs("startVerticalOffset")] + [Range(-10f, 10f)] + [Tooltip("Ray start is offset up or down by this amount.")] + float m_StartVerticalOffset; + + /// + /// Ray start is offset up or down by this amount. + /// + public float StartVerticalOffset + { + get => m_StartVerticalOffset; + set { m_StartVerticalOffset = value; UpdateSensor(); } + } + + [HideInInspector, SerializeField, FormerlySerializedAs("endVerticalOffset")] + [Range(-10f, 10f)] + [Tooltip("Ray end is offset up or down by this amount.")] + float m_EndVerticalOffset; + + /// + /// Ray end is offset up or down by this amount. + /// + public float EndVerticalOffset + { + get => m_EndVerticalOffset; + set { m_EndVerticalOffset = value; UpdateSensor(); } + } + + /// + public override RayPerceptionCastType GetCastType() + { + return RayPerceptionCastType.Cast3D; + } + + /// + public override float GetStartVerticalOffset() + { + return StartVerticalOffset; + } + + /// + public override float GetEndVerticalOffset() + { + return EndVerticalOffset; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponent3D.cs.meta b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponent3D.cs.meta new file mode 100644 index 0000000000..51ec4e5b16 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponent3D.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 6bb6b867a41448888c1cd4f99643ad71 +timeCreated: 1573764567 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs new file mode 100644 index 0000000000..75a114382e --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs @@ -0,0 +1,370 @@ +using System; +using System.Collections.Generic; +using UnityEngine; +using UnityEngine.Serialization; + +namespace Unity.MLAgents.Sensors +{ + /// + /// A base class to support sensor components for raycast-based sensors. + /// + public abstract class RayPerceptionSensorComponentBase : SensorComponent + { + [HideInInspector, SerializeField, FormerlySerializedAs("sensorName")] + string m_SensorName = "RayPerceptionSensor"; + + /// + /// The name of the Sensor that this component wraps. + /// Note that changing this at runtime does not affect how the Agent sorts the sensors. + /// + public string SensorName + { + get { return m_SensorName; } + set { m_SensorName = value; } + } + + [SerializeField, FormerlySerializedAs("detectableTags")] + [Tooltip("List of tags in the scene to compare against.")] + List m_DetectableTags; + + /// + /// List of tags in the scene to compare against. + /// Note that this should not be changed at runtime. + /// + public List DetectableTags + { + get { return m_DetectableTags; } + set { m_DetectableTags = value; } + } + + [HideInInspector, SerializeField, FormerlySerializedAs("raysPerDirection")] + [Range(0, 50)] + [Tooltip("Number of rays to the left and right of center.")] + int m_RaysPerDirection = 3; + + /// + /// Number of rays to the left and right of center. + /// Note that this should not be changed at runtime. + /// + public int RaysPerDirection + { + get { return m_RaysPerDirection; } + // Note: can't change at runtime + set { m_RaysPerDirection = value; } + } + + [HideInInspector, SerializeField, FormerlySerializedAs("maxRayDegrees")] + [Range(0, 180)] + [Tooltip("Cone size for rays. Using 90 degrees will cast rays to the left and right. " + + "Greater than 90 degrees will go backwards.")] + float m_MaxRayDegrees = 70; + + /// + /// Cone size for rays. Using 90 degrees will cast rays to the left and right. + /// Greater than 90 degrees will go backwards. + /// + public float MaxRayDegrees + { + get => m_MaxRayDegrees; + set { m_MaxRayDegrees = value; UpdateSensor(); } + } + + [HideInInspector, SerializeField, FormerlySerializedAs("sphereCastRadius")] + [Range(0f, 10f)] + [Tooltip("Radius of sphere to cast. Set to zero for raycasts.")] + float m_SphereCastRadius = 0.5f; + + /// + /// Radius of sphere to cast. Set to zero for raycasts. + /// + public float SphereCastRadius + { + get => m_SphereCastRadius; + set { m_SphereCastRadius = value; UpdateSensor(); } + } + + [HideInInspector, SerializeField, FormerlySerializedAs("rayLength")] + [Range(1, 1000)] + [Tooltip("Length of the rays to cast.")] + float m_RayLength = 20f; + + /// + /// Length of the rays to cast. + /// + public float RayLength + { + get => m_RayLength; + set { m_RayLength = value; UpdateSensor(); } + } + + // The value of the default layers. + const int k_PhysicsDefaultLayers = -5; + [HideInInspector, SerializeField, FormerlySerializedAs("rayLayerMask")] + [Tooltip("Controls which layers the rays can hit.")] + LayerMask m_RayLayerMask = k_PhysicsDefaultLayers; + + /// + /// Controls which layers the rays can hit. + /// + public LayerMask RayLayerMask + { + get => m_RayLayerMask; + set { m_RayLayerMask = value; UpdateSensor(); } + } + + [HideInInspector, SerializeField, FormerlySerializedAs("observationStacks")] + [Range(1, 50)] + [Tooltip("Number of raycast results that will be stacked before being fed to the neural network.")] + int m_ObservationStacks = 1; + + /// + /// Whether to stack previous observations. Using 1 means no previous observations. + /// Note that changing this after the sensor is created has no effect. + /// + public int ObservationStacks + { + get { return m_ObservationStacks; } + set { m_ObservationStacks = value; } + } + + [HideInInspector, SerializeField] + [Tooltip("Disable to provide the rays in left to right order. Warning: Alternating order will be deprecated, disable it to ensure compatibility with future versions of ML-Agents.")] + public bool m_AlternatingRayOrder = true; + + /// + /// Determines how the rays are ordered. By default the ordering is as follows: middle ray is first; + /// then alternates outward adding rays to the left and right. If set to false, then the rays are + /// ordered from left to right (viewed from above) which is more amenable to processing with + /// conv nets. + /// This property will be deprecated with the next major version update and the left to right ordering + /// will be used thereafter. + /// + public bool AlternatingRayOrder + { + get { return m_AlternatingRayOrder; } + set { m_AlternatingRayOrder = value; } + } + + /// + /// Color to code a ray that hits another object. + /// + [HideInInspector] + [SerializeField] + [Header("Debug Gizmos", order = 999)] + internal Color rayHitColor = Color.red; + + /// + /// Color to code a ray that avoid or misses all other objects. + /// + [HideInInspector] + [SerializeField] + internal Color rayMissColor = Color.white; + + [NonSerialized] + RayPerceptionSensor m_RaySensor; + + /// + /// Get the RayPerceptionSensor that was created. + /// + public RayPerceptionSensor RaySensor + { + get => m_RaySensor; + } + + /// + /// Returns the for the associated raycast sensor. + /// + /// + public abstract RayPerceptionCastType GetCastType(); + + /// + /// Returns the amount that the ray start is offset up or down by. + /// + /// + public virtual float GetStartVerticalOffset() + { + return 0f; + } + + /// + /// Returns the amount that the ray end is offset up or down by. + /// + /// + public virtual float GetEndVerticalOffset() + { + return 0f; + } + + /// + /// Returns an initialized raycast sensor. + /// + /// + public override ISensor[] CreateSensors() + { + var rayPerceptionInput = GetRayPerceptionInput(); + + m_RaySensor = new RayPerceptionSensor(m_SensorName, rayPerceptionInput); + + if (ObservationStacks != 1) + { + var stackingSensor = new StackingSensor(m_RaySensor, ObservationStacks); + return new ISensor[] { stackingSensor }; + } + + return new ISensor[] { m_RaySensor }; + } + + /// + /// Returns the specific ray angles given the number of rays per direction and the + /// cone size for the rays. + /// + /// Number of rays to the left and right of center. + /// + /// Cone size for rays. Using 90 degrees will cast rays to the left and right. + /// Greater than 90 degrees will go backwards. + /// Orders the rays starting with the centermost and alternating to the left and right. + /// Should be deprecated with a future major version release (doing so will break existing + /// models). + /// + /// + internal static float[] GetRayAnglesAlternating(int raysPerDirection, float maxRayDegrees) + { + // Example: + // { 90, 90 - delta, 90 + delta, 90 - 2*delta, 90 + 2*delta } + var anglesOut = new float[2 * raysPerDirection + 1]; + var delta = maxRayDegrees / raysPerDirection; + anglesOut[0] = 90f; + for (var i = 0; i < raysPerDirection; i++) + { + anglesOut[2 * i + 1] = 90 - (i + 1) * delta; + anglesOut[2 * i + 2] = 90 + (i + 1) * delta; + } + return anglesOut; + } + + /// + /// Returns the specific ray angles given the number of rays per direction and the + /// cone size for the rays. + /// + /// Number of rays to the left and right of center. + /// + /// Cone size for rays. Using 90 degrees will cast rays to the left and right. + /// Greater than 90 degrees will go backwards. + /// Orders the rays from the left-most to the right-most which makes using a convolution + /// in the model easier. + /// + /// + internal static float[] GetRayAngles(int raysPerDirection, float maxRayDegrees) + { + // Example: + // { 90 - 3*delta, 90 - 2*delta, ..., 90, 90 + delta, ..., 90 + 3*delta } + var anglesOut = new float[2 * raysPerDirection + 1]; + var delta = maxRayDegrees / raysPerDirection; + + for (var i = 0; i < 2 * raysPerDirection + 1; i++) + { + anglesOut[i] = 90 + (i - raysPerDirection) * delta; + } + + return anglesOut; + } + + /// + /// Get the RayPerceptionInput that is used by the . + /// + /// + public RayPerceptionInput GetRayPerceptionInput() + { + var rayAngles = m_AlternatingRayOrder ? + GetRayAnglesAlternating(RaysPerDirection, MaxRayDegrees) : + GetRayAngles(RaysPerDirection, MaxRayDegrees); + + var rayPerceptionInput = new RayPerceptionInput(); + rayPerceptionInput.RayLength = RayLength; + rayPerceptionInput.DetectableTags = DetectableTags; + rayPerceptionInput.Angles = rayAngles; + rayPerceptionInput.StartOffset = GetStartVerticalOffset(); + rayPerceptionInput.EndOffset = GetEndVerticalOffset(); + rayPerceptionInput.CastRadius = SphereCastRadius; + rayPerceptionInput.Transform = transform; + rayPerceptionInput.CastType = GetCastType(); + rayPerceptionInput.LayerMask = RayLayerMask; + + return rayPerceptionInput; + } + + internal void UpdateSensor() + { + if (m_RaySensor != null) + { + var rayInput = GetRayPerceptionInput(); + m_RaySensor.SetRayPerceptionInput(rayInput); + } + } + + internal int SensorObservationAge() + { + if (m_RaySensor != null) + { + return Time.frameCount - m_RaySensor.DebugLastFrameCount; + } + + return 0; + } + + void OnDrawGizmosSelected() + { + if (m_RaySensor?.RayPerceptionOutput?.RayOutputs != null) + { + // If we have cached debug info from the sensor, draw that. + // Draw "old" observations in a lighter color. + // Since the agent may not step every frame, this helps de-emphasize "stale" hit information. + var alpha = Mathf.Pow(.5f, SensorObservationAge()); + + foreach (var rayInfo in m_RaySensor.RayPerceptionOutput.RayOutputs) + { + DrawRaycastGizmos(rayInfo, alpha); + } + } + else + { + var rayInput = GetRayPerceptionInput(); + // We don't actually need the tags here, since they don't affect the display of the rays. + // Additionally, the user might be in the middle of typing the tag name when this is called, + // and there's no way to turn off the "Tag ... is not defined" error logs. + // So just don't use any tags here. + rayInput.DetectableTags = null; + for (var rayIndex = 0; rayIndex < rayInput.Angles.Count; rayIndex++) + { + var rayOutput = RayPerceptionSensor.PerceiveSingleRay(rayInput, rayIndex); + DrawRaycastGizmos(rayOutput); + } + } + } + + /// + /// Draw the debug information from the sensor (if available). + /// + void DrawRaycastGizmos(RayPerceptionOutput.RayOutput rayOutput, float alpha = 1.0f) + { + var startPositionWorld = rayOutput.StartPositionWorld; + var endPositionWorld = rayOutput.EndPositionWorld; + var rayDirection = endPositionWorld - startPositionWorld; + rayDirection *= rayOutput.HitFraction; + + // hit fraction ^2 will shift "far" hits closer to the hit color + var lerpT = rayOutput.HitFraction * rayOutput.HitFraction; + var color = Color.Lerp(rayHitColor, rayMissColor, lerpT); + color.a *= alpha; + Gizmos.color = color; + Gizmos.DrawRay(startPositionWorld, rayDirection); + + // Draw the hit point as a sphere. If using rays to cast (0 radius), use a small sphere. + if (rayOutput.HasHit) + { + var hitRadius = Mathf.Max(rayOutput.ScaledCastRadius, .05f); + Gizmos.DrawWireSphere(startPositionWorld + rayDirection, hitRadius); + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs.meta b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs.meta new file mode 100644 index 0000000000..97f40e582f --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 45243967d8c0419b953c02bccb7c2768 +timeCreated: 1573087062 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection.meta new file mode 100644 index 0000000000..fb7288f717 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 08ece3d7e9bb94089a9d59c6f269ab0a +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs new file mode 100644 index 0000000000..606656ecb3 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs @@ -0,0 +1,19 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Sensor that wraps a boolean field or property of an object, and returns + /// that as an observation. + /// + internal class BoolReflectionSensor : ReflectionSensorBase + { + public BoolReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 1) + { } + + internal override void WriteReflectedField(ObservationWriter writer) + { + var boolVal = (System.Boolean)GetReflectedValue(); + writer[0] = boolVal ? 1.0f : 0.0f; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs.meta new file mode 100644 index 0000000000..5cac420f11 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: be795c90750a6420d93f569b69ddc1ba +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/EnumReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/EnumReflectionSensor.cs new file mode 100644 index 0000000000..2d92df0369 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/EnumReflectionSensor.cs @@ -0,0 +1,61 @@ +using System; + +namespace Unity.MLAgents.Sensors.Reflection +{ + internal class EnumReflectionSensor : ReflectionSensorBase + { + Array m_Values; + bool m_IsFlags; + + internal EnumReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, GetEnumObservationSize(reflectionSensorInfo.GetMemberType())) + { + var memberType = reflectionSensorInfo.GetMemberType(); + m_Values = Enum.GetValues(memberType); + m_IsFlags = memberType.IsDefined(typeof(FlagsAttribute), false); + } + + internal override void WriteReflectedField(ObservationWriter writer) + { + // Write the enum value as a one-hot encoding. + // Note that unknown enum values will record all 0's. + // Flags will get treated as a sequence of bools. + var enumValue = (Enum)GetReflectedValue(); + + int i = 0; + foreach (var val in m_Values) + { + if (m_IsFlags) + { + if (enumValue.HasFlag((Enum)val)) + { + writer[i] = 1.0f; + } + else + { + writer[i] = 0.0f; + } + } + else + { + if (val.Equals(enumValue)) + { + writer[i] = 1.0f; + } + else + { + writer[i] = 0.0f; + } + } + i++; + } + } + + internal static int GetEnumObservationSize(Type t) + { + var values = Enum.GetValues(t); + // Account for all enum values + return values.Length; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/EnumReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/EnumReflectionSensor.cs.meta new file mode 100644 index 0000000000..d42cce5521 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/EnumReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 7d86e42cede474ec28dc3b1ef1c7a63c +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs new file mode 100644 index 0000000000..a488a9ed5c --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs @@ -0,0 +1,19 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Sensor that wraps a float field or property of an object, and returns + /// that as an observation. + /// + internal class FloatReflectionSensor : ReflectionSensorBase + { + public FloatReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 1) + { } + + internal override void WriteReflectedField(ObservationWriter writer) + { + var floatVal = (System.Single)GetReflectedValue(); + writer[0] = floatVal; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs.meta new file mode 100644 index 0000000000..2de8b18c7c --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 51ed837d5b7cd44349287ac8066120fc +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs new file mode 100644 index 0000000000..6c1b10f45a --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs @@ -0,0 +1,19 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Sensor that wraps an integer field or property of an object, and returns + /// that as an observation. + /// + internal class IntReflectionSensor : ReflectionSensorBase + { + public IntReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 1) + { } + + internal override void WriteReflectedField(ObservationWriter writer) + { + var intVal = (System.Int32)GetReflectedValue(); + writer[0] = intVal; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs.meta new file mode 100644 index 0000000000..a07726937f --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 5cae4c843cc074d11a549aaa3904c898 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs new file mode 100644 index 0000000000..a88528c873 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs @@ -0,0 +1,285 @@ +using System; +using System.Collections.Generic; +using System.Reflection; +using UnityEngine; + +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Specify that a field or property should be used to generate observations for an Agent. + /// For each field or property that uses ObservableAttribute, a corresponding + /// will be created during Agent initialization, and this + /// sensor will read the values during training and inference. + /// + /// + /// ObservableAttribute is intended to make initial setup of an Agent easier. Because it + /// uses reflection to read the values of fields and properties at runtime, this may + /// be much slower than reading the values directly. If the performance of + /// ObservableAttribute is an issue, you can get the same functionality by overriding + /// or creating a custom + /// implementation to read the values without reflection. + /// + /// Note that you do not need to adjust the VectorObservationSize in + /// when adding ObservableAttribute + /// to fields or properties. + /// + /// + /// This sample class will produce two observations, one for the m_Health field, and one + /// for the HealthPercent property. + /// + /// using Unity.MLAgents; + /// using Unity.MLAgents.Sensors.Reflection; + /// + /// public class MyAgent : Agent + /// { + /// [Observable] + /// int m_Health; + /// + /// [Observable] + /// float HealthPercent + /// { + /// get => return 100.0f * m_Health / float(m_MaxHealth); + /// } + /// } + /// + /// + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] + public class ObservableAttribute : Attribute + { + string m_Name; + int m_NumStackedObservations; + + /// + /// Default binding flags used for reflection of members and properties. + /// + const BindingFlags k_BindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; + + /// + /// Supported types and their observation sizes and corresponding sensor type. + /// + static Dictionary s_TypeToSensorInfo = new Dictionary() + { + {typeof(int), (1, typeof(IntReflectionSensor))}, + {typeof(bool), (1, typeof(BoolReflectionSensor))}, + {typeof(float), (1, typeof(FloatReflectionSensor))}, + + {typeof(Vector2), (2, typeof(Vector2ReflectionSensor))}, + {typeof(Vector3), (3, typeof(Vector3ReflectionSensor))}, + {typeof(Vector4), (4, typeof(Vector4ReflectionSensor))}, + {typeof(Quaternion), (4, typeof(QuaternionReflectionSensor))}, + }; + + /// + /// ObservableAttribute constructor. + /// + /// Optional override for the sensor name. Note that all sensors for an Agent + /// must have a unique name. + /// Number of frames to concatenate observations from. + public ObservableAttribute(string name = null, int numStackedObservations = 1) + { + m_Name = name; + m_NumStackedObservations = numStackedObservations; + } + + /// + /// Returns a FieldInfo for all fields that have an ObservableAttribute + /// + /// Object being reflected + /// Whether to exclude inherited properties or not. + /// + static IEnumerable<(FieldInfo, ObservableAttribute)> GetObservableFields(object o, bool excludeInherited) + { + // TODO cache these (and properties) by type, so that we only have to reflect once. + var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0); + var fields = o.GetType().GetFields(bindingFlags); + foreach (var field in fields) + { + var attr = (ObservableAttribute)GetCustomAttribute(field, typeof(ObservableAttribute)); + if (attr != null) + { + yield return (field, attr); + } + } + } + + /// + /// Returns a PropertyInfo for all fields that have an ObservableAttribute + /// + /// Object being reflected + /// Whether to exclude inherited properties or not. + /// + static IEnumerable<(PropertyInfo, ObservableAttribute)> GetObservableProperties(object o, bool excludeInherited) + { + var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0); + var properties = o.GetType().GetProperties(bindingFlags); + foreach (var prop in properties) + { + var attr = (ObservableAttribute)GetCustomAttribute(prop, typeof(ObservableAttribute)); + if (attr != null) + { + yield return (prop, attr); + } + } + } + + /// + /// Creates sensors for each field and property with ObservableAttribute. + /// + /// Object being reflected + /// Whether to exclude inherited properties or not. + /// + internal static List CreateObservableSensors(object o, bool excludeInherited) + { + var sensorsOut = new List(); + foreach (var (field, attr) in GetObservableFields(o, excludeInherited)) + { + var sensor = CreateReflectionSensor(o, field, null, attr); + if (sensor != null) + { + sensorsOut.Add(sensor); + } + } + + foreach (var (prop, attr) in GetObservableProperties(o, excludeInherited)) + { + if (!prop.CanRead) + { + // Skip unreadable properties. + continue; + } + var sensor = CreateReflectionSensor(o, null, prop, attr); + if (sensor != null) + { + sensorsOut.Add(sensor); + } + } + + return sensorsOut; + } + + /// + /// Create the ISensor for either the field or property on the provided object. + /// If the data type is unsupported, or the property is write-only, returns null. + /// + /// + /// + /// + /// + /// + /// + static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute) + { + string memberName; + string declaringTypeName; + Type memberType; + if (fieldInfo != null) + { + declaringTypeName = fieldInfo.DeclaringType.Name; + memberName = fieldInfo.Name; + memberType = fieldInfo.FieldType; + } + else + { + declaringTypeName = propertyInfo.DeclaringType.Name; + memberName = propertyInfo.Name; + memberType = propertyInfo.PropertyType; + } + + if (!s_TypeToSensorInfo.ContainsKey(memberType) && !memberType.IsEnum) + { + // For unsupported types, return null and we'll filter them out later. + return null; + } + + string sensorName; + if (string.IsNullOrEmpty(observableAttribute.m_Name)) + { + sensorName = $"ObservableAttribute:{declaringTypeName}.{memberName}"; + } + else + { + sensorName = observableAttribute.m_Name; + } + + var reflectionSensorInfo = new ReflectionSensorInfo + { + Object = o, + FieldInfo = fieldInfo, + PropertyInfo = propertyInfo, + ObservableAttribute = observableAttribute, + SensorName = sensorName + }; + + ISensor sensor = null; + if (memberType.IsEnum) + { + sensor = new EnumReflectionSensor(reflectionSensorInfo); + } + else + { + var (_, sensorType) = s_TypeToSensorInfo[memberType]; + sensor = (ISensor)Activator.CreateInstance(sensorType, reflectionSensorInfo); + } + + // Wrap the base sensor in a StackingSensor if we're using stacking. + if (observableAttribute.m_NumStackedObservations > 1) + { + return new StackingSensor(sensor, observableAttribute.m_NumStackedObservations); + } + + return sensor; + } + + /// + /// Gets the sum of the observation sizes of the Observable fields and properties on an object. + /// Also appends errors to the errorsOut array. + /// + /// + /// + /// + /// + internal static int GetTotalObservationSize(object o, bool excludeInherited, List errorsOut) + { + int sizeOut = 0; + foreach (var (field, attr) in GetObservableFields(o, excludeInherited)) + { + if (s_TypeToSensorInfo.ContainsKey(field.FieldType)) + { + var (obsSize, _) = s_TypeToSensorInfo[field.FieldType]; + sizeOut += obsSize * attr.m_NumStackedObservations; + } + else if (field.FieldType.IsEnum) + { + sizeOut += EnumReflectionSensor.GetEnumObservationSize(field.FieldType); + } + else + { + errorsOut.Add($"Unsupported Observable type {field.FieldType.Name} on field {field.Name}"); + } + } + + foreach (var (prop, attr) in GetObservableProperties(o, excludeInherited)) + { + if (!prop.CanRead) + { + errorsOut.Add($"Observable property {prop.Name} is write-only."); + } + else if (s_TypeToSensorInfo.ContainsKey(prop.PropertyType)) + { + var (obsSize, _) = s_TypeToSensorInfo[prop.PropertyType]; + sizeOut += obsSize * attr.m_NumStackedObservations; + } + else if (prop.PropertyType.IsEnum) + { + sizeOut += EnumReflectionSensor.GetEnumObservationSize(prop.PropertyType); + } + else + { + errorsOut.Add($"Unsupported Observable type {prop.PropertyType.Name} on property {prop.Name}"); + } + } + + return sizeOut; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs.meta new file mode 100644 index 0000000000..41659283da --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: a75086dc66a594baea6b8b2935f5dacf +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs new file mode 100644 index 0000000000..5cd92ee68f --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs @@ -0,0 +1,19 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Sensor that wraps a quaternion field or property of an object, and returns + /// that as an observation. + /// + internal class QuaternionReflectionSensor : ReflectionSensorBase + { + public QuaternionReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 4) + { } + + internal override void WriteReflectedField(ObservationWriter writer) + { + var quatVal = (UnityEngine.Quaternion)GetReflectedValue(); + writer.Add(quatVal); + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs.meta new file mode 100644 index 0000000000..f3970e6b51 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: d38241d74074d459bb4590f7f5d16c80 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs new file mode 100644 index 0000000000..9a0219146e --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs @@ -0,0 +1,111 @@ +using System; +using System.Reflection; + +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Construction info for a ReflectionSensorBase. + /// + internal struct ReflectionSensorInfo + { + public object Object; + + public FieldInfo FieldInfo; + public PropertyInfo PropertyInfo; + public ObservableAttribute ObservableAttribute; + public string SensorName; + + public Type GetMemberType() + { + return FieldInfo != null ? FieldInfo.FieldType : PropertyInfo.PropertyType; + } + } + + /// + /// Abstract base class for reflection-based sensors. + /// + internal abstract class ReflectionSensorBase : ISensor, IBuiltInSensor + { + protected object m_Object; + + // Exactly one of m_FieldInfo and m_PropertyInfo should be non-null. + protected FieldInfo m_FieldInfo; + protected PropertyInfo m_PropertyInfo; + + // Not currently used, but might want later. + protected ObservableAttribute m_ObservableAttribute; + + // Cached sensor names and shapes. + string m_SensorName; + ObservationSpec m_ObservationSpec; + int m_NumFloats; + + public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size) + { + m_Object = reflectionSensorInfo.Object; + m_FieldInfo = reflectionSensorInfo.FieldInfo; + m_PropertyInfo = reflectionSensorInfo.PropertyInfo; + m_ObservableAttribute = reflectionSensorInfo.ObservableAttribute; + m_SensorName = reflectionSensorInfo.SensorName; + m_ObservationSpec = ObservationSpec.Vector(size); + m_NumFloats = size; + } + + /// + public ObservationSpec GetObservationSpec() + { + return m_ObservationSpec; + } + + /// + public int Write(ObservationWriter writer) + { + WriteReflectedField(writer); + return m_NumFloats; + } + + internal abstract void WriteReflectedField(ObservationWriter writer); + + /// + /// Get either the reflected field, or return the reflected property. + /// This should be used by implementations in their WriteReflectedField() method. + /// + /// + protected object GetReflectedValue() + { + return m_FieldInfo != null ? + m_FieldInfo.GetValue(m_Object) : + m_PropertyInfo.GetMethod.Invoke(m_Object, null); + } + + /// + public byte[] GetCompressedObservation() + { + return null; + } + + /// + public void Update() { } + + /// + public void Reset() { } + + /// + public CompressionSpec GetCompressionSpec() + { + return CompressionSpec.Default(); + } + + /// + public string GetName() + { + return m_SensorName; + } + + /// + public BuiltInSensorType GetBuiltInSensorType() + { + return BuiltInSensorType.ReflectionSensor; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs.meta new file mode 100644 index 0000000000..cef19bb598 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 6b68d855fb94a45fbbeb0dbe968a35f8 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs new file mode 100644 index 0000000000..85c6dea8ef --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs @@ -0,0 +1,20 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Sensor that wraps a Vector2 field or property of an object, and returns + /// that as an observation. + /// + internal class Vector2ReflectionSensor : ReflectionSensorBase + { + public Vector2ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 2) + { } + + internal override void WriteReflectedField(ObservationWriter writer) + { + var vecVal = (UnityEngine.Vector2)GetReflectedValue(); + writer[0] = vecVal.x; + writer[1] = vecVal.y; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs.meta new file mode 100644 index 0000000000..2b78c25ffe --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: da06ff33f6f2d409cbf240cffa2ba0be +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs new file mode 100644 index 0000000000..8fa28b73ee --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs @@ -0,0 +1,19 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Sensor that wraps a Vector3 field or property of an object, and returns + /// that as an observation. + /// + internal class Vector3ReflectionSensor : ReflectionSensorBase + { + public Vector3ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 3) + { } + + internal override void WriteReflectedField(ObservationWriter writer) + { + var vecVal = (UnityEngine.Vector3)GetReflectedValue(); + writer.Add(vecVal); + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs.meta new file mode 100644 index 0000000000..771b690b07 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e756976ec2a0943cfbc0f97a6550a85b +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs new file mode 100644 index 0000000000..76b5c39c27 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs @@ -0,0 +1,19 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Sensor that wraps a Vector4 field or property of an object, and returns + /// that as an observation. + /// + internal class Vector4ReflectionSensor : ReflectionSensorBase + { + public Vector4ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 4) + { } + + internal override void WriteReflectedField(ObservationWriter writer) + { + var vecVal = (UnityEngine.Vector4)GetReflectedValue(); + writer.Add(vecVal); + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs.meta new file mode 100644 index 0000000000..3d938af6c8 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 01d93aaa1b42b47b8960d303d7c498d3 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs new file mode 100644 index 0000000000..678e722220 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs @@ -0,0 +1,130 @@ +using System; +using UnityEngine; + +namespace Unity.MLAgents.Sensors +{ + /// + /// Sensor class that wraps a [RenderTexture](https://docs.unity3d.com/ScriptReference/RenderTexture.html) instance. + /// + public class RenderTextureSensor : ISensor, IBuiltInSensor, IDisposable + { + RenderTexture m_RenderTexture; + bool m_Grayscale; + string m_Name; + private ObservationSpec m_ObservationSpec; + SensorCompressionType m_CompressionType; + Texture2D m_Texture; + + /// + /// The compression type used by the sensor. + /// + public SensorCompressionType CompressionType + { + get { return m_CompressionType; } + set { m_CompressionType = value; } + } + + + /// + /// Initializes the sensor. + /// + /// The [RenderTexture](https://docs.unity3d.com/ScriptReference/RenderTexture.html) + /// instance to wrap. + /// Whether to convert it to grayscale or not. + /// Name of the sensor. + /// Compression method for the render texture. + /// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html + public RenderTextureSensor( + RenderTexture renderTexture, bool grayscale, string name, SensorCompressionType compressionType) + { + m_RenderTexture = renderTexture; + var width = renderTexture != null ? renderTexture.width : 0; + var height = renderTexture != null ? renderTexture.height : 0; + m_Grayscale = grayscale; + m_Name = name; + m_ObservationSpec = ObservationSpec.Visual(height, width, grayscale ? 1 : 3); + m_CompressionType = compressionType; + m_Texture = new Texture2D(width, height, TextureFormat.RGB24, false); + } + + /// + public string GetName() + { + return m_Name; + } + + /// + public ObservationSpec GetObservationSpec() + { + return m_ObservationSpec; + } + + /// + public byte[] GetCompressedObservation() + { + using (TimerStack.Instance.Scoped("RenderTextureSensor.GetCompressedObservation")) + { + ObservationToTexture(m_RenderTexture, m_Texture); + // TODO support more types here, e.g. JPG + var compressed = m_Texture.EncodeToPNG(); + return compressed; + } + } + + /// + public int Write(ObservationWriter writer) + { + using (TimerStack.Instance.Scoped("RenderTextureSensor.Write")) + { + ObservationToTexture(m_RenderTexture, m_Texture); + var numWritten = writer.WriteTexture(m_Texture, m_Grayscale); + return numWritten; + } + } + + /// + public void Update() { } + + /// + public void Reset() { } + + /// + public CompressionSpec GetCompressionSpec() + { + return new CompressionSpec(m_CompressionType); + } + + /// + public BuiltInSensorType GetBuiltInSensorType() + { + return BuiltInSensorType.RenderTextureSensor; + } + + /// + /// Converts a RenderTexture to a 2D texture. + /// + /// RenderTexture. + /// Texture2D to render to. + public static void ObservationToTexture(RenderTexture obsTexture, Texture2D texture2D) + { + var prevActiveRt = RenderTexture.active; + RenderTexture.active = obsTexture; + + texture2D.ReadPixels(new Rect(0, 0, texture2D.width, texture2D.height), 0, 0); + texture2D.Apply(); + RenderTexture.active = prevActiveRt; + } + + /// + /// Clean up the owned Texture2D. + /// + public void Dispose() + { + if (!ReferenceEquals(null, m_Texture)) + { + Utilities.DestroyTexture(m_Texture); + m_Texture = null; + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs.meta new file mode 100644 index 0000000000..28a1dff767 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 958f1f6bb9058405cae3c03266ad9899 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs new file mode 100644 index 0000000000..8e58617dd3 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs @@ -0,0 +1,120 @@ +using System; +using UnityEngine; +using UnityEngine.Serialization; + +namespace Unity.MLAgents.Sensors +{ + /// + /// Component that wraps a . + /// + [AddComponentMenu("ML Agents/Render Texture Sensor", (int)MenuGroup.Sensors)] + public class RenderTextureSensorComponent : SensorComponent, IDisposable + { + RenderTextureSensor m_Sensor; + + /// + /// The [RenderTexture](https://docs.unity3d.com/ScriptReference/RenderTexture.html) instance + /// that the associated wraps. + /// + [HideInInspector, SerializeField, FormerlySerializedAs("renderTexture")] + RenderTexture m_RenderTexture; + + /// + /// Stores the [RenderTexture](https://docs.unity3d.com/ScriptReference/RenderTexture.html) + /// associated with this sensor. + /// + public RenderTexture RenderTexture + { + get { return m_RenderTexture; } + set { m_RenderTexture = value; } + } + + [HideInInspector, SerializeField, FormerlySerializedAs("sensorName")] + string m_SensorName = "RenderTextureSensor"; + + /// + /// Name of the generated . + /// Note that changing this at runtime does not affect how the Agent sorts the sensors. + /// + public string SensorName + { + get { return m_SensorName; } + set { m_SensorName = value; } + } + + [HideInInspector, SerializeField, FormerlySerializedAs("grayscale")] + bool m_Grayscale; + + /// + /// Whether the RenderTexture observation should be converted to grayscale or not. + /// Note that changing this after the sensor is created has no effect. + /// + public bool Grayscale + { + get { return m_Grayscale; } + set { m_Grayscale = value; } + } + + [HideInInspector, SerializeField] + [Range(1, 50)] + [Tooltip("Number of frames that will be stacked before being fed to the neural network.")] + int m_ObservationStacks = 1; + + [HideInInspector, SerializeField, FormerlySerializedAs("compression")] + SensorCompressionType m_Compression = SensorCompressionType.PNG; + + /// + /// Compression type for the render texture observation. + /// + public SensorCompressionType CompressionType + { + get { return m_Compression; } + set { m_Compression = value; UpdateSensor(); } + } + + /// + /// Whether to stack previous observations. Using 1 means no previous observations. + /// Note that changing this after the sensor is created has no effect. + /// + public int ObservationStacks + { + get { return m_ObservationStacks; } + set { m_ObservationStacks = value; } + } + + /// + public override ISensor[] CreateSensors() + { + Dispose(); + m_Sensor = new RenderTextureSensor(RenderTexture, Grayscale, SensorName, m_Compression); + if (ObservationStacks != 1) + { + return new ISensor[] { new StackingSensor(m_Sensor, ObservationStacks) }; + } + return new ISensor[] { m_Sensor }; + } + + /// + /// Update fields that are safe to change on the Sensor at runtime. + /// + internal void UpdateSensor() + { + if (m_Sensor != null) + { + m_Sensor.CompressionType = m_Compression; + } + } + + /// + /// Clean up the sensor created by CreateSensors(). + /// + public void Dispose() + { + if (!ReferenceEquals(null, m_Sensor)) + { + m_Sensor.Dispose(); + m_Sensor = null; + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs.meta b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs.meta new file mode 100644 index 0000000000..542ca3e278 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 132e1194facb64429b007ea1edf562d0 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs new file mode 100644 index 0000000000..4ddcaabb74 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs @@ -0,0 +1,17 @@ +using UnityEngine; + +namespace Unity.MLAgents.Sensors +{ + /// + /// Editor components for creating Sensors. Generally an ISensor implementation should have a + /// corresponding SensorComponent to create it. + /// + public abstract class SensorComponent : MonoBehaviour + { + /// + /// Create the ISensors. This is called by the Agent when it is initialized. + /// + /// Created ISensor objects. + public abstract ISensor[] CreateSensors(); + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs.meta b/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs.meta new file mode 100644 index 0000000000..5576281e12 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 4f1dad589959a4b598d09e54f61fbe02 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs b/com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs new file mode 100644 index 0000000000..9d45e673ed --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs @@ -0,0 +1,55 @@ +using System.Collections.Generic; +using UnityEngine; + +namespace Unity.MLAgents.Sensors +{ + public class SensorShapeValidator + { + List m_SensorShapes; + + /// + /// Check that the List Sensors are the same shape as the previous ones. + /// If this is the first List of Sensors being checked, its Sensor sizes will be saved. + /// + public void ValidateSensors(List sensors) + { + if (m_SensorShapes == null) + { + m_SensorShapes = new List(sensors.Count); + // First agent, save the sensor sizes + foreach (var sensor in sensors) + { + m_SensorShapes.Add(sensor.GetObservationSpec()); + } + } + else + { + // Check for compatibility with the other Agents' Sensors + if (m_SensorShapes.Count != sensors.Count) + { + Debug.AssertFormat( + m_SensorShapes.Count == sensors.Count, + "Number of Sensors must match. {0} != {1}", + m_SensorShapes.Count, + sensors.Count + ); + } + for (var i = 0; i < Mathf.Min(m_SensorShapes.Count, sensors.Count); i++) + { + var cachedSpec = m_SensorShapes[i]; + var sensorSpec = sensors[i].GetObservationSpec(); + if (cachedSpec.Shape != sensorSpec.Shape) + { + Debug.AssertFormat( + cachedSpec.Shape == sensorSpec.Shape, + "Sensor shapes must match. {0} != {1}", + cachedSpec.Shape, + sensorSpec.Shape + ); + + } + } + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs.meta b/com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs.meta new file mode 100644 index 0000000000..6ce44665cc --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: a7b5a4560ee254be497321527f92c174 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs new file mode 100644 index 0000000000..710c58a821 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs @@ -0,0 +1,300 @@ +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using UnityEngine; +using Unity.Barracuda; + +namespace Unity.MLAgents.Sensors +{ + /// + /// Sensor that wraps around another Sensor to provide temporal stacking. + /// Conceptually, consecutive observations are stored left-to-right, which is how they're output + /// For example, 4 stacked sets of observations would be output like + /// | t = now - 3 | t = now -3 | t = now - 2 | t = now | + /// Internally, a circular buffer of arrays is used. The m_CurrentIndex represents the most recent observation. + /// Currently, observations are stacked on the last dimension. + /// + public class StackingSensor : ISensor, IBuiltInSensor + { + /// + /// The wrapped sensor. + /// + ISensor m_WrappedSensor; + + /// + /// Number of stacks to save + /// + int m_NumStackedObservations; + int m_UnstackedObservationSize; + + string m_Name; + private ObservationSpec m_ObservationSpec; + private ObservationSpec m_WrappedSpec; + + /// + /// Buffer of previous observations + /// + float[][] m_StackedObservations; + + byte[][] m_StackedCompressedObservations; + + int m_CurrentIndex; + ObservationWriter m_LocalWriter = new ObservationWriter(); + + byte[] m_EmptyCompressedObservation; + int[] m_CompressionMapping; + TensorShape m_tensorShape; + + /// + /// Initializes the sensor. + /// + /// The wrapped sensor. + /// Number of stacked observations to keep. + public StackingSensor(ISensor wrapped, int numStackedObservations) + { + // TODO ensure numStackedObservations > 1 + m_WrappedSensor = wrapped; + m_NumStackedObservations = numStackedObservations; + + m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}"; + + m_WrappedSpec = wrapped.GetObservationSpec(); + + m_UnstackedObservationSize = wrapped.ObservationSize(); + + // Set up the cached observation spec for the StackingSensor + var newShape = m_WrappedSpec.Shape; + // TODO support arbitrary stacking dimension + newShape[newShape.Length - 1] *= numStackedObservations; + m_ObservationSpec = new ObservationSpec( + newShape, m_WrappedSpec.DimensionProperties, m_WrappedSpec.ObservationType + ); + + // Initialize uncompressed buffer anyway in case python trainer does not + // support the compression mapping and has to fall back to uncompressed obs. + m_StackedObservations = new float[numStackedObservations][]; + for (var i = 0; i < numStackedObservations; i++) + { + m_StackedObservations[i] = new float[m_UnstackedObservationSize]; + } + + if (m_WrappedSensor.GetCompressionSpec().SensorCompressionType != SensorCompressionType.None) + { + m_StackedCompressedObservations = new byte[numStackedObservations][]; + m_EmptyCompressedObservation = CreateEmptyPNG(); + for (var i = 0; i < numStackedObservations; i++) + { + m_StackedCompressedObservations[i] = m_EmptyCompressedObservation; + } + m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped); + } + + if (m_WrappedSpec.Rank != 1) + { + var wrappedShape = m_WrappedSpec.Shape; + m_tensorShape = new TensorShape(0, wrappedShape[0], wrappedShape[1], wrappedShape[2]); + } + } + + /// + public int Write(ObservationWriter writer) + { + // First, call the wrapped sensor's write method. Make sure to use our own writer, not the passed one. + m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedSpec, 0); + m_WrappedSensor.Write(m_LocalWriter); + + // Now write the saved observations (oldest first) + var numWritten = 0; + if (m_WrappedSpec.Rank == 1) + { + for (var i = 0; i < m_NumStackedObservations; i++) + { + var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations; + writer.AddList(m_StackedObservations[obsIndex], numWritten); + numWritten += m_UnstackedObservationSize; + } + } + else + { + for (var i = 0; i < m_NumStackedObservations; i++) + { + var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations; + for (var h = 0; h < m_WrappedSpec.Shape[0]; h++) + { + for (var w = 0; w < m_WrappedSpec.Shape[1]; w++) + { + for (var c = 0; c < m_WrappedSpec.Shape[2]; c++) + { + writer[h, w, i * m_WrappedSpec.Shape[2] + c] = m_StackedObservations[obsIndex][m_tensorShape.Index(0, h, w, c)]; + } + } + } + } + numWritten = m_WrappedSpec.Shape[0] * m_WrappedSpec.Shape[1] * m_WrappedSpec.Shape[2] * m_NumStackedObservations; + } + + return numWritten; + } + + /// + /// Updates the index of the "current" buffer. + /// + public void Update() + { + m_WrappedSensor.Update(); + m_CurrentIndex = (m_CurrentIndex + 1) % m_NumStackedObservations; + } + + /// + public void Reset() + { + m_WrappedSensor.Reset(); + // Zero out the buffer. + for (var i = 0; i < m_NumStackedObservations; i++) + { + Array.Clear(m_StackedObservations[i], 0, m_StackedObservations[i].Length); + } + if (m_WrappedSensor.GetCompressionSpec().SensorCompressionType != SensorCompressionType.None) + { + for (var i = 0; i < m_NumStackedObservations; i++) + { + m_StackedCompressedObservations[i] = m_EmptyCompressedObservation; + } + } + } + + /// + public ObservationSpec GetObservationSpec() + { + return m_ObservationSpec; + } + + /// + public string GetName() + { + return m_Name; + } + + /// + public byte[] GetCompressedObservation() + { + var compressed = m_WrappedSensor.GetCompressedObservation(); + m_StackedCompressedObservations[m_CurrentIndex] = compressed; + + int bytesLength = 0; + foreach (byte[] compressedObs in m_StackedCompressedObservations) + { + bytesLength += compressedObs.Length; + } + + byte[] outputBytes = new byte[bytesLength]; + int offset = 0; + for (var i = 0; i < m_NumStackedObservations; i++) + { + var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations; + Buffer.BlockCopy(m_StackedCompressedObservations[obsIndex], + 0, outputBytes, offset, m_StackedCompressedObservations[obsIndex].Length); + offset += m_StackedCompressedObservations[obsIndex].Length; + } + + return outputBytes; + } + + /// + public CompressionSpec GetCompressionSpec() + { + var wrappedSpec = m_WrappedSensor.GetCompressionSpec(); + return new CompressionSpec(wrappedSpec.SensorCompressionType, m_CompressionMapping); + } + + /// + /// Create Empty PNG for initializing the buffer for stacking. + /// + internal byte[] CreateEmptyPNG() + { + var shape = m_WrappedSpec.Shape; + int height = shape[0]; + int width = shape[1]; + var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false); + Color32[] resetColorArray = texture2D.GetPixels32(); + Color32 black = new Color32(0, 0, 0, 0); + for (int i = 0; i < resetColorArray.Length; i++) + { + resetColorArray[i] = black; + } + texture2D.SetPixels32(resetColorArray); + texture2D.Apply(); + return texture2D.EncodeToPNG(); + } + + /// + /// Construct stacked CompressedChannelMapping. + /// + internal int[] ConstructStackedCompressedChannelMapping(ISensor wrappedSenesor) + { + // Get CompressedChannelMapping of the wrapped sensor. If the + // wrapped sensor doesn't have one, use default mapping. + // Default mapping: {0, 0, 0} for grayscale, identity mapping {1, 2, ..., n} otherwise. + int[] wrappedMapping = null; + int wrappedNumChannel = m_WrappedSpec.Shape[2]; + + wrappedMapping = wrappedSenesor.GetCompressionSpec().CompressedChannelMapping; + if (wrappedMapping == null) + { + if (wrappedNumChannel == 1) + { + wrappedMapping = new[] { 0, 0, 0 }; + } + else + { + wrappedMapping = Enumerable.Range(0, wrappedNumChannel).ToArray(); + } + } + + // Construct stacked mapping using the mapping of wrapped sensor. + // First pad the wrapped mapping to multiple of 3, then repeat + // and add offset to each copy to form the stacked mapping. + int paddedMapLength = (wrappedMapping.Length + 2) / 3 * 3; + var compressionMapping = new int[paddedMapLength * m_NumStackedObservations]; + for (var i = 0; i < m_NumStackedObservations; i++) + { + var offset = wrappedNumChannel * i; + for (var j = 0; j < paddedMapLength; j++) + { + if (j < wrappedMapping.Length) + { + compressionMapping[j + paddedMapLength * i] = wrappedMapping[j] >= 0 ? wrappedMapping[j] + offset : -1; + } + else + { + compressionMapping[j + paddedMapLength * i] = -1; + } + } + } + return compressionMapping; + } + + /// + public BuiltInSensorType GetBuiltInSensorType() + { + IBuiltInSensor wrappedBuiltInSensor = m_WrappedSensor as IBuiltInSensor; + return wrappedBuiltInSensor?.GetBuiltInSensorType() ?? BuiltInSensorType.Unknown; + } + + /// + /// Returns the stacked observations as a read-only collection. + /// + /// The stacked observations as a read-only collection. + internal ReadOnlyCollection GetStackedObservations() + { + List observations = new List(); + for (var i = 0; i < m_NumStackedObservations; i++) + { + var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations; + observations.AddRange(m_StackedObservations[obsIndex].ToList()); + } + return observations.AsReadOnly(); + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs.meta new file mode 100644 index 0000000000..f0289542ff --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 8b7a6e88d47d4438ad67e1862566462c +timeCreated: 1572299581 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs b/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs new file mode 100644 index 0000000000..d4bd0507c4 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs @@ -0,0 +1,218 @@ +using System.Collections.Generic; +using System.Collections.ObjectModel; +using UnityEngine; + +namespace Unity.MLAgents.Sensors +{ + /// + /// A sensor implementation for vector observations. + /// + public class VectorSensor : ISensor, IBuiltInSensor + { + // TODO use float[] instead + // TODO allow setting float[] + List m_Observations; + ObservationSpec m_ObservationSpec; + string m_Name; + + /// + /// Initializes the sensor. + /// + /// Number of vector observations. + /// Name of the sensor. + /// + public VectorSensor(int observationSize, string name = null, ObservationType observationType = ObservationType.Default) + { + if (string.IsNullOrEmpty(name)) + { + name = $"VectorSensor_size{observationSize}"; + if (observationType != ObservationType.Default) + { + name += $"_{observationType.ToString()}"; + } + } + + m_Observations = new List(observationSize); + m_Name = name; + m_ObservationSpec = ObservationSpec.Vector(observationSize, observationType); + } + + /// + public int Write(ObservationWriter writer) + { + var expectedObservations = m_ObservationSpec.Shape[0]; + if (m_Observations.Count > expectedObservations) + { + // Too many observations, truncate + Debug.LogWarningFormat( + "More observations ({0}) made than vector observation size ({1}). The observations will be truncated.", + m_Observations.Count, expectedObservations + ); + m_Observations.RemoveRange(expectedObservations, m_Observations.Count - expectedObservations); + } + else if (m_Observations.Count < expectedObservations) + { + // Not enough observations; pad with zeros. + Debug.LogWarningFormat( + "Fewer observations ({0}) made than vector observation size ({1}). The observations will be padded.", + m_Observations.Count, expectedObservations + ); + for (int i = m_Observations.Count; i < expectedObservations; i++) + { + m_Observations.Add(0); + } + } + writer.AddList(m_Observations); + return expectedObservations; + } + + /// + /// Returns a read-only view of the observations that added. + /// + /// A read-only view of the observations list. + internal ReadOnlyCollection GetObservations() + { + return m_Observations.AsReadOnly(); + } + + /// + public void Update() + { + Clear(); + } + + /// + public void Reset() + { + Clear(); + } + + /// + public ObservationSpec GetObservationSpec() + { + return m_ObservationSpec; + } + + /// + public string GetName() + { + return m_Name; + } + + /// + public virtual byte[] GetCompressedObservation() + { + return null; + } + + /// + public CompressionSpec GetCompressionSpec() + { + return CompressionSpec.Default(); + } + + /// + public BuiltInSensorType GetBuiltInSensorType() + { + return BuiltInSensorType.VectorSensor; + } + + void Clear() + { + m_Observations.Clear(); + } + + void AddFloatObs(float obs) + { + Utilities.DebugCheckNanAndInfinity(obs, nameof(obs), nameof(AddFloatObs)); + m_Observations.Add(obs); + } + + // Compatibility methods with Agent observation. These should be removed eventually. + + /// + /// Adds a float observation to the vector observations of the agent. + /// + /// Observation. + public void AddObservation(float observation) + { + AddFloatObs(observation); + } + + /// + /// Adds an integer observation to the vector observations of the agent. + /// + /// Observation. + public void AddObservation(int observation) + { + AddFloatObs(observation); + } + + /// + /// Adds an Vector3 observation to the vector observations of the agent. + /// + /// Observation. + public void AddObservation(Vector3 observation) + { + AddFloatObs(observation.x); + AddFloatObs(observation.y); + AddFloatObs(observation.z); + } + + /// + /// Adds an Vector2 observation to the vector observations of the agent. + /// + /// Observation. + public void AddObservation(Vector2 observation) + { + AddFloatObs(observation.x); + AddFloatObs(observation.y); + } + + /// + /// Adds a list or array of float observations to the vector observations of the agent. + /// + /// Observation. + public void AddObservation(IList observation) + { + for (var i = 0; i < observation.Count; i++) + { + AddFloatObs(observation[i]); + } + } + + /// + /// Adds a quaternion observation to the vector observations of the agent. + /// + /// Observation. + public void AddObservation(Quaternion observation) + { + AddFloatObs(observation.x); + AddFloatObs(observation.y); + AddFloatObs(observation.z); + AddFloatObs(observation.w); + } + + /// + /// Adds a boolean observation to the vector observation of the agent. + /// + /// Observation. + public void AddObservation(bool observation) + { + AddFloatObs(observation ? 1f : 0f); + } + + /// + /// Adds a one-hot encoding observation. + /// + /// The index of this observation. + /// The upper limit on the value observation can take (exclusive). + public void AddOneHotObservation(int observation, int range) + { + for (var i = 0; i < range; i++) + { + AddFloatObs(i == observation ? 1.0f : 0.0f); + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs.meta new file mode 100644 index 0000000000..277ef0d59e --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: e3966c9961b343108808d91a4d140a68 +timeCreated: 1572300800 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs new file mode 100644 index 0000000000..26deb7434f --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs @@ -0,0 +1,87 @@ +using UnityEngine; + +namespace Unity.MLAgents.Sensors +{ + /// + /// A SensorComponent that creates a . + /// + [AddComponentMenu("ML Agents/Vector Sensor", (int)MenuGroup.Sensors)] + public class VectorSensorComponent : SensorComponent + { + /// + /// Name of the generated object. + /// Note that changing this at runtime does not affect how the Agent sorts the sensors. + /// + public string SensorName + { + get { return m_SensorName; } + set { m_SensorName = value; } + } + [HideInInspector, SerializeField] + private string m_SensorName = "VectorSensor"; + + /// + /// The number of float observations in the VectorSensor + /// + public int ObservationSize + { + get { return m_ObservationSize; } + set { m_ObservationSize = value; } + } + + [HideInInspector, SerializeField] + int m_ObservationSize; + + [HideInInspector, SerializeField] + ObservationType m_ObservationType; + + VectorSensor m_Sensor; + + /// + /// The type of the observation. + /// + public ObservationType ObservationType + { + get { return m_ObservationType; } + set { m_ObservationType = value; } + } + + [HideInInspector, SerializeField] + [Range(1, 50)] + [Tooltip("Number of camera frames that will be stacked before being fed to the neural network.")] + int m_ObservationStacks = 1; + + /// + /// Whether to stack previous observations. Using 1 means no previous observations. + /// Note that changing this after the sensor is created has no effect. + /// + public int ObservationStacks + { + get { return m_ObservationStacks; } + set { m_ObservationStacks = value; } + } + + /// + /// Creates a VectorSensor. + /// + /// + public override ISensor[] CreateSensors() + { + m_Sensor = new VectorSensor(m_ObservationSize, m_SensorName, m_ObservationType); + if (ObservationStacks != 1) + { + return new ISensor[] { new StackingSensor(m_Sensor, ObservationStacks) }; + } + return new ISensor[] { m_Sensor }; + } + + /// + /// Returns the underlying VectorSensor + /// + /// + public VectorSensor GetSensor() + { + return m_Sensor; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs.meta b/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs.meta new file mode 100644 index 0000000000..c867a60f2b --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 38b7cc1f5819445aa85e9a9b054552dc +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/SideChannels.meta b/com.unity.ml-agents/Runtime/SideChannels.meta new file mode 100644 index 0000000000..6bff982a90 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 9de9d822922c6454ca88483e2b9eeeac +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs b/com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs new file mode 100644 index 0000000000..7bea193dd5 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs @@ -0,0 +1,76 @@ +using System; +using UnityEngine; + +namespace Unity.MLAgents.SideChannels +{ + + /// + /// Side channel that supports modifying attributes specific to the Unity Engine. + /// + internal class EngineConfigurationChannel : SideChannel + { + internal enum ConfigurationType : int + { + ScreenResolution = 0, + QualityLevel = 1, + TimeScale = 2, + TargetFrameRate = 3, + CaptureFrameRate = 4 + } + + const string k_EngineConfigId = "e951342c-4f7e-11ea-b238-784f4387d1f7"; + + /// + /// Initializes the side channel. The constructor is internal because only one instance is + /// supported at a time, and is created by the Academy. + /// + internal EngineConfigurationChannel() + { + ChannelId = new Guid(k_EngineConfigId); + } + + /// + protected override void OnMessageReceived(IncomingMessage msg) + { + var messageType = (ConfigurationType)msg.ReadInt32(); + switch (messageType) + { + case ConfigurationType.ScreenResolution: + var width = msg.ReadInt32(); + var height = msg.ReadInt32(); + Screen.SetResolution(width, height, false); + break; + case ConfigurationType.QualityLevel: + var qualityLevel = msg.ReadInt32(); + QualitySettings.SetQualityLevel(qualityLevel, true); + break; + case ConfigurationType.TimeScale: + var timeScale = msg.ReadFloat32(); + + // There's an upper limit for the timeScale in the editor (but not in the player) + // Always ensure that timeScale >= 1 also, +#if UNITY_EDITOR + const float maxTimeScale = 100f; +#else + const float maxTimeScale = float.PositiveInfinity; +#endif + timeScale = Mathf.Clamp(timeScale, 1, maxTimeScale); + Time.timeScale = timeScale; + break; + case ConfigurationType.TargetFrameRate: + var targetFrameRate = msg.ReadInt32(); + Application.targetFrameRate = targetFrameRate; + break; + case ConfigurationType.CaptureFrameRate: + var captureFrameRate = msg.ReadInt32(); + Time.captureFramerate = captureFrameRate; + break; + default: + Debug.LogWarning( + "Unknown engine configuration received from Python. Make sure" + + " your Unity and Python versions are compatible."); + break; + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs.meta b/com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs.meta new file mode 100644 index 0000000000..8f6335e9b0 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 18ccdf3ce76784f2db68016fa284c33f +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs b/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs new file mode 100644 index 0000000000..9b2842b300 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs @@ -0,0 +1,143 @@ +using System.Collections.Generic; +using System; +using UnityEngine; + +namespace Unity.MLAgents.SideChannels +{ + /// + /// Lists the different data types supported. + /// + internal enum EnvironmentDataTypes + { + Float = 0, + Sampler = 1 + } + + /// + /// The types of distributions from which to sample reset parameters. + /// + internal enum SamplerType + { + /// + /// Samples a reset parameter from a uniform distribution. + /// + Uniform = 0, + + /// + /// Samples a reset parameter from a Gaussian distribution. + /// + Gaussian = 1, + + /// + /// Samples a reset parameter from a MultiRangeUniform distribution. + /// + MultiRangeUniform = 2 + + } + + /// + /// A side channel that manages the environment parameter values from Python. Currently + /// limited to parameters of type float. + /// + internal class EnvironmentParametersChannel : SideChannel + { + Dictionary> m_Parameters = new Dictionary>(); + Dictionary> m_RegisteredActions = + new Dictionary>(); + + const string k_EnvParamsId = "534c891e-810f-11ea-a9d0-822485860400"; + + /// + /// Initializes the side channel. The constructor is internal because only one instance is + /// supported at a time, and is created by the Academy. + /// + internal EnvironmentParametersChannel() + { + ChannelId = new Guid(k_EnvParamsId); + } + + /// + protected override void OnMessageReceived(IncomingMessage msg) + { + var key = msg.ReadString(); + var type = msg.ReadInt32(); + if ((int)EnvironmentDataTypes.Float == type) + { + var value = msg.ReadFloat32(); + + m_Parameters[key] = () => value; + + Action action; + m_RegisteredActions.TryGetValue(key, out action); + action?.Invoke(value); + } + else if ((int)EnvironmentDataTypes.Sampler == type) + { + int seed = msg.ReadInt32(); + int samplerType = msg.ReadInt32(); + Func sampler = () => 0.0f; + if ((int)SamplerType.Uniform == samplerType) + { + float min = msg.ReadFloat32(); + float max = msg.ReadFloat32(); + sampler = SamplerFactory.CreateUniformSampler(min, max, seed); + } + else if ((int)SamplerType.Gaussian == samplerType) + { + float mean = msg.ReadFloat32(); + float stddev = msg.ReadFloat32(); + + sampler = SamplerFactory.CreateGaussianSampler(mean, stddev, seed); + } + else if ((int)SamplerType.MultiRangeUniform == samplerType) + { + IList intervals = msg.ReadFloatList(); + sampler = SamplerFactory.CreateMultiRangeUniformSampler(intervals, seed); + } + else + { + Debug.LogWarning("EnvironmentParametersChannel received an unknown data type."); + } + m_Parameters[key] = sampler; + } + else + { + Debug.LogWarning("EnvironmentParametersChannel received an unknown data type."); + } + } + + /// + /// Returns the parameter value associated with the provided key. Returns the default + /// value if one doesn't exist. + /// + /// Parameter key. + /// Default value to return. + /// + public float GetWithDefault(string key, float defaultValue) + { + Func valueOut; + bool hasKey = m_Parameters.TryGetValue(key, out valueOut); + return hasKey ? valueOut.Invoke() : defaultValue; + } + + /// + /// Registers a callback for the associated parameter key. Will overwrite any existing + /// actions for this parameter key. + /// + /// The parameter key. + /// The callback. + public void RegisterCallback(string key, Action action) + { + m_RegisteredActions[key] = action; + } + + /// + /// Returns all parameter keys that have a registered value. + /// + /// + public IList ListParameters() + { + return new List(m_Parameters.Keys); + } + } +} diff --git a/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs.meta b/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs.meta new file mode 100644 index 0000000000..f118b1f99f --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: a849760d5bec946b884984e35c66fcfa +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/SideChannels/FloatPropertiesChannel.cs b/com.unity.ml-agents/Runtime/SideChannels/FloatPropertiesChannel.cs new file mode 100644 index 0000000000..f4c293547b --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/FloatPropertiesChannel.cs @@ -0,0 +1,97 @@ +using System.Collections.Generic; +using System; + +namespace Unity.MLAgents.SideChannels +{ + /// + /// Side channel that is comprised of a collection of float variables. + /// + public class FloatPropertiesChannel : SideChannel + { + Dictionary m_FloatProperties = new Dictionary(); + Dictionary> m_RegisteredActions = new Dictionary>(); + const string k_FloatPropertiesDefaultId = "60ccf7d0-4f7e-11ea-b238-784f4387d1f7"; + + /// + /// Initializes the side channel with the provided channel ID. + /// + /// ID for the side channel. + public FloatPropertiesChannel(Guid channelId = default(Guid)) + { + if (channelId == default(Guid)) + { + ChannelId = new Guid(k_FloatPropertiesDefaultId); + } + else + { + ChannelId = channelId; + } + } + + /// + protected override void OnMessageReceived(IncomingMessage msg) + { + var key = msg.ReadString(); + var value = msg.ReadFloat32(); + + m_FloatProperties[key] = value; + + Action action; + m_RegisteredActions.TryGetValue(key, out action); + action?.Invoke(value); + } + + /// + /// Sets one of the float properties of the environment. This data will be sent to Python. + /// + /// The string identifier of the property. + /// The float value of the property. + public void Set(string key, float value) + { + m_FloatProperties[key] = value; + using (var msgOut = new OutgoingMessage()) + { + msgOut.WriteString(key); + msgOut.WriteFloat32(value); + QueueMessageToSend(msgOut); + } + + Action action; + m_RegisteredActions.TryGetValue(key, out action); + action?.Invoke(value); + } + + /// + /// Get an Environment property with a default value. If there is a value for this property, + /// it will be returned, otherwise, the default value will be returned. + /// + /// The string identifier of the property. + /// The default value of the property. + /// + public float GetWithDefault(string key, float defaultValue) + { + float valueOut; + bool hasKey = m_FloatProperties.TryGetValue(key, out valueOut); + return hasKey ? valueOut : defaultValue; + } + + /// + /// Registers an action to be performed everytime the property is changed. + /// + /// The string identifier of the property. + /// The action that ill be performed. Takes a float as input. + public void RegisterCallback(string key, Action action) + { + m_RegisteredActions[key] = action; + } + + /// + /// Returns a list of all the string identifiers of the properties currently present. + /// + /// The list of string identifiers + public IList Keys() + { + return new List(m_FloatProperties.Keys); + } + } +} diff --git a/com.unity.ml-agents/Runtime/SideChannels/FloatPropertiesChannel.cs.meta b/com.unity.ml-agents/Runtime/SideChannels/FloatPropertiesChannel.cs.meta new file mode 100644 index 0000000000..d4b87eb1e4 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/FloatPropertiesChannel.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 452f8b3c01c4642aba645dcf0b6bfc6e +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/SideChannels/IncomingMessage.cs b/com.unity.ml-agents/Runtime/SideChannels/IncomingMessage.cs new file mode 100644 index 0000000000..90425955de --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/IncomingMessage.cs @@ -0,0 +1,127 @@ +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System; +using System.IO; +using System.Text; + +namespace Unity.MLAgents.SideChannels +{ + /// + /// Utility class for reading the data sent to the SideChannel. + /// + public class IncomingMessage : IDisposable + { + byte[] m_Data; + Stream m_Stream; + BinaryReader m_Reader; + + /// + /// Construct an IncomingMessage from the byte array. + /// + /// + public IncomingMessage(byte[] data) + { + m_Data = data; + m_Stream = new MemoryStream(data); + m_Reader = new BinaryReader(m_Stream); + } + + /// + /// Read a boolean value from the message. + /// + /// Default value to use if the end of the message is reached. + /// + public bool ReadBoolean(bool defaultValue = false) + { + return CanReadMore() ? m_Reader.ReadBoolean() : defaultValue; + } + + /// + /// Read an integer value from the message. + /// + /// Default value to use if the end of the message is reached. + /// + public int ReadInt32(int defaultValue = 0) + { + return CanReadMore() ? m_Reader.ReadInt32() : defaultValue; + } + + /// + /// Read a float value from the message. + /// + /// Default value to use if the end of the message is reached. + /// + public float ReadFloat32(float defaultValue = 0.0f) + { + return CanReadMore() ? m_Reader.ReadSingle() : defaultValue; + } + + /// + /// Read a string value from the message. + /// + /// Default value to use if the end of the message is reached. + /// + public string ReadString(string defaultValue = default) + { + if (!CanReadMore()) + { + return defaultValue; + } + + var strLength = ReadInt32(); + var str = Encoding.ASCII.GetString(m_Reader.ReadBytes(strLength)); + return str; + } + + /// + /// Reads a list of floats from the message. The length of the list is stored in the message. + /// + /// Default value to use if the end of the message is reached. + /// + public IList ReadFloatList(IList defaultValue = default) + { + if (!CanReadMore()) + { + return defaultValue; + } + + var len = ReadInt32(); + var output = new float[len]; + for (var i = 0; i < len; i++) + { + output[i] = ReadFloat32(); + } + + return output; + } + + /// + /// Gets the original data of the message. Note that this will return all of the data, + /// even if part of it has already been read. + /// + /// + public byte[] GetRawBytes() + { + return m_Data; + } + + /// + /// Clean up the internal storage. + /// + public void Dispose() + { + m_Reader?.Dispose(); + m_Stream?.Dispose(); + } + + /// + /// Whether or not there is more data left in the stream that can be read. + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + bool CanReadMore() + { + return m_Stream.Position < m_Stream.Length; + } + } +} diff --git a/com.unity.ml-agents/Runtime/SideChannels/IncomingMessage.cs.meta b/com.unity.ml-agents/Runtime/SideChannels/IncomingMessage.cs.meta new file mode 100644 index 0000000000..f70c658d1a --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/IncomingMessage.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: c8043cec65aeb4ec09db1d25ad694328 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/SideChannels/OutgoingMessage.cs b/com.unity.ml-agents/Runtime/SideChannels/OutgoingMessage.cs new file mode 100644 index 0000000000..70a7948e22 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/OutgoingMessage.cs @@ -0,0 +1,110 @@ +using System.Collections.Generic; +using System; +using System.IO; +using System.Text; + +namespace Unity.MLAgents.SideChannels +{ + /// + /// Utility class for forming the data that is sent to the SideChannel. + /// + public class OutgoingMessage : IDisposable + { + BinaryWriter m_Writer; + MemoryStream m_Stream; + + /// + /// Create a new empty OutgoingMessage. + /// + public OutgoingMessage() + { + m_Stream = new MemoryStream(); + m_Writer = new BinaryWriter(m_Stream); + } + + /// + /// Clean up the internal storage. + /// + public void Dispose() + { + m_Writer?.Dispose(); + m_Stream?.Dispose(); + } + + /// + /// Write a boolean value to the message. + /// + /// + public void WriteBoolean(bool b) + { + m_Writer.Write(b); + } + + /// + /// Write an interger value to the message. + /// + /// + public void WriteInt32(int i) + { + m_Writer.Write(i); + } + + /// + /// Write a float values to the message. + /// + /// + public void WriteFloat32(float f) + { + m_Writer.Write(f); + } + + /// + /// Write a string value to the message. + /// + /// + public void WriteString(string s) + { + var stringEncoded = Encoding.ASCII.GetBytes(s); + m_Writer.Write(stringEncoded.Length); + m_Writer.Write(stringEncoded); + } + + /// + /// Write a list or array of floats to the message. + /// + /// + public void WriteFloatList(IList floatList) + { + WriteInt32(floatList.Count); + foreach (var f in floatList) + { + WriteFloat32(f); + } + } + + /// + /// Overwrite the message with a specific byte array. + /// + /// + public void SetRawBytes(byte[] data) + { + // Reset first. Set the length to zero so that if there's more data than we're going to + // write, we don't have any of the original data. + m_Stream.Seek(0, SeekOrigin.Begin); + m_Stream.SetLength(0); + + // Then append the data. Increase the capacity if needed (but don't shrink it). + m_Stream.Capacity = (m_Stream.Capacity < data.Length) ? data.Length : m_Stream.Capacity; + m_Stream.Write(data, 0, data.Length); + } + + /// + /// Read the byte array of the message. + /// + /// + internal byte[] ToByteArray() + { + return m_Stream.ToArray(); + } + } +} diff --git a/com.unity.ml-agents/Runtime/SideChannels/OutgoingMessage.cs.meta b/com.unity.ml-agents/Runtime/SideChannels/OutgoingMessage.cs.meta new file mode 100644 index 0000000000..348b80de80 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/OutgoingMessage.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 1a007135a9a1e49849eb2d295f4c3879 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/SideChannels/RawBytesChannel.cs b/com.unity.ml-agents/Runtime/SideChannels/RawBytesChannel.cs new file mode 100644 index 0000000000..133832447b --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/RawBytesChannel.cs @@ -0,0 +1,70 @@ +using System.Collections.Generic; +using System; + +namespace Unity.MLAgents.SideChannels +{ + /// + /// Side channel for managing raw bytes of data. It is up to the clients of this side channel + /// to interpret the messages. + /// + public class RawBytesChannel : SideChannel + { + List m_MessagesReceived = new List(); + + /// + /// RawBytesChannel provides a way to exchange raw byte arrays between Unity and Python. + /// + /// The identifier for the RawBytesChannel. Must be + /// the same on Python and Unity. + public RawBytesChannel(Guid channelId) + { + ChannelId = channelId; + } + + /// + protected override void OnMessageReceived(IncomingMessage msg) + { + m_MessagesReceived.Add(msg.GetRawBytes()); + } + + /// + /// Sends the byte array message to the Python side channel. The message will be sent + /// alongside the simulation step. + /// + /// The byte array of data to send to Python. + public void SendRawBytes(byte[] data) + { + using (var msg = new OutgoingMessage()) + { + msg.SetRawBytes(data); + QueueMessageToSend(msg); + } + } + + /// + /// Gets the messages that were sent by python since the last call to + /// GetAndClearReceivedMessages. + /// + /// a list of byte array messages that Python has sent. + public IList GetAndClearReceivedMessages() + { + var result = new List(); + result.AddRange(m_MessagesReceived); + m_MessagesReceived.Clear(); + return result; + } + + /// + /// Gets the messages that were sent by python since the last call to + /// GetAndClearReceivedMessages. Note that the messages received will not + /// be cleared with a call to GetReceivedMessages. + /// + /// a list of byte array messages that Python has sent. + public IList GetReceivedMessages() + { + var result = new List(); + result.AddRange(m_MessagesReceived); + return result; + } + } +} diff --git a/com.unity.ml-agents/Runtime/SideChannels/RawBytesChannel.cs.meta b/com.unity.ml-agents/Runtime/SideChannels/RawBytesChannel.cs.meta new file mode 100644 index 0000000000..90a49234ba --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/RawBytesChannel.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 40b01e9cdbfd94865b54ebeb4e5aeaa5 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/SideChannels/SideChannel.cs b/com.unity.ml-agents/Runtime/SideChannels/SideChannel.cs new file mode 100644 index 0000000000..250e638b0f --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/SideChannel.cs @@ -0,0 +1,69 @@ +using System.Collections.Generic; +using System; +using UnityEngine; + +namespace Unity.MLAgents.SideChannels +{ + /// + /// Side channels provide an alternative mechanism of sending/receiving data from Unity + /// to Python that is outside of the traditional machine learning loop. ML-Agents provides + /// some specific implementations of side channels, but users can create their own. + /// + /// To create your own, you'll need to create two, new mirrored classes, one in Unity (by + /// extending ) and another in Python by extending a Python class + /// also called SideChannel. Then, within your project, use + /// and + /// to register and unregister your + /// custom side channel. + /// + public abstract class SideChannel + { + // The list of messages (byte arrays) that need to be sent to Python via the communicator. + // Should only ever be read and cleared by a ICommunicator object. + internal List MessageQueue = new List(); + + /// + /// An int identifier for the SideChannel. Ensures that there is only ever one side channel + /// of each type. Ensure the Unity side channels will be linked to their Python equivalent. + /// + /// The integer identifier of the SideChannel. + public Guid ChannelId + { + get; + protected set; + } + + internal void ProcessMessage(byte[] msg) + { + try + { + using (var incomingMsg = new IncomingMessage(msg)) + { + OnMessageReceived(incomingMsg); + } + } + catch (Exception ex) + { + // Catch all errors in the sidechannel processing, so that a single + // bad SideChannel implementation doesn't take everything down with it. + Debug.LogError($"Error processing SideChannel message: {ex}.\nThe message will be skipped."); + } + } + + /// + /// Is called by the communicator every time a message is received from Python by the SideChannel. + /// Can be called multiple times per simulation step if multiple messages were sent. + /// + /// The incoming message. + protected abstract void OnMessageReceived(IncomingMessage msg); + + /// + /// Queues a message to be sent to Python during the next simulation step. + /// + /// The byte array of data to be sent to Python. + protected void QueueMessageToSend(OutgoingMessage msg) + { + MessageQueue.Add(msg.ToByteArray()); + } + } +} diff --git a/com.unity.ml-agents/Runtime/SideChannels/SideChannel.cs.meta b/com.unity.ml-agents/Runtime/SideChannels/SideChannel.cs.meta new file mode 100644 index 0000000000..c668b0187f --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/SideChannel.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 77b7d19dd6ce343eeba907540b5a2286 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/SideChannels/SideChannelManager.cs b/com.unity.ml-agents/Runtime/SideChannels/SideChannelManager.cs new file mode 100644 index 0000000000..bdcc596b85 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/SideChannelManager.cs @@ -0,0 +1,244 @@ +using System; +using System.Collections.Generic; +using UnityEngine; +using System.IO; + +namespace Unity.MLAgents.SideChannels +{ + /// + /// Collection of static utilities for managing the registering/unregistering of + /// and the sending/receiving of messages for all the channels. + /// + public static class SideChannelManager + { + static Dictionary s_RegisteredChannels = new Dictionary(); + + struct CachedSideChannelMessage + { + public Guid ChannelId; + public byte[] Message; + } + + static readonly Queue s_CachedMessages = + new Queue(); + + /// + /// Register a side channel to begin sending and receiving messages. This method is + /// available for environments that have custom side channels. All built-in side + /// channels within the ML-Agents Toolkit are managed internally and do not need to + /// be explicitly registered/unregistered. A side channel may only be registered once. + /// + /// The side channel to register. + public static void RegisterSideChannel(SideChannel sideChannel) + { + var channelId = sideChannel.ChannelId; + if (s_RegisteredChannels.ContainsKey(channelId)) + { + throw new UnityAgentsException( + $"A side channel with id {channelId} is already registered. " + + "You cannot register multiple side channels of the same id."); + } + + // Process any messages that we've already received for this channel ID. + var numMessages = s_CachedMessages.Count; + for (var i = 0; i < numMessages; i++) + { + var cachedMessage = s_CachedMessages.Dequeue(); + if (channelId == cachedMessage.ChannelId) + { + sideChannel.ProcessMessage(cachedMessage.Message); + } + else + { + s_CachedMessages.Enqueue(cachedMessage); + } + } + s_RegisteredChannels.Add(channelId, sideChannel); + } + + /// + /// Unregister a side channel to stop sending and receiving messages. This method is + /// available for environments that have custom side channels. All built-in side + /// channels within the ML-Agents Toolkit are managed internally and do not need to + /// be explicitly registered/unregistered. Unregistering a side channel that has already + /// been unregistered (or never registered in the first place) has no negative side effects. + /// Note that unregistering a side channel may not stop the Python side + /// from sending messages, but it does mean that sent messages with not result in a call + /// to . Furthermore, + /// those messages will not be buffered and will, in essence, be lost. + /// + /// The side channel to unregister. + public static void UnregisterSideChannel(SideChannel sideChannel) + { + if (s_RegisteredChannels.ContainsKey(sideChannel.ChannelId)) + { + s_RegisteredChannels.Remove(sideChannel.ChannelId); + } + } + + /// + /// Unregisters all the side channels from the communicator. + /// + internal static void UnregisterAllSideChannels() + { + s_RegisteredChannels = new Dictionary(); + } + + /// + /// Returns the SideChannel of Type T if there is one registered, or null if it doesn't. + /// If there are multiple SideChannels of the same type registered, the returned instance is arbitrary. + /// + /// + /// + internal static T GetSideChannel() where T : SideChannel + { + foreach (var sc in s_RegisteredChannels.Values) + { + if (sc.GetType() == typeof(T)) + { + return (T)sc; + } + } + return null; + } + + /// + /// Grabs the messages that the registered side channels will send to Python at the current step + /// into a singe byte array. + /// + /// + internal static byte[] GetSideChannelMessage() + { + return GetSideChannelMessage(s_RegisteredChannels); + } + + /// + /// Grabs the messages that the registered side channels will send to Python at the current step + /// into a singe byte array. + /// + /// A dictionary of channel type to channel. + /// + internal static byte[] GetSideChannelMessage(Dictionary sideChannels) + { + if (!HasOutgoingMessages(sideChannels)) + { + // Early out so that we don't create the MemoryStream or BinaryWriter. + // This is the most common case. + return Array.Empty(); + } + + using (var memStream = new MemoryStream()) + { + using (var binaryWriter = new BinaryWriter(memStream)) + { + foreach (var sideChannel in sideChannels.Values) + { + var messageList = sideChannel.MessageQueue; + foreach (var message in messageList) + { + binaryWriter.Write(sideChannel.ChannelId.ToByteArray()); + binaryWriter.Write(message.Length); + binaryWriter.Write(message); + } + sideChannel.MessageQueue.Clear(); + } + return memStream.ToArray(); + } + } + } + + /// + /// Check whether any of the sidechannels have queued messages. + /// + /// + /// + static bool HasOutgoingMessages(Dictionary sideChannels) + { + foreach (var sideChannel in sideChannels.Values) + { + var messageList = sideChannel.MessageQueue; + if (messageList.Count > 0) + { + return true; + } + } + + return false; + } + + /// + /// Separates the data received from Python into individual messages for each registered side channel. + /// + /// The byte array of data received from Python. + internal static void ProcessSideChannelData(byte[] dataReceived) + { + ProcessSideChannelData(s_RegisteredChannels, dataReceived); + } + + /// + /// Separates the data received from Python into individual messages for each registered side channel. + /// + /// A dictionary of channel type to channel. + /// The byte array of data received from Python. + internal static void ProcessSideChannelData(Dictionary sideChannels, byte[] dataReceived) + { + while (s_CachedMessages.Count != 0) + { + var cachedMessage = s_CachedMessages.Dequeue(); + if (sideChannels.ContainsKey(cachedMessage.ChannelId)) + { + sideChannels[cachedMessage.ChannelId].ProcessMessage(cachedMessage.Message); + } + else + { + Debug.Log(string.Format( + "Unknown side channel data received. Channel Id is " + + ": {0}", cachedMessage.ChannelId)); + } + } + + if (dataReceived.Length == 0) + { + return; + } + using (var memStream = new MemoryStream(dataReceived)) + { + using (var binaryReader = new BinaryReader(memStream)) + { + while (memStream.Position < memStream.Length) + { + Guid channelId = Guid.Empty; + byte[] message = null; + try + { + channelId = new Guid(binaryReader.ReadBytes(16)); + var messageLength = binaryReader.ReadInt32(); + message = binaryReader.ReadBytes(messageLength); + } + catch (Exception ex) + { + throw new UnityAgentsException( + "There was a problem reading a message in a SideChannel. Please make sure the " + + "version of MLAgents in Unity is compatible with the Python version. Original error : " + + ex.Message); + } + if (sideChannels.ContainsKey(channelId)) + { + sideChannels[channelId].ProcessMessage(message); + } + else + { + // Don't recognize this ID, but cache it in case the SideChannel that can handle + // it is registered before the next call to ProcessSideChannelData. + s_CachedMessages.Enqueue(new CachedSideChannelMessage + { + ChannelId = channelId, + Message = message + }); + } + } + } + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/SideChannels/SideChannelManager.cs.meta b/com.unity.ml-agents/Runtime/SideChannels/SideChannelManager.cs.meta new file mode 100644 index 0000000000..251cc14632 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/SideChannelManager.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: ccc0d134445f947349c68a6d07e3cdc2 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs b/com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs new file mode 100644 index 0000000000..dbe1b5aeec --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs @@ -0,0 +1,43 @@ +using System; +namespace Unity.MLAgents.SideChannels +{ + /// + /// A Side Channel for sending data. + /// + internal class StatsSideChannel : SideChannel + { + const string k_StatsSideChannelDefaultId = "a1d8f7b7-cec8-50f9-b78b-d3e165a78520"; + + /// + /// Initializes the side channel. The constructor is internal because only one instance is + /// supported at a time. + /// + internal StatsSideChannel() + { + ChannelId = new Guid(k_StatsSideChannelDefaultId); + } + + /// + /// Add a stat value for reporting. + /// + /// The stat name. + /// The stat value. + /// How multiple values should be treated. + public void AddStat(string key, float value, StatAggregationMethod aggregationMethod) + { + using (var msg = new OutgoingMessage()) + { + msg.WriteString(key); + msg.WriteFloat32(value); + msg.WriteInt32((int)aggregationMethod); + QueueMessageToSend(msg); + } + } + + /// + protected override void OnMessageReceived(IncomingMessage msg) + { + throw new UnityAgentsException("StatsSideChannel should never receive messages."); + } + } +} diff --git a/com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs.meta b/com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs.meta new file mode 100644 index 0000000000..ebc11e7092 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 83a07fdb9e8f04536908a51447dfe548 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs b/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs new file mode 100644 index 0000000000..0c880e4be9 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs @@ -0,0 +1,52 @@ +using System; +using Unity.MLAgents.Analytics; +using Unity.MLAgents.CommunicatorObjects; + +namespace Unity.MLAgents.SideChannels +{ + /// + /// Side Channel implementation for recording which training features are being used. + /// + internal class TrainingAnalyticsSideChannel : SideChannel + { + const string k_TrainingAnalyticsConfigId = "b664a4a9-d86f-5a5f-95cb-e8353a7e8356"; + + /// + /// Initializes the side channel. The constructor is internal because only one instance is + /// supported at a time, and is created by the Academy. + /// + internal TrainingAnalyticsSideChannel() + { + ChannelId = new Guid(k_TrainingAnalyticsConfigId); + } + + /// + protected override void OnMessageReceived(IncomingMessage msg) + { + Google.Protobuf.WellKnownTypes.Any anyMessage = null; + try + { + anyMessage = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(msg.GetRawBytes()); + } + catch (Google.Protobuf.InvalidProtocolBufferException) + { + // Bad message, nothing we can do about it, so just ignore. + return; + } + + if (anyMessage.Is(TrainingEnvironmentInitialized.Descriptor)) + { + var envInitProto = anyMessage.Unpack(); + var envInitEvent = envInitProto.ToTrainingEnvironmentInitializedEvent(); + TrainingAnalytics.TrainingEnvironmentInitialized(envInitEvent); + } + else if (anyMessage.Is(TrainingBehaviorInitialized.Descriptor)) + { + var behaviorInitProto = anyMessage.Unpack(); + var behaviorTrainingEvent = behaviorInitProto.ToTrainingBehaviorInitializedEvent(); + TrainingAnalytics.TrainingBehaviorInitialized(behaviorTrainingEvent); + } + // Don't do anything for unknown types, since the user probably can't do anything about it. + } + } +} diff --git a/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs.meta b/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs.meta new file mode 100644 index 0000000000..757d0d0d4f --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 13c87198bbd54b40a0b93308eb37933e +timeCreated: 1608337471 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs b/com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs new file mode 100644 index 0000000000..c5fe6ce835 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs @@ -0,0 +1,145 @@ +using System; +using System.Linq; +using System.Collections.Generic; + +namespace Unity.MLAgents +{ + /// + /// A basic class implementation of MultiAgentGroup. + /// + public class SimpleMultiAgentGroup : IMultiAgentGroup, IDisposable + { + readonly int m_Id = MultiAgentGroupIdCounter.GetGroupId(); + HashSet m_Agents = new HashSet(); + + /// + /// Disposes of the SimpleMultiAgentGroup. + /// + public virtual void Dispose() + { + while (m_Agents.Count > 0) + { + UnregisterAgent(m_Agents.First()); + } + } + + /// + public virtual void RegisterAgent(Agent agent) + { + if (!m_Agents.Contains(agent)) + { + agent.SetMultiAgentGroup(this); + m_Agents.Add(agent); + agent.OnAgentDisabled += UnregisterAgent; + } + } + + /// + public virtual void UnregisterAgent(Agent agent) + { + if (m_Agents.Contains(agent)) + { + agent.SetMultiAgentGroup(null); + m_Agents.Remove(agent); + agent.OnAgentDisabled -= UnregisterAgent; + } + } + + /// + public int GetId() + { + return m_Id; + } + + /// + /// Get list of all agents currently registered to this MultiAgentGroup. + /// + /// + /// List of agents registered to the MultiAgentGroup. + /// + public IReadOnlyCollection GetRegisteredAgents() + { + return m_Agents; + } + + /// + /// Increments the group rewards for all agents in this MultiAgentGroup. + /// + /// + /// This function increases or decreases the group rewards by a given amount for all agents + /// in the group. Use to set the group reward assigned + /// to the current step with a specific value rather than increasing or decreasing it. + /// + /// A positive group reward indicates the whole group's accomplishments or desired behaviors. + /// Every agent in the group will receive the same group reward no matter whether the + /// agent's act directly leads to the reward. Group rewards are meant to reinforce agents + /// to act in the group's best interest instead of individual ones. + /// Group rewards are treated differently than individual agent rewards during training, so + /// calling AddGroupReward() is not equivalent to calling agent.AddReward() on each agent in the group. + /// + /// Incremental group reward value. + public void AddGroupReward(float reward) + { + foreach (var agent in m_Agents) + { + agent.AddGroupReward(reward); + } + } + + /// + /// Set the group rewards for all agents in this MultiAgentGroup. + /// + /// + /// This function replaces any group rewards given during the current step for all agents in the group. + /// Use to incrementally change the group reward rather than + /// overriding it. + /// + /// A positive group reward indicates the whole group's accomplishments or desired behaviors. + /// Every agent in the group will receive the same group reward no matter whether the + /// agent's act directly leads to the reward. Group rewards are meant to reinforce agents + /// to act in the group's best interest instead of indivisual ones. + /// Group rewards are treated differently than individual agent rewards during training, so + /// calling SetGroupReward() is not equivalent to calling agent.SetReward() on each agent in the group. + /// + /// The new value of the group reward. + public void SetGroupReward(float reward) + { + foreach (var agent in m_Agents) + { + agent.SetGroupReward(reward); + } + } + + /// + /// End episodes for all agents in this MultiAgentGroup. + /// + /// + /// This should be used when the episode can no longer continue, such as when the group + /// reaches the goal or fails at the task. + /// + public void EndGroupEpisode() + { + foreach (var agent in m_Agents) + { + agent.EndEpisode(); + } + } + + /// + /// Indicate that the episode is over but not due to the "fault" of the group. + /// This has the same end result as calling , but has a + /// slightly different effect on training. + /// + /// + /// This should be used when the episode could continue, but has gone on for + /// a sufficient number of steps, such as if the environment hits some maximum number of steps. + /// + public void GroupEpisodeInterrupted() + { + foreach (var agent in m_Agents) + { + agent.EpisodeInterrupted(); + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs.meta b/com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs.meta new file mode 100644 index 0000000000..33b0a0559e --- /dev/null +++ b/com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 3454e3c3c70964dca93b63ee4b650095 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/StatsRecorder.cs b/com.unity.ml-agents/Runtime/StatsRecorder.cs new file mode 100644 index 0000000000..d7250862b9 --- /dev/null +++ b/com.unity.ml-agents/Runtime/StatsRecorder.cs @@ -0,0 +1,80 @@ +using Unity.MLAgents.SideChannels; + +namespace Unity.MLAgents +{ + /// + /// Determines the behavior of how multiple stats within the same summary period are combined. + /// + public enum StatAggregationMethod + { + /// + /// Values within the summary period are averaged before reporting. + /// + Average = 0, + + /// + /// Only the most recent value is reported. + /// To avoid conflicts when training with multiple concurrent environments, only + /// stats from worker index 0 will be tracked. + /// + MostRecent = 1, + + /// + /// Values within the summary period are summed up before reporting. + /// + Sum = 2, + + /// + /// Values within the summary period are reported as a histogram. + /// + Histogram = 3 + } + + /// + /// Add stats (key-value pairs) for reporting. These values will sent these to a StatsReporter + /// instance, which means the values will appear in the TensorBoard summary, as well as trainer + /// gauges. You can nest stats in TensorBoard by adding "/" in the name (e.g. "Agent/Health" + /// and "Agent/Wallet"). Note that stats are only written to TensorBoard each summary_frequency + /// steps (a trainer configuration). If a stat is received multiple times, within that period + /// then the values will be aggregated using the provided. + /// + public sealed class StatsRecorder + { + /// + /// The side channel that is used to receive the new parameter values. + /// + readonly StatsSideChannel m_Channel; + + /// + /// Constructor. + /// + internal StatsRecorder() + { + m_Channel = new StatsSideChannel(); + SideChannelManager.RegisterSideChannel(m_Channel); + } + + /// + /// Add a stat value for reporting. + /// + /// The stat name. + /// + /// The stat value. You can nest stats in TensorBoard by using "/". + /// + /// + /// How multiple values sent in the same summary window should be treated. + /// + public void Add( + string key, + float value, + StatAggregationMethod aggregationMethod = StatAggregationMethod.Average) + { + m_Channel.AddStat(key, value, aggregationMethod); + } + + internal void Dispose() + { + SideChannelManager.UnregisterSideChannel(m_Channel); + } + } +} diff --git a/com.unity.ml-agents/Runtime/StatsRecorder.cs.meta b/com.unity.ml-agents/Runtime/StatsRecorder.cs.meta new file mode 100644 index 0000000000..bfc4addbb1 --- /dev/null +++ b/com.unity.ml-agents/Runtime/StatsRecorder.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: d9add8900e8a746e6a4cb410cb27d664 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Timer.cs b/com.unity.ml-agents/Runtime/Timer.cs new file mode 100644 index 0000000000..35a6660e75 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Timer.cs @@ -0,0 +1,533 @@ +// Compile with: csc CRefTest.cs -doc:Results.xml +#if UNITY_EDITOR || UNITY_STANDALONE +#define MLA_SUPPORTED_TRAINING_PLATFORM +#endif +using System; +using UnityEngine; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using UnityEngine.Profiling; +using System.Runtime.Serialization; +using System.Runtime.Serialization.Json; +using UnityEngine.SceneManagement; + +namespace Unity.MLAgents +{ + [DataContract] + internal class TimerNode + { + static string s_Separator = "."; + static double s_TicksToSeconds = 1e-7; // 100 ns per tick + + /// + /// Full name of the node. This is the node's parents full name concatenated with this + /// node's name. + /// + string m_FullName; + + /// + /// Child nodes, indexed by name. + /// + [DataMember(Name = "children", Order = 999)] + Dictionary m_Children; + + /// + /// Custom sampler used to add timings to the profiler. + /// + CustomSampler m_Sampler; + + /// + /// Number of total ticks elapsed for this node. + /// + long m_TotalTicks; + + /// + /// If the node is currently running, the time (in ticks) when the node was started. + /// If the node is not running, is set to 0. + /// + long m_TickStart; + + /// + /// Number of times the corresponding code block has been called. + /// + [DataMember(Name = "count")] + int m_NumCalls; + + /// + /// The total recorded ticks for the timer node, plus the currently elapsed ticks + /// if the timer is still running (i.e. if m_TickStart is non-zero). + /// + public long CurrentTicks + { + get + { + var currentTicks = m_TotalTicks; + if (m_TickStart != 0) + { + currentTicks += (DateTime.Now.Ticks - m_TickStart); + } + + return currentTicks; + } + } + + /// + /// Total elapsed seconds. + /// + [DataMember(Name = "total")] + public double TotalSeconds + { + get { return CurrentTicks * s_TicksToSeconds; } + set { } // Serialization needs this, but unused. + } + + /// + /// Total seconds spent in this block, excluding it's children. + /// + [DataMember(Name = "self")] + public double SelfSeconds + { + get + { + long totalChildTicks = 0; + if (m_Children != null) + { + foreach (var child in m_Children.Values) + { + totalChildTicks += child.m_TotalTicks; + } + } + + var selfTicks = Mathf.Max(0, CurrentTicks - totalChildTicks); + return selfTicks * s_TicksToSeconds; + } + set { } // Serialization needs this, but unused. + } + + public IReadOnlyDictionary Children + { + get { return m_Children; } + } + + public int NumCalls + { + get { return m_NumCalls; } + } + + public TimerNode(string name, bool isRoot = false) + { + m_FullName = name; + if (isRoot) + { + // The root node is considered always running. This means that when we output stats, it'll + // have a sensible value for total time (the running time since reset). + // The root node doesn't have a sampler since that could interfere with the profiler. + m_NumCalls = 1; + m_TickStart = DateTime.Now.Ticks; + } + else + { + m_Sampler = CustomSampler.Create(m_FullName); + } + } + + /// + /// Start timing a block of code. + /// + public void Begin() + { + m_Sampler?.Begin(); + m_TickStart = DateTime.Now.Ticks; + } + + /// + /// Stop timing a block of code, and increment internal counts. + /// + public void End() + { + var elapsed = DateTime.Now.Ticks - m_TickStart; + m_TotalTicks += elapsed; + m_TickStart = 0; + m_NumCalls++; + m_Sampler?.End(); + } + + /// + /// Return a child node for the given name. + /// The children dictionary will be created if it does not already exist, and + /// a new Node will be created if it's not already in the dictionary. + /// Note that these allocations only happen once for a given timed block. + /// + /// + /// + public TimerNode GetChild(string name) + { + // Lazily create the children dictionary. + if (m_Children == null) + { + m_Children = new Dictionary(); + } + + if (!m_Children.ContainsKey(name)) + { + var childFullName = m_FullName + s_Separator + name; + var newChild = new TimerNode(childFullName); + m_Children[name] = newChild; + return newChild; + } + + return m_Children[name]; + } + + /// + /// Recursively form a string representing the current timer information. + /// + /// + /// + /// + public string DebugGetTimerString(string parentName = "", int level = 0) + { + var indent = new string(' ', 2 * level); // TODO generalize + var shortName = (level == 0) ? m_FullName : m_FullName.Replace(parentName + s_Separator, ""); + string timerString; + if (level == 0) + { + timerString = $"{shortName}(root)\n"; + } + else + { + timerString = $"{indent}{shortName}\t\traw={TotalSeconds} rawCount={m_NumCalls}\n"; + } + + // TODO use StringBuilder? might be overkill since this is only debugging code? + if (m_Children != null) + { + foreach (var c in m_Children.Values) + { + timerString += c.DebugGetTimerString(m_FullName, level + 1); + } + } + return timerString; + } + } + + [DataContract] + internal class RootNode : TimerNode + { + // Timer output format version + internal const string k_TimerFormatVersion = "0.1.0"; + + [DataMember(Name = "metadata", Order = 0)] + Dictionary m_Metadata = new Dictionary(); + + /// + /// Gauge Nodes to measure arbitrary values. + /// + [DataMember(Name = "gauges", EmitDefaultValue = false)] + Dictionary m_Gauges = new Dictionary(); + + public RootNode(string name = "root") : base(name, true) + { + m_Metadata.Add("timer_format_version", k_TimerFormatVersion); + m_Metadata.Add("start_time_seconds", $"{DateTimeOffset.Now.ToUnixTimeSeconds()}"); + m_Metadata.Add("unity_version", Application.unityVersion); + m_Metadata.Add("command_line_arguments", String.Join(" ", GetCleanedCommandLineArguments())); + } + /// + /// Cleans Environment CommandLine Argument from license infos + /// + /// cleaned string list of commandLine + private static List GetCleanedCommandLineArguments() + { + List commandLineArgs = Environment.GetCommandLineArgs().ToList(); + List toRemoveIndices = new List { }; + for (var i = 0; i < commandLineArgs.Count; i++) + { + if (commandLineArgs[i].Contains("accessToken") || + commandLineArgs[i].Contains("hubSessionId") || + commandLineArgs[i].Contains("licensingIpc")) + { + toRemoveIndices.Add(i); + toRemoveIndices.Add(i + 1); + } + } + // remove in reverse order + for (var i = toRemoveIndices.Count() - 1; i >= 0; i--) + { + commandLineArgs.RemoveAt(toRemoveIndices[i]); + } + return commandLineArgs; + } + + public void AddMetadata(string key, string value) + { + m_Metadata[key] = value; + } + + public Dictionary Gauges + { + get { return m_Gauges; } + } + + public Dictionary Metadata + { + get { return m_Metadata; } + } + } + + /// + /// Tracks the most recent value of a metric. This is analogous to gauges in statsd and Prometheus. + /// + [DataContract] + internal class GaugeNode + { + const float k_SmoothingFactor = .25f; // weight for exponential moving average. + + /// + /// The most recent value that the gauge was set to. + /// + [DataMember] + public float value; + + /// + /// The smallest value that has been seen for the gauge since it was created. + /// + [DataMember(Name = "min")] + public float minValue; + + /// + /// The largest value that has been seen for the gauge since it was created. + /// + [DataMember(Name = "max")] + public float maxValue; + + /// + /// The exponential moving average of the gauge value. This will take all values into account, + /// but weights older values less as more values are added. + /// + [DataMember(Name = "weightedAverage")] + public float weightedAverage; + + /// + /// The running average of all gauge values. + /// + [DataMember] + public float runningAverage; + + /// + /// The number of times the gauge has been updated. + /// + [DataMember] + public uint count; + + public GaugeNode(float value) + { + this.value = value; + weightedAverage = value; + runningAverage = value; + minValue = value; + maxValue = value; + count = 1; + } + + public void Update(float newValue) + { + ++count; + minValue = Mathf.Min(minValue, newValue); + maxValue = Mathf.Max(maxValue, newValue); + // update exponential moving average + weightedAverage = (k_SmoothingFactor * newValue) + ((1f - k_SmoothingFactor) * weightedAverage); + value = newValue; + + // Update running average - see https://www.johndcook.com/blog/standard_deviation/ for formula. + runningAverage = runningAverage + (newValue - runningAverage) / count; + } + } + + /// + /// A "stack" of timers that allows for lightweight hierarchical profiling of long-running processes. + /// + /// Example usage: + /// + /// using(TimerStack.Instance.Scoped("foo")) + /// { + /// doSomeWork(); + /// for (int i=0; i<5; i++) + /// { + /// using(myTimer.Scoped("bar")) + /// { + /// doSomeMoreWork(); + /// } + /// } + /// } + /// + /// + /// + /// + /// This implements the Singleton pattern (solution 4) as described in + /// https://csharpindepth.com/articles/singleton + /// + internal class TimerStack : IDisposable + { + static readonly TimerStack k_Instance = new TimerStack(); + + Stack m_Stack; + RootNode m_RootNode; + Dictionary m_Metadata; + + // Explicit static constructor to tell C# compiler + // not to mark type as beforefieldinit + static TimerStack() + { + } + + TimerStack() + { + Reset(); + } + + /// + /// Resets the timer stack and the root node. + /// + /// Name of the root node. + public void Reset(string name = "root") + { + m_Stack = new Stack(); + m_RootNode = new RootNode(name); + m_Stack.Push(m_RootNode); + } + + /// + /// The singleton instance. + /// + public static TimerStack Instance + { + get { return k_Instance; } + } + + internal RootNode RootNode + { + get { return m_RootNode; } + } + + /// + /// Updates the referenced gauge in the root node with the provided value. + /// + /// The name of the Gauge to modify. + /// The value to update the Gauge with. + public void SetGauge(string name, float value) + { + if (!float.IsNaN(value)) + { + GaugeNode gauge; + if (m_RootNode.Gauges.TryGetValue(name, out gauge)) + { + gauge.Update(value); + } + else + { + m_RootNode.Gauges[name] = new GaugeNode(value); + } + } + } + + public void AddMetadata(string key, string value) + { + m_RootNode.AddMetadata(key, value); + } + + void Push(string name) + { + var current = m_Stack.Peek(); + var next = current.GetChild(name); + m_Stack.Push(next); + next.Begin(); + } + + void Pop() + { + var node = m_Stack.Pop(); + node.End(); + } + + /// + /// Start a scoped timer. This should be used with the "using" statement. + /// + /// + /// + public TimerStack Scoped(string name) + { + Push(name); + return this; + } + + /// + /// Closes the current scoped timer. This should never be called directly, only + /// at the end of a "using" statement. + /// Note that the instance is not actually disposed of; this is just to allow it to be used + /// conveniently with "using". + /// + public void Dispose() + { + Pop(); + } + + /// + /// Get a string representation of the timers. + /// Potentially slow so call sparingly. + /// + /// + internal string DebugGetTimerString() + { + return m_RootNode.DebugGetTimerString(); + } + + /// + /// Save the timers in JSON format to the provided filename. + /// If the filename is null, a default one will be used. + /// + /// + public void SaveJsonTimers(string filename = null) + { +#if MLA_SUPPORTED_TRAINING_PLATFORM + try + { + if (filename == null) + { + var activeScene = SceneManager.GetActiveScene(); + var timerDir = Path.Combine(Application.dataPath, "ML-Agents", "Timers"); + Directory.CreateDirectory(timerDir); + + filename = Path.Combine(timerDir, $"{activeScene.name}_timers.json"); + } + + var fs = new FileStream(filename, FileMode.Create, FileAccess.Write); + SaveJsonTimers(fs); + fs.Close(); + } + catch (SystemException) + { + // We may not have write access to the directory. + Debug.LogWarning($"Unable to save timers to file {filename}"); + } +#endif + } + + /// + /// Write the timers in JSON format to the provided stream. + /// + /// + public void SaveJsonTimers(Stream stream) + { + // Add some final metadata info + AddMetadata("scene_name", SceneManager.GetActiveScene().name); + AddMetadata("end_time_seconds", $"{DateTimeOffset.Now.ToUnixTimeSeconds()}"); + + var jsonSettings = new DataContractJsonSerializerSettings(); + jsonSettings.UseSimpleDictionaryFormat = true; + var ser = new DataContractJsonSerializer(typeof(RootNode), jsonSettings); + ser.WriteObject(stream, m_RootNode); + } + } +} diff --git a/com.unity.ml-agents/Runtime/Timer.cs.meta b/com.unity.ml-agents/Runtime/Timer.cs.meta new file mode 100644 index 0000000000..e28315908d --- /dev/null +++ b/com.unity.ml-agents/Runtime/Timer.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: d268f7dfcc74c47939e1fc520adb8d81 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Unity.ML-Agents.asmdef b/com.unity.ml-agents/Runtime/Unity.ML-Agents.asmdef new file mode 100755 index 0000000000..80c8c62950 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Unity.ML-Agents.asmdef @@ -0,0 +1,38 @@ +{ + "name": "Unity.ML-Agents", + "rootNamespace": "", + "references": [ + "Unity.Barracuda", + "Unity.ML-Agents.CommunicatorObjects", + "Unity.Mathematics" + ], + "includePlatforms": [], + "excludePlatforms": [], + "allowUnsafeCode": false, + "overrideReferences": false, + "precompiledReferences": [ + "System.IO.Abstractions.dll", + "Google.Protobuf.dll", + "Grpc.Core.dll" + ], + "autoReferenced": true, + "defineConstraints": [], + "versionDefines": [ + { + "name": "com.unity.modules.unityanalytics", + "expression": "1.0.0", + "define": "MLA_UNITY_ANALYTICS_MODULE" + }, + { + "name": "com.unity.modules.physics", + "expression": "1.0.0", + "define": "MLA_UNITY_PHYSICS_MODULE" + }, + { + "name": "com.unity.modules.physics2d", + "expression": "1.0.0", + "define": "MLA_UNITY_PHYSICS2D_MODULE" + } + ], + "noEngineReferences": false +} \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Unity.ML-Agents.asmdef.meta b/com.unity.ml-agents/Runtime/Unity.ML-Agents.asmdef.meta new file mode 100644 index 0000000000..21cbeb9793 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Unity.ML-Agents.asmdef.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: 85e0054f8e64b47309646d35f8851f81 +AssemblyDefinitionImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/UnityAgentsException.cs b/com.unity.ml-agents/Runtime/UnityAgentsException.cs new file mode 100644 index 0000000000..0abc1ace88 --- /dev/null +++ b/com.unity.ml-agents/Runtime/UnityAgentsException.cs @@ -0,0 +1,32 @@ +using System; + +namespace Unity.MLAgents +{ + /// + /// Contains exceptions specific to ML-Agents. + /// + [Serializable] + public class UnityAgentsException : Exception + { + /// + /// When a UnityAgentsException is called, the timeScale is set to 0. + /// The simulation will end since no steps will be taken. + /// + /// The exception message + public UnityAgentsException(string message) : base(message) + { + } + + /// + /// A constructor is needed for serialization when an exception propagates + /// from a remoting server to the client. + /// + /// Data for serializing/de-serializing + /// Describes the source and destination of the serialized stream + protected UnityAgentsException( + System.Runtime.Serialization.SerializationInfo info, + System.Runtime.Serialization.StreamingContext context) + { + } + } +} diff --git a/com.unity.ml-agents/Runtime/UnityAgentsException.cs.meta b/com.unity.ml-agents/Runtime/UnityAgentsException.cs.meta new file mode 100755 index 0000000000..f72768b1df --- /dev/null +++ b/com.unity.ml-agents/Runtime/UnityAgentsException.cs.meta @@ -0,0 +1,12 @@ +fileFormatVersion: 2 +guid: e63e4a66d820245778f9a2abfa5b68e0 +timeCreated: 1504131359 +licenseType: Free +MonoImporter: + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Utilities.cs b/com.unity.ml-agents/Runtime/Utilities.cs new file mode 100644 index 0000000000..e9d4425048 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Utilities.cs @@ -0,0 +1,60 @@ +using System; +using System.Diagnostics; +using UnityEngine; + +namespace Unity.MLAgents +{ + internal static class Utilities + { + /// + /// Calculates the cumulative sum of an integer array. The result array will be one element + /// larger than the input array since it has a padded 0 at the beginning. + /// If the input is [a, b, c], the result will be [0, a, a+b, a+b+c] + /// + /// + /// Input array whose elements will be cumulatively added + /// + /// The cumulative sum of the input array. + internal static int[] CumSum(int[] input) + { + var runningSum = 0; + var result = new int[input.Length + 1]; + for (var actionIndex = 0; actionIndex < input.Length; actionIndex++) + { + runningSum += input[actionIndex]; + result[actionIndex + 1] = runningSum; + } + return result; + } + + /// + /// Safely destroy a texture. This has to be used differently in unit tests. + /// + /// + internal static void DestroyTexture(Texture2D texture) + { + if (Application.isEditor) + { + // Edit Mode tests complain if we use Destroy() + UnityEngine.Object.DestroyImmediate(texture); + } + else + { + UnityEngine.Object.Destroy(texture); + } + } + + [Conditional("DEBUG")] + internal static void DebugCheckNanAndInfinity(float value, string valueCategory, string caller) + { + if (float.IsNaN(value)) + { + throw new ArgumentException($"NaN {valueCategory} passed to {caller}."); + } + if (float.IsInfinity(value)) + { + throw new ArgumentException($"Inifinity {valueCategory} passed to {caller}."); + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Utilities.cs.meta b/com.unity.ml-agents/Runtime/Utilities.cs.meta new file mode 100644 index 0000000000..872088ea83 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Utilities.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 0e664c25f496478c9c26df6688379f7e +timeCreated: 1537468595 \ No newline at end of file diff --git a/com.unity.ml-agents/Samples/3DBall/.sample.json b/com.unity.ml-agents/Samples/3DBall/.sample.json new file mode 100644 index 0000000000..7055a5b220 --- /dev/null +++ b/com.unity.ml-agents/Samples/3DBall/.sample.json @@ -0,0 +1 @@ +{"displayName":"3D Ball","description":"The 3D Ball sample is a simple environment that is a great for jumping into ML-Agents to see how things work."} diff --git a/com.unity.ml-agents/Samples/3DBall/3DBall.unitypackage b/com.unity.ml-agents/Samples/3DBall/3DBall.unitypackage new file mode 100644 index 0000000000..eb4043d19d Binary files /dev/null and b/com.unity.ml-agents/Samples/3DBall/3DBall.unitypackage differ diff --git a/com.unity.ml-agents/Tests.meta b/com.unity.ml-agents/Tests.meta new file mode 100644 index 0000000000..ffec42d32e --- /dev/null +++ b/com.unity.ml-agents/Tests.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 2715dc4ceb2c345df9ba92d799ae72ff +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/.tests.json b/com.unity.ml-agents/Tests/.tests.json new file mode 100755 index 0000000000..327abb29e5 --- /dev/null +++ b/com.unity.ml-agents/Tests/.tests.json @@ -0,0 +1,3 @@ +{ + "createSeparatePackage": false +} diff --git a/com.unity.ml-agents/Tests/Editor.meta b/com.unity.ml-agents/Tests/Editor.meta new file mode 100644 index 0000000000..9a45234ac0 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: b779986bc10d84f2aa8601a3f1c763ff +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/AcademyTests.cs b/com.unity.ml-agents/Tests/Editor/AcademyTests.cs new file mode 100644 index 0000000000..61225acdab --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/AcademyTests.cs @@ -0,0 +1,52 @@ +using NUnit.Framework; +using Unity.MLAgents.Sensors; +using UnityEngine; + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class AcademyTests + { + [Test] + public void TestPackageVersion() + { + var packageInfo = UnityEditor.PackageManager.PackageInfo.FindForAssembly(typeof(Agent).Assembly); + Assert.AreEqual("com.unity.ml-agents", packageInfo.name); + Assert.AreEqual(Academy.k_PackageVersion, packageInfo.version); + } + + class RecursiveAgent : Agent + { + int m_collectObsCount; + public override void CollectObservations(VectorSensor sensor) + { + m_collectObsCount++; + if (m_collectObsCount == 1) + { + // NEVER DO THIS IN REAL CODE! + Academy.Instance.EnvironmentStep(); + } + } + } + + [Test] + public void TestRecursiveStepThrows() + { + var gameObj = new GameObject(); + var agent = gameObj.AddComponent(); + agent.Awake(); + agent.LazyInitialize(); + agent.RequestDecision(); + + Assert.Throws(() => + { + Academy.Instance.EnvironmentStep(); + }); + + // Make sure the Academy reset to a good state and is still steppable. + Academy.Instance.EnvironmentStep(); + } + + + } +} diff --git a/com.unity.ml-agents/Tests/Editor/AcademyTests.cs.meta b/com.unity.ml-agents/Tests/Editor/AcademyTests.cs.meta new file mode 100644 index 0000000000..fa65ced67d --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/AcademyTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: f434773fe0f1b41c5b3f446fa0adece4 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Actuators.meta b/com.unity.ml-agents/Tests/Editor/Actuators.meta new file mode 100644 index 0000000000..5c6399dc6c --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Actuators.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: c7e705f7d549e43c6be18ae809cd6f54 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs new file mode 100644 index 0000000000..817a9c1fc3 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs @@ -0,0 +1,63 @@ +using System; +using NUnit.Framework; +using Unity.MLAgents.Actuators; + +namespace Unity.MLAgents.Tests.Actuators +{ + [TestFixture] + public class ActionSegmentTests + { + [Test] + public void TestConstruction() + { + var floatArray = new[] { 1f, 2f, 3f, 4f, 5f, 6f, 7f }; + Assert.Throws( + () => new ActionSegment(floatArray, 100, 1)); + + var segment = new ActionSegment(Array.Empty(), 0, 0); + Assert.AreEqual(segment, ActionSegment.Empty); + } + [Test] + public void TestIndexing() + { + var floatArray = new[] { 1f, 2f, 3f, 4f, 5f, 6f, 7f }; + for (var i = 0; i < floatArray.Length; i++) + { + var start = 0 + i; + var length = floatArray.Length - i; + var actionSegment = new ActionSegment(floatArray, start, length); + for (var j = 0; j < actionSegment.Length; j++) + { + Assert.AreEqual(actionSegment[j], floatArray[start + j]); + } + } + } + + [Test] + public void TestEnumerator() + { + var floatArray = new[] { 1f, 2f, 3f, 4f, 5f, 6f, 7f }; + for (var i = 0; i < floatArray.Length; i++) + { + var start = 0 + i; + var length = floatArray.Length - i; + var actionSegment = new ActionSegment(floatArray, start, length); + var j = 0; + foreach (var item in actionSegment) + { + Assert.AreEqual(item, floatArray[start + j++]); + } + } + } + + [Test] + public void TestNullConstructor() + { + var actionSegment = new ActionSegment(null); + Assert.IsTrue(actionSegment.Length == 0); + Assert.IsTrue(actionSegment.Array == Array.Empty()); + } + + } + +} diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs.meta new file mode 100644 index 0000000000..2332580c17 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 18cb6d052fba43a2b7437d87c0d9abad +timeCreated: 1596486604 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActionSpecTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/ActionSpecTests.cs new file mode 100644 index 0000000000..09dfd33670 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActionSpecTests.cs @@ -0,0 +1,37 @@ +using System.Collections.Generic; +using System.Linq; +using NUnit.Framework; +using Unity.MLAgents.Actuators; + +namespace Unity.MLAgents.Tests.Actuators +{ + [TestFixture] + public class ActionSpecTests + { + [Test] + public void ActionSpecCombineTest() + { + var as0 = new ActionSpec(3, new[] { 3, 2, 1 }); + var as1 = new ActionSpec(1, new[] { 35, 122, 1, 3, 8, 3 }); + + var as0NumCon = 3; + var as0NumDis = as0.NumDiscreteActions; + var as1NumCon = 1; + var as1NumDis = as1.NumDiscreteActions; + var branchSizes = new List(); + branchSizes.AddRange(as0.BranchSizes); + branchSizes.AddRange(as1.BranchSizes); + + var asc = ActionSpec.Combine(as0, as1); + + Assert.AreEqual(as0NumCon + as1NumCon, asc.NumContinuousActions); + Assert.AreEqual(as0NumDis + as1NumDis, asc.NumDiscreteActions); + Assert.IsTrue(branchSizes.ToArray().SequenceEqual(asc.BranchSizes)); + + as0 = new ActionSpec(3); + as1 = new ActionSpec(1); + asc = ActionSpec.Combine(as0, as1); + Assert.IsEmpty(asc.BranchSizes); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActionSpecTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Actuators/ActionSpecTests.cs.meta new file mode 100644 index 0000000000..18ebcbb881 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActionSpecTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 99d76ec04c944b75bc6b85abfff4ac4e +timeCreated: 1613680505 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs new file mode 100644 index 0000000000..1c486af483 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs @@ -0,0 +1,136 @@ +using System.Collections.Generic; +using NUnit.Framework; +using Unity.MLAgents.Actuators; + +namespace Unity.MLAgents.Tests.Actuators +{ + [TestFixture] + public class ActuatorDiscreteActionMaskTests + { + [Test] + public void Construction() + { + var masker = new ActuatorDiscreteActionMask(new List(), 0, 0); + Assert.IsNotNull(masker); + } + + [Test] + public void NullMask() + { + var masker = new ActuatorDiscreteActionMask(new List(), 0, 0); + var mask = masker.GetMask(); + Assert.IsNull(mask); + } + + [Test] + public void FirstBranchMask() + { + var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); + var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3); + var mask = masker.GetMask(); + Assert.IsNull(mask); + masker.SetActionEnabled(0, 1, false); + masker.SetActionEnabled(0, 2, false); + masker.SetActionEnabled(0, 3, false); + mask = masker.GetMask(); + Assert.IsFalse(mask[0]); + Assert.IsTrue(mask[1]); + Assert.IsTrue(mask[2]); + Assert.IsTrue(mask[3]); + Assert.IsFalse(mask[4]); + Assert.AreEqual(mask.Length, 15); + } + + [Test] + public void CanOverwriteMask() + { + var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); + var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3); + masker.SetActionEnabled(0, 1, false); + var mask = masker.GetMask(); + Assert.IsTrue(mask[1]); + + masker.SetActionEnabled(0, 1, true); + Assert.IsFalse(mask[1]); + } + + [Test] + public void SecondBranchMask() + { + var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); + var masker = new ActuatorDiscreteActionMask(new[] { actuator1 }, 15, 3); + masker.SetActionEnabled(1, 1, false); + masker.SetActionEnabled(1, 2, false); + masker.SetActionEnabled(1, 3, false); + var mask = masker.GetMask(); + Assert.IsFalse(mask[0]); + Assert.IsFalse(mask[4]); + Assert.IsTrue(mask[5]); + Assert.IsTrue(mask[6]); + Assert.IsTrue(mask[7]); + Assert.IsFalse(mask[8]); + Assert.IsFalse(mask[9]); + } + + [Test] + public void MaskReset() + { + var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); + var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3); + masker.SetActionEnabled(1, 1, false); + masker.SetActionEnabled(1, 2, false); + masker.SetActionEnabled(1, 3, false); + masker.ResetMask(); + var mask = masker.GetMask(); + for (var i = 0; i < 15; i++) + { + Assert.IsFalse(mask[i]); + } + } + + [Test] + public void ThrowsError() + { + var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); + var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3); + Assert.Catch( + () => masker.SetActionEnabled(0, 5, false)); + Assert.Catch( + () => masker.SetActionEnabled(1, 5, false)); + masker.SetActionEnabled(2, 5, false); + Assert.Catch( + () => masker.SetActionEnabled(3, 1, false)); + masker.GetMask(); + masker.ResetMask(); + masker.SetActionEnabled(0, 0, false); + masker.SetActionEnabled(0, 1, false); + masker.SetActionEnabled(0, 2, false); + masker.SetActionEnabled(0, 3, false); + Assert.Catch( + () => masker.GetMask()); + } + + [Test] + public void MultipleMaskEdit() + { + var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); + var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3); + masker.SetActionEnabled(0, 0, false); + masker.SetActionEnabled(0, 1, false); + masker.SetActionEnabled(0, 3, false); + masker.SetActionEnabled(2, 1, false); + var mask = masker.GetMask(); + for (var i = 0; i < 15; i++) + { + if ((i == 0) || (i == 1) || (i == 3) || (i == 10)) + { + Assert.IsTrue(mask[i]); + } + else + { + Assert.IsFalse(mask[i]); + } + } + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs.meta new file mode 100644 index 0000000000..a5dd1f3ad9 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: b9f5f87049d04d8bba39d193a3ab2f5a +timeCreated: 1596491682 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs new file mode 100644 index 0000000000..429d334828 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs @@ -0,0 +1,353 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using NUnit.Framework; +using Unity.MLAgents.Actuators; +using UnityEngine; +using UnityEngine.TestTools; +using Assert = UnityEngine.Assertions.Assert; + +namespace Unity.MLAgents.Tests.Actuators +{ + [TestFixture] + public class ActuatorManagerTests + { + [Test] + public void TestEnsureBufferSizeContinuous() + { + var manager = new ActuatorManager(); + var actuator1 = new TestActuator(ActionSpec.MakeContinuous(10), "actuator1"); + var actuator2 = new TestActuator(ActionSpec.MakeContinuous(2), "actuator2"); + manager.Add(actuator1); + manager.Add(actuator2); + var actuator1ActionSpaceDef = actuator1.ActionSpec; + var actuator2ActionSpaceDef = actuator2.ActionSpec; + manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 }, + actuator1ActionSpaceDef.NumContinuousActions + actuator2ActionSpaceDef.NumContinuousActions, + actuator1ActionSpaceDef.SumOfDiscreteBranchSizes + actuator2ActionSpaceDef.SumOfDiscreteBranchSizes, + actuator1ActionSpaceDef.NumDiscreteActions + actuator2ActionSpaceDef.NumDiscreteActions); + + manager.UpdateActions(new ActionBuffers(new[] + { 0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f }, Array.Empty())); + + Assert.IsTrue(12 == manager.NumContinuousActions); + Assert.IsTrue(0 == manager.NumDiscreteActions); + Assert.IsTrue(0 == manager.SumOfDiscreteBranchSizes); + Assert.IsTrue(12 == manager.StoredActions.ContinuousActions.Length); + Assert.IsTrue(0 == manager.StoredActions.DiscreteActions.Length); + } + + [Test] + public void TestEnsureBufferDiscrete() + { + var manager = new ActuatorManager(); + var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3, 4 }), "actuator1"); + var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 1, 1 }), "actuator2"); + manager.Add(actuator1); + manager.Add(actuator2); + var actuator1ActionSpaceDef = actuator1.ActionSpec; + var actuator2ActionSpaceDef = actuator2.ActionSpec; + manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 }, + actuator1ActionSpaceDef.NumContinuousActions + actuator2ActionSpaceDef.NumContinuousActions, + actuator1ActionSpaceDef.SumOfDiscreteBranchSizes + actuator2ActionSpaceDef.SumOfDiscreteBranchSizes, + actuator1ActionSpaceDef.NumDiscreteActions + actuator2ActionSpaceDef.NumDiscreteActions); + + manager.UpdateActions(new ActionBuffers(Array.Empty(), + new[] { 0, 1, 2, 3, 4, 5, 6 })); + + Assert.IsTrue(0 == manager.NumContinuousActions); + Assert.IsTrue(7 == manager.NumDiscreteActions); + Assert.IsTrue(13 == manager.SumOfDiscreteBranchSizes); + Assert.IsTrue(0 == manager.StoredActions.ContinuousActions.Length); + Assert.IsTrue(7 == manager.StoredActions.DiscreteActions.Length); + } + + [Test] + public void TestAllowMixedActions() + { + // Make sure discrete + continuous actuators are allowed. + var manager = new ActuatorManager(); + var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3, 4 }), "actuator1"); + var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2"); + manager.Add(actuator1); + manager.Add(actuator2); + manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 }, 3, 10, 4); + } + + [Test] + public void TestFailOnSameActuatorName() + { + var manager = new ActuatorManager(); + var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator1"); + var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator1"); + manager.Add(actuator1); + manager.Add(actuator2); + manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 }, 3, 10, 4); + LogAssert.Expect(LogType.Assert, "Actuator names must be unique."); + } + + [Test] + public void TestExecuteActionsDiscrete() + { + var manager = new ActuatorManager(); + var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3, 4 }), "actuator1"); + var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 1, 1 }), "actuator2"); + manager.Add(actuator1); + manager.Add(actuator2); + + var discreteActionBuffer = new[] { 0, 1, 2, 3, 4, 5, 6 }; + manager.UpdateActions(new ActionBuffers(Array.Empty(), + discreteActionBuffer)); + + manager.ExecuteActions(); + var actuator1Actions = actuator1.LastActionBuffer.DiscreteActions; + var actuator2Actions = actuator2.LastActionBuffer.DiscreteActions; + TestSegmentEquality(actuator1Actions, discreteActionBuffer); TestSegmentEquality(actuator2Actions, discreteActionBuffer); + } + + [Test] + public void TestExecuteActionsContinuous() + { + var manager = new ActuatorManager(); + var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3), + "actuator1"); + var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2"); + manager.Add(actuator1); + manager.Add(actuator2); + + var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f }; + manager.UpdateActions(new ActionBuffers(continuousActionBuffer, + Array.Empty())); + + manager.ExecuteActions(); + var actuator1Actions = actuator1.LastActionBuffer.ContinuousActions; + var actuator2Actions = actuator2.LastActionBuffer.ContinuousActions; + TestSegmentEquality(actuator1Actions, continuousActionBuffer); + TestSegmentEquality(actuator2Actions, continuousActionBuffer); + } + + static void TestSegmentEquality(ActionSegment actionSegment, T[] actionBuffer) + where T : struct + { + Assert.IsFalse(actionSegment.Length == 0); + for (var i = 0; i < actionSegment.Length; i++) + { + var action = actionSegment[i]; + Assert.AreEqual(action, actionBuffer[actionSegment.Offset + i]); + } + } + + [Test] + public void TestUpdateActionsContinuous() + { + var manager = new ActuatorManager(); + var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3), + "actuator1"); + var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2"); + manager.Add(actuator1); + manager.Add(actuator2); + var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f }; + manager.UpdateActions(new ActionBuffers(continuousActionBuffer, + Array.Empty())); + + Assert.IsTrue(manager.StoredActions.ContinuousActions.SequenceEqual(continuousActionBuffer)); + } + + [Test] + public void TestUpdateActionsDiscrete() + { + var manager = new ActuatorManager(); + var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }), + "actuator1"); + var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }), "actuator2"); + manager.Add(actuator1); + manager.Add(actuator2); + var discreteActionBuffer = new[] { 0, 1, 2, 3, 4, 5 }; + manager.UpdateActions(new ActionBuffers(Array.Empty(), + discreteActionBuffer)); + + Debug.Log(manager.StoredActions.DiscreteActions); + Debug.Log(discreteActionBuffer); + Assert.IsTrue(manager.StoredActions.DiscreteActions.SequenceEqual(discreteActionBuffer)); + } + + [Test] + public void TestRemove() + { + var manager = new ActuatorManager(); + var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }), + "actuator1"); + var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }), "actuator2"); + + manager.Add(actuator1); + manager.Add(actuator2); + Assert.IsTrue(manager.NumDiscreteActions == 6); + Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 12); + + manager.Remove(actuator2); + + Assert.IsTrue(manager.NumDiscreteActions == 3); + Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 6); + + manager.Remove(null); + + Assert.IsTrue(manager.NumDiscreteActions == 3); + Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 6); + + manager.RemoveAt(0); + Assert.IsTrue(manager.NumDiscreteActions == 0); + Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 0); + } + + [Test] + public void TestClear() + { + var manager = new ActuatorManager(); + var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }), + "actuator1"); + var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }), "actuator2"); + manager.Add(actuator1); + manager.Add(actuator2); + + Assert.IsTrue(manager.NumDiscreteActions == 6); + Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 12); + + manager.Clear(); + + Assert.IsTrue(manager.NumDiscreteActions == 0); + Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 0); + } + + [Test] + public void TestIndexSet() + { + var manager = new ActuatorManager(); + var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3, 4 }), + "actuator1"); + var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }), "actuator2"); + manager.Add(actuator1); + Assert.IsTrue(manager.NumDiscreteActions == 4); + Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 10); + manager[0] = actuator2; + Assert.IsTrue(manager.NumDiscreteActions == 3); + Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 6); + } + + [Test] + public void TestInsert() + { + var manager = new ActuatorManager(); + var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3, 4 }), + "actuator1"); + var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }), "actuator2"); + manager.Add(actuator1); + Assert.IsTrue(manager.NumDiscreteActions == 4); + Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 10); + manager.Insert(0, actuator2); + Assert.IsTrue(manager.NumDiscreteActions == 7); + Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 16); + Assert.IsTrue(manager.IndexOf(actuator2) == 0); + } + + [Test] + public void TestResetData() + { + var manager = new ActuatorManager(); + var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3), + "actuator1"); + var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2"); + manager.Add(actuator1); + manager.Add(actuator2); + var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f }; + manager.UpdateActions(new ActionBuffers(continuousActionBuffer, + Array.Empty())); + + Assert.IsTrue(manager.StoredActions.ContinuousActions.SequenceEqual(continuousActionBuffer)); + Assert.IsTrue(manager.NumContinuousActions == 6); + manager.ResetData(); + + Assert.IsTrue(manager.StoredActions.ContinuousActions.SequenceEqual(new[] { 0f, 0f, 0f, 0f, 0f, 0f })); + } + + [Test] + public void TestWriteDiscreteActionMask() + { + var manager = new ActuatorManager(2); + var va1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }), "name"); + var va2 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 3, 2, 1 }), "name1"); + manager.Add(va1); + manager.Add(va2); + + var groundTruthMask = new[] + { + false, + true, false, + false, true, true, + true, false, true, + false, true, + false + }; + + va1.Masks = new[] + { + Array.Empty(), + new[] { 0 }, + new[] { 1, 2 } + }; + + va2.Masks = new[] + { + new[] {0, 2}, + new[] {1}, + Array.Empty() + }; + manager.WriteActionMask(); + Assert.IsTrue(groundTruthMask.SequenceEqual(manager.DiscreteActionMask.GetMask())); + } + + [Test] + public void TestHeuristic() + { + var manager = new ActuatorManager(2); + var va1 = new TestActuator(ActionSpec.MakeDiscrete(1, 2, 3), "name"); + var va2 = new TestActuator(ActionSpec.MakeDiscrete(3, 2, 1, 8), "name1"); + manager.Add(va1); + manager.Add(va2); + + var actionBuf = new ActionBuffers(Array.Empty(), new[] { 0, 0, 0, 0, 0, 0, 0 }); + manager.ApplyHeuristic(actionBuf); + + Assert.IsTrue(va1.m_HeuristicCalled); + Assert.AreEqual(va1.m_DiscreteBufferSize, 3); + Assert.IsTrue(va2.m_HeuristicCalled); + Assert.AreEqual(va2.m_DiscreteBufferSize, 4); + } + + + /// + /// Test that sensors sort by name consistently across culture settings. + /// Example strings and cultures taken from + /// https://docs.microsoft.com/en-us/globalization/locale/sorting-and-string-comparison + /// + /// + [TestCase("da-DK")] + [TestCase("en-US")] + public void TestSortActuators(string culture) + { + List actuators = new List(); + var actuator0 = new TestActuator(ActionSpec.MakeContinuous(2), "Apple"); + var actuator1 = new TestActuator(ActionSpec.MakeContinuous(2), "Æble"); + actuators.Add(actuator0); + actuators.Add(actuator1); + + var originalCulture = CultureInfo.CurrentCulture; + CultureInfo.CurrentCulture = new CultureInfo(culture); + ActuatorManager.SortActuators(actuators); + CultureInfo.CurrentCulture = originalCulture; + + Assert.AreEqual(actuator1, actuators[0]); + Assert.AreEqual(actuator0, actuators[1]); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs.meta new file mode 100644 index 0000000000..4946ff19fb --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: d48ba72f0ac64d7db0af22c9d82b11d8 +timeCreated: 1596494279 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs b/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs new file mode 100644 index 0000000000..f4d32708f1 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs @@ -0,0 +1,49 @@ +using Unity.MLAgents.Actuators; +namespace Unity.MLAgents.Tests.Actuators +{ + internal class TestActuator : IActuator + { + public ActionBuffers LastActionBuffer; + public int[][] Masks; + public bool m_HeuristicCalled; + public int m_DiscreteBufferSize; + + public TestActuator(ActionSpec actuatorSpace, string name) + { + ActionSpec = actuatorSpace; + + Name = name; + } + + public void OnActionReceived(ActionBuffers actionBuffers) + { + LastActionBuffer = actionBuffers; + } + + public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) + { + + for (var i = 0; i < Masks.Length; i++) + { + foreach (var actionIndex in Masks[i]) + { + actionMask.SetActionEnabled(i, actionIndex, false); + } + } + } + + public ActionSpec ActionSpec { get; } + + public string Name { get; } + + public void ResetData() + { + } + + public void Heuristic(in ActionBuffers actionBuffersOut) + { + m_HeuristicCalled = true; + m_DiscreteBufferSize = actionBuffersOut.DiscreteActions.Length; + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs.meta b/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs.meta new file mode 100644 index 0000000000..57e13a0e26 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: fa950d7b175749bfa287fd8761dd831f +timeCreated: 1596665978 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs new file mode 100644 index 0000000000..7fe52951c8 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs @@ -0,0 +1,117 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using NUnit.Framework; +using Unity.MLAgents.Actuators; +using Assert = UnityEngine.Assertions.Assert; + +namespace Unity.MLAgents.Tests.Actuators +{ + [TestFixture] + public class VectorActuatorTests + { + class TestActionReceiver : IActionReceiver, IHeuristicProvider + { + public ActionBuffers LastActionBuffers; + public int Branch; + public IList Mask; + public ActionSpec ActionSpec { get; } + public bool HeuristicCalled; + + public void OnActionReceived(ActionBuffers actionBuffers) + { + LastActionBuffers = actionBuffers; + } + + public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) + { + foreach (var actionIndex in Mask) + { + actionMask.SetActionEnabled(Branch, actionIndex, false); + } + } + + public void Heuristic(in ActionBuffers actionBuffersOut) + { + HeuristicCalled = true; + } + } + + [Test] + public void TestConstruct() + { + var ar = new TestActionReceiver(); + var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name"); + + Assert.IsTrue(va.ActionSpec.NumDiscreteActions == 3); + Assert.IsTrue(va.ActionSpec.SumOfDiscreteBranchSizes == 6); + Assert.IsTrue(va.ActionSpec.NumContinuousActions == 0); + + var va1 = new VectorActuator(ar, ActionSpec.MakeContinuous(4), "name"); + + Assert.IsTrue(va1.ActionSpec.NumContinuousActions == 4); + Assert.IsTrue(va1.ActionSpec.SumOfDiscreteBranchSizes == 0); + Assert.AreEqual(va1.Name, "name-Continuous"); + } + + [Test] + public void TestOnActionReceived() + { + var ar = new TestActionReceiver(); + var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name"); + + var discreteActions = new[] { 0, 1, 1 }; + var ab = new ActionBuffers(ActionSegment.Empty, + new ActionSegment(discreteActions, 0, 3)); + + va.OnActionReceived(ab); + + Assert.AreEqual(ar.LastActionBuffers, ab); + va.ResetData(); + Assert.AreEqual(va.ActionBuffers.ContinuousActions, ActionSegment.Empty); + Assert.AreEqual(va.ActionBuffers.DiscreteActions, ActionSegment.Empty); + } + + [Test] + public void TestResetData() + { + var ar = new TestActionReceiver(); + var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name"); + + var discreteActions = new[] { 0, 1, 1 }; + var ab = new ActionBuffers(ActionSegment.Empty, + new ActionSegment(discreteActions, 0, 3)); + + va.OnActionReceived(ab); + } + + [Test] + public void TestWriteDiscreteActionMask() + { + var ar = new TestActionReceiver(); + var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name"); + var bdam = new ActuatorDiscreteActionMask(new[] { va }, 6, 3); + + var groundTruthMask = new[] { false, true, false, false, true, true }; + + ar.Branch = 1; + ar.Mask = new[] { 0 }; + va.WriteDiscreteActionMask(bdam); + ar.Branch = 2; + ar.Mask = new[] { 1, 2 }; + va.WriteDiscreteActionMask(bdam); + + Assert.IsTrue(groundTruthMask.SequenceEqual(bdam.GetMask())); + } + + [Test] + public void TestHeuristic() + { + var ar = new TestActionReceiver(); + var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name"); + + va.Heuristic(new ActionBuffers(Array.Empty(), va.ActionSpec.BranchSizes)); + Assert.IsTrue(ar.HeuristicCalled); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs.meta new file mode 100644 index 0000000000..2a5a86efd0 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: c2b191d2929f49adab0769705d49d86a +timeCreated: 1596580289 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Analytics.meta b/com.unity.ml-agents/Tests/Editor/Analytics.meta new file mode 100644 index 0000000000..473f2be08f --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Analytics.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: adbf291ff40848a296523d69a5be65a5 +timeCreated: 1607379470 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs b/com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs new file mode 100644 index 0000000000..7686517b6b --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs @@ -0,0 +1,105 @@ +using System; +using System.Collections.Generic; +using NUnit.Framework; +using Unity.MLAgents.Sensors; +using UnityEngine; +using Unity.Barracuda; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Policies; +using Unity.MLAgents.Analytics; +using UnityEditor; + + +namespace Unity.MLAgents.Tests.Analytics +{ + [TestFixture] + public class InferenceAnalyticsTests + { + const string k_continuousONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_v1_0.onnx"; + NNModel continuousONNXModel; + Test3DSensorComponent sensor_21_20_3; + Test3DSensorComponent sensor_20_22_3; + + ActionSpec GetContinuous2vis8vec2actionActionSpec() + { + return ActionSpec.MakeContinuous(2); + } + + [SetUp] + public void SetUp() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + + continuousONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousONNXPath, typeof(NNModel)); + var go = new GameObject("SensorA"); + sensor_21_20_3 = go.AddComponent(); + sensor_21_20_3.Sensor = new Test3DSensor("SensorA", 21, 20, 3); + sensor_20_22_3 = go.AddComponent(); + sensor_20_22_3.Sensor = new Test3DSensor("SensorB", 20, 22, 3); + } + + [Test] + public void TestModelEvent() + { + var sensors = new List { sensor_21_20_3.Sensor, sensor_20_22_3.Sensor }; + var behaviorName = "continuousModel"; + var actionSpec = GetContinuous2vis8vec2actionActionSpec(); + + var vectorActuator = new VectorActuator(null, actionSpec, "test'"); + var actuators = new IActuator[] { vectorActuator }; + + var continuousEvent = InferenceAnalytics.GetEventForModel( + continuousONNXModel, behaviorName, + InferenceDevice.CPU, sensors, actionSpec, + actuators + ); + + // The behavior name should be hashed, not pass-through. + Assert.AreNotEqual(behaviorName, continuousEvent.BehaviorName); + + Assert.AreEqual(2, continuousEvent.ActionSpec.NumContinuousActions); + Assert.AreEqual(0, continuousEvent.ActionSpec.NumDiscreteActions); + Assert.AreEqual(2, continuousEvent.ObservationSpecs.Count); + Assert.AreEqual(3, continuousEvent.ObservationSpecs[0].DimensionInfos.Length); + Assert.AreEqual(20, continuousEvent.ObservationSpecs[0].DimensionInfos[0].Size); + Assert.AreEqual(0, continuousEvent.ObservationSpecs[0].ObservationType); + Assert.AreEqual((int)DimensionProperty.TranslationalEquivariance, continuousEvent.ObservationSpecs[0].DimensionInfos[0].Flags); + Assert.AreEqual((int)DimensionProperty.None, continuousEvent.ObservationSpecs[0].DimensionInfos[2].Flags); + Assert.AreEqual("None", continuousEvent.ObservationSpecs[0].CompressionType); + Assert.AreEqual(Test3DSensor.k_BuiltInSensorType, continuousEvent.ObservationSpecs[0].BuiltInSensorType); + Assert.AreEqual((int)BuiltInActuatorType.VectorActuator, continuousEvent.ActuatorInfos[0].BuiltInActuatorType); + Assert.AreNotEqual(null, continuousEvent.ModelHash); + + // Make sure nested fields get serialized + var jsonString = JsonUtility.ToJson(continuousEvent, true); + Assert.IsTrue(jsonString.Contains("ObservationSpecs")); + Assert.IsTrue(jsonString.Contains("ActionSpec")); + Assert.IsTrue(jsonString.Contains("NumDiscreteActions")); + Assert.IsTrue(jsonString.Contains("SensorName")); + Assert.IsTrue(jsonString.Contains("Flags")); + Assert.IsTrue(jsonString.Contains("ActuatorInfos")); + } + + [Test] + public void TestBarracudaPolicy() + { + // Explicitly request decisions for a policy so we get code coverage on the event sending + using (new AnalyticsUtils.DisableAnalyticsSending()) + { + var sensors = new List { sensor_21_20_3.Sensor, sensor_20_22_3.Sensor }; + var policy = new BarracudaPolicy( + GetContinuous2vis8vec2actionActionSpec(), + Array.Empty(), + continuousONNXModel, + InferenceDevice.CPU, + "testBehavior" + ); + policy.RequestDecision(new AgentInfo(), sensors); + } + Academy.Instance.Dispose(); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs.meta new file mode 100644 index 0000000000..20f024f03b --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 9f054f620b8b468bbd8ccf7d2cc14ccd +timeCreated: 1607379491 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs b/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs new file mode 100644 index 0000000000..0487a7a524 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs @@ -0,0 +1,96 @@ +using System; +using System.Collections.Generic; +using NUnit.Framework; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Analytics; +using Unity.MLAgents.Policies; +using UnityEditor; + +namespace Unity.MLAgents.Tests.Analytics +{ + [TestFixture] + public class TrainingAnalyticsTests + { + [TestCase("foo?team=42", ExpectedResult = "foo")] + [TestCase("foo", ExpectedResult = "foo")] + [TestCase("foo?bar?team=1337", ExpectedResult = "foo?bar")] + public string TestParseBehaviorName(string fullyQualifiedBehaviorName) + { + return TrainingAnalytics.ParseBehaviorName(fullyQualifiedBehaviorName); + } + + [Test] + public void TestRemotePolicyEvent() + { + var behaviorName = "testBehavior"; + var sensor1 = new Test3DSensor("SensorA", 21, 20, 3); + var sensor2 = new Test3DSensor("SensorB", 20, 22, 3); + var sensors = new List { sensor1, sensor2 }; + + var actionSpec = ActionSpec.MakeContinuous(2); + + var vectorActuator = new VectorActuator(null, actionSpec, "test'"); + var actuators = new IActuator[] { vectorActuator }; + + var remotePolicyEvent = TrainingAnalytics.GetEventForRemotePolicy(behaviorName, sensors, actionSpec, actuators); + + // The behavior name should be hashed, not pass-through. + Assert.AreNotEqual(behaviorName, remotePolicyEvent.BehaviorName); + + Assert.AreEqual(2, remotePolicyEvent.ObservationSpecs.Count); + Assert.AreEqual(3, remotePolicyEvent.ObservationSpecs[0].DimensionInfos.Length); + Assert.AreEqual(20, remotePolicyEvent.ObservationSpecs[0].DimensionInfos[0].Size); + Assert.AreEqual(0, remotePolicyEvent.ObservationSpecs[0].ObservationType); + Assert.AreEqual("None", remotePolicyEvent.ObservationSpecs[0].CompressionType); + Assert.AreEqual(Test3DSensor.k_BuiltInSensorType, remotePolicyEvent.ObservationSpecs[0].BuiltInSensorType); + + Assert.AreEqual(2, remotePolicyEvent.ActionSpec.NumContinuousActions); + Assert.AreEqual(0, remotePolicyEvent.ActionSpec.NumDiscreteActions); + + Assert.AreEqual(2, remotePolicyEvent.ActuatorInfos[0].NumContinuousActions); + Assert.AreEqual(0, remotePolicyEvent.ActuatorInfos[0].NumDiscreteActions); + } + + [Test] + public void TestRemotePolicy() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + + using (new AnalyticsUtils.DisableAnalyticsSending()) + { + var actionSpec = ActionSpec.MakeContinuous(3); + var policy = new RemotePolicy(actionSpec, Array.Empty(), "TestBehavior?team=42"); + policy.RequestDecision(new AgentInfo(), new List()); + } + + Academy.Instance.Dispose(); + } + + [TestCase("a name we expect to hash", ExpectedResult = "d084a8b6da6a6a1c097cdc9ffea95e1546da4647352113ed77cbe7b4192e6d73")] + [TestCase("another_name", ExpectedResult = "0b74613c872e79aba11e06eda3538f2b646eb2b459e75087829ea500bd703d0b")] + [TestCase("0b74613c872e79aba11e06eda3538f2b646eb2b459e75087829ea500bd703d0b", ExpectedResult = "0b74613c872e79aba11e06eda3538f2b646eb2b459e75087829ea500bd703d0b")] + public string TestTrainingBehaviorInitialized(string stringToMaybeHash) + { + var tbiEvent = new TrainingBehaviorInitializedEvent(); + tbiEvent.BehaviorName = stringToMaybeHash; + tbiEvent.Config = "{}"; + + var sanitizedEvent = TrainingAnalytics.SanitizeTrainingBehaviorInitializedEvent(tbiEvent); + return sanitizedEvent.BehaviorName; + } + + [Test] + public void TestEnableAnalytics() + { +#if UNITY_EDITOR && MLA_UNITY_ANALYTICS_MODULE + Assert.IsTrue(EditorAnalytics.enabled == TrainingAnalytics.EnableAnalytics()); +#else + Assert.IsFalse(TrainingAnalytics.EnableAnalytics()); +#endif + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs.meta b/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs.meta new file mode 100644 index 0000000000..df394c157a --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 70b8f1544bc34b4e8f1bc1068c64f01c +timeCreated: 1610419546 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Areas.meta b/com.unity.ml-agents/Tests/Editor/Areas.meta new file mode 100644 index 0000000000..42901a0e6b --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Areas.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: d32a102dc1f004c33b05a30190a9d039 +timeCreated: 1632841906 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Areas/TrainingAreaReplicatorTests.cs b/com.unity.ml-agents/Tests/Editor/Areas/TrainingAreaReplicatorTests.cs new file mode 100644 index 0000000000..f1046ebf68 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Areas/TrainingAreaReplicatorTests.cs @@ -0,0 +1,61 @@ +using System.Linq; +using NUnit.Framework; +using Unity.Mathematics; +using Unity.MLAgents.Areas; +using UnityEngine; + +namespace Unity.MLAgents.Tests.Areas +{ + [TestFixture] + public class TrainingAreaReplicatorTests + { + private TrainingAreaReplicator m_Replicator; + + [SetUp] + public void Setup() + { + var gameObject = new GameObject(); + var trainingArea = new GameObject(); + trainingArea.name = "MyTrainingArea"; + m_Replicator = gameObject.AddComponent(); + m_Replicator.baseArea = trainingArea; + } + + private static object[] NumAreasCases = + { + new object[] {1}, + new object[] {2}, + new object[] {5}, + new object[] {7}, + new object[] {8}, + new object[] {64}, + new object[] {63}, + }; + + [TestCaseSource(nameof(NumAreasCases))] + public void TestComputeGridSize(int numAreas) + { + m_Replicator.numAreas = numAreas; + m_Replicator.Awake(); + m_Replicator.OnEnable(); + var m_CorrectGridSize = int3.zero; + var m_RootNumAreas = Mathf.Pow(numAreas, 1.0f / 3.0f); + m_CorrectGridSize.x = Mathf.CeilToInt(m_RootNumAreas); + m_CorrectGridSize.y = Mathf.CeilToInt(m_RootNumAreas); + m_CorrectGridSize.z = Mathf.CeilToInt((float)numAreas / (m_CorrectGridSize.x * m_CorrectGridSize.y)); + Assert.GreaterOrEqual(m_Replicator.GridSize.x * m_Replicator.GridSize.y * m_Replicator.GridSize.z, m_Replicator.numAreas); + Assert.AreEqual(m_CorrectGridSize, m_Replicator.GridSize); + } + + [Test] + public void TestAddEnvironments() + { + m_Replicator.numAreas = 10; + m_Replicator.Awake(); + m_Replicator.OnEnable(); + var trainingAreas = Resources.FindObjectsOfTypeAll().Where(obj => obj.name == m_Replicator.TrainingAreaName); + Assert.AreEqual(10, trainingAreas.Count()); + + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Areas/TrainingAreaReplicatorTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Areas/TrainingAreaReplicatorTests.cs.meta new file mode 100644 index 0000000000..4ebc4ba4d1 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Areas/TrainingAreaReplicatorTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 391a03d82068e44b5bba0ca55215b0c7 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs b/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs new file mode 100644 index 0000000000..78f5891785 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs @@ -0,0 +1,76 @@ +using NUnit.Framework; +using Unity.Barracuda; +using Unity.MLAgents.Actuators; +using UnityEngine; +using Unity.MLAgents.Policies; +using UnityEditor; +using UnityEngine.TestTools; + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class BehaviorParameterTests : IHeuristicProvider + { + const string k_continuousONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_v1_0.onnx"; + public void Heuristic(in ActionBuffers actionsOut) + { + // No-op + } + + [Test] + public void TestNoModelInferenceOnlyThrows() + { + var gameObj = new GameObject(); + var bp = gameObj.AddComponent(); + bp.BehaviorType = BehaviorType.InferenceOnly; + var actionSpec = new ActionSpec(); + + Assert.Throws(() => + { + bp.GeneratePolicy(actionSpec, new ActuatorManager()); + }); + } + + [Test] + public void TestIsInHeuristicMode() + { + var gameObj = new GameObject(); + var bp = gameObj.AddComponent(); + bp.Model = null; + gameObj.AddComponent(); + bp.BehaviorType = BehaviorType.HeuristicOnly; + Assert.IsTrue(bp.IsInHeuristicMode()); + + bp.BehaviorType = BehaviorType.Default; + Assert.IsTrue(bp.IsInHeuristicMode()); + + bp.Model = ScriptableObject.CreateInstance(); + Assert.IsFalse(bp.IsInHeuristicMode()); + } + + [Test] + public void TestPolicyUpdateEventFired() + { + var gameObj = new GameObject(); + var bp = gameObj.AddComponent(); + gameObj.AddComponent().LazyInitialize(); + bp.OnPolicyUpdated += delegate (bool isInHeuristicMode) { Debug.Log($"OnPolicyChanged:{isInHeuristicMode}"); }; + bp.BehaviorType = BehaviorType.HeuristicOnly; + LogAssert.Expect(LogType.Log, $"OnPolicyChanged:{true}"); + + bp.BehaviorType = BehaviorType.Default; + LogAssert.Expect(LogType.Log, $"OnPolicyChanged:{true}"); + + Assert.Throws(() => + { + bp.BehaviorType = BehaviorType.InferenceOnly; + }); + + bp.Model = AssetDatabase.LoadAssetAtPath(k_continuousONNXPath); + LogAssert.Expect(LogType.Log, $"OnPolicyChanged:{false}"); + + bp.BehaviorType = BehaviorType.HeuristicOnly; + LogAssert.Expect(LogType.Log, $"OnPolicyChanged:{true}"); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs.meta b/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs.meta new file mode 100644 index 0000000000..0656104d8c --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 877266b9e1bfe4330a68ab5f2da1836b +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Communicator.meta b/com.unity.ml-agents/Tests/Editor/Communicator.meta new file mode 100644 index 0000000000..170d273a27 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Communicator.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: da8f640243c749388a0329393c8fce64 +timeCreated: 1586386315 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs new file mode 100644 index 0000000000..34bd3caaea --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs @@ -0,0 +1,276 @@ +using System; +using System.Text.RegularExpressions; +using Google.Protobuf; +using NUnit.Framework; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Demonstrations; +using Unity.MLAgents.Policies; +using Unity.MLAgents.Sensors; + +using Unity.MLAgents.Analytics; +using Unity.MLAgents.CommunicatorObjects; +using UnityEngine; +using UnityEngine.TestTools; + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class GrpcExtensionsTests + { + [SetUp] + public void SetUp() + { + Academy.Instance.TrainerCapabilities = new UnityRLCapabilities(); + } + + [Test] + public void TestDefaultBrainParametersToProto() + { + // Should be able to convert a default instance to proto. + var brain = new BrainParameters(); + brain.ToProto("foo", false); + Academy.Instance.TrainerCapabilities = new UnityRLCapabilities + { + BaseRLCapabilities = true, + HybridActions = false + }; + brain.ToProto("foo", false); + } + + [Test] + public void TestDefaultActionSpecToProto() + { + // Should be able to convert a default instance to proto. + var actionSpec = new ActionSpec(); + actionSpec.ToBrainParametersProto("foo", false); + Academy.Instance.TrainerCapabilities = new UnityRLCapabilities + { + BaseRLCapabilities = true, + HybridActions = false + }; + actionSpec.ToBrainParametersProto("foo", false); + + Academy.Instance.TrainerCapabilities = new UnityRLCapabilities(); + // Continuous + actionSpec = ActionSpec.MakeContinuous(3); + actionSpec.ToBrainParametersProto("foo", false); + Academy.Instance.TrainerCapabilities = new UnityRLCapabilities + { + BaseRLCapabilities = true, + HybridActions = false + }; + actionSpec.ToBrainParametersProto("foo", false); + + Academy.Instance.TrainerCapabilities = new UnityRLCapabilities(); + + // Discrete + actionSpec = ActionSpec.MakeDiscrete(1, 2, 3); + actionSpec.ToBrainParametersProto("foo", false); + Academy.Instance.TrainerCapabilities = new UnityRLCapabilities + { + BaseRLCapabilities = true, + HybridActions = false + }; + actionSpec.ToBrainParametersProto("foo", false); + } + + [Test] + public void ToBrainParameters() + { + // Should be able to convert a default instance to proto. + var actionSpec = new ActionSpec(); + actionSpec.ToBrainParametersProto("foo", false).ToBrainParameters(); + Academy.Instance.TrainerCapabilities = new UnityRLCapabilities + { + BaseRLCapabilities = true, + HybridActions = false + }; + actionSpec.ToBrainParametersProto("foo", false).ToBrainParameters(); + + Academy.Instance.TrainerCapabilities = new UnityRLCapabilities(); + // Continuous + actionSpec = ActionSpec.MakeContinuous(3); + actionSpec.ToBrainParametersProto("foo", false).ToBrainParameters(); + Academy.Instance.TrainerCapabilities = new UnityRLCapabilities + { + BaseRLCapabilities = true, + HybridActions = false + }; + actionSpec.ToBrainParametersProto("foo", false).ToBrainParameters(); + + Academy.Instance.TrainerCapabilities = new UnityRLCapabilities(); + + // Discrete + actionSpec = ActionSpec.MakeDiscrete(1, 2, 3); + actionSpec.ToBrainParametersProto("foo", false).ToBrainParameters(); + Academy.Instance.TrainerCapabilities = new UnityRLCapabilities + { + BaseRLCapabilities = true, + HybridActions = false + }; + actionSpec.ToBrainParametersProto("foo", false).ToBrainParameters(); + } + + [Test] + public void TestDefaultAgentInfoToProto() + { + // Should be able to convert a default instance to proto. + var agentInfo = new AgentInfo(); + var pairProto = agentInfo.ToInfoActionPairProto(); + pairProto.AgentInfo.Observations.Add(new ObservationProto + { + CompressedData = ByteString.Empty, + CompressionType = CompressionTypeProto.None, + FloatData = new ObservationProto.Types.FloatData(), + ObservationType = ObservationTypeProto.Default, + Name = "Sensor" + }); + pairProto.AgentInfo.Observations[0].Shape.Add(0); + pairProto.GetObservationSummaries(); + agentInfo.ToAgentInfoProto(); + agentInfo.groupId = 1; + Academy.Instance.TrainerCapabilities = new UnityRLCapabilities + { + BaseRLCapabilities = true, + MultiAgentGroups = false + }; + agentInfo.ToAgentInfoProto(); + LogAssert.Expect(LogType.Warning, new Regex(".+")); + Academy.Instance.TrainerCapabilities = new UnityRLCapabilities + { + BaseRLCapabilities = true, + MultiAgentGroups = true + }; + agentInfo.ToAgentInfoProto(); + } + + [Test] + public void TestDefaultDemonstrationMetaDataToProto() + { + // Should be able to convert a default instance to proto. + var demoMetaData = new DemonstrationMetaData(); + demoMetaData.ToProto(); + } + + class DummySensor : ISensor + { + public ObservationSpec ObservationSpec; + public SensorCompressionType CompressionType; + + public ObservationSpec GetObservationSpec() + { + return ObservationSpec; + } + + public int Write(ObservationWriter writer) + { + return 0; + } + + public byte[] GetCompressedObservation() + { + return new byte[] { 13, 37 }; + } + + public void Update() { } + + public void Reset() { } + + public CompressionSpec GetCompressionSpec() + { + return new CompressionSpec(CompressionType); + } + + public string GetName() + { + return "Dummy"; + } + } + + [Test] + public void TestGetObservationProtoCapabilities() + { + // Shape, compression type, concatenatedPngObservations, expect throw + var variants = new[] + { + // Vector observations + (new[] {3}, SensorCompressionType.None, false, false), + // Uncompressed floats + (new[] {4, 4, 3}, SensorCompressionType.None, false, false), + // Compressed floats, 3 channels + (new[] {4, 4, 3}, SensorCompressionType.PNG, false, true), + + // Compressed floats, >3 channels + (new[] {4, 4, 4}, SensorCompressionType.PNG, false, false), // Unsupported - results in uncompressed + (new[] {4, 4, 4}, SensorCompressionType.PNG, true, true), // Supported compressed + }; + + foreach (var (shape, compressionType, supportsMultiPngObs, expectCompressed) in variants) + { + var inplaceShape = InplaceArray.FromList(shape); + var dummySensor = new DummySensor(); + var obsWriter = new ObservationWriter(); + + if (shape.Length == 1) + { + dummySensor.ObservationSpec = ObservationSpec.Vector(shape[0]); + } + else if (shape.Length == 3) + { + dummySensor.ObservationSpec = ObservationSpec.Visual(shape[0], shape[1], shape[2]); + } + else + { + throw new ArgumentOutOfRangeException(); + } + dummySensor.CompressionType = compressionType; + obsWriter.SetTarget(new float[128], inplaceShape, 0); + + var caps = new UnityRLCapabilities + { + ConcatenatedPngObservations = supportsMultiPngObs + }; + Academy.Instance.TrainerCapabilities = caps; + + + var obsProto = dummySensor.GetObservationProto(obsWriter); + if (expectCompressed) + { + Assert.Greater(obsProto.CompressedData.Length, 0); + Assert.AreEqual(obsProto.FloatData, null); + } + else + { + Assert.Greater(obsProto.FloatData.Data.Count, 0); + Assert.AreEqual(obsProto.CompressedData.Length, 0); + } + } + } + + [Test] + public void TestDefaultTrainingEvents() + { + var trainingEnvInit = new TrainingEnvironmentInitialized + { + PythonVersion = "test", + }; + var trainingEnvInitEvent = trainingEnvInit.ToTrainingEnvironmentInitializedEvent(); + Assert.AreEqual(trainingEnvInit.PythonVersion, trainingEnvInitEvent.TrainerPythonVersion); + + var trainingBehavInit = new TrainingBehaviorInitialized + { + BehaviorName = "testBehavior", + ExtrinsicRewardEnabled = true, + CuriosityRewardEnabled = true, + + RecurrentEnabled = true, + SelfPlayEnabled = true, + }; + var trainingBehavInitEvent = trainingBehavInit.ToTrainingBehaviorInitializedEvent(); + Assert.AreEqual(trainingBehavInit.BehaviorName, trainingBehavInitEvent.BehaviorName); + + Assert.AreEqual(RewardSignals.Extrinsic | RewardSignals.Curiosity, trainingBehavInitEvent.RewardSignalFlags); + Assert.AreEqual(TrainingFeatures.Recurrent | TrainingFeatures.SelfPlay, trainingBehavInitEvent.TrainingFeatureFlags); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs.meta new file mode 100644 index 0000000000..411f1cd45e --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 7aa28d0e370064c18bb8a913417ad21d +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Communicator/RpcCommunicatorTests.cs b/com.unity.ml-agents/Tests/Editor/Communicator/RpcCommunicatorTests.cs new file mode 100644 index 0000000000..4f62672f71 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Communicator/RpcCommunicatorTests.cs @@ -0,0 +1,42 @@ +using NUnit.Framework; +using UnityEngine.TestTools; + +namespace Unity.MLAgents.Tests.Communicator +{ + [TestFixture] + public class RpcCommunicatorTests + { + + [Test] + public void TestCheckCommunicationVersionsAreCompatible() + { + var unityVerStr = "1.0.0"; + var pythonVerStr = "1.0.0"; + + Assert.IsTrue(RpcCommunicator.CheckCommunicationVersionsAreCompatible(unityVerStr, + pythonVerStr)); + LogAssert.NoUnexpectedReceived(); + + pythonVerStr = "1.1.0"; + Assert.IsTrue(RpcCommunicator.CheckCommunicationVersionsAreCompatible(unityVerStr, + pythonVerStr)); + LogAssert.NoUnexpectedReceived(); + + unityVerStr = "2.0.0"; + Assert.IsFalse(RpcCommunicator.CheckCommunicationVersionsAreCompatible(unityVerStr, + pythonVerStr)); + + unityVerStr = "0.15.0"; + pythonVerStr = "0.15.0"; + Assert.IsTrue(RpcCommunicator.CheckCommunicationVersionsAreCompatible(unityVerStr, + pythonVerStr)); + unityVerStr = "0.16.0"; + Assert.IsFalse(RpcCommunicator.CheckCommunicationVersionsAreCompatible(unityVerStr, + pythonVerStr)); + unityVerStr = "1.15.0"; + Assert.IsFalse(RpcCommunicator.CheckCommunicationVersionsAreCompatible(unityVerStr, + pythonVerStr)); + + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Communicator/RpcCommunicatorTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Communicator/RpcCommunicatorTests.cs.meta new file mode 100644 index 0000000000..1d0689e5cb --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Communicator/RpcCommunicatorTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 251fab8dff424abb95b2b381c7c924c3 +timeCreated: 1586386329 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Communicator/UnityRLCapabilitiesTests.cs b/com.unity.ml-agents/Tests/Editor/Communicator/UnityRLCapabilitiesTests.cs new file mode 100644 index 0000000000..fd3ef3acee --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Communicator/UnityRLCapabilitiesTests.cs @@ -0,0 +1,23 @@ +using System.Text.RegularExpressions; +using NUnit.Framework; +using UnityEngine; +using UnityEngine.TestTools; + +namespace Unity.MLAgents.Tests.Communicator +{ + [TestFixture] + public class UnityRLCapabilitiesTests + { + [Test] + public void TestWarnOnPythonMissingBaseRLCapabilities() + { + var caps = new UnityRLCapabilities(); + Assert.False(caps.WarnOnPythonMissingBaseRLCapabilities()); + LogAssert.NoUnexpectedReceived(); + caps = new UnityRLCapabilities(false); + Assert.True(caps.WarnOnPythonMissingBaseRLCapabilities()); + LogAssert.Expect(LogType.Warning, new Regex(".+")); + } + + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Communicator/UnityRLCapabilitiesTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Communicator/UnityRLCapabilitiesTests.cs.meta new file mode 100644 index 0000000000..fc7b8cec7a --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Communicator/UnityRLCapabilitiesTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: e6a3e82911b84029a446dcfd2d8af520 +timeCreated: 1587695055 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs b/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs new file mode 100644 index 0000000000..55060f5645 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs @@ -0,0 +1,150 @@ +using NUnit.Framework; +using UnityEngine; +using System.IO.Abstractions.TestingHelpers; +using System.Reflection; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.CommunicatorObjects; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Demonstrations; +using Unity.MLAgents.Policies; +using Unity.MLAgents.Utils.Tests; + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class DemonstrationTests + { + const string k_DemoDirectory = "Assets/Demonstrations/"; + const string k_ExtensionType = ".demo"; + const string k_DemoName = "Test"; + + [SetUp] + public void SetUp() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + } + + [Test] + public void TestSanitization() + { + const string dirtyString = "abc1234567&!@"; + const string knownCleanString = "abc123"; + var cleanString = DemonstrationRecorder.SanitizeName(dirtyString, 6); + Assert.AreNotEqual(dirtyString, cleanString); + Assert.AreEqual(cleanString, knownCleanString); + } + + [Test] + public void TestStoreInitialize() + { + var fileSystem = new MockFileSystem(); + + var gameobj = new GameObject("gameObj"); + + var bp = gameobj.AddComponent(); + bp.BrainParameters.VectorObservationSize = 3; + bp.BrainParameters.NumStackedVectorObservations = 2; + bp.BrainParameters.VectorActionDescriptions = new[] { "TestActionA", "TestActionB" }; + bp.BrainParameters.ActionSpec = ActionSpec.MakeDiscrete(2, 2); + + gameobj.AddComponent(); + + Assert.IsFalse(fileSystem.Directory.Exists(k_DemoDirectory)); + + var demoRec = gameobj.AddComponent(); + demoRec.Record = true; + demoRec.DemonstrationName = k_DemoName; + demoRec.DemonstrationDirectory = k_DemoDirectory; + var demoWriter = demoRec.LazyInitialize(fileSystem); + + Assert.IsTrue(fileSystem.Directory.Exists(k_DemoDirectory)); + Assert.IsTrue(fileSystem.FileExists(k_DemoDirectory + k_DemoName + k_ExtensionType)); + + var agentInfo = new AgentInfo + { + reward = 1f, + discreteActionMasks = new[] { false, true }, + done = true, + episodeId = 5, + maxStepReached = true, + storedActions = new ActionBuffers(null, new[] { 0, 1 }), + }; + + + demoWriter.Record(agentInfo, new System.Collections.Generic.List()); + demoRec.Close(); + + // Make sure close can be called multiple times + demoWriter.Close(); + demoRec.Close(); + + // Make sure trying to write after closing doesn't raise an error. + demoWriter.Record(agentInfo, new System.Collections.Generic.List()); + } + + public class ObservationAgent : TestAgent + { + public override void CollectObservations(VectorSensor sensor) + { + collectObservationsCalls += 1; + sensor.AddObservation(1f); + sensor.AddObservation(2f); + sensor.AddObservation(3f); + } + } + + [Test] + public void TestAgentWrite() + { + var agentGo1 = new GameObject("TestAgent"); + var bpA = agentGo1.AddComponent(); + bpA.BrainParameters.VectorObservationSize = 3; + bpA.BrainParameters.NumStackedVectorObservations = 1; + bpA.BrainParameters.VectorActionDescriptions = new[] { "TestActionA", "TestActionB" }; + bpA.BrainParameters.ActionSpec = ActionSpec.MakeDiscrete(2, 2); + + agentGo1.AddComponent(); + var agent1 = agentGo1.GetComponent(); + + agentGo1.AddComponent(); + var demoRecorder = agentGo1.GetComponent(); + var fileSystem = new MockFileSystem(); + demoRecorder.DemonstrationDirectory = k_DemoDirectory; + demoRecorder.DemonstrationName = "TestBrain"; + demoRecorder.Record = true; + demoRecorder.LazyInitialize(fileSystem); + + var agentEnableMethod = typeof(Agent).GetMethod("OnEnable", + BindingFlags.Instance | BindingFlags.NonPublic); + var agentSendInfo = typeof(Agent).GetMethod("SendInfo", + BindingFlags.Instance | BindingFlags.NonPublic); + + agentEnableMethod?.Invoke(agent1, new object[] { }); + + // Step the agent + agent1.RequestDecision(); + agentSendInfo?.Invoke(agent1, new object[] { }); + + demoRecorder.Close(); + + // Read back the demo file and make sure observations were written + var reader = fileSystem.File.OpenRead("Assets/Demonstrations/TestBrain.demo"); + reader.Seek(DemonstrationWriter.MetaDataBytes + 1, 0); + BrainParametersProto.Parser.ParseDelimitedFrom(reader); + + var agentInfoProto = AgentInfoActionPairProto.Parser.ParseDelimitedFrom(reader).AgentInfo; + var obs = agentInfoProto.Observations[2]; // skip dummy sensors + { + var vecObs = obs.FloatData.Data; + Assert.AreEqual(bpA.BrainParameters.VectorObservationSize, vecObs.Count); + for (var i = 0; i < vecObs.Count; i++) + { + Assert.AreEqual((float)i + 1, vecObs[i]); + } + } + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs.meta b/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs.meta new file mode 100644 index 0000000000..434861ff9b --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 4c5a970f5b6be4b57b3bd7a5f84c3623 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Inference.meta b/com.unity.ml-agents/Tests/Editor/Inference.meta new file mode 100644 index 0000000000..1427471ab0 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Inference.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 7b8fc3bc69d3a4cd9a66ad334f944fb2 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Inference/DiscreteActionOutputApplierTest.cs b/com.unity.ml-agents/Tests/Editor/Inference/DiscreteActionOutputApplierTest.cs new file mode 100644 index 0000000000..58e6c9a95b --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Inference/DiscreteActionOutputApplierTest.cs @@ -0,0 +1,86 @@ +using System.Collections.Generic; +using Unity.Barracuda; +using NUnit.Framework; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Inference; + +namespace Unity.MLAgents.Tests +{ + + public class DiscreteActionOutputApplierTest + { + [Test] + public void TestDiscreteApply() + { + var actionSpec = ActionSpec.MakeDiscrete(3, 2); + + var applier = new DiscreteActionOutputApplier(actionSpec, 2020, null); + var agentIds = new List { 42, 1337 }; + var actionBuffers = new Dictionary(); + actionBuffers[42] = new ActionBuffers(actionSpec); + actionBuffers[1337] = new ActionBuffers(actionSpec); + + var actionTensor = new TensorProxy + { + data = new Tensor( + 2, + 2, + new[] + { + 2.0f, // Agent 0, branch 0 + 1.0f, // Agent 0, branch 1 + 0.0f, // Agent 1, branch 0 + 0.0f // Agent 1, branch 1 + }), + shape = new long[] { 2, 2 }, + valueType = TensorProxy.TensorType.FloatingPoint + }; + + applier.Apply(actionTensor, agentIds, actionBuffers); + Assert.AreEqual(2, actionBuffers[42].DiscreteActions[0]); + Assert.AreEqual(1, actionBuffers[42].DiscreteActions[1]); + + Assert.AreEqual(0, actionBuffers[1337].DiscreteActions[0]); + Assert.AreEqual(0, actionBuffers[1337].DiscreteActions[1]); + } + } + + public class LegacyDiscreteActionOutputApplierTest + { + [Test] + public void TestDiscreteApply() + { + var actionSpec = ActionSpec.MakeDiscrete(3, 2); + const float smallLogProb = -1000.0f; + const float largeLogProb = -1.0f; + + var logProbs = new TensorProxy + { + data = new Tensor( + 2, + 5, + new[] + { + smallLogProb, smallLogProb, largeLogProb, // Agent 0, branch 0 + smallLogProb, largeLogProb, // Agent 0, branch 1 + largeLogProb, smallLogProb, smallLogProb, // Agent 1, branch 0 + largeLogProb, smallLogProb, // Agent 1, branch 1 + }), + valueType = TensorProxy.TensorType.FloatingPoint + }; + + var applier = new LegacyDiscreteActionOutputApplier(actionSpec, 2020, null); + var agentIds = new List { 42, 1337 }; + var actionBuffers = new Dictionary(); + actionBuffers[42] = new ActionBuffers(actionSpec); + actionBuffers[1337] = new ActionBuffers(actionSpec); + + applier.Apply(logProbs, agentIds, actionBuffers); + Assert.AreEqual(2, actionBuffers[42].DiscreteActions[0]); + Assert.AreEqual(1, actionBuffers[42].DiscreteActions[1]); + + Assert.AreEqual(0, actionBuffers[1337].DiscreteActions[0]); + Assert.AreEqual(0, actionBuffers[1337].DiscreteActions[1]); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Inference/DiscreteActionOutputApplierTest.cs.meta b/com.unity.ml-agents/Tests/Editor/Inference/DiscreteActionOutputApplierTest.cs.meta new file mode 100644 index 0000000000..ac93a50aab --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Inference/DiscreteActionOutputApplierTest.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: aa4c4ceac5f246a0b341958724ecd752 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Inference/EditModeTestInternalBrainTensorApplier.cs b/com.unity.ml-agents/Tests/Editor/Inference/EditModeTestInternalBrainTensorApplier.cs new file mode 100644 index 0000000000..e679254908 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Inference/EditModeTestInternalBrainTensorApplier.cs @@ -0,0 +1,200 @@ +using System.Collections.Generic; +using NUnit.Framework; +using Unity.Barracuda; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Inference; + +namespace Unity.MLAgents.Tests +{ + public class EditModeTestInternalBrainTensorApplier + { + class TestAgent : Agent + { + } + + [Test] + public void Construction() + { + var actionSpec = new ActionSpec(); + var alloc = new TensorCachingAllocator(); + var mem = new Dictionary>(); + var tensorGenerator = new TensorApplier(actionSpec, 0, alloc, mem); + Assert.IsNotNull(tensorGenerator); + alloc.Dispose(); + } + + [Test] + public void ApplyContinuousActionOutput() + { + var actionSpec = ActionSpec.MakeContinuous(3); + var inputTensor = new TensorProxy() + { + shape = new long[] { 2, 3 }, + data = new Tensor(2, 3, new float[] { 1, 2, 3, 4, 5, 6 }) + }; + + var applier = new ContinuousActionOutputApplier(actionSpec); + + var agentIds = new List() { 0, 1 }; + // Dictionary from AgentId to Action + var actionDict = new Dictionary() { { 0, ActionBuffers.Empty }, { 1, ActionBuffers.Empty } }; + + applier.Apply(inputTensor, agentIds, actionDict); + + + Assert.AreEqual(actionDict[0].ContinuousActions[0], 1); + Assert.AreEqual(actionDict[0].ContinuousActions[1], 2); + Assert.AreEqual(actionDict[0].ContinuousActions[2], 3); + + Assert.AreEqual(actionDict[1].ContinuousActions[0], 4); + Assert.AreEqual(actionDict[1].ContinuousActions[1], 5); + Assert.AreEqual(actionDict[1].ContinuousActions[2], 6); + } + + [Test] + public void ApplyDiscreteActionOutputLegacy() + { + var actionSpec = ActionSpec.MakeDiscrete(2, 3); + var inputTensor = new TensorProxy() + { + shape = new long[] { 2, 5 }, + data = new Tensor( + 2, + 5, + new[] { 0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f }) + }; + var alloc = new TensorCachingAllocator(); + var applier = new LegacyDiscreteActionOutputApplier(actionSpec, 0, alloc); + + var agentIds = new List() { 0, 1 }; + // Dictionary from AgentId to Action + var actionDict = new Dictionary() { { 0, ActionBuffers.Empty }, { 1, ActionBuffers.Empty } }; + + + applier.Apply(inputTensor, agentIds, actionDict); + + Assert.AreEqual(actionDict[0].DiscreteActions[0], 1); + Assert.AreEqual(actionDict[0].DiscreteActions[1], 1); + + Assert.AreEqual(actionDict[1].DiscreteActions[0], 1); + Assert.AreEqual(actionDict[1].DiscreteActions[1], 2); + alloc.Dispose(); + } + + [Test] + public void ApplyDiscreteActionOutput() + { + var actionSpec = ActionSpec.MakeDiscrete(2, 3); + var inputTensor = new TensorProxy() + { + shape = new long[] { 2, 2 }, + data = new Tensor( + 2, + 2, + new[] { 1f, 1f, 1f, 2f }), + }; + var alloc = new TensorCachingAllocator(); + var applier = new DiscreteActionOutputApplier(actionSpec, 0, alloc); + + var agentIds = new List() { 0, 1 }; + // Dictionary from AgentId to Action + var actionDict = new Dictionary() { { 0, ActionBuffers.Empty }, { 1, ActionBuffers.Empty } }; + + + applier.Apply(inputTensor, agentIds, actionDict); + + Assert.AreEqual(actionDict[0].DiscreteActions[0], 1); + Assert.AreEqual(actionDict[0].DiscreteActions[1], 1); + + Assert.AreEqual(actionDict[1].DiscreteActions[0], 1); + Assert.AreEqual(actionDict[1].DiscreteActions[1], 2); + alloc.Dispose(); + } + + [Test] + public void ApplyHybridActionOutputLegacy() + { + var actionSpec = new ActionSpec(3, new[] { 2, 3 }); + var continuousInputTensor = new TensorProxy() + { + shape = new long[] { 2, 3 }, + data = new Tensor(2, 3, new float[] { 1, 2, 3, 4, 5, 6 }) + }; + var discreteInputTensor = new TensorProxy() + { + shape = new long[] { 2, 8 }, + data = new Tensor( + 2, + 5, + new[] { 0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f }) + }; + var continuousApplier = new ContinuousActionOutputApplier(actionSpec); + var alloc = new TensorCachingAllocator(); + var discreteApplier = new LegacyDiscreteActionOutputApplier(actionSpec, 0, alloc); + + var agentIds = new List() { 0, 1 }; + // Dictionary from AgentId to Action + var actionDict = new Dictionary() { { 0, ActionBuffers.Empty }, { 1, ActionBuffers.Empty } }; + + + continuousApplier.Apply(continuousInputTensor, agentIds, actionDict); + discreteApplier.Apply(discreteInputTensor, agentIds, actionDict); + + Assert.AreEqual(actionDict[0].ContinuousActions[0], 1); + Assert.AreEqual(actionDict[0].ContinuousActions[1], 2); + Assert.AreEqual(actionDict[0].ContinuousActions[2], 3); + Assert.AreEqual(actionDict[0].DiscreteActions[0], 1); + Assert.AreEqual(actionDict[0].DiscreteActions[1], 1); + + Assert.AreEqual(actionDict[1].ContinuousActions[0], 4); + Assert.AreEqual(actionDict[1].ContinuousActions[1], 5); + Assert.AreEqual(actionDict[1].ContinuousActions[2], 6); + Assert.AreEqual(actionDict[1].DiscreteActions[0], 1); + Assert.AreEqual(actionDict[1].DiscreteActions[1], 2); + alloc.Dispose(); + } + + [Test] + public void ApplyHybridActionOutput() + { + var actionSpec = new ActionSpec(3, new[] { 2, 3 }); + var continuousInputTensor = new TensorProxy() + { + shape = new long[] { 2, 3 }, + data = new Tensor(2, 3, new float[] { 1, 2, 3, 4, 5, 6 }) + }; + var discreteInputTensor = new TensorProxy() + { + shape = new long[] { 2, 2 }, + data = new Tensor( + 2, + 2, + new[] { 1f, 1f, 1f, 2f }), + }; + var continuousApplier = new ContinuousActionOutputApplier(actionSpec); + var alloc = new TensorCachingAllocator(); + var discreteApplier = new DiscreteActionOutputApplier(actionSpec, 0, alloc); + + var agentIds = new List() { 0, 1 }; + // Dictionary from AgentId to Action + var actionDict = new Dictionary() { { 0, ActionBuffers.Empty }, { 1, ActionBuffers.Empty } }; + + + continuousApplier.Apply(continuousInputTensor, agentIds, actionDict); + discreteApplier.Apply(discreteInputTensor, agentIds, actionDict); + + Assert.AreEqual(actionDict[0].ContinuousActions[0], 1); + Assert.AreEqual(actionDict[0].ContinuousActions[1], 2); + Assert.AreEqual(actionDict[0].ContinuousActions[2], 3); + Assert.AreEqual(actionDict[0].DiscreteActions[0], 1); + Assert.AreEqual(actionDict[0].DiscreteActions[1], 1); + + Assert.AreEqual(actionDict[1].ContinuousActions[0], 4); + Assert.AreEqual(actionDict[1].ContinuousActions[1], 5); + Assert.AreEqual(actionDict[1].ContinuousActions[2], 6); + Assert.AreEqual(actionDict[1].DiscreteActions[0], 1); + Assert.AreEqual(actionDict[1].DiscreteActions[1], 2); + alloc.Dispose(); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Inference/EditModeTestInternalBrainTensorApplier.cs.meta b/com.unity.ml-agents/Tests/Editor/Inference/EditModeTestInternalBrainTensorApplier.cs.meta new file mode 100644 index 0000000000..98212413d0 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Inference/EditModeTestInternalBrainTensorApplier.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: be419f7ed5c24b24a6f2636d3b107535 +timeCreated: 1537915674 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Inference/EditModeTestInternalBrainTensorGenerator.cs b/com.unity.ml-agents/Tests/Editor/Inference/EditModeTestInternalBrainTensorGenerator.cs new file mode 100644 index 0000000000..579a204c4b --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Inference/EditModeTestInternalBrainTensorGenerator.cs @@ -0,0 +1,191 @@ +using System.Collections.Generic; +using Unity.Barracuda; +using NUnit.Framework; +using UnityEngine; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Inference; +using Unity.MLAgents.Policies; +using Unity.MLAgents.Utils.Tests; + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class EditModeTestInternalBrainTensorGenerator + { + [SetUp] + public void SetUp() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + } + + static List GetFakeAgents(ObservableAttributeOptions observableAttributeOptions = ObservableAttributeOptions.Ignore) + { + var goA = new GameObject("goA"); + var bpA = goA.AddComponent(); + bpA.BrainParameters.VectorObservationSize = 3; + bpA.BrainParameters.NumStackedVectorObservations = 1; + bpA.ObservableAttributeHandling = observableAttributeOptions; + var agentA = goA.AddComponent(); + + var goB = new GameObject("goB"); + var bpB = goB.AddComponent(); + bpB.BrainParameters.VectorObservationSize = 3; + bpB.BrainParameters.NumStackedVectorObservations = 1; + bpB.ObservableAttributeHandling = observableAttributeOptions; + var agentB = goB.AddComponent(); + + var agents = new List { agentA, agentB }; + foreach (var agent in agents) + { + agent.LazyInitialize(); + } + agentA.collectObservationsSensor.AddObservation(new Vector3(1, 2, 3)); + agentB.collectObservationsSensor.AddObservation(new Vector3(4, 5, 6)); + + var infoA = new AgentInfo + { + storedActions = new ActionBuffers(null, new[] { 1, 2 }), + discreteActionMasks = null, + }; + + var infoB = new AgentInfo + { + storedActions = new ActionBuffers(null, new[] { 3, 4 }), + discreteActionMasks = new[] { true, false, false, false, false }, + }; + + + agentA._Info = infoA; + agentB._Info = infoB; + return agents; + } + + [Test] + public void Construction() + { + var alloc = new TensorCachingAllocator(); + var mem = new Dictionary>(); + var tensorGenerator = new TensorGenerator(0, alloc, mem); + Assert.IsNotNull(tensorGenerator); + alloc.Dispose(); + } + + [Test] + public void GenerateBatchSize() + { + var inputTensor = new TensorProxy(); + var alloc = new TensorCachingAllocator(); + const int batchSize = 4; + var generator = new BatchSizeGenerator(alloc); + generator.Generate(inputTensor, batchSize, null); + Assert.IsNotNull(inputTensor.data); + Assert.AreEqual(inputTensor.data[0], batchSize); + alloc.Dispose(); + } + + [Test] + public void GenerateSequenceLength() + { + var inputTensor = new TensorProxy(); + var alloc = new TensorCachingAllocator(); + const int batchSize = 4; + var generator = new SequenceLengthGenerator(alloc); + generator.Generate(inputTensor, batchSize, null); + Assert.IsNotNull(inputTensor.data); + Assert.AreEqual(inputTensor.data[0], 1); + alloc.Dispose(); + } + + [Test] + public void GenerateVectorObservation() + { + var inputTensor = new TensorProxy + { + shape = new long[] { 2, 4 } + }; + const int batchSize = 4; + var agentInfos = GetFakeAgents(ObservableAttributeOptions.ExamineAll); + var alloc = new TensorCachingAllocator(); + var generator = new ObservationGenerator(alloc); + generator.AddSensorIndex(0); // ObservableAttribute (size 1) + generator.AddSensorIndex(1); // TestSensor (size 0) + generator.AddSensorIndex(2); // TestSensor (size 0) + generator.AddSensorIndex(3); // VectorSensor (size 3) + var agent0 = agentInfos[0]; + var agent1 = agentInfos[1]; + var inputs = new List + { + new AgentInfoSensorsPair {agentInfo = agent0._Info, sensors = agent0.sensors}, + new AgentInfoSensorsPair {agentInfo = agent1._Info, sensors = agent1.sensors}, + }; + generator.Generate(inputTensor, batchSize, inputs); + Assert.IsNotNull(inputTensor.data); + Assert.AreEqual(inputTensor.data[0, 1], 1); + Assert.AreEqual(inputTensor.data[0, 3], 3); + Assert.AreEqual(inputTensor.data[1, 1], 4); + Assert.AreEqual(inputTensor.data[1, 3], 6); + alloc.Dispose(); + } + + [Test] + public void GeneratePreviousActionInput() + { + var inputTensor = new TensorProxy + { + shape = new long[] { 2, 2 }, + valueType = TensorProxy.TensorType.Integer + }; + const int batchSize = 4; + var agentInfos = GetFakeAgents(); + var alloc = new TensorCachingAllocator(); + var generator = new PreviousActionInputGenerator(alloc); + var agent0 = agentInfos[0]; + var agent1 = agentInfos[1]; + var inputs = new List + { + new AgentInfoSensorsPair {agentInfo = agent0._Info, sensors = agent0.sensors}, + new AgentInfoSensorsPair {agentInfo = agent1._Info, sensors = agent1.sensors}, + }; + generator.Generate(inputTensor, batchSize, inputs); + Assert.IsNotNull(inputTensor.data); + Assert.AreEqual(inputTensor.data[0, 0], 1); + Assert.AreEqual(inputTensor.data[0, 1], 2); + Assert.AreEqual(inputTensor.data[1, 0], 3); + Assert.AreEqual(inputTensor.data[1, 1], 4); + alloc.Dispose(); + } + + [Test] + public void GenerateActionMaskInput() + { + var inputTensor = new TensorProxy + { + shape = new long[] { 2, 5 }, + valueType = TensorProxy.TensorType.FloatingPoint + }; + const int batchSize = 4; + var agentInfos = GetFakeAgents(); + var alloc = new TensorCachingAllocator(); + var generator = new ActionMaskInputGenerator(alloc); + + var agent0 = agentInfos[0]; + var agent1 = agentInfos[1]; + var inputs = new List + { + new AgentInfoSensorsPair {agentInfo = agent0._Info, sensors = agent0.sensors}, + new AgentInfoSensorsPair {agentInfo = agent1._Info, sensors = agent1.sensors}, + }; + + generator.Generate(inputTensor, batchSize, inputs); + Assert.IsNotNull(inputTensor.data); + Assert.AreEqual(inputTensor.data[0, 0], 1); + Assert.AreEqual(inputTensor.data[0, 4], 1); + Assert.AreEqual(inputTensor.data[1, 0], 0); + Assert.AreEqual(inputTensor.data[1, 4], 1); + alloc.Dispose(); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Inference/EditModeTestInternalBrainTensorGenerator.cs.meta b/com.unity.ml-agents/Tests/Editor/Inference/EditModeTestInternalBrainTensorGenerator.cs.meta new file mode 100644 index 0000000000..bffd9e8144 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Inference/EditModeTestInternalBrainTensorGenerator.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: d2d2076c51c414ac7a91f8fbf15d4f7c +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs b/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs new file mode 100644 index 0000000000..da802a38d5 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs @@ -0,0 +1,236 @@ +using System; +using System.Linq; +using NUnit.Framework; +using UnityEngine; +using UnityEditor; +using Unity.Barracuda; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Inference; +using Unity.MLAgents.Policies; +using System.Collections.Generic; + +namespace Unity.MLAgents.Tests +{ + public class FloatThresholdComparer : IEqualityComparer + { + private readonly float _threshold; + public FloatThresholdComparer(float threshold) + { + _threshold = threshold; + } + + public bool Equals(float x, float y) + { + return Math.Abs(x - y) < _threshold; + } + + public int GetHashCode(float f) + { + throw new NotImplementedException("Unable to generate a hash code for threshold floats, do not use this method"); + } + } + + [TestFixture] + public class ModelRunnerTest + { + const string k_hybrid_ONNX_recurr_v2 = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis8vec_2c_2_3d_v2_0.onnx"; + + const string k_continuousONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_v1_0.onnx"; + const string k_discreteONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_obsolete_recurr_v1_0.onnx"; + const string k_hybridONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction_v1_0.onnx"; + const string k_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated_v1_0.nn"; + const string k_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated_v1_0.nn"; + // models with deterministic action tensors + private const string k_deterministic_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx"; + private const string k_deterministic_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx"; + + NNModel hybridONNXModelV2; + NNModel continuousONNXModel; + NNModel discreteONNXModel; + NNModel hybridONNXModel; + NNModel continuousNNModel; + NNModel discreteNNModel; + NNModel deterministicDiscreteNNModel; + NNModel deterministicContinuousNNModel; + Test3DSensorComponent sensor_21_20_3; + Test3DSensorComponent sensor_20_22_3; + + + ActionSpec GetContinuous2vis8vec2actionActionSpec() + { + return ActionSpec.MakeContinuous(2); + } + + ActionSpec GetDiscrete1vis0vec_2_3action_recurrModelActionSpec() + { + return ActionSpec.MakeDiscrete(2, 3); + } + + ActionSpec GetHybrid0vis53vec_3c_2dActionSpec() + { + return new ActionSpec(3, new[] { 2 }); + } + + [SetUp] + public void SetUp() + { + hybridONNXModelV2 = (NNModel)AssetDatabase.LoadAssetAtPath(k_hybrid_ONNX_recurr_v2, typeof(NNModel)); + + continuousONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousONNXPath, typeof(NNModel)); + discreteONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discreteONNXPath, typeof(NNModel)); + hybridONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_hybridONNXPath, typeof(NNModel)); + continuousNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousNNPath, typeof(NNModel)); + discreteNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discreteNNPath, typeof(NNModel)); + deterministicDiscreteNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_deterministic_discreteNNPath, typeof(NNModel)); + deterministicContinuousNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_deterministic_continuousNNPath, typeof(NNModel)); + var go = new GameObject("SensorA"); + sensor_21_20_3 = go.AddComponent(); + sensor_21_20_3.Sensor = new Test3DSensor("SensorA", 21, 20, 3); + sensor_20_22_3 = go.AddComponent(); + sensor_20_22_3.Sensor = new Test3DSensor("SensorB", 20, 22, 3); + } + + [Test] + public void TestModelExist() + { + Assert.IsNotNull(continuousONNXModel); + Assert.IsNotNull(discreteONNXModel); + Assert.IsNotNull(hybridONNXModel); + Assert.IsNotNull(continuousNNModel); + Assert.IsNotNull(discreteNNModel); + Assert.IsNotNull(hybridONNXModelV2); + Assert.IsNotNull(deterministicDiscreteNNModel); + Assert.IsNotNull(deterministicContinuousNNModel); + } + + [Test] + public void TestCreation() + { + var inferenceDevice = InferenceDevice.Burst; + var modelRunner = new ModelRunner(continuousONNXModel, GetContinuous2vis8vec2actionActionSpec(), inferenceDevice); + modelRunner.Dispose(); + Assert.Throws(() => + { + // Cannot load a model trained with 1.x that has an LSTM + modelRunner = new ModelRunner(discreteONNXModel, GetDiscrete1vis0vec_2_3action_recurrModelActionSpec(), inferenceDevice); + modelRunner.Dispose(); + }); + modelRunner = new ModelRunner(hybridONNXModel, GetHybrid0vis53vec_3c_2dActionSpec(), inferenceDevice); + modelRunner.Dispose(); + modelRunner = new ModelRunner(continuousNNModel, GetContinuous2vis8vec2actionActionSpec(), inferenceDevice); + modelRunner.Dispose(); + + Assert.Throws(() => + { + // Cannot load a model trained with 1.x that has an LSTM + modelRunner = new ModelRunner(discreteNNModel, GetDiscrete1vis0vec_2_3action_recurrModelActionSpec(), inferenceDevice); + modelRunner.Dispose(); + }); + // This one was trained with 2.0 so it should not raise an error: + modelRunner = new ModelRunner(hybridONNXModelV2, new ActionSpec(2, new[] { 2, 3 }), inferenceDevice); + modelRunner.Dispose(); + + // V2.0 Model that has serialized deterministic action tensors, discrete + modelRunner = new ModelRunner(deterministicDiscreteNNModel, new ActionSpec(0, new[] { 7 }), inferenceDevice); + modelRunner.Dispose(); + // V2.0 Model that has serialized deterministic action tensors, continuous + modelRunner = new ModelRunner(deterministicContinuousNNModel, + GetContinuous2vis8vec2actionActionSpec(), inferenceDevice, + deterministicInference: true); + modelRunner.Dispose(); + } + + [Test] + public void TestHasModel() + { + var modelRunner = new ModelRunner(continuousONNXModel, GetContinuous2vis8vec2actionActionSpec(), InferenceDevice.CPU); + Assert.True(modelRunner.HasModel(continuousONNXModel, InferenceDevice.CPU)); + Assert.False(modelRunner.HasModel(continuousONNXModel, InferenceDevice.GPU)); + Assert.False(modelRunner.HasModel(discreteONNXModel, InferenceDevice.CPU)); + modelRunner.Dispose(); + } + + [Test] + public void TestRunModel() + { + var actionSpec = GetContinuous2vis8vec2actionActionSpec(); + var modelRunner = new ModelRunner(continuousONNXModel, actionSpec, InferenceDevice.Burst); + var sensor_8 = new Sensors.VectorSensor(8, "VectorSensor8"); + var info1 = new AgentInfo(); + info1.episodeId = 1; + modelRunner.PutObservations(info1, new[] { + sensor_8, + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] }.ToList()); + var info2 = new AgentInfo(); + info2.episodeId = 2; + modelRunner.PutObservations(info2, new[] { + sensor_8, + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] }.ToList()); + + modelRunner.DecideBatch(); + + Assert.IsFalse(modelRunner.GetAction(1).Equals(ActionBuffers.Empty)); + Assert.IsFalse(modelRunner.GetAction(2).Equals(ActionBuffers.Empty)); + Assert.IsTrue(modelRunner.GetAction(3).Equals(ActionBuffers.Empty)); + Assert.AreEqual(actionSpec.NumDiscreteActions, modelRunner.GetAction(1).DiscreteActions.Length); + modelRunner.Dispose(); + } + + + [Test] + public void TestRunModel_stochastic() + { + var actionSpec = GetContinuous2vis8vec2actionActionSpec(); + // deterministicInference = false by default + var modelRunner = new ModelRunner(deterministicContinuousNNModel, actionSpec, InferenceDevice.Burst); + var sensor_8 = new Sensors.VectorSensor(8, "VectorSensor8"); + var info1 = new AgentInfo(); + var obs = new[] + { + sensor_8, + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }.ToList(); + info1.episodeId = 1; + modelRunner.PutObservations(info1, obs); + modelRunner.DecideBatch(); + var stochAction1 = (float[])modelRunner.GetAction(1).ContinuousActions.Array.Clone(); + + modelRunner.PutObservations(info1, obs); + modelRunner.DecideBatch(); + var stochAction2 = (float[])modelRunner.GetAction(1).ContinuousActions.Array.Clone(); + // Stochastic action selection should output randomly different action values with same obs + Assert.IsFalse(Enumerable.SequenceEqual(stochAction1, stochAction2, new FloatThresholdComparer(0.001f))); + modelRunner.Dispose(); + } + [Test] + public void TestRunModel_deterministic() + { + var actionSpec = GetContinuous2vis8vec2actionActionSpec(); + var modelRunner = new ModelRunner(deterministicContinuousNNModel, actionSpec, InferenceDevice.Burst); + var sensor_8 = new Sensors.VectorSensor(8, "VectorSensor8"); + var info1 = new AgentInfo(); + var obs = new[] + { + sensor_8, + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }.ToList(); + var deterministicModelRunner = new ModelRunner(deterministicContinuousNNModel, actionSpec, InferenceDevice.Burst, + deterministicInference: true); + info1.episodeId = 1; + deterministicModelRunner.PutObservations(info1, obs); + deterministicModelRunner.DecideBatch(); + var deterministicAction1 = (float[])deterministicModelRunner.GetAction(1).ContinuousActions.Array.Clone(); + + deterministicModelRunner.PutObservations(info1, obs); + deterministicModelRunner.DecideBatch(); + var deterministicAction2 = (float[])deterministicModelRunner.GetAction(1).ContinuousActions.Array.Clone(); + // Deterministic action selection should output same action everytime + Assert.IsTrue(Enumerable.SequenceEqual(deterministicAction1, deterministicAction2, new FloatThresholdComparer(0.001f))); + modelRunner.Dispose(); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs.meta b/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs.meta new file mode 100644 index 0000000000..273d3bb579 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: a2b924bac7d86467183ad9dc1436e550 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs b/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs new file mode 100644 index 0000000000..ae9d52d74b --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs @@ -0,0 +1,558 @@ +using System.Linq; +using NUnit.Framework; +using UnityEngine; +using UnityEditor; +using Unity.Barracuda; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Inference; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Policies; + +namespace Unity.MLAgents.Tests +{ + public class Test3DSensorComponent : SensorComponent + { + public ISensor Sensor; + + public override ISensor[] CreateSensors() + { + return new ISensor[] { Sensor }; + } + } + + public class Test3DSensor : ISensor, IBuiltInSensor + { + int m_Width; + int m_Height; + int m_Channels; + string m_Name; + // Dummy value for the IBuiltInSensor interface + public const int k_BuiltInSensorType = -42; + + public Test3DSensor(string name, int width, int height, int channels) + { + m_Width = width; + m_Height = height; + m_Channels = channels; + m_Name = name; + } + + public ObservationSpec GetObservationSpec() + { + return ObservationSpec.Visual(m_Height, m_Width, m_Channels); + } + + public int Write(ObservationWriter writer) + { + for (int i = 0; i < m_Width * m_Height * m_Channels; i++) + { + writer[i] = 0.0f; + } + return m_Width * m_Height * m_Channels; + } + + public byte[] GetCompressedObservation() + { + return new byte[0]; + } + + public void Update() { } + public void Reset() { } + + public CompressionSpec GetCompressionSpec() + { + return CompressionSpec.Default(); + } + + public string GetName() + { + return m_Name; + } + + public BuiltInSensorType GetBuiltInSensorType() + { + return (BuiltInSensorType)k_BuiltInSensorType; + } + } + + [TestFixture] + public class ParameterLoaderTest + { + const string k_discrete_ONNX_v2 = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete_rank2_vector_v2_0.onnx"; + const string k_hybrid_ONNX_recurr_v2 = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis8vec_2c_2_3d_v2_0.onnx"; + + + // ONNX model with continuous/discrete action output (support hybrid action) + const string k_continuousONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_v1_0.onnx"; + const string k_discreteONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_obsolete_recurr_v1_0.onnx"; + const string k_hybridONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction_v1_0.onnx"; + // NN model with single action output (deprecated, does not support hybrid action). + // Same BrainParameters settings as the corresponding ONNX model. + const string k_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated_v1_0.nn"; + const string k_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated_v1_0.nn"; + + NNModel rank2ONNXModel; + NNModel hybridRecurrV2Model; + NNModel continuousONNXModel; + NNModel discreteONNXModel; + NNModel hybridONNXModel; + NNModel continuousNNModel; + NNModel discreteNNModel; + Test3DSensorComponent sensor_21_20_3; + Test3DSensorComponent sensor_20_22_3; + BufferSensor sensor_23_20; + VectorSensor sensor_8; + VectorSensor sensor_10; + + BrainParameters GetContinuous2vis8vec2actionBrainParameters() + { + var validBrainParameters = new BrainParameters(); + validBrainParameters.VectorObservationSize = 8; + validBrainParameters.NumStackedVectorObservations = 1; + validBrainParameters.ActionSpec = ActionSpec.MakeContinuous(2); + return validBrainParameters; + } + + BrainParameters GetDiscrete1vis0vec_2_3action_recurrModelBrainParameters() + { + var validBrainParameters = new BrainParameters(); + validBrainParameters.VectorObservationSize = 0; + validBrainParameters.NumStackedVectorObservations = 1; + validBrainParameters.ActionSpec = ActionSpec.MakeDiscrete(2, 3); + return validBrainParameters; + } + + BrainParameters GetHybridBrainParameters() + { + var validBrainParameters = new BrainParameters(); + validBrainParameters.VectorObservationSize = 53; + validBrainParameters.NumStackedVectorObservations = 1; + validBrainParameters.ActionSpec = new ActionSpec(3, new[] { 2 }); + return validBrainParameters; + } + + BrainParameters GetRank2BrainParameters() + { + var validBrainParameters = new BrainParameters(); + validBrainParameters.VectorObservationSize = 4; + validBrainParameters.NumStackedVectorObservations = 2; + validBrainParameters.ActionSpec = ActionSpec.MakeDiscrete(3, 3, 3); + return validBrainParameters; + } + + BrainParameters GetRecurrHybridBrainParameters() + { + var validBrainParameters = new BrainParameters(); + validBrainParameters.VectorObservationSize = 8; + validBrainParameters.NumStackedVectorObservations = 1; + validBrainParameters.ActionSpec = new ActionSpec(2, new int[] { 2, 3 }); + return validBrainParameters; + } + + [SetUp] + public void SetUp() + { + continuousONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousONNXPath, typeof(NNModel)); + discreteONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discreteONNXPath, typeof(NNModel)); + hybridONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_hybridONNXPath, typeof(NNModel)); + continuousNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousNNPath, typeof(NNModel)); + discreteNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discreteNNPath, typeof(NNModel)); + rank2ONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discrete_ONNX_v2, typeof(NNModel)); + hybridRecurrV2Model = (NNModel)AssetDatabase.LoadAssetAtPath(k_hybrid_ONNX_recurr_v2, typeof(NNModel)); + var go = new GameObject("SensorA"); + sensor_21_20_3 = go.AddComponent(); + sensor_21_20_3.Sensor = new Test3DSensor("SensorA", 21, 20, 3); + sensor_20_22_3 = go.AddComponent(); + sensor_20_22_3.Sensor = new Test3DSensor("SensorA", 20, 22, 3); + sensor_23_20 = new BufferSensor(20, 23, "BufferSensor"); + sensor_8 = new VectorSensor(8, "VectorSensor8"); + sensor_10 = new VectorSensor(10, "VectorSensor10"); + } + + [Test] + public void TestModelExist() + { + Assert.IsNotNull(continuousONNXModel); + Assert.IsNotNull(discreteONNXModel); + Assert.IsNotNull(hybridONNXModel); + Assert.IsNotNull(continuousNNModel); + Assert.IsNotNull(discreteNNModel); + Assert.IsNotNull(rank2ONNXModel); + Assert.IsNotNull(hybridRecurrV2Model); + } + + [TestCase(true)] + [TestCase(false)] + public void TestGetInputTensorsContinuous(bool useDeprecatedNNModel) + { + var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel); + var inputNames = model.GetInputNames(); + // Model should contain 3 inputs : vector, visual 1 and visual 2 + Assert.AreEqual(3, inputNames.Count()); + Assert.Contains(TensorNames.VectorObservationPlaceholder, inputNames); + Assert.Contains(TensorNames.VisualObservationPlaceholderPrefix + "0", inputNames); + Assert.Contains(TensorNames.VisualObservationPlaceholderPrefix + "1", inputNames); + + Assert.AreEqual(2, model.GetNumVisualInputs()); + + // Test if the model is null + model = null; + Assert.AreEqual(0, model.GetInputTensors().Count); + Assert.AreEqual(0, model.GetNumVisualInputs()); + } + + [TestCase(true)] + [TestCase(false)] + public void TestGetInputTensorsDiscrete(bool useDeprecatedNNModel) + { + var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel); + var inputNames = model.GetInputNames(); + // Model should contain 2 inputs : recurrent and visual 1 + + Assert.Contains(TensorNames.VisualObservationPlaceholderPrefix + "0", inputNames); + // TODO :There are some memory tensors as well + } + + [Test] + public void TestGetInputTensorsHybrid() + { + var model = ModelLoader.Load(hybridONNXModel); + var inputNames = model.GetInputNames(); + Assert.Contains(TensorNames.VectorObservationPlaceholder, inputNames); + } + + [TestCase(true)] + [TestCase(false)] + public void TestGetOutputTensorsContinuous(bool useDeprecatedNNModel) + { + var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel); + var outputNames = model.GetOutputNames(); + var actionOutputName = useDeprecatedNNModel ? TensorNames.ActionOutputDeprecated : TensorNames.ContinuousActionOutput; + Assert.Contains(actionOutputName, outputNames); + Assert.AreEqual(1, outputNames.Count()); + + model = null; + Assert.AreEqual(0, model.GetOutputNames().Count()); + } + + [TestCase(true)] + [TestCase(false)] + public void TestGetOutputTensorsDiscrete(bool useDeprecatedNNModel) + { + var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel); + var outputNames = model.GetOutputNames(); + var actionOutputName = useDeprecatedNNModel ? TensorNames.ActionOutputDeprecated : TensorNames.DiscreteActionOutput; + Assert.Contains(actionOutputName, outputNames); + // TODO : There are some memory tensors as well + } + + [Test] + public void TestGetOutputTensorsHybrid() + { + var model = ModelLoader.Load(hybridONNXModel); + var outputNames = model.GetOutputNames(); + + Assert.AreEqual(2, outputNames.Count()); + Assert.Contains(TensorNames.ContinuousActionOutput, outputNames); + Assert.Contains(TensorNames.DiscreteActionOutput, outputNames); + + model = null; + Assert.AreEqual(0, model.GetOutputNames().Count()); + } + + [Test] + public void TestCheckModelRank2() + { + var model = ModelLoader.Load(rank2ONNXModel); + var validBrainParameters = GetRank2BrainParameters(); + + var errors = BarracudaModelParamLoader.CheckModel( + model, validBrainParameters, + new ISensor[] { sensor_23_20, sensor_10, sensor_8 }, new ActuatorComponent[0] + ); + Assert.AreEqual(0, errors.Count()); // There should not be any errors + + errors = BarracudaModelParamLoader.CheckModel( + model, validBrainParameters, + new ISensor[] { sensor_23_20, sensor_10 }, new ActuatorComponent[0] + ); + Assert.AreNotEqual(0, errors.Count()); // Wrong number of sensors + + errors = BarracudaModelParamLoader.CheckModel( + model, validBrainParameters, + new ISensor[] { new BufferSensor(20, 40, "BufferSensor"), sensor_10, sensor_8 }, new ActuatorComponent[0] + ); + Assert.AreNotEqual(0, errors.Count()); // Wrong buffer sensor size + + errors = BarracudaModelParamLoader.CheckModel( + model, validBrainParameters, + new ISensor[] { sensor_23_20, sensor_10, sensor_10 }, new ActuatorComponent[0] + ); + Assert.AreNotEqual(0, errors.Count()); // Wrong vector sensor size + } + + [TestCase(true)] + [TestCase(false)] + public void TestCheckModelValidContinuous(bool useDeprecatedNNModel) + { + var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel); + var validBrainParameters = GetContinuous2vis8vec2actionBrainParameters(); + + var errors = BarracudaModelParamLoader.CheckModel( + model, validBrainParameters, + new ISensor[] + { + new VectorSensor(8), + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }, + new ActuatorComponent[0] + ); + Assert.AreEqual(0, errors.Count()); // There should not be any errors + } + + [TestCase(true)] + [TestCase(false)] + public void TestCheckModelValidDiscrete(bool useDeprecatedNNModel) + { + var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel); + var validBrainParameters = GetDiscrete1vis0vec_2_3action_recurrModelBrainParameters(); + + var errors = BarracudaModelParamLoader.CheckModel( + model, validBrainParameters, + new ISensor[] { sensor_21_20_3.CreateSensors()[0] }, new ActuatorComponent[0] + ); + foreach (var e in errors) + { + Debug.Log(e.Message); + } + Assert.Greater(errors.Count(), 0); // There should be an error since LSTM v1.x is not supported + } + + [Test] + public void TestCheckModelValidRecurrent() + { + var model = ModelLoader.Load(hybridRecurrV2Model); + var num_errors = 0; // A model trained with v2 should not raise errors + var validBrainParameters = GetRecurrHybridBrainParameters(); + + var errors = BarracudaModelParamLoader.CheckModel( + model, validBrainParameters, + new ISensor[] { sensor_8 }, new ActuatorComponent[0] + ); + Assert.AreEqual(num_errors, errors.Count()); // There should not be any errors + + var invalidBrainParameters = GetRecurrHybridBrainParameters(); + invalidBrainParameters.ActionSpec = new ActionSpec(1, new int[] { 2, 3 }); + errors = BarracudaModelParamLoader.CheckModel( + model, invalidBrainParameters, + new ISensor[] { sensor_8 }, new ActuatorComponent[0] + ); + Assert.AreEqual(1, errors.Count()); // 1 continuous action instead of 2 + + invalidBrainParameters.ActionSpec = new ActionSpec(2, new int[] { 3, 2 }); + errors = BarracudaModelParamLoader.CheckModel( + model, invalidBrainParameters, + new ISensor[] { sensor_8 }, new ActuatorComponent[0] + ); + Assert.AreEqual(1, errors.Count()); // Discrete action branches flipped + } + + [Test] + public void TestCheckModelValidHybrid() + { + var model = ModelLoader.Load(hybridONNXModel); + var validBrainParameters = GetHybridBrainParameters(); + + var errors = BarracudaModelParamLoader.CheckModel( + model, validBrainParameters, + new ISensor[] + { + new VectorSensor(validBrainParameters.VectorObservationSize) + }, new ActuatorComponent[0] + ); + Assert.AreEqual(0, errors.Count()); // There should not be any errors + } + + [TestCase(true)] + [TestCase(false)] + public void TestCheckModelThrowsVectorObservationContinuous(bool useDeprecatedNNModel) + { + var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel); + + var brainParameters = GetContinuous2vis8vec2actionBrainParameters(); + brainParameters.VectorObservationSize = 9; // Invalid observation + var errors = BarracudaModelParamLoader.CheckModel( + model, brainParameters, + new ISensor[] + { + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }, + new ActuatorComponent[0] + ); + Assert.Greater(errors.Count(), 0); + + brainParameters = GetContinuous2vis8vec2actionBrainParameters(); + brainParameters.NumStackedVectorObservations = 2;// Invalid stacking + errors = BarracudaModelParamLoader.CheckModel( + model, brainParameters, + new ISensor[] + { + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }, + new ActuatorComponent[0] + ); + Assert.Greater(errors.Count(), 0); + } + + [TestCase(true)] + [TestCase(false)] + public void TestCheckModelThrowsVectorObservationDiscrete(bool useDeprecatedNNModel) + { + var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel); + + var brainParameters = GetDiscrete1vis0vec_2_3action_recurrModelBrainParameters(); + brainParameters.VectorObservationSize = 1; // Invalid observation + var errors = BarracudaModelParamLoader.CheckModel( + model, brainParameters, new ISensor[] + { + sensor_21_20_3.CreateSensors()[0] + }, + new ActuatorComponent[0] + ); + Assert.Greater(errors.Count(), 0); + } + + [Test] + public void TestCheckModelThrowsVectorObservationHybrid() + { + var model = ModelLoader.Load(hybridONNXModel); + + var brainParameters = GetHybridBrainParameters(); + brainParameters.VectorObservationSize = 9; // Invalid observation + var errors = BarracudaModelParamLoader.CheckModel( + model, brainParameters, + new ISensor[] { }, new ActuatorComponent[0] + ); + Assert.Greater(errors.Count(), 0); + + brainParameters = GetContinuous2vis8vec2actionBrainParameters(); + brainParameters.NumStackedVectorObservations = 2;// Invalid stacking + errors = BarracudaModelParamLoader.CheckModel( + model, brainParameters, + new ISensor[] { }, new ActuatorComponent[0] + ); + Assert.Greater(errors.Count(), 0); + } + + [TestCase(true)] + [TestCase(false)] + public void TestCheckModelThrowsActionContinuous(bool useDeprecatedNNModel) + { + var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel); + + var brainParameters = GetContinuous2vis8vec2actionBrainParameters(); + brainParameters.ActionSpec = ActionSpec.MakeContinuous(3); // Invalid action + var errors = BarracudaModelParamLoader.CheckModel( + model, brainParameters, new ISensor[] + { + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }, + new ActuatorComponent[0] + ); + Assert.Greater(errors.Count(), 0); + + brainParameters = GetContinuous2vis8vec2actionBrainParameters(); + brainParameters.ActionSpec = ActionSpec.MakeDiscrete(3); // Invalid SpaceType + errors = BarracudaModelParamLoader.CheckModel( + model, brainParameters, new ISensor[] + { + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }, + new ActuatorComponent[0] + ); + Assert.Greater(errors.Count(), 0); + } + + [TestCase(true)] + [TestCase(false)] + public void TestCheckModelThrowsActionDiscrete(bool useDeprecatedNNModel) + { + var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel); + + var brainParameters = GetDiscrete1vis0vec_2_3action_recurrModelBrainParameters(); + brainParameters.ActionSpec = ActionSpec.MakeDiscrete(3, 3); // Invalid action + var errors = BarracudaModelParamLoader.CheckModel( + model, brainParameters, + new ISensor[] { sensor_21_20_3.CreateSensors()[0] }, + new ActuatorComponent[0] + ); + Assert.Greater(errors.Count(), 0); + + brainParameters = GetContinuous2vis8vec2actionBrainParameters(); + brainParameters.ActionSpec = ActionSpec.MakeContinuous(2); // Invalid SpaceType + errors = BarracudaModelParamLoader.CheckModel( + model, + brainParameters, + new ISensor[] { sensor_21_20_3.CreateSensors()[0] }, + new ActuatorComponent[0] + ); + Assert.Greater(errors.Count(), 0); + } + + [Test] + public void TestCheckModelThrowsActionHybrid() + { + var model = ModelLoader.Load(hybridONNXModel); + + var brainParameters = GetHybridBrainParameters(); + brainParameters.ActionSpec = new ActionSpec(3, new[] { 3 }); // Invalid discrete action size + var errors = BarracudaModelParamLoader.CheckModel( + model, + brainParameters, + new ISensor[] + { + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }, + new ActuatorComponent[0] + ); + Assert.Greater(errors.Count(), 0); + + brainParameters = GetContinuous2vis8vec2actionBrainParameters(); + brainParameters.ActionSpec = ActionSpec.MakeDiscrete(2); // Missing continuous action + errors = BarracudaModelParamLoader.CheckModel( + model, + brainParameters, + new ISensor[] + { + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }, + new ActuatorComponent[0] + ); + Assert.Greater(errors.Count(), 0); + } + + [Test] + public void TestCheckModelThrowsNoModel() + { + var brainParameters = GetContinuous2vis8vec2actionBrainParameters(); + var errors = BarracudaModelParamLoader.CheckModel( + null, + brainParameters, + new ISensor[] + { + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }, + new ActuatorComponent[0] + ); + Assert.Greater(errors.Count(), 0); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs.meta b/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs.meta new file mode 100644 index 0000000000..25bff938ef --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: edd38d6ad78c8456d80f0a90bcb2e1b7 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Inference/TensorUtilsTest.cs b/com.unity.ml-agents/Tests/Editor/Inference/TensorUtilsTest.cs new file mode 100644 index 0000000000..5673b7ddef --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Inference/TensorUtilsTest.cs @@ -0,0 +1,135 @@ +using System; +using NUnit.Framework; +using Unity.Barracuda; +using Unity.MLAgents.Inference; +using Unity.MLAgents.Inference.Utils; + +namespace Unity.MLAgents.Tests +{ + public class TensorUtilsTest + { + [TestCase(4, TestName = "TestResizeTensor_4D")] + [TestCase(8, TestName = "TestResizeTensor_8D")] + public void TestResizeTensor(int dimension) + { + var alloc = new TensorCachingAllocator(); + var height = 64; + var width = 84; + var channels = 3; + + // Set shape to {1, ..., height, width, channels} + // For 8D, the ... are all 1's + var shape = new long[dimension]; + for (var i = 0; i < dimension; i++) + { + shape[i] = 1; + } + + shape[dimension - 3] = height; + shape[dimension - 2] = width; + shape[dimension - 1] = channels; + + var intShape = new int[dimension]; + for (var i = 0; i < dimension; i++) + { + intShape[i] = (int)shape[i]; + } + + var tensorProxy = new TensorProxy + { + valueType = TensorProxy.TensorType.Integer, + data = new Tensor(intShape), + shape = shape, + }; + + // These should be invariant after the resize. + Assert.AreEqual(height, tensorProxy.data.shape.height); + Assert.AreEqual(width, tensorProxy.data.shape.width); + Assert.AreEqual(channels, tensorProxy.data.shape.channels); + + TensorUtils.ResizeTensor(tensorProxy, 42, alloc); + + Assert.AreEqual(height, tensorProxy.shape[dimension - 3]); + Assert.AreEqual(width, tensorProxy.shape[dimension - 2]); + Assert.AreEqual(channels, tensorProxy.shape[dimension - 1]); + + Assert.AreEqual(height, tensorProxy.data.shape.height); + Assert.AreEqual(width, tensorProxy.data.shape.width); + Assert.AreEqual(channels, tensorProxy.data.shape.channels); + + alloc.Dispose(); + } + + [Test] + public void RandomNormalTestTensorInt() + { + var rn = new RandomNormal(1982); + var t = new TensorProxy + { + valueType = TensorProxy.TensorType.Integer + }; + + Assert.Throws( + () => TensorUtils.FillTensorWithRandomNormal(t, rn)); + } + + [Test] + public void RandomNormalTestDataNull() + { + var rn = new RandomNormal(1982); + var t = new TensorProxy + { + valueType = TensorProxy.TensorType.FloatingPoint + }; + + Assert.Throws( + () => TensorUtils.FillTensorWithRandomNormal(t, rn)); + } + + [Test] + public void RandomNormalTestTensor() + { + var rn = new RandomNormal(1982); + var t = new TensorProxy + { + valueType = TensorProxy.TensorType.FloatingPoint, + data = new Tensor(1, 3, 4, 2) + }; + + TensorUtils.FillTensorWithRandomNormal(t, rn); + + var reference = new[] + { + -0.4315872f, + -1.11074f, + 0.3414804f, + -1.130287f, + 0.1413168f, + -0.5105762f, + -0.3027347f, + -0.2645015f, + 1.225356f, + -0.02921959f, + 0.3716498f, + -1.092338f, + 0.9561074f, + -0.5018106f, + 1.167787f, + -0.7763879f, + -0.07491868f, + 0.5396146f, + -0.1377991f, + 0.3331701f, + 0.06144788f, + 0.9520947f, + 1.088157f, + -1.177194f, + }; + + for (var i = 0; i < t.data.length; i++) + { + Assert.AreEqual(t.data[i], reference[i], 0.0001); + } + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Inference/TensorUtilsTest.cs.meta b/com.unity.ml-agents/Tests/Editor/Inference/TensorUtilsTest.cs.meta new file mode 100644 index 0000000000..4141d49504 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Inference/TensorUtilsTest.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 0a700a7c6187a433ca44d60d243bb0cd +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs b/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs new file mode 100644 index 0000000000..2e8f21cf28 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs @@ -0,0 +1,190 @@ +using System; +using System.Collections; +using NUnit.Framework; + + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class InplaceArrayTests + { + class LengthCases : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return 1; + yield return 2; + yield return 3; + yield return 4; + } + } + + private InplaceArray GetTestArray(int length) + { + switch (length) + { + case 1: + return new InplaceArray(11); + case 2: + return new InplaceArray(11, 22); + case 3: + return new InplaceArray(11, 22, 33); + case 4: + return new InplaceArray(11, 22, 33, 44); + default: + throw new ArgumentException("bad test!"); + } + } + + private InplaceArray GetZeroArray(int length) + { + switch (length) + { + case 1: + return new InplaceArray(0); + case 2: + return new InplaceArray(0, 0); + case 3: + return new InplaceArray(0, 0, 0); + case 4: + return new InplaceArray(0, 0, 0, 0); + default: + throw new ArgumentException("bad test!"); + } + } + + [Test] + public void TestInplaceArrayCtor() + { + var a1 = new InplaceArray(11); + Assert.AreEqual(1, a1.Length); + Assert.AreEqual(11, a1[0]); + + var a2 = new InplaceArray(11, 22); + Assert.AreEqual(2, a2.Length); + Assert.AreEqual(11, a2[0]); + Assert.AreEqual(22, a2[1]); + + var a3 = new InplaceArray(11, 22, 33); + Assert.AreEqual(3, a3.Length); + Assert.AreEqual(11, a3[0]); + Assert.AreEqual(22, a3[1]); + Assert.AreEqual(33, a3[2]); + + var a4 = new InplaceArray(11, 22, 33, 44); + Assert.AreEqual(4, a4.Length); + Assert.AreEqual(11, a4[0]); + Assert.AreEqual(22, a4[1]); + Assert.AreEqual(33, a4[2]); + Assert.AreEqual(44, a4[3]); + } + + [TestCaseSource(typeof(LengthCases))] + public void TestInplaceGetSet(int length) + { + var original = GetTestArray(length); + + for (var i = 0; i < original.Length; i++) + { + var modified = original; + modified[i] = 0; + for (var j = 0; j < original.Length; j++) + { + if (i == j) + { + // This is the one we overwrote + Assert.AreEqual(0, modified[j]); + } + else + { + // Other elements should be unchanged + Assert.AreEqual(original[j], modified[j]); + } + } + } + } + + [TestCaseSource(typeof(LengthCases))] + public void TestInvalidAccess(int length) + { + var tmp = 0; + var a = GetTestArray(length); + // get + Assert.Throws(() => { tmp += a[-1]; }); + Assert.Throws(() => { tmp += a[length]; }); + + // set + Assert.Throws(() => { a[-1] = 0; }); + Assert.Throws(() => { a[length] = 0; }); + + // Make sure temp is used + Assert.AreEqual(0, tmp); + } + + [Test] + public void TestOperatorEqualsDifferentLengths() + { + // Check arrays of different length are never equal (even if they have 0s in all elements) + for (var l1 = 1; l1 <= 4; l1++) + { + var a1 = GetZeroArray(l1); + for (var l2 = 1; l2 <= 4; l2++) + { + var a2 = GetZeroArray(l2); + if (l1 == l2) + { + Assert.AreEqual(a1, a2); + Assert.IsTrue(a1 == a2); + } + else + { + Assert.AreNotEqual(a1, a2); + Assert.IsTrue(a1 != a2); + } + } + } + } + + [TestCaseSource(typeof(LengthCases))] + public void TestOperatorEquals(int length) + { + for (var index = 0; index < length; index++) + { + var a1 = GetTestArray(length); + var a2 = GetTestArray(length); + Assert.AreEqual(a1, a2); + Assert.IsTrue(a1 == a2); + + a1[index] = 42; + Assert.AreNotEqual(a1, a2); + Assert.IsTrue(a1 != a2); + + a2[index] = 42; + Assert.AreEqual(a1, a2); + Assert.IsTrue(a1 == a2); + } + } + + [Test] + public void TestToString() + { + Assert.AreEqual("[1]", new InplaceArray(1).ToString()); + Assert.AreEqual("[1, 2]", new InplaceArray(1, 2).ToString()); + Assert.AreEqual("[1, 2, 3]", new InplaceArray(1, 2, 3).ToString()); + Assert.AreEqual("[1, 2, 3, 4]", new InplaceArray(1, 2, 3, 4).ToString()); + } + + [TestCaseSource(typeof(LengthCases))] + public void TestFromList(int length) + { + var intArray = new int[length]; + for (var i = 0; i < length; i++) + { + intArray[i] = (i + 1) * 11; // 11, 22, etc. + } + + var converted = InplaceArray.FromList(intArray); + Assert.AreEqual(GetTestArray(length), converted); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta b/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta new file mode 100644 index 0000000000..227738d65f --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 8e1cdc27e533749fabc04b3cdeb93501 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Integrations.meta b/com.unity.ml-agents/Tests/Editor/Integrations.meta new file mode 100644 index 0000000000..395f71ca51 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 7e89e6f6ab7e4c9397958c0320bd5931 +timeCreated: 1618359633 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3.meta b/com.unity.ml-agents/Tests/Editor/Integrations/Match3.meta new file mode 100644 index 0000000000..f710b9aa06 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 77b0212dde404f7c8ce9aac13bd550b8 +timeCreated: 1601332716 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/AbstractBoardTests.cs b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/AbstractBoardTests.cs new file mode 100644 index 0000000000..b8e9337c1e --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/AbstractBoardTests.cs @@ -0,0 +1,217 @@ +using System; +using System.Collections.Generic; +using UnityEngine; +using NUnit.Framework; +using Unity.MLAgents.Integrations.Match3; + +namespace Unity.MLAgents.Tests.Integrations.Match3 +{ + internal class StringBoard : AbstractBoard + { + internal int MaxRows; + internal int MaxColumns; + internal int NumCellTypes; + internal int NumSpecialTypes; + public int CurrentRows; + public int CurrentColumns; + + public override BoardSize GetMaxBoardSize() + { + return new BoardSize + { + Rows = MaxRows, + Columns = MaxColumns, + NumCellTypes = NumCellTypes, + NumSpecialTypes = NumSpecialTypes + }; + } + + public override BoardSize GetCurrentBoardSize() + { + return new BoardSize + { + Rows = CurrentRows, + Columns = CurrentColumns, + NumCellTypes = NumCellTypes, + NumSpecialTypes = NumSpecialTypes + }; + } + + private string[] m_Board; + private string[] m_Special; + + /// + /// Convert a string like "000\n010\n000" to a board representation + /// Row 0 is considered the bottom row + /// + /// + public void SetBoard(string newBoard) + { + m_Board = newBoard.Split((char[])null, StringSplitOptions.RemoveEmptyEntries); + MaxRows = m_Board.Length; + MaxColumns = m_Board[0].Length; + CurrentRows = MaxRows; + CurrentColumns = MaxColumns; + NumCellTypes = 0; + for (var r = 0; r < MaxRows; r++) + { + for (var c = 0; c < MaxColumns; c++) + { + NumCellTypes = Mathf.Max(NumCellTypes, 1 + GetCellType(r, c)); + } + } + } + + public void SetSpecial(string newSpecial) + { + m_Special = newSpecial.Split((char[])null, StringSplitOptions.RemoveEmptyEntries); + Debug.Assert(MaxRows == m_Special.Length); + Debug.Assert(MaxColumns == m_Special[0].Length); + NumSpecialTypes = 0; + for (var r = 0; r < MaxRows; r++) + { + for (var c = 0; c < MaxColumns; c++) + { + NumSpecialTypes = Mathf.Max(NumSpecialTypes, GetSpecialType(r, c)); + } + } + + } + + public override bool MakeMove(Move m) + { + return true; + } + + public override bool IsMoveValid(Move m) + { + return SimpleIsMoveValid(m); + } + + public override int GetCellType(int row, int col) + { + if (row >= CurrentRows || col >= CurrentColumns) + { + throw new IndexOutOfRangeException("Tried to get celltype out of bounds"); + } + + var character = m_Board[m_Board.Length - 1 - row][col]; + return (character - '0'); + } + + public override int GetSpecialType(int row, int col) + { + if (row >= CurrentRows || col >= CurrentColumns) + { + throw new IndexOutOfRangeException("Tried to get specialtype out of bounds"); + } + + var character = m_Special[m_Board.Length - 1 - row][col]; + return (character - '0'); + } + + } + + public class AbstractBoardTests + { + [Test] + public void TestBoardInit() + { + var boardString = + @"000 + 000 + 010"; + var gameObj = new GameObject("board"); + var board = gameObj.AddComponent(); + board.SetBoard(boardString); + + var boardSize = board.GetMaxBoardSize(); + + Assert.AreEqual(3, boardSize.Rows); + Assert.AreEqual(3, boardSize.Columns); + Assert.AreEqual(2, boardSize.NumCellTypes); + for (var r = 0; r < 3; r++) + { + for (var c = 0; c < 3; c++) + { + var expected = (r == 0 && c == 1) ? 1 : 0; + Assert.AreEqual(expected, board.GetCellType(r, c)); + } + } + } + + internal static List GetValidMoves4x4(bool fullBoard, BoardSize boardSize) + { + var validMoves = new List + { + Move.FromPositionAndDirection(2, 1, Direction.Down, boardSize), // equivalent to (1, 1, Up) + Move.FromPositionAndDirection(1, 1, Direction.Down, boardSize), + Move.FromPositionAndDirection(1, 1, Direction.Left, boardSize), + Move.FromPositionAndDirection(1, 1, Direction.Right, boardSize), + Move.FromPositionAndDirection(0, 1, Direction.Left, boardSize), + }; + + if (fullBoard) + { + // This would move out of range on the small board + // Equivalent to (3, 1, Down) + validMoves.Add(Move.FromPositionAndDirection(2, 1, Direction.Up, boardSize)); + + // These moves require matching with a cell that's off the small board, so they're invalid + // (even though the move itself doesn't go out of range). + validMoves.Add(Move.FromPositionAndDirection(2, 1, Direction.Left, boardSize)); // Equivalent to (2, 0, Right) + validMoves.Add(Move.FromPositionAndDirection(2, 1, Direction.Right, boardSize)); + } + + return validMoves; + } + + [TestCase(true, TestName = "Full Board")] + [TestCase(false, TestName = "Small Board")] + public void TestCheckValidMoves(bool fullBoard) + { + var gameObj = new GameObject("board"); + var board = gameObj.AddComponent(); + + var boardString = + @"0105 + 1024 + 0203 + 2022"; + board.SetBoard(boardString); + var boardSize = board.GetMaxBoardSize(); + if (!fullBoard) + { + board.CurrentRows -= 1; + } + + var validMoves = GetValidMoves4x4(fullBoard, boardSize); + + foreach (var m in validMoves) + { + Assert.IsTrue(board.IsMoveValid(m)); + } + + // Run through all moves and make sure those are the only valid ones + HashSet validIndices = new HashSet(); + foreach (var m in validMoves) + { + validIndices.Add(m.MoveIndex); + } + + // Make sure iterating over AllMoves is OK with the smaller board + foreach (var move in board.AllMoves()) + { + var expected = validIndices.Contains(move.MoveIndex); + Assert.AreEqual(expected, board.IsMoveValid(move), $"({move.Row}, {move.Column}, {move.Direction})"); + } + + HashSet validIndicesFromIterator = new HashSet(); + foreach (var move in board.ValidMoves()) + { + validIndicesFromIterator.Add(move.MoveIndex); + } + Assert.IsTrue(validIndices.SetEquals(validIndicesFromIterator)); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/AbstractBoardTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/AbstractBoardTests.cs.meta new file mode 100644 index 0000000000..79da98cb7a --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/AbstractBoardTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: a6d0404471364cd5b0b86ef72e6fe653 +timeCreated: 1601332740 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/Match3ActuatorTests.cs b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/Match3ActuatorTests.cs new file mode 100644 index 0000000000..2beea36b32 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/Match3ActuatorTests.cs @@ -0,0 +1,207 @@ +using System.Collections.Generic; +using NUnit.Framework; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Integrations.Match3; +using UnityEngine; + +namespace Unity.MLAgents.Tests.Integrations.Match3 +{ + internal class SimpleBoard : AbstractBoard + { + public int Rows; + public int Columns; + public int NumCellTypes; + public int NumSpecialTypes; + + public int LastMoveIndex; + public bool MovesAreValid = true; + + public bool CallbackCalled; + + public override BoardSize GetMaxBoardSize() + { + return new BoardSize + { + Rows = Rows, + Columns = Columns, + NumCellTypes = NumCellTypes, + NumSpecialTypes = NumSpecialTypes + }; + } + + public override int GetCellType(int row, int col) + { + return 0; + } + + public override int GetSpecialType(int row, int col) + { + return 0; + } + + public override bool IsMoveValid(Move m) + { + return MovesAreValid; + } + + public override bool MakeMove(Move m) + { + LastMoveIndex = m.MoveIndex; + return MovesAreValid; + } + + public void Callback() + { + CallbackCalled = true; + } + } + + public class Match3ActuatorTests + { + [SetUp] + public void SetUp() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + } + + [TestCase(true)] + [TestCase(false)] + public void TestValidMoves(bool movesAreValid) + { + // Check that a board with no valid moves doesn't raise an exception. + var gameObj = new GameObject(); + var board = gameObj.AddComponent(); + var agent = gameObj.AddComponent(); + gameObj.AddComponent(); + + board.Rows = 5; + board.Columns = 5; + board.NumCellTypes = 5; + board.NumSpecialTypes = 0; + + board.MovesAreValid = movesAreValid; + board.OnNoValidMovesAction = board.Callback; + board.LastMoveIndex = -1; + + agent.LazyInitialize(); + agent.RequestDecision(); + Academy.Instance.EnvironmentStep(); + + if (movesAreValid) + { + Assert.IsFalse(board.CallbackCalled); + } + else + { + Assert.IsTrue(board.CallbackCalled); + } + Assert.AreNotEqual(-1, board.LastMoveIndex); + } + + [Test] + public void TestActionSpec() + { + var gameObj = new GameObject(); + var board = gameObj.AddComponent(); + var actuator = gameObj.AddComponent(); + + board.Rows = 5; + board.Columns = 5; + board.NumCellTypes = 5; + board.NumSpecialTypes = 0; + + var actionSpec = actuator.ActionSpec; + Assert.AreEqual(1, actionSpec.NumDiscreteActions); + Assert.AreEqual(board.NumMoves(), actionSpec.BranchSizes[0]); + } + + [Test] + public void TestActionSpecNullBoard() + { + var gameObj = new GameObject(); + var actuator = gameObj.AddComponent(); + + var actionSpec = actuator.ActionSpec; + Assert.AreEqual(0, actionSpec.NumDiscreteActions); + Assert.AreEqual(0, actionSpec.NumContinuousActions); + } + + public class HashSetActionMask : IDiscreteActionMask + { + public HashSet[] HashSets; + public HashSetActionMask(ActionSpec spec) + { + HashSets = new HashSet[spec.NumDiscreteActions]; + for (var i = 0; i < spec.NumDiscreteActions; i++) + { + HashSets[i] = new HashSet(); + } + } + + public void SetActionEnabled(int branch, int actionIndex, bool isEnabled) + { + var hashSet = HashSets[branch]; + if (isEnabled) + { + hashSet.Remove(actionIndex); + } + else + { + hashSet.Add(actionIndex); + } + } + } + + [TestCase(true, TestName = "Full Board")] + [TestCase(false, TestName = "Small Board")] + public void TestMasking(bool fullBoard) + { + var gameObj = new GameObject("board"); + var board = gameObj.AddComponent(); + + var boardString = + @"0105 + 1024 + 0203 + 2022"; + board.SetBoard(boardString); + var boardSize = board.GetMaxBoardSize(); + if (!fullBoard) + { + board.CurrentRows -= 1; + } + + var validMoves = AbstractBoardTests.GetValidMoves4x4(fullBoard, boardSize); + + var actuatorComponent = gameObj.AddComponent(); + var actuator = actuatorComponent.CreateActuators()[0]; + + var masks = new HashSetActionMask(actuator.ActionSpec); + actuator.WriteDiscreteActionMask(masks); + + // Run through all moves and make sure those are the only valid ones + HashSet validIndices = new HashSet(); + foreach (var m in validMoves) + { + validIndices.Add(m.MoveIndex); + } + + // Valid moves and masked moves should be disjoint + Assert.IsFalse(validIndices.Overlaps(masks.HashSets[0])); + // And they should add up to all the potential moves + Assert.AreEqual(validIndices.Count + masks.HashSets[0].Count, board.NumMoves()); + } + + [Test] + public void TestNoBoardReturnsEmptyActuators() + { + var gameObj = new GameObject("board"); + var actuatorComponent = gameObj.AddComponent(); + var actuators = actuatorComponent.CreateActuators(); + Assert.AreEqual(0, actuators.Length); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/Match3ActuatorTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/Match3ActuatorTests.cs.meta new file mode 100644 index 0000000000..3731b4e758 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/Match3ActuatorTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 2edf24df24ac426085cb31a94d063683 +timeCreated: 1603392289 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/Match3SensorTests.cs b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/Match3SensorTests.cs new file mode 100644 index 0000000000..a376bd3523 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/Match3SensorTests.cs @@ -0,0 +1,408 @@ +using System.Collections.Generic; +using System.IO; +using System.Reflection; +using NUnit.Framework; +using UnityEngine; +using Unity.MLAgents.Integrations.Match3; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Tests.Integrations.Match3 +{ + public class Match3SensorTests + { + // Whether the expected PNG data should be written to a file. + // Only set this to true if the compressed observation format changes. + private bool WritePNGDataToFile = false; + private const string k_CellObservationPng = "match3obs_"; + private const string k_SpecialObservationPng = "match3obs_special_"; + private const string k_Suffix2x2 = "2x2_"; + + [TestCase(true, TestName = "Full Board")] + [TestCase(false, TestName = "Small Board")] + public void TestVectorObservations(bool fullBoard) + { + var boardString = + @"000 + 000 + 010"; + var gameObj = new GameObject("board"); + var board = gameObj.AddComponent(); + board.SetBoard(boardString); + if (!fullBoard) + { + board.CurrentRows = 2; + board.CurrentColumns = 2; + } + + var sensorComponent = gameObj.AddComponent(); + sensorComponent.ObservationType = Match3ObservationType.Vector; + var sensor = sensorComponent.CreateSensors()[0]; + + var expectedShape = new InplaceArray(3 * 3 * 2); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); + + float[] expectedObs; + + if (fullBoard) + { + expectedObs = new float[] + { + 1, 0, /* 0 */ 0, 1, /* 1 */ 1, 0, /* 0 */ + 1, 0, /* 0 */ 1, 0, /* 0 */ 1, 0, /* 0 */ + 1, 0, /* 0 */ 1, 0, /* 0 */ 1, 0, /* 0 */ + }; + } + else + { + expectedObs = new float[] + { + 1, 0, /* 0 */ 0, 1, /* 1 */ 0, 0, /* empty */ + 1, 0, /* 0 */ 1, 0, /* 0 */ 0, 0, /* empty */ + 0, 0, /* empty */ 0, 0, /* empty */ 0, 0, /* empty */ + }; + } + SensorTestHelper.CompareObservation(sensor, expectedObs); + } + + [Test] + public void TestVectorObservationsSpecial() + { + var boardString = + @"000 + 000 + 010"; + var specialString = + @"010 + 200 + 000"; + + var gameObj = new GameObject("board"); + var board = gameObj.AddComponent(); + board.SetBoard(boardString); + board.SetSpecial(specialString); + + var sensorComponent = gameObj.AddComponent(); + sensorComponent.ObservationType = Match3ObservationType.Vector; + var sensors = sensorComponent.CreateSensors(); + var cellSensor = sensors[0]; + var specialSensor = sensors[1]; + + { + var expectedShape = new InplaceArray(3 * 3 * 2); + Assert.AreEqual(expectedShape, cellSensor.GetObservationSpec().Shape); + + var expectedObs = new float[] + { + 1, 0, /* (0) */ 0, 1, /* (1) */ 1, 0, /* (0) */ + 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ + 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ + }; + SensorTestHelper.CompareObservation(cellSensor, expectedObs); + } + { + var expectedShape = new InplaceArray(3 * 3 * 3); + Assert.AreEqual(expectedShape, specialSensor.GetObservationSpec().Shape); + + var expectedObs = new float[] + { + 1, 0, 0, /* (0) */ 1, 0, 0, /* (1) */ 1, 0, 0, /* (0) */ + 0, 0, 1, /* (2) */ 1, 0, 0, /* (0) */ 1, 0, 0, /* (0) */ + 1, 0, 0, /* (0) */ 0, 1, 0, /* (1) */ 1, 0, 0, /* (0) */ + }; + SensorTestHelper.CompareObservation(specialSensor, expectedObs); + } + } + + [TestCase(true, TestName = "Full Board")] + [TestCase(false, TestName = "Small Board")] + public void TestVisualObservations(bool fullBoard) + { + var boardString = + @"000 + 000 + 010"; + var gameObj = new GameObject("board"); + var board = gameObj.AddComponent(); + board.SetBoard(boardString); + if (!fullBoard) + { + board.CurrentRows = 2; + board.CurrentColumns = 2; + } + + var sensorComponent = gameObj.AddComponent(); + sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual; + var sensor = sensorComponent.CreateSensors()[0]; + + var expectedShape = new InplaceArray(3, 3, 2); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); + + Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType); + + float[] expectedObs; + float[,,] expectedObs3D; + + if (fullBoard) + { + expectedObs = new float[] + { + 1, 0, /**/ 0, 1, /**/ 1, 0, + 1, 0, /**/ 1, 0, /**/ 1, 0, + 1, 0, /**/ 1, 0, /**/ 1, 0, + }; + + expectedObs3D = new float[,,] + { + {{1, 0}, {0, 1}, {1, 0}}, + {{1, 0}, {1, 0}, {1, 0}}, + {{1, 0}, {1, 0}, {1, 0}}, + }; + } + else + { + expectedObs = new float[] + { + 1, 0, /* 0 */ 0, 1, /* 1 */ 0, 0, /* empty */ + 1, 0, /* 0 */ 1, 0, /* 0 */ 0, 0, /* empty */ + 0, 0, /* empty */ 0, 0, /* empty */ 0, 0, /* empty */ + }; + expectedObs3D = new float[,,] + { + {{1, 0}, {0, 1}, {0, 0}}, + {{1, 0}, {1, 0}, {0, 0}}, + {{0, 0}, {0, 0}, {0, 0}}, + }; + } + SensorTestHelper.CompareObservation(sensor, expectedObs); + SensorTestHelper.CompareObservation(sensor, expectedObs3D); + } + + [Test] + public void TestVisualObservationsSpecial() + { + var boardString = + @"000 + 000 + 010"; + var specialString = + @"010 + 200 + 000"; + + var gameObj = new GameObject("board"); + var board = gameObj.AddComponent(); + board.SetBoard(boardString); + board.SetSpecial(specialString); + + var sensorComponent = gameObj.AddComponent(); + sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual; + var sensors = sensorComponent.CreateSensors(); + var cellSensor = sensors[0]; + var specialSensor = sensors[1]; + + { + var expectedShape = new InplaceArray(3, 3, 2); + Assert.AreEqual(expectedShape, cellSensor.GetObservationSpec().Shape); + + Assert.AreEqual(SensorCompressionType.None, cellSensor.GetCompressionSpec().SensorCompressionType); + + var expectedObs = new float[] + { + 1, 0, /* (0) */ 0, 1, /* (1) */ 1, 0, /* (0) */ + 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ + 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ + }; + SensorTestHelper.CompareObservation(cellSensor, expectedObs); + + var expectedObs3D = new float[,,] + { + {{1, 0}, {0, 1}, {1, 0}}, + {{1, 0}, {1, 0}, {1, 0}}, + {{1, 0}, {1, 0}, {1, 0}}, + }; + SensorTestHelper.CompareObservation(cellSensor, expectedObs3D); + } + { + var expectedShape = new InplaceArray(3, 3, 3); + Assert.AreEqual(expectedShape, specialSensor.GetObservationSpec().Shape); + + Assert.AreEqual(SensorCompressionType.None, specialSensor.GetCompressionSpec().SensorCompressionType); + + var expectedObs = new float[] + { + 1, 0, 0, /* (0) */ 1, 0, 0, /* (1) */ 1, 0, 0, /* (0) */ + 0, 0, 1, /* (2) */ 1, 0, 0, /* (0) */ 1, 0, 0, /* (0) */ + 1, 0, 0, /* (0) */ 0, 1, 0, /* (1) */ 1, 0, 0, /* (0) */ + }; + SensorTestHelper.CompareObservation(specialSensor, expectedObs); + + var expectedObs3D = new float[,,] + { + {{1, 0, 0}, {1, 0, 0}, {1, 0, 0}}, + {{0, 0, 1}, {1, 0, 0}, {1, 0, 0}}, + {{1, 0, 0}, {0, 1, 0}, {1, 0, 0}}, + }; + SensorTestHelper.CompareObservation(specialSensor, expectedObs3D); + } + + // Test that Dispose() cleans up the component and its sensors + sensorComponent.Dispose(); + + var flags = BindingFlags.Instance | BindingFlags.NonPublic; + var componentSensors = (ISensor[])typeof(Match3SensorComponent).GetField("m_Sensors", flags).GetValue(sensorComponent); + Assert.IsNull(componentSensors); + var cellTexture = (Texture2D)typeof(Match3Sensor).GetField("m_ObservationTexture", flags).GetValue(cellSensor); + Assert.IsNull(cellTexture); + var specialTexture = (Texture2D)typeof(Match3Sensor).GetField("m_ObservationTexture", flags).GetValue(cellSensor); + Assert.IsNull(specialTexture); + } + + + [TestCase(true, false, TestName = "Full Board, No Special")] + [TestCase(false, false, TestName = "Small Board, No Special")] + [TestCase(true, true, TestName = "Full Board, Special")] + [TestCase(false, true, TestName = "Small Board, Special")] + public void TestCompressedVisualObservationsSpecial(bool fullBoard, bool useSpecial) + { + var boardString = + @"003 + 000 + 010"; + var specialString = + @"014 + 200 + 000"; + + var gameObj = new GameObject("board"); + var board = gameObj.AddComponent(); + board.SetBoard(boardString); + var paths = new List { k_CellObservationPng }; + if (useSpecial) + { + board.SetSpecial(specialString); + paths.Add(k_SpecialObservationPng); + } + + if (!fullBoard) + { + // Shrink the board, and change the paths we're using for the ground truth PNGs + board.CurrentRows = 2; + board.CurrentColumns = 2; + for (var i = 0; i < paths.Count; i++) + { + paths[i] = paths[i] + k_Suffix2x2; + } + } + + var sensorComponent = gameObj.AddComponent(); + sensorComponent.ObservationType = Match3ObservationType.CompressedVisual; + var sensors = sensorComponent.CreateSensors(); + + var expectedNumChannels = new[] { 4, 5 }; + + for (var i = 0; i < paths.Count; i++) + { + var sensor = sensors[i]; + var expectedShape = new InplaceArray(3, 3, expectedNumChannels[i]); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); + + Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType); + + var pngData = sensor.GetCompressedObservation(); + if (WritePNGDataToFile) + { + // Enable this if the format of the observation changes + SavePNGs(pngData, paths[i]); + } + + var expectedPng = LoadPNGs(paths[i], 2); + Assert.AreEqual(expectedPng, pngData); + } + } + + /// + /// Helper method for un-concatenating PNG observations. + /// + /// + /// + List SplitPNGs(byte[] concatenated) + { + var pngsOut = new List(); + var pngHeader = new byte[] { 137, 80, 78, 71, 13, 10, 26, 10 }; + + var current = new List(); + for (var i = 0; i < concatenated.Length; i++) + { + current.Add(concatenated[i]); + + // Check if the header starts at the next position + // If so, we'll start a new output array. + var headerIsNext = false; + if (i + 1 < concatenated.Length - pngHeader.Length) + { + for (var j = 0; j < pngHeader.Length; j++) + { + if (concatenated[i + 1 + j] != pngHeader[j]) + { + break; + } + + if (j == pngHeader.Length - 1) + { + headerIsNext = true; + } + } + } + + if (headerIsNext) + { + pngsOut.Add(current.ToArray()); + current = new List(); + } + } + pngsOut.Add(current.ToArray()); + + return pngsOut; + } + + void SavePNGs(byte[] concatenatedPngData, string pathPrefix) + { + var splitPngs = SplitPNGs(concatenatedPngData); + + for (var i = 0; i < splitPngs.Count; i++) + { + var pngData = splitPngs[i]; + var path = $"Packages/com.unity.ml-agents/Tests/Editor/Integrations/Match3/{pathPrefix}{i}.png"; + using (var sw = File.Create(path)) + { + foreach (var b in pngData) + { + sw.WriteByte(b); + } + } + } + } + + byte[] LoadPNGs(string pathPrefix, int numExpected) + { + var bytesOut = new List(); + for (var i = 0; i < numExpected; i++) + { + var path = $"Packages/com.unity.ml-agents/Tests/Editor/Integrations/Match3/{pathPrefix}{i}.png"; + var res = File.ReadAllBytes(path); + bytesOut.AddRange(res); + } + + return bytesOut.ToArray(); + } + + [Test] + public void TestNoBoardReturnsEmptySensors() + { + var gameObj = new GameObject("board"); + var sensorComponent = gameObj.AddComponent(); + var sensors = sensorComponent.CreateSensors(); + Assert.AreEqual(0, sensors.Length); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/Match3SensorTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/Match3SensorTests.cs.meta new file mode 100644 index 0000000000..38a1a4d010 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/Match3SensorTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: dfe94a9d6e994f408cb97d07dd44c994 +timeCreated: 1603493723 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/MoveTests.cs b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/MoveTests.cs new file mode 100644 index 0000000000..8174e7e5db --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/MoveTests.cs @@ -0,0 +1,66 @@ +using System; +using NUnit.Framework; +using Unity.MLAgents.Integrations.Match3; + +namespace Unity.MLAgents.Tests.Integrations.Match3 +{ + public class MoveTests + { + [Test] + public void TestMoveEquivalence() + { + var board10x10 = new BoardSize { Rows = 10, Columns = 10 }; + var moveUp = Move.FromPositionAndDirection(1, 1, Direction.Up, board10x10); + var moveDown = Move.FromPositionAndDirection(2, 1, Direction.Down, board10x10); + Assert.AreEqual(moveUp.MoveIndex, moveDown.MoveIndex); + + var moveRight = Move.FromPositionAndDirection(1, 1, Direction.Right, board10x10); + var moveLeft = Move.FromPositionAndDirection(1, 2, Direction.Left, board10x10); + Assert.AreEqual(moveRight.MoveIndex, moveLeft.MoveIndex); + } + + [Test] + public void TestNext() + { + var maxRows = 8; + var maxCols = 13; + var boardSize = new BoardSize + { + Rows = maxRows, + Columns = maxCols + }; + // make sure using Next agrees with FromMoveIndex. + var advanceMove = Move.FromMoveIndex(0, boardSize); + for (var moveIndex = 0; moveIndex < Move.NumPotentialMoves(boardSize); moveIndex++) + { + var moveFromIndex = Move.FromMoveIndex(moveIndex, boardSize); + Assert.AreEqual(advanceMove.MoveIndex, moveFromIndex.MoveIndex); + Assert.AreEqual(advanceMove.Row, moveFromIndex.Row); + Assert.AreEqual(advanceMove.Column, moveFromIndex.Column); + Assert.AreEqual(advanceMove.Direction, moveFromIndex.Direction); + + advanceMove.Next(boardSize); + } + } + + // These are off the board + [TestCase(-1, 5, Direction.Up)] + [TestCase(10, 5, Direction.Up)] + [TestCase(5, -1, Direction.Up)] + [TestCase(5, 10, Direction.Up)] + // These are on the board but would move off + [TestCase(0, 5, Direction.Down)] + [TestCase(9, 5, Direction.Up)] + [TestCase(5, 0, Direction.Left)] + [TestCase(5, 9, Direction.Right)] + public void TestInvalidMove(int row, int col, Direction dir) + { + var board10x10 = new BoardSize { Rows = 10, Columns = 10 }; + Assert.Throws(() => + { + Move.FromPositionAndDirection(row, col, dir, board10x10); + }); + + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/MoveTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/MoveTests.cs.meta new file mode 100644 index 0000000000..e016865fcb --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/MoveTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 42981032af6f4241ae20fe24e898f60b +timeCreated: 1601336681 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_0.png b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_0.png new file mode 100644 index 0000000000..0743d0bc13 Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_0.png differ diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_0.png.meta b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_0.png.meta new file mode 100644 index 0000000000..9faafa2e61 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_0.png.meta @@ -0,0 +1,92 @@ +fileFormatVersion: 2 +guid: d91bf3f1eb13c4361bfb8bb61b94a71a +TextureImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 11 + mipmaps: + mipMapMode: 0 + enableMipMap: 1 + sRGBTexture: 1 + linearTexture: 0 + fadeOut: 0 + borderMipMap: 0 + mipMapsPreserveCoverage: 0 + alphaTestReferenceValue: 0.5 + mipMapFadeDistanceStart: 1 + mipMapFadeDistanceEnd: 3 + bumpmap: + convertToNormalMap: 0 + externalNormalMap: 0 + heightScale: 0.25 + normalMapFilter: 0 + isReadable: 0 + streamingMipmaps: 0 + streamingMipmapsPriority: 0 + grayScaleToAlpha: 0 + generateCubemap: 6 + cubemapConvolution: 0 + seamlessCubemap: 0 + textureFormat: 1 + maxTextureSize: 2048 + textureSettings: + serializedVersion: 2 + filterMode: -1 + aniso: -1 + mipBias: -100 + wrapU: -1 + wrapV: -1 + wrapW: -1 + nPOTScale: 1 + lightmap: 0 + compressionQuality: 50 + spriteMode: 0 + spriteExtrude: 1 + spriteMeshType: 1 + alignment: 0 + spritePivot: {x: 0.5, y: 0.5} + spritePixelsToUnits: 100 + spriteBorder: {x: 0, y: 0, z: 0, w: 0} + spriteGenerateFallbackPhysicsShape: 1 + alphaUsage: 1 + alphaIsTransparency: 0 + spriteTessellationDetail: -1 + textureType: 0 + textureShape: 1 + singleChannelComponent: 0 + maxTextureSizeSet: 0 + compressionQualitySet: 0 + textureFormatSet: 0 + applyGammaDecoding: 0 + platformSettings: + - serializedVersion: 3 + buildTarget: DefaultTexturePlatform + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + spriteSheet: + serializedVersion: 2 + sprites: [] + outline: [] + physicsShape: [] + bones: [] + spriteID: + internalID: 0 + vertices: [] + indices: + edges: [] + weights: [] + secondaryTextures: [] + spritePackingTag: + pSDRemoveMatte: 0 + pSDShowRemoveMatteOption: 0 + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_1.png b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_1.png new file mode 100644 index 0000000000..afd50b9af8 Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_1.png differ diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_1.png.meta b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_1.png.meta new file mode 100644 index 0000000000..adbae821a0 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_1.png.meta @@ -0,0 +1,92 @@ +fileFormatVersion: 2 +guid: dba7ae2c8d38c4c20b14412feb512a73 +TextureImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 11 + mipmaps: + mipMapMode: 0 + enableMipMap: 1 + sRGBTexture: 1 + linearTexture: 0 + fadeOut: 0 + borderMipMap: 0 + mipMapsPreserveCoverage: 0 + alphaTestReferenceValue: 0.5 + mipMapFadeDistanceStart: 1 + mipMapFadeDistanceEnd: 3 + bumpmap: + convertToNormalMap: 0 + externalNormalMap: 0 + heightScale: 0.25 + normalMapFilter: 0 + isReadable: 0 + streamingMipmaps: 0 + streamingMipmapsPriority: 0 + grayScaleToAlpha: 0 + generateCubemap: 6 + cubemapConvolution: 0 + seamlessCubemap: 0 + textureFormat: 1 + maxTextureSize: 2048 + textureSettings: + serializedVersion: 2 + filterMode: -1 + aniso: -1 + mipBias: -100 + wrapU: -1 + wrapV: -1 + wrapW: -1 + nPOTScale: 1 + lightmap: 0 + compressionQuality: 50 + spriteMode: 0 + spriteExtrude: 1 + spriteMeshType: 1 + alignment: 0 + spritePivot: {x: 0.5, y: 0.5} + spritePixelsToUnits: 100 + spriteBorder: {x: 0, y: 0, z: 0, w: 0} + spriteGenerateFallbackPhysicsShape: 1 + alphaUsage: 1 + alphaIsTransparency: 0 + spriteTessellationDetail: -1 + textureType: 0 + textureShape: 1 + singleChannelComponent: 0 + maxTextureSizeSet: 0 + compressionQualitySet: 0 + textureFormatSet: 0 + applyGammaDecoding: 0 + platformSettings: + - serializedVersion: 3 + buildTarget: DefaultTexturePlatform + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + spriteSheet: + serializedVersion: 2 + sprites: [] + outline: [] + physicsShape: [] + bones: [] + spriteID: + internalID: 0 + vertices: [] + indices: + edges: [] + weights: [] + secondaryTextures: [] + spritePackingTag: + pSDRemoveMatte: 0 + pSDShowRemoveMatteOption: 0 + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_2x2_0.png b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_2x2_0.png new file mode 100644 index 0000000000..678315a87e Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_2x2_0.png differ diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_2x2_0.png.meta b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_2x2_0.png.meta new file mode 100644 index 0000000000..7e8aa6438a --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_2x2_0.png.meta @@ -0,0 +1,92 @@ +fileFormatVersion: 2 +guid: 63cd02fa3b69e430aa14ed2b919071fb +TextureImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 11 + mipmaps: + mipMapMode: 0 + enableMipMap: 1 + sRGBTexture: 1 + linearTexture: 0 + fadeOut: 0 + borderMipMap: 0 + mipMapsPreserveCoverage: 0 + alphaTestReferenceValue: 0.5 + mipMapFadeDistanceStart: 1 + mipMapFadeDistanceEnd: 3 + bumpmap: + convertToNormalMap: 0 + externalNormalMap: 0 + heightScale: 0.25 + normalMapFilter: 0 + isReadable: 0 + streamingMipmaps: 0 + streamingMipmapsPriority: 0 + grayScaleToAlpha: 0 + generateCubemap: 6 + cubemapConvolution: 0 + seamlessCubemap: 0 + textureFormat: 1 + maxTextureSize: 2048 + textureSettings: + serializedVersion: 2 + filterMode: -1 + aniso: -1 + mipBias: -100 + wrapU: -1 + wrapV: -1 + wrapW: -1 + nPOTScale: 1 + lightmap: 0 + compressionQuality: 50 + spriteMode: 0 + spriteExtrude: 1 + spriteMeshType: 1 + alignment: 0 + spritePivot: {x: 0.5, y: 0.5} + spritePixelsToUnits: 100 + spriteBorder: {x: 0, y: 0, z: 0, w: 0} + spriteGenerateFallbackPhysicsShape: 1 + alphaUsage: 1 + alphaIsTransparency: 0 + spriteTessellationDetail: -1 + textureType: 0 + textureShape: 1 + singleChannelComponent: 0 + maxTextureSizeSet: 0 + compressionQualitySet: 0 + textureFormatSet: 0 + applyGammaDecoding: 0 + platformSettings: + - serializedVersion: 3 + buildTarget: DefaultTexturePlatform + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + spriteSheet: + serializedVersion: 2 + sprites: [] + outline: [] + physicsShape: [] + bones: [] + spriteID: + internalID: 0 + vertices: [] + indices: + edges: [] + weights: [] + secondaryTextures: [] + spritePackingTag: + pSDRemoveMatte: 0 + pSDShowRemoveMatteOption: 0 + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_2x2_1.png b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_2x2_1.png new file mode 100644 index 0000000000..6f66a52754 Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_2x2_1.png differ diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_2x2_1.png.meta b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_2x2_1.png.meta new file mode 100644 index 0000000000..8a8b427966 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_2x2_1.png.meta @@ -0,0 +1,92 @@ +fileFormatVersion: 2 +guid: 09381a5789a894c39b220f74e8b59a2a +TextureImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 11 + mipmaps: + mipMapMode: 0 + enableMipMap: 1 + sRGBTexture: 1 + linearTexture: 0 + fadeOut: 0 + borderMipMap: 0 + mipMapsPreserveCoverage: 0 + alphaTestReferenceValue: 0.5 + mipMapFadeDistanceStart: 1 + mipMapFadeDistanceEnd: 3 + bumpmap: + convertToNormalMap: 0 + externalNormalMap: 0 + heightScale: 0.25 + normalMapFilter: 0 + isReadable: 0 + streamingMipmaps: 0 + streamingMipmapsPriority: 0 + grayScaleToAlpha: 0 + generateCubemap: 6 + cubemapConvolution: 0 + seamlessCubemap: 0 + textureFormat: 1 + maxTextureSize: 2048 + textureSettings: + serializedVersion: 2 + filterMode: -1 + aniso: -1 + mipBias: -100 + wrapU: -1 + wrapV: -1 + wrapW: -1 + nPOTScale: 1 + lightmap: 0 + compressionQuality: 50 + spriteMode: 0 + spriteExtrude: 1 + spriteMeshType: 1 + alignment: 0 + spritePivot: {x: 0.5, y: 0.5} + spritePixelsToUnits: 100 + spriteBorder: {x: 0, y: 0, z: 0, w: 0} + spriteGenerateFallbackPhysicsShape: 1 + alphaUsage: 1 + alphaIsTransparency: 0 + spriteTessellationDetail: -1 + textureType: 0 + textureShape: 1 + singleChannelComponent: 0 + maxTextureSizeSet: 0 + compressionQualitySet: 0 + textureFormatSet: 0 + applyGammaDecoding: 0 + platformSettings: + - serializedVersion: 3 + buildTarget: DefaultTexturePlatform + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + spriteSheet: + serializedVersion: 2 + sprites: [] + outline: [] + physicsShape: [] + bones: [] + spriteID: + internalID: 0 + vertices: [] + indices: + edges: [] + weights: [] + secondaryTextures: [] + spritePackingTag: + pSDRemoveMatte: 0 + pSDShowRemoveMatteOption: 0 + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_0.png b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_0.png new file mode 100644 index 0000000000..217e1f0b0a Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_0.png differ diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_0.png.meta b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_0.png.meta new file mode 100644 index 0000000000..880d53c638 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_0.png.meta @@ -0,0 +1,92 @@ +fileFormatVersion: 2 +guid: 79c99419c5c4a4378b93c1496ea40338 +TextureImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 11 + mipmaps: + mipMapMode: 0 + enableMipMap: 1 + sRGBTexture: 1 + linearTexture: 0 + fadeOut: 0 + borderMipMap: 0 + mipMapsPreserveCoverage: 0 + alphaTestReferenceValue: 0.5 + mipMapFadeDistanceStart: 1 + mipMapFadeDistanceEnd: 3 + bumpmap: + convertToNormalMap: 0 + externalNormalMap: 0 + heightScale: 0.25 + normalMapFilter: 0 + isReadable: 0 + streamingMipmaps: 0 + streamingMipmapsPriority: 0 + grayScaleToAlpha: 0 + generateCubemap: 6 + cubemapConvolution: 0 + seamlessCubemap: 0 + textureFormat: 1 + maxTextureSize: 2048 + textureSettings: + serializedVersion: 2 + filterMode: -1 + aniso: -1 + mipBias: -100 + wrapU: -1 + wrapV: -1 + wrapW: -1 + nPOTScale: 1 + lightmap: 0 + compressionQuality: 50 + spriteMode: 0 + spriteExtrude: 1 + spriteMeshType: 1 + alignment: 0 + spritePivot: {x: 0.5, y: 0.5} + spritePixelsToUnits: 100 + spriteBorder: {x: 0, y: 0, z: 0, w: 0} + spriteGenerateFallbackPhysicsShape: 1 + alphaUsage: 1 + alphaIsTransparency: 0 + spriteTessellationDetail: -1 + textureType: 0 + textureShape: 1 + singleChannelComponent: 0 + maxTextureSizeSet: 0 + compressionQualitySet: 0 + textureFormatSet: 0 + applyGammaDecoding: 0 + platformSettings: + - serializedVersion: 3 + buildTarget: DefaultTexturePlatform + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + spriteSheet: + serializedVersion: 2 + sprites: [] + outline: [] + physicsShape: [] + bones: [] + spriteID: + internalID: 0 + vertices: [] + indices: + edges: [] + weights: [] + secondaryTextures: [] + spritePackingTag: + pSDRemoveMatte: 0 + pSDShowRemoveMatteOption: 0 + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_1.png b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_1.png new file mode 100644 index 0000000000..6ed8c591ed Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_1.png differ diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_1.png.meta b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_1.png.meta new file mode 100644 index 0000000000..be876a47d6 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_1.png.meta @@ -0,0 +1,92 @@ +fileFormatVersion: 2 +guid: 0ba48eb2f5f7e4ccc91a1c28784b2e13 +TextureImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 11 + mipmaps: + mipMapMode: 0 + enableMipMap: 1 + sRGBTexture: 1 + linearTexture: 0 + fadeOut: 0 + borderMipMap: 0 + mipMapsPreserveCoverage: 0 + alphaTestReferenceValue: 0.5 + mipMapFadeDistanceStart: 1 + mipMapFadeDistanceEnd: 3 + bumpmap: + convertToNormalMap: 0 + externalNormalMap: 0 + heightScale: 0.25 + normalMapFilter: 0 + isReadable: 0 + streamingMipmaps: 0 + streamingMipmapsPriority: 0 + grayScaleToAlpha: 0 + generateCubemap: 6 + cubemapConvolution: 0 + seamlessCubemap: 0 + textureFormat: 1 + maxTextureSize: 2048 + textureSettings: + serializedVersion: 2 + filterMode: -1 + aniso: -1 + mipBias: -100 + wrapU: -1 + wrapV: -1 + wrapW: -1 + nPOTScale: 1 + lightmap: 0 + compressionQuality: 50 + spriteMode: 0 + spriteExtrude: 1 + spriteMeshType: 1 + alignment: 0 + spritePivot: {x: 0.5, y: 0.5} + spritePixelsToUnits: 100 + spriteBorder: {x: 0, y: 0, z: 0, w: 0} + spriteGenerateFallbackPhysicsShape: 1 + alphaUsage: 1 + alphaIsTransparency: 0 + spriteTessellationDetail: -1 + textureType: 0 + textureShape: 1 + singleChannelComponent: 0 + maxTextureSizeSet: 0 + compressionQualitySet: 0 + textureFormatSet: 0 + applyGammaDecoding: 0 + platformSettings: + - serializedVersion: 3 + buildTarget: DefaultTexturePlatform + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + spriteSheet: + serializedVersion: 2 + sprites: [] + outline: [] + physicsShape: [] + bones: [] + spriteID: + internalID: 0 + vertices: [] + indices: + edges: [] + weights: [] + secondaryTextures: [] + spritePackingTag: + pSDRemoveMatte: 0 + pSDShowRemoveMatteOption: 0 + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_2x2_0.png b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_2x2_0.png new file mode 100644 index 0000000000..15a5808be1 Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_2x2_0.png differ diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_2x2_0.png.meta b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_2x2_0.png.meta new file mode 100644 index 0000000000..c625d01a60 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_2x2_0.png.meta @@ -0,0 +1,92 @@ +fileFormatVersion: 2 +guid: 3364811072f604580bd30b733f84f485 +TextureImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 11 + mipmaps: + mipMapMode: 0 + enableMipMap: 1 + sRGBTexture: 1 + linearTexture: 0 + fadeOut: 0 + borderMipMap: 0 + mipMapsPreserveCoverage: 0 + alphaTestReferenceValue: 0.5 + mipMapFadeDistanceStart: 1 + mipMapFadeDistanceEnd: 3 + bumpmap: + convertToNormalMap: 0 + externalNormalMap: 0 + heightScale: 0.25 + normalMapFilter: 0 + isReadable: 0 + streamingMipmaps: 0 + streamingMipmapsPriority: 0 + grayScaleToAlpha: 0 + generateCubemap: 6 + cubemapConvolution: 0 + seamlessCubemap: 0 + textureFormat: 1 + maxTextureSize: 2048 + textureSettings: + serializedVersion: 2 + filterMode: -1 + aniso: -1 + mipBias: -100 + wrapU: -1 + wrapV: -1 + wrapW: -1 + nPOTScale: 1 + lightmap: 0 + compressionQuality: 50 + spriteMode: 0 + spriteExtrude: 1 + spriteMeshType: 1 + alignment: 0 + spritePivot: {x: 0.5, y: 0.5} + spritePixelsToUnits: 100 + spriteBorder: {x: 0, y: 0, z: 0, w: 0} + spriteGenerateFallbackPhysicsShape: 1 + alphaUsage: 1 + alphaIsTransparency: 0 + spriteTessellationDetail: -1 + textureType: 0 + textureShape: 1 + singleChannelComponent: 0 + maxTextureSizeSet: 0 + compressionQualitySet: 0 + textureFormatSet: 0 + applyGammaDecoding: 0 + platformSettings: + - serializedVersion: 3 + buildTarget: DefaultTexturePlatform + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + spriteSheet: + serializedVersion: 2 + sprites: [] + outline: [] + physicsShape: [] + bones: [] + spriteID: + internalID: 0 + vertices: [] + indices: + edges: [] + weights: [] + secondaryTextures: [] + spritePackingTag: + pSDRemoveMatte: 0 + pSDShowRemoveMatteOption: 0 + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_2x2_1.png b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_2x2_1.png new file mode 100644 index 0000000000..6f66a52754 Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_2x2_1.png differ diff --git a/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_2x2_1.png.meta b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_2x2_1.png.meta new file mode 100644 index 0000000000..a05eb54ec5 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Integrations/Match3/match3obs_special_2x2_1.png.meta @@ -0,0 +1,92 @@ +fileFormatVersion: 2 +guid: 8fc3c98253b4846dc99475d4cd3ba93b +TextureImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 11 + mipmaps: + mipMapMode: 0 + enableMipMap: 1 + sRGBTexture: 1 + linearTexture: 0 + fadeOut: 0 + borderMipMap: 0 + mipMapsPreserveCoverage: 0 + alphaTestReferenceValue: 0.5 + mipMapFadeDistanceStart: 1 + mipMapFadeDistanceEnd: 3 + bumpmap: + convertToNormalMap: 0 + externalNormalMap: 0 + heightScale: 0.25 + normalMapFilter: 0 + isReadable: 0 + streamingMipmaps: 0 + streamingMipmapsPriority: 0 + grayScaleToAlpha: 0 + generateCubemap: 6 + cubemapConvolution: 0 + seamlessCubemap: 0 + textureFormat: 1 + maxTextureSize: 2048 + textureSettings: + serializedVersion: 2 + filterMode: -1 + aniso: -1 + mipBias: -100 + wrapU: -1 + wrapV: -1 + wrapW: -1 + nPOTScale: 1 + lightmap: 0 + compressionQuality: 50 + spriteMode: 0 + spriteExtrude: 1 + spriteMeshType: 1 + alignment: 0 + spritePivot: {x: 0.5, y: 0.5} + spritePixelsToUnits: 100 + spriteBorder: {x: 0, y: 0, z: 0, w: 0} + spriteGenerateFallbackPhysicsShape: 1 + alphaUsage: 1 + alphaIsTransparency: 0 + spriteTessellationDetail: -1 + textureType: 0 + textureShape: 1 + singleChannelComponent: 0 + maxTextureSizeSet: 0 + compressionQualitySet: 0 + textureFormatSet: 0 + applyGammaDecoding: 0 + platformSettings: + - serializedVersion: 3 + buildTarget: DefaultTexturePlatform + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + spriteSheet: + serializedVersion: 2 + sprites: [] + outline: [] + physicsShape: [] + bones: [] + spriteID: + internalID: 0 + vertices: [] + indices: + edges: [] + weights: [] + secondaryTextures: [] + spritePackingTag: + pSDRemoveMatte: 0 + pSDShowRemoveMatteOption: 0 + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs new file mode 100644 index 0000000000..221fe71cf7 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -0,0 +1,742 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using UnityEngine; +using NUnit.Framework; +using System.Reflection; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; +using Unity.MLAgents.Policies; +using Unity.MLAgents.SideChannels; +using Unity.MLAgents.Utils.Tests; + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class EditModeTestGeneration + { + [SetUp] + public void SetUp() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + } + + [Test] + public void TestAcademy() + { + var aca = Academy.Instance; + Assert.AreNotEqual(null, aca); + Assert.AreEqual(0, aca.EpisodeCount); + Assert.AreEqual(0, aca.StepCount); + Assert.AreEqual(0, aca.TotalStepCount); + } + + [Test] + public void TestAgent() + { + var agentGo = new GameObject("TestAgent"); + agentGo.AddComponent(); + var agent = agentGo.GetComponent(); + Assert.AreNotEqual(null, agent); + Assert.AreEqual(0, agent.initializeAgentCalls); + } + } + + [TestFixture] + public class EditModeTestInitialization + { + [SetUp] + public void SetUp() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + } + + [Test] + public void TestAcademy() + { + Assert.AreEqual(false, Academy.IsInitialized); + var aca = Academy.Instance; + Assert.AreEqual(true, Academy.IsInitialized); + + // Check that init is idempotent + aca.LazyInitialize(); + aca.LazyInitialize(); + + Assert.AreEqual(0, aca.EpisodeCount); + Assert.AreEqual(0, aca.StepCount); + Assert.AreEqual(0, aca.TotalStepCount); + Assert.AreNotEqual(null, SideChannelManager.GetSideChannel()); + Assert.AreNotEqual(null, SideChannelManager.GetSideChannel()); + Assert.AreNotEqual(null, SideChannelManager.GetSideChannel()); + + // Check that Dispose is idempotent + aca.Dispose(); + Assert.AreEqual(false, Academy.IsInitialized); + aca.Dispose(); + } + + [Test] + public void TestAcademyDispose() + { + var envParams1 = SideChannelManager.GetSideChannel(); + var engineParams1 = SideChannelManager.GetSideChannel(); + var statsParams1 = SideChannelManager.GetSideChannel(); + Academy.Instance.Dispose(); + + Academy.Instance.LazyInitialize(); + var envParams2 = SideChannelManager.GetSideChannel(); + var engineParams2 = SideChannelManager.GetSideChannel(); + var statsParams2 = SideChannelManager.GetSideChannel(); + Academy.Instance.Dispose(); + + Assert.AreNotEqual(envParams1, envParams2); + Assert.AreNotEqual(engineParams1, engineParams2); + Assert.AreNotEqual(statsParams1, statsParams2); + } + + [Test] + public void TestAgent() + { + var agentGo1 = new GameObject("TestAgent"); + agentGo1.AddComponent(); + var agent1 = agentGo1.GetComponent(); + var bp1 = agentGo1.GetComponent(); + bp1.ObservableAttributeHandling = ObservableAttributeOptions.ExcludeInherited; + + var agentGo2 = new GameObject("TestAgent"); + agentGo2.AddComponent(); + var agent2 = agentGo2.GetComponent(); + + Assert.AreEqual(0, agent1.agentOnEpisodeBeginCalls); + Assert.AreEqual(0, agent2.agentOnEpisodeBeginCalls); + Assert.AreEqual(0, agent1.initializeAgentCalls); + Assert.AreEqual(0, agent2.initializeAgentCalls); + Assert.AreEqual(0, agent1.agentActionCalls); + Assert.AreEqual(0, agent2.agentActionCalls); + + + agent2.LazyInitialize(); + agent1.LazyInitialize(); + + // agent1 was not enabled when the academy started + // The agents have been initialized + Assert.AreEqual(0, agent1.agentOnEpisodeBeginCalls); + Assert.AreEqual(0, agent2.agentOnEpisodeBeginCalls); + Assert.AreEqual(1, agent1.initializeAgentCalls); + Assert.AreEqual(1, agent2.initializeAgentCalls); + Assert.AreEqual(0, agent1.agentActionCalls); + Assert.AreEqual(0, agent2.agentActionCalls); + + // Make sure the Sensors were sorted + Assert.AreEqual(agent1.sensors[0].GetName(), "observableFloat"); + Assert.AreEqual(agent1.sensors[1].GetName(), "testsensor1"); + Assert.AreEqual(agent1.sensors[2].GetName(), "testsensor2"); + + // agent2 should only have two sensors (no observableFloat) + Assert.AreEqual(agent2.sensors[0].GetName(), "testsensor1"); + Assert.AreEqual(agent2.sensors[1].GetName(), "testsensor2"); + } + } + + [TestFixture] + public class EditModeTestStep + { + [SetUp] + public void SetUp() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + } + + [Test] + public void TestAcademy() + { + var aca = Academy.Instance; + + var numberReset = 0; + for (var i = 0; i < 10; i++) + { + Assert.AreEqual(numberReset, aca.EpisodeCount); + Assert.AreEqual(i, aca.StepCount); + + // The reset happens at the beginning of the first step + if (i == 0) + { + numberReset += 1; + } + Academy.Instance.EnvironmentStep(); + } + } + + [Test] + public void TestAcademyAutostep() + { + var aca = Academy.Instance; + Assert.IsTrue(aca.AutomaticSteppingEnabled); + aca.AutomaticSteppingEnabled = false; + Assert.IsFalse(aca.AutomaticSteppingEnabled); + aca.AutomaticSteppingEnabled = true; + Assert.IsTrue(aca.AutomaticSteppingEnabled); + } + + [Test] + public void TestAgent() + { + var agentGo1 = new GameObject("TestAgent"); + var bp1 = agentGo1.AddComponent(); + bp1.BrainParameters.ActionSpec = ActionSpec.MakeContinuous(1); + agentGo1.AddComponent(); + var agent1 = agentGo1.GetComponent(); + var agentGo2 = new GameObject("TestAgent"); + var bp2 = agentGo2.AddComponent(); + bp2.BrainParameters.ActionSpec = ActionSpec.MakeContinuous(1); + agentGo2.AddComponent(); + var agent2 = agentGo2.GetComponent(); + + var aca = Academy.Instance; + + var decisionRequester = agent1.gameObject.AddComponent(); + decisionRequester.DecisionPeriod = 2; + decisionRequester.Awake(); + // agent1 will take an action at every step and request a decision every 2 steps + // agent2 will request decisions only when RequestDecision is called + + agent1.LazyInitialize(); + + var numberAgent1Episodes = 0; + var numberAgent2Episodes = 0; + var numberAgent2Initialization = 0; + var requestDecision = 0; + var requestAction = 0; + for (var i = 0; i < 50; i++) + { + Assert.AreEqual(numberAgent1Episodes, agent1.agentOnEpisodeBeginCalls); + Assert.AreEqual(numberAgent2Episodes, agent2.agentOnEpisodeBeginCalls); + Assert.AreEqual(1, agent1.initializeAgentCalls); + Assert.AreEqual(numberAgent2Initialization, agent2.initializeAgentCalls); + Assert.AreEqual(i, agent1.agentActionCalls); + Assert.AreEqual(requestAction, agent2.agentActionCalls); + Assert.AreEqual((i + 1) / 2, agent1.collectObservationsCalls); + Assert.AreEqual(requestDecision, agent2.collectObservationsCalls); + // Agent 1 starts a new episode at the first step + if (i == 0) + { + numberAgent1Episodes += 1; + } + //Agent 2 is only initialized at step 2 + if (i == 2) + { + // Since Agent2 is initialized after the Academy has stepped, its OnEpisodeBegin should be called now. + Assert.AreEqual(0, agent2.agentOnEpisodeBeginCalls); + agent2.LazyInitialize(); + Assert.AreEqual(1, agent2.agentOnEpisodeBeginCalls); + numberAgent2Initialization += 1; + numberAgent2Episodes += 1; + } + + // We are testing request decision and request actions when called + // at different intervals + if ((i % 3 == 0) && (i > 2)) + { + //Every 3 steps after agent 2 is initialized, request decision + requestDecision += 1; + requestAction += 1; + agent2.RequestDecision(); + } + else if ((i % 5 == 0) && (i > 2)) + { + // Every 5 steps after agent 2 is initialized, request action + requestAction += 1; + agent2.RequestAction(); + } + aca.EnvironmentStep(); + } + } + } + + [TestFixture] + public class EditModeTestReset + { + [SetUp] + public void SetUp() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + } + + [Test] + public void TestAcademy() + { + var aca = Academy.Instance; + + var numberReset = 0; + var stepsSinceReset = 0; + for (var i = 0; i < 50; i++) + { + Assert.AreEqual(stepsSinceReset, aca.StepCount); + Assert.AreEqual(numberReset, aca.EpisodeCount); + Assert.AreEqual(i, aca.TotalStepCount); + + // Academy resets at the first step + if (i == 0) + { + numberReset += 1; + } + + stepsSinceReset += 1; + aca.EnvironmentStep(); + } + } + + [Test] + public void TestAgent() + { + var agentGo1 = new GameObject("TestAgent"); + var bp1 = agentGo1.AddComponent(); + bp1.BrainParameters.ActionSpec = ActionSpec.MakeContinuous(1); + agentGo1.AddComponent(); + var agent1 = agentGo1.GetComponent(); + var agentGo2 = new GameObject("TestAgent"); + var bp2 = agentGo2.AddComponent(); + bp2.BrainParameters.ActionSpec = ActionSpec.MakeContinuous(1); + agentGo2.AddComponent(); + var agent2 = agentGo2.GetComponent(); + + var aca = Academy.Instance; + + var decisionRequester = agent1.gameObject.AddComponent(); + decisionRequester.DecisionPeriod = 2; + + agent2.LazyInitialize(); + + var numberAgent1Episodes = 0; + var numberAgent2Episodes = 0; + var numberAcaReset = 0; + var acaStepsSinceReset = 0; + var agent2StepForEpisode = 0; + for (var i = 0; i < 5000; i++) + { + Assert.AreEqual(acaStepsSinceReset, aca.StepCount); + Assert.AreEqual(numberAcaReset, aca.EpisodeCount); + + Assert.AreEqual(i, aca.TotalStepCount); + Assert.AreEqual(numberAgent2Episodes, agent2.agentOnEpisodeBeginCalls); + Assert.AreEqual(agent2StepForEpisode, agent2.StepCount); + + // Agent 2 and academy reset at the first step + if (i == 0) + { + Assert.AreEqual(numberAgent2Episodes, agent2.agentOnEpisodeBeginCalls); + numberAcaReset += 1; + numberAgent2Episodes += 1; + } + + //Agent 1 is only initialized at step 2 + if (i == 2) + { + Assert.AreEqual(numberAgent1Episodes, agent1.agentOnEpisodeBeginCalls); + agent1.LazyInitialize(); + numberAgent1Episodes += 1; + Assert.AreEqual(numberAgent1Episodes, agent1.agentOnEpisodeBeginCalls); + } + + // Set agent 1 to done every 11 steps to test behavior + if (i % 11 == 5) + { + Assert.AreEqual(numberAgent1Episodes, agent1.agentOnEpisodeBeginCalls); + agent1.EndEpisode(); + numberAgent1Episodes += 1; + Assert.AreEqual(numberAgent1Episodes, agent1.agentOnEpisodeBeginCalls); + } + + // Ending the episode for agent 2 regularly + if (i % 13 == 3) + { + Assert.AreEqual(numberAgent2Episodes, agent2.agentOnEpisodeBeginCalls); + agent2.EndEpisode(); + numberAgent2Episodes += 1; + agent2StepForEpisode = 0; + Assert.AreEqual(numberAgent2Episodes, agent2.agentOnEpisodeBeginCalls); + } + + // Request a decision for agent 2 regularly + if (i % 3 == 2) + { + agent2.RequestDecision(); + } + else if (i % 5 == 1) + { + // Request an action without decision regularly + agent2.RequestAction(); + } + + acaStepsSinceReset += 1; + agent2StepForEpisode += 1; + aca.EnvironmentStep(); + } + } + } + + [TestFixture] + public class EditModeTestMiscellaneous + { + [SetUp] + public void SetUp() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + } + + [Test] + public void TestCumulativeReward() + { + var agentGo1 = new GameObject("TestAgent"); + var bp1 = agentGo1.AddComponent(); + bp1.BrainParameters.ActionSpec = ActionSpec.MakeContinuous(1); + var agent1 = agentGo1.AddComponent(); + var agentGo2 = new GameObject("TestAgent"); + var bp2 = agentGo2.AddComponent(); + bp2.BrainParameters.ActionSpec = ActionSpec.MakeContinuous(1); + var agent2 = agentGo2.AddComponent(); + var aca = Academy.Instance; + + var decisionRequester = agent1.gameObject.AddComponent(); + decisionRequester.DecisionPeriod = 2; + decisionRequester.Awake(); + + + agent1.MaxStep = 20; + + agent2.LazyInitialize(); + agent1.LazyInitialize(); + agent2.SetPolicy(new TestPolicy()); + + var expectedAgent1ActionForEpisode = 0; + + for (var i = 0; i < 50; i++) + { + expectedAgent1ActionForEpisode += 1; + if (expectedAgent1ActionForEpisode == agent1.MaxStep || i == 0) + { + expectedAgent1ActionForEpisode = 0; + } + agent2.RequestAction(); + Assert.LessOrEqual(Mathf.Abs(expectedAgent1ActionForEpisode * 10.1f - agent1.GetCumulativeReward()), 0.05f); + Assert.LessOrEqual(Mathf.Abs(i * 0.1f - agent2.GetCumulativeReward()), 0.05f); + + agent1.AddReward(10f); + aca.EnvironmentStep(); + } + } + + [Test] + public void TestMaxStepsReset() + { + var agentGo1 = new GameObject("TestAgent"); + var bp1 = agentGo1.AddComponent(); + bp1.BrainParameters.ActionSpec = ActionSpec.MakeContinuous(1); + agentGo1.AddComponent(); + var agent1 = agentGo1.GetComponent(); + var aca = Academy.Instance; + + var decisionRequester = agent1.gameObject.AddComponent(); + decisionRequester.DecisionPeriod = 1; + decisionRequester.Awake(); + + const int maxStep = 6; + agent1.MaxStep = maxStep; + agent1.LazyInitialize(); + + var expectedAgentStepCount = 0; + var expectedEpisodes = 0; + var expectedAgentAction = 0; + var expectedAgentActionForEpisode = 0; + var expectedCollectObsCalls = 0; + var expectedCollectObsCallsForEpisode = 0; + var expectedCompletedEpisodes = 0; + var expectedSensorResetCalls = 0; + + for (var i = 0; i < 15; i++) + { + // Agent should observe and act on each Academy step + expectedAgentAction += 1; + expectedAgentActionForEpisode += 1; + expectedCollectObsCalls += 1; + expectedCollectObsCallsForEpisode += 1; + expectedAgentStepCount += 1; + + // If the next step will put the agent at maxSteps, we expect it to reset + if (agent1.StepCount == maxStep - 1 || (i == 0)) + { + expectedEpisodes += 1; + } + + if (agent1.StepCount == maxStep - 1) + { + expectedAgentActionForEpisode = 0; + expectedCollectObsCallsForEpisode = 0; + expectedAgentStepCount = 0; + expectedCompletedEpisodes++; + expectedSensorResetCalls++; + expectedCollectObsCalls += 1; + } + aca.EnvironmentStep(); + + Assert.AreEqual(expectedAgentStepCount, agent1.StepCount); + Assert.AreEqual(expectedEpisodes, agent1.agentOnEpisodeBeginCalls); + Assert.AreEqual(expectedAgentAction, agent1.agentActionCalls); + Assert.AreEqual(expectedAgentActionForEpisode, agent1.agentActionCallsForEpisode); + Assert.AreEqual(expectedCollectObsCalls, agent1.collectObservationsCalls); + Assert.AreEqual(expectedCollectObsCallsForEpisode, agent1.collectObservationsCallsForEpisode); + Assert.AreEqual(expectedCompletedEpisodes, agent1.CompletedEpisodes); + Assert.AreEqual(expectedSensorResetCalls, agent1.sensor1.numResetCalls); + } + } + + [Test] + public void TestHeuristicPolicyStepsSensors() + { + // Make sure that Agents with HeuristicPolicies step their sensors each Academy step. + var agentGo1 = new GameObject("TestAgent"); + var bp1 = agentGo1.AddComponent(); + bp1.BrainParameters.ActionSpec = ActionSpec.MakeContinuous(1); + agentGo1.AddComponent(); + var agent1 = agentGo1.GetComponent(); + var aca = Academy.Instance; + + var decisionRequester = agent1.gameObject.AddComponent(); + decisionRequester.DecisionPeriod = 1; + decisionRequester.Awake(); + + agent1.LazyInitialize(); + Assert.AreEqual(agent1.GetPolicy().GetType(), typeof(HeuristicPolicy)); + + var numSteps = 10; + for (var i = 0; i < numSteps; i++) + { + aca.EnvironmentStep(); + } + Assert.AreEqual(numSteps, agent1.heuristicCalls); + Assert.AreEqual(numSteps, agent1.sensor1.numWriteCalls); + Assert.AreEqual(numSteps, agent1.sensor2.numCompressedCalls); + + Assert.AreEqual( + agent1.collectObservationsCallsForEpisode, + agent1.GetStoredActionBuffers().ContinuousActions[0] + ); + } + + [Test] + public void TestNullList() + { + var nullList = new HeuristicPolicy.NullList(); + Assert.Throws(() => + { + _ = ((IEnumerable)nullList).GetEnumerator(); + }); + + Assert.Throws(() => + { + _ = ((IEnumerable)nullList).GetEnumerator(); + }); + + Assert.Throws(() => + { + nullList.CopyTo(new[] { 0f }, 0); + }); + + nullList.Add(0); + Assert.IsTrue(nullList.Count == 0); + + nullList.Clear(); + Assert.IsTrue(nullList.Count == 0); + + nullList.Add(0); + Assert.IsFalse(nullList.Contains(0)); + Assert.IsFalse(nullList.Remove(0)); + Assert.IsFalse(nullList.IsReadOnly); + Assert.IsTrue(nullList.IndexOf(0) == -1); + nullList.Insert(0, 0); + Assert.IsFalse(nullList.Count > 0); + nullList.RemoveAt(0); + Assert.IsTrue(nullList.Count == 0); + Assert.IsTrue(Mathf.Approximately(0f, nullList[0])); + Assert.IsTrue(Mathf.Approximately(0f, nullList[1])); + } + } + + [TestFixture] + public class TestOnEnableOverride + { + public class OnEnableAgent : Agent + { + public bool callBase; + + protected override void OnEnable() + { + if (callBase) + base.OnEnable(); + } + } + + static void _InnerAgentTestOnEnableOverride(bool callBase = false) + { + var go = new GameObject(); + var agent = go.AddComponent(); + agent.callBase = callBase; + var onEnable = typeof(OnEnableAgent).GetMethod("OnEnable", BindingFlags.NonPublic | BindingFlags.Instance); + var sendInfo = typeof(Agent).GetMethod("SendInfoToBrain", BindingFlags.NonPublic | BindingFlags.Instance); + Assert.NotNull(onEnable); + onEnable.Invoke(agent, null); + Assert.NotNull(sendInfo); + if (agent.callBase) + { + Assert.DoesNotThrow(() => sendInfo.Invoke(agent, null)); + } + else + { + Assert.Throws(() => + { + try + { + sendInfo.Invoke(agent, null); + } + catch (TargetInvocationException e) + { + throw e.GetBaseException(); + } + }); + } + } + + [Test] + public void TestAgentCallBaseOnEnable() + { + _InnerAgentTestOnEnableOverride(true); + } + + [Test] + public void TestAgentDontCallBaseOnEnable() + { + _InnerAgentTestOnEnableOverride(); + } + } + + [TestFixture] + public class ObservableAttributeBehaviorTests + { + public class BaseObservableAgent : Agent + { + [Observable] + public float BaseField; + } + + public class DerivedObservableAgent : BaseObservableAgent + { + [Observable] + public float DerivedField; + } + + + [Test] + public void TestObservableAttributeBehaviorIgnore() + { + var variants = new[] + { + // No observables found + (ObservableAttributeOptions.Ignore, 0), + // Only DerivedField found + (ObservableAttributeOptions.ExcludeInherited, 1), + // DerivedField and BaseField found + (ObservableAttributeOptions.ExamineAll, 2) + }; + + foreach (var (behavior, expectedNumSensors) in variants) + { + var go = new GameObject(); + var agent = go.AddComponent(); + var bp = go.GetComponent(); + bp.ObservableAttributeHandling = behavior; + agent.LazyInitialize(); + int numAttributeSensors = 0; + foreach (var sensor in agent.sensors) + { + if (sensor.GetType() != typeof(VectorSensor)) + { + numAttributeSensors++; + } + } + Assert.AreEqual(expectedNumSensors, numAttributeSensors); + } + } + } + + [TestFixture] + public class AgentRecursionTests + { + [SetUp] + public void SetUp() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + } + + class CollectObsEndEpisodeAgent : Agent + { + public override void CollectObservations(VectorSensor sensor) + { + // NEVER DO THIS IN REAL CODE! + EndEpisode(); + } + } + + class OnEpisodeBeginEndEpisodeAgent : Agent + { + public override void OnEpisodeBegin() + { + // NEVER DO THIS IN REAL CODE! + EndEpisode(); + } + } + + void TestRecursiveThrows() where T : Agent + { + var gameObj = new GameObject(); + var agent = gameObj.AddComponent(); + agent.LazyInitialize(); + agent.RequestDecision(); + + Assert.Throws(() => + { + Academy.Instance.EnvironmentStep(); + }); + } + + [Test] + public void TestRecursiveCollectObsEndEpisodeThrows() + { + TestRecursiveThrows(); + } + + [Test] + public void TestRecursiveOnEpisodeBeginEndEpisodeThrows() + { + TestRecursiveThrows(); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs.meta b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs.meta new file mode 100644 index 0000000000..2823ab05a5 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 3170fcbfa5f4d4a8ca82c50c750e9083 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs b/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs new file mode 100644 index 0000000000..5d987d55df --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs @@ -0,0 +1,118 @@ +using System; +using System.Reflection; +using NUnit.Framework; +using UnityEngine; + +namespace Unity.MLAgents.Tests +{ + public class MultiAgentGroupTests + { + class TestAgent : Agent + { + internal int _GroupId + { + get + { + return (int)typeof(Agent).GetField("m_GroupId", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); + } + } + + internal float _GroupReward + { + get + { + return (float)typeof(Agent).GetField("m_GroupReward", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); + } + } + + internal Action _OnAgentDisabledActions + { + get + { + return (Action)typeof(Agent).GetField("OnAgentDisabled", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); + } + } + } + + [Test] + public void TestRegisteredAgentGroupId() + { + var agentGo = new GameObject("TestAgent"); + agentGo.AddComponent(); + var agent = agentGo.GetComponent(); + + // test register + SimpleMultiAgentGroup agentGroup1 = new SimpleMultiAgentGroup(); + agentGroup1.RegisterAgent(agent); + Assert.AreEqual(agentGroup1.GetId(), agent._GroupId); + Assert.IsNotNull(agent._OnAgentDisabledActions); + + // should not be able to registered to multiple groups + SimpleMultiAgentGroup agentGroup2 = new SimpleMultiAgentGroup(); + Assert.Throws( + () => agentGroup2.RegisterAgent(agent)); + Assert.AreEqual(agentGroup1.GetId(), agent._GroupId); + + // test unregister + agentGroup1.UnregisterAgent(agent); + Assert.AreEqual(0, agent._GroupId); + Assert.IsNull(agent._OnAgentDisabledActions); + + // test register to another group after unregister + agentGroup2.RegisterAgent(agent); + Assert.AreEqual(agentGroup2.GetId(), agent._GroupId); + Assert.IsNotNull(agent._OnAgentDisabledActions); + } + + [Test] + public void TestRegisterMultipleAgent() + { + var agentGo1 = new GameObject("TestAgent"); + agentGo1.AddComponent(); + var agent1 = agentGo1.GetComponent(); + var agentGo2 = new GameObject("TestAgent"); + agentGo2.AddComponent(); + var agent2 = agentGo2.GetComponent(); + + SimpleMultiAgentGroup agentGroup = new SimpleMultiAgentGroup(); + agentGroup.RegisterAgent(agent1); // register + Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 1); + agentGroup.UnregisterAgent(agent2); // unregister non-member agent + Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 1); + agentGroup.UnregisterAgent(agent1); // unregister + Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 0); + agentGroup.RegisterAgent(agent1); + agentGroup.RegisterAgent(agent1); // duplicated register + Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 1); + agentGroup.RegisterAgent(agent2); // register another + Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 2); + + // test add/set group rewards + agentGroup.AddGroupReward(0.1f); + Assert.AreEqual(0.1f, agent1._GroupReward); + agentGroup.AddGroupReward(0.5f); + Assert.AreEqual(0.6f, agent1._GroupReward); + agentGroup.SetGroupReward(0.3f); + Assert.AreEqual(0.3f, agent1._GroupReward); + // unregistered agent should not receive group reward + agentGroup.UnregisterAgent(agent1); + agentGroup.AddGroupReward(0.2f); + Assert.AreEqual(0.3f, agent1._GroupReward); + Assert.AreEqual(0.5f, agent2._GroupReward); + + // dispose group should automatically unregister all + agentGroup.Dispose(); + Assert.AreEqual(0, agent1._GroupId); + Assert.AreEqual(0, agent2._GroupId); + } + + [Test] + public void TestGroupIdCounter() + { + SimpleMultiAgentGroup group1 = new SimpleMultiAgentGroup(); + SimpleMultiAgentGroup group2 = new SimpleMultiAgentGroup(); + // id should be unique + Assert.AreNotEqual(group1.GetId(), group2.GetId()); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs.meta b/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs.meta new file mode 100644 index 0000000000..7edd502278 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: ef0158fde748d478ca5ee3bbe22a4c9e +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/MultinomialTest.cs b/com.unity.ml-agents/Tests/Editor/MultinomialTest.cs new file mode 100644 index 0000000000..28e3b1db6e --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/MultinomialTest.cs @@ -0,0 +1,54 @@ +using NUnit.Framework; +using Unity.MLAgents.Inference.Utils; + +namespace Unity.MLAgents.Tests +{ + public class MultinomialTest + { + [Test] + public void TestDim1() + { + var m = new Multinomial(2018); + var cdf = new[] { 1f }; + + Assert.AreEqual(0, m.Sample(cdf)); + Assert.AreEqual(0, m.Sample(cdf)); + Assert.AreEqual(0, m.Sample(cdf)); + } + + [Test] + public void TestDim1Unscaled() + { + var m = new Multinomial(2018); + var cdf = new[] { 0.1f }; + + Assert.AreEqual(0, m.Sample(cdf)); + Assert.AreEqual(0, m.Sample(cdf)); + Assert.AreEqual(0, m.Sample(cdf)); + } + + [Test] + public void TestDim3() + { + var m = new Multinomial(2018); + var cdf = new[] { 0.1f, 0.3f, 1.0f }; + + Assert.AreEqual(2, m.Sample(cdf)); + Assert.AreEqual(2, m.Sample(cdf)); + Assert.AreEqual(2, m.Sample(cdf)); + Assert.AreEqual(1, m.Sample(cdf)); + } + + [Test] + public void TestDim3Unscaled() + { + var m = new Multinomial(2018); + var cdf = new[] { 0.05f, 0.15f, 0.5f }; + + Assert.AreEqual(2, m.Sample(cdf)); + Assert.AreEqual(2, m.Sample(cdf)); + Assert.AreEqual(2, m.Sample(cdf)); + Assert.AreEqual(1, m.Sample(cdf)); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/MultinomialTest.cs.meta b/com.unity.ml-agents/Tests/Editor/MultinomialTest.cs.meta new file mode 100644 index 0000000000..44fd9bbe91 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/MultinomialTest.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 668f4ac2d83814df5a8883722633e4e5 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs b/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs new file mode 100644 index 0000000000..3dd0c91156 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs @@ -0,0 +1,75 @@ +using NUnit.Framework; +using Unity.MLAgents.Sensors; + + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class ObservationSpecTests + { + [Test] + public void TestVectorObsSpec() + { + var obsSpec = ObservationSpec.Vector(5); + Assert.AreEqual(1, obsSpec.Rank); + + var shape = obsSpec.Shape; + Assert.AreEqual(1, shape.Length); + Assert.AreEqual(5, shape[0]); + + var dimensionProps = obsSpec.DimensionProperties; + Assert.AreEqual(1, dimensionProps.Length); + Assert.AreEqual(DimensionProperty.None, dimensionProps[0]); + + Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType); + } + + [Test] + public void TestVariableLengthObsSpec() + { + var obsSpec = ObservationSpec.VariableLength(5, 6); + Assert.AreEqual(2, obsSpec.Rank); + + var shape = obsSpec.Shape; + Assert.AreEqual(2, shape.Length); + Assert.AreEqual(5, shape[0]); + Assert.AreEqual(6, shape[1]); + + var dimensionProps = obsSpec.DimensionProperties; + Assert.AreEqual(2, dimensionProps.Length); + Assert.AreEqual(DimensionProperty.VariableSize, dimensionProps[0]); + Assert.AreEqual(DimensionProperty.None, dimensionProps[1]); + + Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType); + } + + [Test] + public void TestVisualObsSpec() + { + var obsSpec = ObservationSpec.Visual(5, 6, 7); + Assert.AreEqual(3, obsSpec.Rank); + + var shape = obsSpec.Shape; + Assert.AreEqual(3, shape.Length); + Assert.AreEqual(5, shape[0]); + Assert.AreEqual(6, shape[1]); + Assert.AreEqual(7, shape[2]); + + var dimensionProps = obsSpec.DimensionProperties; + Assert.AreEqual(3, dimensionProps.Length); + Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[0]); + Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[1]); + Assert.AreEqual(DimensionProperty.None, dimensionProps[2]); + + Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType); + } + + [Test] + public void TestMismatchShapeDimensionPropThrows() + { + var shape = new InplaceArray(1, 2); + var dimProps = new InplaceArray(DimensionProperty.TranslationalEquivariance); + Assert.Throws(() => new ObservationSpec(shape, dimProps)); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs.meta b/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs.meta new file mode 100644 index 0000000000..2ea6756e50 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 27ff1979bd5e4b8ebeb4d98f414a5090 +timeCreated: 1615863866 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Policies.meta b/com.unity.ml-agents/Tests/Editor/Policies.meta new file mode 100644 index 0000000000..be3f189b91 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Policies.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: df271cac120e4d6893b14599fa8eb64d +timeCreated: 1617813392 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs b/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs new file mode 100644 index 0000000000..740f6a1a6b --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs @@ -0,0 +1,125 @@ +using NUnit.Framework; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Policies; +using UnityEngine; + +namespace Unity.MLAgents.Tests.Policies +{ + [TestFixture] + public class HeuristicPolicyTest + { + [SetUp] + public void SetUp() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + } + + /// + /// Assert that the action buffers are initialized to zero, and then set them to non-zero values. + /// + /// + static void CheckAndSetBuffer(in ActionBuffers actionsOut) + { + var continuousActions = actionsOut.ContinuousActions; + for (var continuousIndex = 0; continuousIndex < continuousActions.Length; continuousIndex++) + { + Assert.AreEqual(continuousActions[continuousIndex], 0.0f); + continuousActions[continuousIndex] = 1.0f; + } + + var discreteActions = actionsOut.DiscreteActions; + for (var discreteIndex = 0; discreteIndex < discreteActions.Length; discreteIndex++) + { + Assert.AreEqual(discreteActions[discreteIndex], 0); + discreteActions[discreteIndex] = 1; + } + } + + + class ActionClearedAgent : Agent + { + public int HeuristicCalls; + public override void Heuristic(in ActionBuffers actionsOut) + { + CheckAndSetBuffer(actionsOut); + HeuristicCalls++; + } + } + + class ActionClearedActuator : IActuator + { + public int HeuristicCalls; + public ActionClearedActuator(ActionSpec actionSpec) + { + ActionSpec = actionSpec; + Name = GetType().Name; + } + + public void OnActionReceived(ActionBuffers actionBuffers) + { + } + + public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) + { + } + + public void Heuristic(in ActionBuffers actionBuffersOut) + { + CheckAndSetBuffer(actionBuffersOut); + HeuristicCalls++; + } + + public ActionSpec ActionSpec { get; } + public string Name { get; } + + public void ResetData() + { + + } + } + + class ActionClearedActuatorComponent : ActuatorComponent + { + public ActionClearedActuator ActionClearedActuator; + public ActionClearedActuatorComponent() + { + ActionSpec = new ActionSpec(2, new[] { 3, 3 }); + } + + public override IActuator[] CreateActuators() + { + ActionClearedActuator = new ActionClearedActuator(ActionSpec); + return new IActuator[] { ActionClearedActuator }; + } + + public override ActionSpec ActionSpec { get; } + } + + [Test] + public void TestActionsCleared() + { + var gameObj = new GameObject(); + var agent = gameObj.AddComponent(); + var behaviorParameters = agent.GetComponent(); + behaviorParameters.BrainParameters.ActionSpec = new ActionSpec(1, new[] { 4 }); + behaviorParameters.BrainParameters.VectorObservationSize = 0; + behaviorParameters.BehaviorType = BehaviorType.HeuristicOnly; + + var actuatorComponent = gameObj.AddComponent(); + agent.LazyInitialize(); + + const int k_NumSteps = 5; + for (var i = 0; i < k_NumSteps; i++) + { + agent.RequestDecision(); + Academy.Instance.EnvironmentStep(); + } + + Assert.AreEqual(agent.HeuristicCalls, k_NumSteps); + Assert.AreEqual(actuatorComponent.ActionClearedActuator.HeuristicCalls, k_NumSteps); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs.meta b/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs.meta new file mode 100644 index 0000000000..682a64b746 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 5108e92f91a04ddab9d628c9bc57cadb +timeCreated: 1617813411 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/PublicAPI.meta b/com.unity.ml-agents/Tests/Editor/PublicAPI.meta new file mode 100644 index 0000000000..96673482ab --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/PublicAPI.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: f5aa894d83ebc411581c8475cd2f9ae0 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs b/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs new file mode 100644 index 0000000000..44b22d9c11 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs @@ -0,0 +1,100 @@ +using Unity.MLAgents.Sensors; +using NUnit.Framework; +using Unity.MLAgents; +using UnityEngine; + +namespace Unity.MLAgentsExamples +{ + /// + /// The purpose of these tests is to make sure that we can do basic operations like creating + /// an Agent and adding components from code without requiring access to internal methods. + /// The tests aren't intended to add extra test coverage (although they might) and might + /// not check any conditions. + /// + [TestFixture] + public class PublicApiValidation + { + [Test] + public void CheckSetupCameraSensorComponent() + { + var gameObject = new GameObject(); + var width = 24; + var height = 16; + + var sensorComponent = gameObject.AddComponent(); + sensorComponent.Camera = Camera.main; + sensorComponent.SensorName = "camera1"; + sensorComponent.Width = width; + sensorComponent.Height = height; + sensorComponent.Grayscale = true; + + // Make sure the sets actually applied + Assert.AreEqual("camera1", sensorComponent.SensorName); + Assert.AreEqual(width, sensorComponent.Width); + Assert.AreEqual(height, sensorComponent.Height); + Assert.IsTrue(sensorComponent.Grayscale); + } + + [Test] + public void CheckSetupRenderTextureSensorComponent() + { + var gameObject = new GameObject(); + + var sensorComponent = gameObject.AddComponent(); + var width = 24; + var height = 16; + var texture = new RenderTexture(width, height, 0); + sensorComponent.RenderTexture = texture; + sensorComponent.SensorName = "rtx1"; + sensorComponent.Grayscale = true; + + // Make sure the sets actually applied + Assert.AreEqual("rtx1", sensorComponent.SensorName); + Assert.IsTrue(sensorComponent.Grayscale); + } + +#if MLA_UNITY_PHYSICS_MODULE + [Test] + public void CheckSetupRayPerceptionSensorComponent() + { + var gameObject = new GameObject(); + + var sensorComponent = gameObject.AddComponent(); + sensorComponent.SensorName = "ray3d"; + sensorComponent.DetectableTags = new List { "Player", "Respawn" }; + sensorComponent.RaysPerDirection = 3; + sensorComponent.MaxRayDegrees = 30; + sensorComponent.SphereCastRadius = .1f; + sensorComponent.RayLayerMask = 0; + sensorComponent.ObservationStacks = 2; + + sensorComponent.CreateSensors(); + + var sensor = sensorComponent.RaySensor; + sensor.Update(); + var outputs = sensor.RayPerceptionOutput; + Assert.AreEqual(outputs.RayOutputs.Length, 2*sensorComponent.RaysPerDirection + 1); + } +#endif + + /// + /// Make sure we can inherit from DecisionRequester and override some logic. + /// + class CustomDecisionRequester : DecisionRequester + { + /// + /// Example logic. If the killswitch flag is set, the Agent never requests a decision. + /// + public bool KillswitchEnabled; + + public CustomDecisionRequester() + { + } + + protected override bool ShouldRequestDecision(DecisionRequestContext context) + { + return !KillswitchEnabled && base.ShouldRequestDecision(context); + } + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs.meta b/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs.meta new file mode 100644 index 0000000000..b4fadad54a --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 016a3ac45b0345e3ab95f14ecaabdb11 +timeCreated: 1583858370 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/PublicAPI/Unity.ML-Agents.Editor.Tests.PublicAPI.asmdef b/com.unity.ml-agents/Tests/Editor/PublicAPI/Unity.ML-Agents.Editor.Tests.PublicAPI.asmdef new file mode 100755 index 0000000000..c96b91cd52 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/PublicAPI/Unity.ML-Agents.Editor.Tests.PublicAPI.asmdef @@ -0,0 +1,26 @@ +{ + "name": "Unity.ML-Agents.Editor.Tests.PublicAPI", + "references": [ + "Unity.ML-Agents.Editor", + "Unity.ML-Agents", + "Unity.Barracuda", + "Unity.ML-Agents.CommunicatorObjects" + ], + "optionalUnityReferences": [ + "TestAssemblies" + ], + "includePlatforms": [ + "Editor" + ], + "excludePlatforms": [], + "allowUnsafeCode": false, + "overrideReferences": true, + "precompiledReferences": [ + "System.IO.Abstractions.dll", + "System.IO.Abstractions.TestingHelpers.dll" + ], + "autoReferenced": false, + "defineConstraints": [ + "UNITY_INCLUDE_TESTS" + ] +} diff --git a/com.unity.ml-agents/Tests/Editor/PublicAPI/Unity.ML-Agents.Editor.Tests.PublicAPI.asmdef.meta b/com.unity.ml-agents/Tests/Editor/PublicAPI/Unity.ML-Agents.Editor.Tests.PublicAPI.asmdef.meta new file mode 100644 index 0000000000..e1329a95fc --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/PublicAPI/Unity.ML-Agents.Editor.Tests.PublicAPI.asmdef.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: 38254a2538c8f42fb9f2057d17fd0e70 +AssemblyDefinitionImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/RandomNormalTest.cs b/com.unity.ml-agents/Tests/Editor/RandomNormalTest.cs new file mode 100644 index 0000000000..520e3e656f --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/RandomNormalTest.cs @@ -0,0 +1,91 @@ +using System; +using NUnit.Framework; +using Unity.MLAgents.Inference.Utils; + +namespace Unity.MLAgents.Tests +{ + public class RandomNormalTest + { + const float k_FirstValue = -1.19580f; + const float k_SecondValue = -0.97345f; + const double k_Epsilon = 0.0001; + + [Test] + public void RandomNormalTestTwoDouble() + { + var rn = new RandomNormal(2018); + + Assert.AreEqual(k_FirstValue, rn.NextDouble(), k_Epsilon); + Assert.AreEqual(k_SecondValue, rn.NextDouble(), k_Epsilon); + } + + [Test] + public void RandomNormalTestWithMean() + { + var rn = new RandomNormal(2018, 5.0f); + + Assert.AreEqual(k_FirstValue + 5.0, rn.NextDouble(), k_Epsilon); + Assert.AreEqual(k_SecondValue + 5.0, rn.NextDouble(), k_Epsilon); + } + + [Test] + public void RandomNormalTestWithStddev() + { + var rn = new RandomNormal(2018, 0.0f, 4.2f); + + Assert.AreEqual(k_FirstValue * 4.2, rn.NextDouble(), k_Epsilon); + Assert.AreEqual(k_SecondValue * 4.2, rn.NextDouble(), k_Epsilon); + } + + [Test] + public void RandomNormalTestWithMeanStddev() + { + const float mean = -3.2f; + const float stddev = 2.2f; + var rn = new RandomNormal(2018, mean, stddev); + + Assert.AreEqual(k_FirstValue * stddev + mean, rn.NextDouble(), k_Epsilon); + Assert.AreEqual(k_SecondValue * stddev + mean, rn.NextDouble(), k_Epsilon); + } + + [Test] + public void RandomNormalTestDistribution() + { + const float mean = -3.2f; + const float stddev = 2.2f; + var rn = new RandomNormal(2018, mean, stddev); + + const int numSamples = 100000; + // Adapted from https://www.johndcook.com/blog/standard_deviation/ + // Computes stddev and mean without losing precision + double oldM = 0.0, newM = 0.0, oldS = 0.0, newS = 0.0; + + for (var i = 0; i < numSamples; i++) + { + var x = rn.NextDouble(); + if (i == 0) + { + oldM = newM = x; + oldS = 0.0; + } + else + { + newM = oldM + (x - oldM) / i; + newS = oldS + (x - oldM) * (x - newM); + + // set up for next iteration + oldM = newM; + oldS = newS; + } + } + + var sampleMean = newM; + var sampleVariance = newS / (numSamples - 1); + var sampleStddev = Math.Sqrt(sampleVariance); + + // Note a larger epsilon here. We could get closer to the true values with more samples. + Assert.AreEqual(mean, sampleMean, 0.01); + Assert.AreEqual(stddev, sampleStddev, 0.01); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/RandomNormalTest.cs.meta b/com.unity.ml-agents/Tests/Editor/RandomNormalTest.cs.meta new file mode 100644 index 0000000000..f1d5c0833b --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/RandomNormalTest.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 518c8e6e10fd94059a064ffbe65557af +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs b/com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs new file mode 100644 index 0000000000..a06e914dbd --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs @@ -0,0 +1,72 @@ +using System; +using NUnit.Framework; + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class RecursionCheckerTests + { + class InfiniteRecurser + { + RecursionChecker m_checker = new RecursionChecker("InfiniteRecurser"); + public int NumCalls; + + public void Implode() + { + NumCalls++; + using (m_checker.Start()) + { + Implode(); + } + } + } + + [Test] + public void TestRecursionCheck() + { + var rc = new InfiniteRecurser(); + Assert.Throws(() => + { + rc.Implode(); + }); + + // Should increment twice before bailing out. + Assert.AreEqual(2, rc.NumCalls); + } + + class OneTimeThrower + { + RecursionChecker m_checker = new RecursionChecker("OneTimeThrower"); + public int NumCalls; + + public void DoStuff() + { + // This method throws from inside the checker the first time. + // Later calls do nothing. + NumCalls++; + using (m_checker.Start()) + { + if (NumCalls == 1) + { + throw new ArgumentException("oops"); + } + } + } + } + + [Test] + public void TestThrowResetsFlag() + { + var ott = new OneTimeThrower(); + Assert.Throws(() => + { + ott.DoStuff(); + }); + + // Make sure the flag is cleared if we throw in the "using". Should be able to step subsequently. + ott.DoStuff(); + ott.DoStuff(); + Assert.AreEqual(3, ott.NumCalls); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs.meta b/com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs.meta new file mode 100644 index 0000000000..7240ff8a0b --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 5a7183e11dd5434684a4225c80169173 +timeCreated: 1602781778 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/SamplerTests.cs b/com.unity.ml-agents/Tests/Editor/SamplerTests.cs new file mode 100644 index 0000000000..95eab64c5a --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/SamplerTests.cs @@ -0,0 +1,106 @@ +using NUnit.Framework; +using System.IO; +using Unity.MLAgents.SideChannels; + +namespace Unity.MLAgents.Tests +{ + public class SamplerTests + { + const int k_Seed = 1337; + const double k_Epsilon = 0.0001; + EnvironmentParametersChannel m_Channel; + + public SamplerTests() + { + m_Channel = SideChannelManager.GetSideChannel(); + // if running test on its own + if (m_Channel == null) + { + m_Channel = new EnvironmentParametersChannel(); + SideChannelManager.RegisterSideChannel(m_Channel); + } + } + [Test] + public void UniformSamplerTest() + { + float min_value = 1.0f; + float max_value = 2.0f; + string parameter = "parameter1"; + using (var outgoingMsg = new OutgoingMessage()) + { + outgoingMsg.WriteString(parameter); + // 1 indicates this meessage is a Sampler + outgoingMsg.WriteInt32(1); + outgoingMsg.WriteInt32(k_Seed); + outgoingMsg.WriteInt32((int)SamplerType.Uniform); + outgoingMsg.WriteFloat32(min_value); + outgoingMsg.WriteFloat32(max_value); + byte[] message = GetByteMessage(m_Channel, outgoingMsg); + SideChannelManager.ProcessSideChannelData(message); + } + Assert.AreEqual(1.208888f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); + Assert.AreEqual(1.118017f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); + } + + [Test] + public void GaussianSamplerTest() + { + float mean = 3.0f; + float stddev = 0.2f; + string parameter = "parameter2"; + using (var outgoingMsg = new OutgoingMessage()) + { + outgoingMsg.WriteString(parameter); + // 1 indicates this meessage is a Sampler + outgoingMsg.WriteInt32(1); + outgoingMsg.WriteInt32(k_Seed); + outgoingMsg.WriteInt32((int)SamplerType.Gaussian); + outgoingMsg.WriteFloat32(mean); + outgoingMsg.WriteFloat32(stddev); + byte[] message = GetByteMessage(m_Channel, outgoingMsg); + SideChannelManager.ProcessSideChannelData(message); + } + Assert.AreEqual(2.936162f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); + Assert.AreEqual(2.951348f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); + } + + [Test] + public void MultiRangeUniformSamplerTest() + { + float[] intervals = new float[4]; + intervals[0] = 1.2f; + intervals[1] = 2f; + intervals[2] = 3.2f; + intervals[3] = 4.1f; + string parameter = "parameter3"; + using (var outgoingMsg = new OutgoingMessage()) + { + outgoingMsg.WriteString(parameter); + // 1 indicates this meessage is a Sampler + outgoingMsg.WriteInt32(1); + outgoingMsg.WriteInt32(k_Seed); + outgoingMsg.WriteInt32((int)SamplerType.MultiRangeUniform); + outgoingMsg.WriteFloatList(intervals); + byte[] message = GetByteMessage(m_Channel, outgoingMsg); + SideChannelManager.ProcessSideChannelData(message); + } + Assert.AreEqual(3.387999f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); + Assert.AreEqual(1.294413f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); + } + + internal static byte[] GetByteMessage(SideChannel sideChannel, OutgoingMessage msg) + { + byte[] message = msg.ToByteArray(); + using (var memStream = new MemoryStream()) + { + using (var binaryWriter = new BinaryWriter(memStream)) + { + binaryWriter.Write(sideChannel.ChannelId.ToByteArray()); + binaryWriter.Write(message.Length); + binaryWriter.Write(message); + } + return memStream.ToArray(); + } + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/SamplerTests.cs.meta b/com.unity.ml-agents/Tests/Editor/SamplerTests.cs.meta new file mode 100644 index 0000000000..ef0d54e72a --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/SamplerTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 7e6609c51018d4132beda8ddedd46d91 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Serialization.meta b/com.unity.ml-agents/Tests/Editor/Serialization.meta new file mode 100644 index 0000000000..1386f3eca6 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Serialization.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: e76819646e3e4424f9c2802edbd2e41b +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Serialization/TestLoadOldPrefab.cs b/com.unity.ml-agents/Tests/Editor/Serialization/TestLoadOldPrefab.cs new file mode 100644 index 0000000000..7755ea6879 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Serialization/TestLoadOldPrefab.cs @@ -0,0 +1,26 @@ +using NUnit.Framework; +using UnityEngine; +using UnityEditor; +using Unity.MLAgents.Policies; + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class TestSerialization + { + const string k_oldPrefabPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/old_serialized_agent.prefab"; + const int k_numVecObs = 212; + const int k_numContinuousActions = 39; + + [Test] + public void TestDeserialization() + { + var prefab = AssetDatabase.LoadAssetAtPath(k_oldPrefabPath); + var agent = GameObject.Instantiate(prefab); + var bp = agent.GetComponent(); + Assert.AreEqual(bp.BrainParameters.ActionSpec.NumContinuousActions, k_numContinuousActions); + Assert.AreEqual(bp.BrainParameters.VectorObservationSize, k_numVecObs); + } + + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Serialization/TestLoadOldPrefab.cs.meta b/com.unity.ml-agents/Tests/Editor/Serialization/TestLoadOldPrefab.cs.meta new file mode 100644 index 0000000000..11479dd124 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Serialization/TestLoadOldPrefab.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: a3070d063598144268171a468db17ddd +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/SideChannels.meta b/com.unity.ml-agents/Tests/Editor/SideChannels.meta new file mode 100644 index 0000000000..67d4d62cd3 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/SideChannels.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 1228f198ceee45a38c7d9ff50425b65d +timeCreated: 1610760867 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/SideChannels/EngineConfigurationChannelTests.cs b/com.unity.ml-agents/Tests/Editor/SideChannels/EngineConfigurationChannelTests.cs new file mode 100644 index 0000000000..74bc2076c2 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/SideChannels/EngineConfigurationChannelTests.cs @@ -0,0 +1,44 @@ +using NUnit.Framework; +using Unity.MLAgents.SideChannels; +using UnityEngine; + +namespace Unity.MLAgents.Tests +{ + public class EngineConfigurationChannelTests + { + float m_OldTimeScale = 1.0f; + + [SetUp] + public void Setup() + { + m_OldTimeScale = Time.timeScale; + } + + [TearDown] + public void TearDown() + { + Time.timeScale = m_OldTimeScale; + } + + [Test] + public void TestTimeScaleClamping() + { + OutgoingMessage pythonMsg = new OutgoingMessage(); + pythonMsg.WriteInt32((int)EngineConfigurationChannel.ConfigurationType.TimeScale); + pythonMsg.WriteFloat32(1000f); + + var sideChannel = new EngineConfigurationChannel(); + sideChannel.ProcessMessage(pythonMsg.ToByteArray()); + +#if UNITY_EDITOR + // Should be clamped + Assert.AreEqual(100.0f, Time.timeScale); +#else + // Not sure we can run this test from a player, but just in case, shouldn't clamp. + Assert.AreEqual(1000.0f, Time.timeScale); +#endif + } + + + } +} diff --git a/com.unity.ml-agents/Tests/Editor/SideChannels/EngineConfigurationChannelTests.cs.meta b/com.unity.ml-agents/Tests/Editor/SideChannels/EngineConfigurationChannelTests.cs.meta new file mode 100644 index 0000000000..68ff93bb9e --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/SideChannels/EngineConfigurationChannelTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 71aa620295f74ca5875e8e4782f08768 +timeCreated: 1610760906 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/SideChannels/SideChannelTests.cs b/com.unity.ml-agents/Tests/Editor/SideChannels/SideChannelTests.cs new file mode 100644 index 0000000000..6ca25f2eda --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/SideChannels/SideChannelTests.cs @@ -0,0 +1,193 @@ +using System; +using NUnit.Framework; +using System.Collections.Generic; +using System.Text; +using Unity.MLAgents.SideChannels; + +namespace Unity.MLAgents.Tests +{ + public class SideChannelTests + { + // This test side channel only deals in integers + public class TestSideChannel : SideChannel + { + public List messagesReceived = new List(); + + public TestSideChannel() + { + ChannelId = new Guid("6afa2c06-4f82-11ea-b238-784f4387d1f7"); + } + + protected override void OnMessageReceived(IncomingMessage msg) + { + messagesReceived.Add(msg.ReadInt32()); + } + + public void SendInt(int value) + { + using (var msg = new OutgoingMessage()) + { + msg.WriteInt32(value); + QueueMessageToSend(msg); + } + } + } + + [Test] + public void TestIntegerSideChannel() + { + var intSender = new TestSideChannel(); + var intReceiver = new TestSideChannel(); + var dictSender = new Dictionary { { intSender.ChannelId, intSender } }; + var dictReceiver = new Dictionary { { intReceiver.ChannelId, intReceiver } }; + + intSender.SendInt(4); + intSender.SendInt(5); + intSender.SendInt(6); + + byte[] fakeData = SideChannelManager.GetSideChannelMessage(dictSender); + SideChannelManager.ProcessSideChannelData(dictReceiver, fakeData); + + Assert.AreEqual(intReceiver.messagesReceived[0], 4); + Assert.AreEqual(intReceiver.messagesReceived[1], 5); + Assert.AreEqual(intReceiver.messagesReceived[2], 6); + } + + [Test] + public void TestRawBytesSideChannel() + { + var str1 = "Test string"; + var str2 = "Test string, second"; + + var strSender = new RawBytesChannel(new Guid("9a5b8954-4f82-11ea-b238-784f4387d1f7")); + var strReceiver = new RawBytesChannel(new Guid("9a5b8954-4f82-11ea-b238-784f4387d1f7")); + var dictSender = new Dictionary { { strSender.ChannelId, strSender } }; + var dictReceiver = new Dictionary { { strReceiver.ChannelId, strReceiver } }; + + strSender.SendRawBytes(Encoding.ASCII.GetBytes(str1)); + strSender.SendRawBytes(Encoding.ASCII.GetBytes(str2)); + + byte[] fakeData = SideChannelManager.GetSideChannelMessage(dictSender); + SideChannelManager.ProcessSideChannelData(dictReceiver, fakeData); + + var messages = strReceiver.GetAndClearReceivedMessages(); + + Assert.AreEqual(messages.Count, 2); + Assert.AreEqual(Encoding.ASCII.GetString(messages[0]), str1); + Assert.AreEqual(Encoding.ASCII.GetString(messages[1]), str2); + } + + [Test] + public void TestFloatPropertiesSideChannel() + { + var k1 = "gravity"; + var k2 = "length"; + int wasCalled = 0; + + var propA = new FloatPropertiesChannel(); + var propB = new FloatPropertiesChannel(); + var dictReceiver = new Dictionary { { propA.ChannelId, propA } }; + var dictSender = new Dictionary { { propB.ChannelId, propB } }; + + propA.RegisterCallback(k1, f => { wasCalled++; }); + var tmp = propB.GetWithDefault(k2, 3.0f); + Assert.AreEqual(tmp, 3.0f); + propB.Set(k2, 1.0f); + tmp = propB.GetWithDefault(k2, 3.0f); + Assert.AreEqual(tmp, 1.0f); + + byte[] fakeData = SideChannelManager.GetSideChannelMessage(dictSender); + SideChannelManager.ProcessSideChannelData(dictReceiver, fakeData); + + tmp = propA.GetWithDefault(k2, 3.0f); + Assert.AreEqual(tmp, 1.0f); + + Assert.AreEqual(wasCalled, 0); + propB.Set(k1, 1.0f); + Assert.AreEqual(wasCalled, 0); + fakeData = SideChannelManager.GetSideChannelMessage(dictSender); + SideChannelManager.ProcessSideChannelData(dictReceiver, fakeData); + Assert.AreEqual(wasCalled, 1); + + var keysA = propA.Keys(); + Assert.AreEqual(2, keysA.Count); + Assert.IsTrue(keysA.Contains(k1)); + Assert.IsTrue(keysA.Contains(k2)); + + var keysB = propA.Keys(); + Assert.AreEqual(2, keysB.Count); + Assert.IsTrue(keysB.Contains(k1)); + Assert.IsTrue(keysB.Contains(k2)); + } + + [Test] + public void TestOutgoingMessageRawBytes() + { + // Make sure that SetRawBytes resets the buffer correctly. + // Write 8 bytes (an int and float) then call SetRawBytes with 4 bytes + var msg = new OutgoingMessage(); + msg.WriteInt32(42); + msg.WriteFloat32(1.0f); + + var data = new byte[] { 1, 2, 3, 4 }; + msg.SetRawBytes(data); + + var result = msg.ToByteArray(); + Assert.AreEqual(data, result); + } + + [Test] + public void TestMessageReadWrites() + { + var boolVal = true; + var intVal = 1337; + var floatVal = 4.2f; + var floatListVal = new float[] { 1001, 1002 }; + var stringVal = "mlagents!"; + + IncomingMessage incomingMsg; + using (var outgoingMsg = new OutgoingMessage()) + { + outgoingMsg.WriteBoolean(boolVal); + outgoingMsg.WriteInt32(intVal); + outgoingMsg.WriteFloat32(floatVal); + outgoingMsg.WriteString(stringVal); + outgoingMsg.WriteFloatList(floatListVal); + + incomingMsg = new IncomingMessage(outgoingMsg.ToByteArray()); + } + + Assert.AreEqual(boolVal, incomingMsg.ReadBoolean()); + Assert.AreEqual(intVal, incomingMsg.ReadInt32()); + Assert.AreEqual(floatVal, incomingMsg.ReadFloat32()); + Assert.AreEqual(stringVal, incomingMsg.ReadString()); + Assert.AreEqual(floatListVal, incomingMsg.ReadFloatList()); + } + + [Test] + public void TestMessageReadDefaults() + { + // Make sure reading past the end of a message will apply defaults. + IncomingMessage incomingMsg; + using (var outgoingMsg = new OutgoingMessage()) + { + incomingMsg = new IncomingMessage(outgoingMsg.ToByteArray()); + } + + Assert.AreEqual(false, incomingMsg.ReadBoolean()); + Assert.AreEqual(true, incomingMsg.ReadBoolean(defaultValue: true)); + + Assert.AreEqual(0, incomingMsg.ReadInt32()); + Assert.AreEqual(42, incomingMsg.ReadInt32(defaultValue: 42)); + + Assert.AreEqual(0.0f, incomingMsg.ReadFloat32()); + Assert.AreEqual(1337.0f, incomingMsg.ReadFloat32(defaultValue: 1337.0f)); + + Assert.AreEqual(default(string), incomingMsg.ReadString()); + Assert.AreEqual("foo", incomingMsg.ReadString(defaultValue: "foo")); + + Assert.AreEqual(default(float[]), incomingMsg.ReadFloatList()); + Assert.AreEqual(new float[] { 1001, 1002 }, incomingMsg.ReadFloatList(new float[] { 1001, 1002 })); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/SideChannels/SideChannelTests.cs.meta b/com.unity.ml-agents/Tests/Editor/SideChannels/SideChannelTests.cs.meta new file mode 100644 index 0000000000..cef0d1104e --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/SideChannels/SideChannelTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 589f475debcdb479295a24799777b5e5 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/TestModels.meta b/com.unity.ml-agents/Tests/Editor/TestModels.meta new file mode 100644 index 0000000000..1928478d70 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TestModels.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 95997790219c547e584c3cb50122a95f +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated_v1_0.nn b/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated_v1_0.nn new file mode 100644 index 0000000000..c1cd17e682 Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated_v1_0.nn differ diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated_v1_0.nn.meta b/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated_v1_0.nn.meta new file mode 100644 index 0000000000..413ad1d6f4 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated_v1_0.nn.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: bf4543cc3c6944794bbba065bdf90079 +ScriptedImporter: + fileIDToRecycleName: + 11400000: main obj + 11400002: model data + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: + script: {fileID: 11500000, guid: 19ed1486aa27d4903b34839f37b8f69f, type: 3} diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_v1_0.onnx b/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_v1_0.onnx new file mode 100644 index 0000000000..74581eb059 Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_v1_0.onnx differ diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_v1_0.onnx.meta b/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_v1_0.onnx.meta new file mode 100644 index 0000000000..73bd62d783 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_v1_0.onnx.meta @@ -0,0 +1,14 @@ +fileFormatVersion: 2 +guid: f90bffb60a3784a2385299a321f354a6 +ScriptedImporter: + fileIDToRecycleName: + 11400000: main obj + 11400002: model data + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: + script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3} + optimizeModel: 1 + forceArbitraryBatchSize: 1 + treatErrorsAsWarnings: 0 diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx b/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx new file mode 100644 index 0000000000..56c1cd4355 Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx differ diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx.meta b/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx.meta new file mode 100644 index 0000000000..cc92cc94b8 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx.meta @@ -0,0 +1,14 @@ +fileFormatVersion: 2 +guid: e905d8f9eadcf45aa8c485594fecba6d +ScriptedImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 2 + userData: + assetBundleName: + assetBundleVariant: + script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3} + optimizeModel: 1 + forceArbitraryBatchSize: 1 + treatErrorsAsWarnings: 0 + importMode: 1 diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx b/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx new file mode 100644 index 0000000000..3aa846e204 Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx differ diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx.meta b/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx.meta new file mode 100644 index 0000000000..a141a55235 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx.meta @@ -0,0 +1,14 @@ +fileFormatVersion: 2 +guid: d132cc9c934a54fdc99758427373e038 +ScriptedImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 2 + userData: + assetBundleName: + assetBundleVariant: + script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3} + optimizeModel: 1 + forceArbitraryBatchSize: 1 + treatErrorsAsWarnings: 0 + importMode: 1 diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_obsolete_recurr_v1_0.onnx b/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_obsolete_recurr_v1_0.onnx new file mode 100644 index 0000000000..e7e6c0cce4 Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_obsolete_recurr_v1_0.onnx differ diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_obsolete_recurr_v1_0.onnx.meta b/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_obsolete_recurr_v1_0.onnx.meta new file mode 100644 index 0000000000..92c9212a72 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_obsolete_recurr_v1_0.onnx.meta @@ -0,0 +1,14 @@ +fileFormatVersion: 2 +guid: 9ecb2a56b0c6b42f7ad2b40ab97c5515 +ScriptedImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 2 + userData: + assetBundleName: + assetBundleVariant: + script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3} + optimizeModel: 1 + forceArbitraryBatchSize: 1 + treatErrorsAsWarnings: 0 + importMode: 1 diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated_v1_0.nn b/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated_v1_0.nn new file mode 100644 index 0000000000..a427b24d56 Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated_v1_0.nn differ diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated_v1_0.nn.meta b/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated_v1_0.nn.meta new file mode 100644 index 0000000000..6f732557b4 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated_v1_0.nn.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 6d6040ad621454dd5b713beb5483e347 +ScriptedImporter: + fileIDToRecycleName: + 11400000: main obj + 11400002: model data + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: + script: {fileID: 11500000, guid: 19ed1486aa27d4903b34839f37b8f69f, type: 3} diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_v1_0.onnx b/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_v1_0.onnx new file mode 100644 index 0000000000..85488707eb Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_v1_0.onnx differ diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_v1_0.onnx.meta b/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_v1_0.onnx.meta new file mode 100644 index 0000000000..2bceceabac --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_v1_0.onnx.meta @@ -0,0 +1,14 @@ +fileFormatVersion: 2 +guid: 68991653e04394f95b15a222253c0729 +ScriptedImporter: + fileIDToRecycleName: + 11400000: main obj + 11400002: model data + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: + script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3} + optimizeModel: 1 + forceArbitraryBatchSize: 1 + treatErrorsAsWarnings: 0 diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/discrete_rank2_vector_v2_0.onnx b/com.unity.ml-agents/Tests/Editor/TestModels/discrete_rank2_vector_v2_0.onnx new file mode 100644 index 0000000000..c0937d733d Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/TestModels/discrete_rank2_vector_v2_0.onnx differ diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/discrete_rank2_vector_v2_0.onnx.meta b/com.unity.ml-agents/Tests/Editor/TestModels/discrete_rank2_vector_v2_0.onnx.meta new file mode 100644 index 0000000000..5cfd605639 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TestModels/discrete_rank2_vector_v2_0.onnx.meta @@ -0,0 +1,15 @@ +fileFormatVersion: 2 +guid: b6c7faadd10084c3995ad9fff7aa8c54 +ScriptedImporter: + fileIDToRecycleName: + 11400000: main obj + 11400002: model data + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: + script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3} + optimizeModel: 1 + forceArbitraryBatchSize: 1 + treatErrorsAsWarnings: 0 + importMode: 1 diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction_v1_0.onnx b/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction_v1_0.onnx new file mode 100644 index 0000000000..f04cac9a3a Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction_v1_0.onnx differ diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction_v1_0.onnx.meta b/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction_v1_0.onnx.meta new file mode 100644 index 0000000000..792ae290f7 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction_v1_0.onnx.meta @@ -0,0 +1,14 @@ +fileFormatVersion: 2 +guid: 9f774b4c578c3435da77d2831db84105 +ScriptedImporter: + fileIDToRecycleName: + 11400000: main obj + 11400002: model data + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: + script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3} + optimizeModel: 1 + forceArbitraryBatchSize: 1 + treatErrorsAsWarnings: 0 diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis8vec_2c_2_3d_v2_0.onnx b/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis8vec_2c_2_3d_v2_0.onnx new file mode 100644 index 0000000000..cd2a356f1c Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis8vec_2c_2_3d_v2_0.onnx differ diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis8vec_2c_2_3d_v2_0.onnx.meta b/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis8vec_2c_2_3d_v2_0.onnx.meta new file mode 100644 index 0000000000..5abc7e6432 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis8vec_2c_2_3d_v2_0.onnx.meta @@ -0,0 +1,14 @@ +fileFormatVersion: 2 +guid: 2f6b2ae61d96a4555b60892a0ad924bb +ScriptedImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 2 + userData: + assetBundleName: + assetBundleVariant: + script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3} + optimizeModel: 1 + forceArbitraryBatchSize: 1 + treatErrorsAsWarnings: 0 + importMode: 1 diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/old_serialized_agent.prefab b/com.unity.ml-agents/Tests/Editor/TestModels/old_serialized_agent.prefab new file mode 100644 index 0000000000..f8786e912b --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TestModels/old_serialized_agent.prefab @@ -0,0 +1,76 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!1 &5381908961062339374 +GameObject: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + serializedVersion: 6 + m_Component: + - component: {fileID: 5379309137663827240} + - component: {fileID: 5414230854946179998} + - component: {fileID: 8153975132613398210} + m_Layer: 0 + m_Name: old_serialized_agent + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!4 &5379309137663827240 +Transform: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 5381908961062339374} + m_LocalRotation: {x: -0, y: -0, z: 0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: 0} + m_LocalScale: {x: 1, y: 1, z: 1} + m_Children: [] + m_Father: {fileID: 0} + m_RootOrder: 0 + m_LocalEulerAnglesHint: {x: 0, y: 90, z: 0} +--- !u!114 &5414230854946179998 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 5381908961062339374} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: 5d1c4e0b1822b495aa52bc52839ecb30, type: 3} + m_Name: + m_EditorClassIdentifier: + m_BrainParameters: + vectorObservationSize: 212 + numStackedVectorObservations: 1 + vectorActionSize: 27000000 + vectorActionDescriptions: [] + vectorActionSpaceType: 1 + m_Model: {fileID: 11400000, guid: 4e86a19e012da43bfa5ab97ae8089b98, type: 3} + m_InferenceDevice: 0 + m_BehaviorType: 0 + m_BehaviorName: Walker + TeamId: 0 + m_UseChildSensors: 1 + m_UseChildActuators: 1 + m_ObservableAttributeHandling: 0 +--- !u!114 &8153975132613398210 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 5381908961062339374} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: 88b6042bc9a5d4aa58d931eae49442e5, type: 3} + m_Name: + m_EditorClassIdentifier: + agentParameters: + maxStep: 0 + hasUpgradedFromAgentParameters: 1 + MaxStep: 0 diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/old_serialized_agent.prefab.meta b/com.unity.ml-agents/Tests/Editor/TestModels/old_serialized_agent.prefab.meta new file mode 100644 index 0000000000..2d444fbbd2 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TestModels/old_serialized_agent.prefab.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: 07e2de94fa37b49c3bdf7a21857c0f73 +PrefabImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/TimerTest.cs b/com.unity.ml-agents/Tests/Editor/TimerTest.cs new file mode 100644 index 0000000000..58fc048c91 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TimerTest.cs @@ -0,0 +1,86 @@ +using NUnit.Framework; + +namespace Unity.MLAgents.Tests +{ + public class TimerTests + { + [Test] + public void TestNested() + { + TimerStack myTimer = TimerStack.Instance; + myTimer.Reset(); + using (myTimer.Scoped("foo")) + { + for (int i = 0; i < 5; i++) + { + using (myTimer.Scoped("bar")) + { + myTimer.SetGauge("my_gauge", i); + myTimer.AddMetadata("i", $"{i}"); + } + } + } + + var rootChildren = myTimer.RootNode.Children; + Assert.That(rootChildren, Contains.Key("foo")); + Assert.AreEqual(rootChildren["foo"].NumCalls, 1); + var gauge = myTimer.RootNode.Gauges["my_gauge"]; + Assert.NotNull(gauge); + Assert.AreEqual(5, gauge.count); + Assert.AreEqual(0, gauge.minValue); + Assert.AreEqual(4, gauge.maxValue); + Assert.AreEqual(4, gauge.value); + Assert.AreEqual("4", myTimer.RootNode.Metadata["i"]); + + var fooChildren = rootChildren["foo"].Children; + Assert.That(fooChildren, Contains.Key("bar")); + Assert.AreEqual(fooChildren["bar"].NumCalls, 5); + + myTimer.Reset(); + Assert.AreEqual(myTimer.RootNode.Children, null); + } + + [Test] + public void TestGauges() + { + TimerStack myTimer = TimerStack.Instance; + myTimer.Reset(); + + // Simple test - adding 1's should keep that for the weighted and running averages. + myTimer.SetGauge("one", 1.0f); + var oneNode = myTimer.RootNode.Gauges["one"]; + Assert.AreEqual(oneNode.weightedAverage, 1.0f); + Assert.AreEqual(oneNode.runningAverage, 1.0f); + + for (int i = 0; i < 10; i++) + { + myTimer.SetGauge("one", 1.0f); + } + + Assert.AreEqual(oneNode.weightedAverage, 1.0f); + Assert.AreEqual(oneNode.runningAverage, 1.0f); + + // Try some more interesting values + myTimer.SetGauge("increasing", 1.0f); + myTimer.SetGauge("increasing", 2.0f); + myTimer.SetGauge("increasing", 3.0f); + + myTimer.SetGauge("decreasing", 3.0f); + myTimer.SetGauge("decreasing", 2.0f); + myTimer.SetGauge("decreasing", 1.0f); + var increasingNode = myTimer.RootNode.Gauges["increasing"]; + var decreasingNode = myTimer.RootNode.Gauges["decreasing"]; + + // Expect the running average to be (roughly) the same, + // but weighted averages will be biased differently. + Assert.AreEqual(increasingNode.runningAverage, 2.0f); + Assert.AreEqual(decreasingNode.runningAverage, 2.0f); + + // The older values are actually weighted more heavily, so we expect the + // increasing series to have a lower moving average. + Assert.Less(increasingNode.weightedAverage, decreasingNode.weightedAverage); + + + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/TimerTest.cs.meta b/com.unity.ml-agents/Tests/Editor/TimerTest.cs.meta new file mode 100644 index 0000000000..9824d01d30 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TimerTest.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 506de2f6a1c74967a6f16ebf494c01d5 +timeCreated: 1569370981 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs b/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs new file mode 100644 index 0000000000..d71abf6555 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs @@ -0,0 +1,65 @@ +using System; +using System.Linq; +using System.Text; +using NUnit.Framework; +using Google.Protobuf; +using Unity.MLAgents.Analytics; +using Unity.MLAgents.SideChannels; +using Unity.MLAgents.CommunicatorObjects; + + +namespace Unity.MLAgents.Tests +{ + /// + /// These tests send messages through the event handling code. + /// There's no output to test, so just make sure there are no exceptions + /// (and get the code coverage above the minimum). + /// + public class TrainingAnalyticsSideChannelTests + { + [Test] + public void TestTrainingEnvironmentReceived() + { + var anyMsg = Google.Protobuf.WellKnownTypes.Any.Pack(new TrainingEnvironmentInitialized()); + var anyMsgBytes = anyMsg.ToByteArray(); + var sideChannel = new TrainingAnalyticsSideChannel(); + using (new AnalyticsUtils.DisableAnalyticsSending()) + { + sideChannel.ProcessMessage(anyMsgBytes); + } + } + + [Test] + public void TestTrainingBehaviorReceived() + { + var anyMsg = Google.Protobuf.WellKnownTypes.Any.Pack(new TrainingBehaviorInitialized()); + var anyMsgBytes = anyMsg.ToByteArray(); + var sideChannel = new TrainingAnalyticsSideChannel(); + using (new AnalyticsUtils.DisableAnalyticsSending()) + { + sideChannel.ProcessMessage(anyMsgBytes); + } + } + + [Test] + public void TestInvalidProtobufMessage() + { + // Test an invalid (non-protobuf) message. This should silently ignore the data. + var badBytes = Encoding.ASCII.GetBytes("Lorem ipsum"); + var sideChannel = new TrainingAnalyticsSideChannel(); + using (new AnalyticsUtils.DisableAnalyticsSending()) + { + sideChannel.ProcessMessage(badBytes); + } + + // Test an almost-valid message. This should silently ignore the data. + var anyMsg = Google.Protobuf.WellKnownTypes.Any.Pack(new TrainingBehaviorInitialized()); + var anyMsgBytes = anyMsg.ToByteArray(); + var truncatedMessage = new ArraySegment(anyMsgBytes, 0, anyMsgBytes.Length - 1).ToArray(); + using (new AnalyticsUtils.DisableAnalyticsSending()) + { + sideChannel.ProcessMessage(truncatedMessage); + } + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs.meta b/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs.meta new file mode 100644 index 0000000000..ebb5915235 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: c2a71036ddec4ba4bf83c5e8ba1b8daa +timeCreated: 1610574895 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Unity.ML-Agents.Editor.Tests.asmdef b/com.unity.ml-agents/Tests/Editor/Unity.ML-Agents.Editor.Tests.asmdef new file mode 100755 index 0000000000..128105b400 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Unity.ML-Agents.Editor.Tests.asmdef @@ -0,0 +1,47 @@ +{ + "name": "Unity.ML-Agents.Editor.Tests", + "references": [ + "Unity.ML-Agents.Editor", + "Unity.ML-Agents", + "Unity.Barracuda", + "Unity.Mathematics", + "Unity.ML-Agents.CommunicatorObjects", + "Unity.ML-Agents.Runtime.Utils.Tests", + "Unity.ML-Agents.Runtime.Sensor.Tests" + ], + "optionalUnityReferences": [ + "TestAssemblies" + ], + "includePlatforms": [ + "Editor" + ], + "excludePlatforms": [], + "allowUnsafeCode": false, + "overrideReferences": true, + "precompiledReferences": [ + "System.IO.Abstractions.dll", + "System.IO.Abstractions.TestingHelpers.dll", + "Google.Protobuf.dll" + ], + "autoReferenced": false, + "defineConstraints": [ + "UNITY_INCLUDE_TESTS" + ], + "versionDefines": [ + { + "name": "com.unity.modules.unityanalytics", + "expression": "1.0.0", + "define": "MLA_UNITY_ANALYTICS_MODULE" + }, + { + "name": "com.unity.modules.physics", + "expression": "1.0.0", + "define": "MLA_UNITY_PHYSICS_MODULE" + }, + { + "name": "com.unity.modules.physics2d", + "expression": "1.0.0", + "define": "MLA_UNITY_PHYSICS2D_MODULE" + } + ] +} diff --git a/com.unity.ml-agents/Tests/Editor/Unity.ML-Agents.Editor.Tests.asmdef.meta b/com.unity.ml-agents/Tests/Editor/Unity.ML-Agents.Editor.Tests.asmdef.meta new file mode 100644 index 0000000000..db46168638 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Unity.ML-Agents.Editor.Tests.asmdef.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: 57f6004a925b546cd94e94ed518e275d +AssemblyDefinitionImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/UtilitiesTests.cs b/com.unity.ml-agents/Tests/Editor/UtilitiesTests.cs new file mode 100644 index 0000000000..de90a6797a --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/UtilitiesTests.cs @@ -0,0 +1,23 @@ +using NUnit.Framework; + +namespace Unity.MLAgents.Tests +{ + public class UtilitiesTests + { + [Test] + public void TestCumSum() + { + var output = Utilities.CumSum(new[] { 1, 2, 3, 10 }); + CollectionAssert.AreEqual(output, new[] { 0, 1, 3, 6, 16 }); + + output = Utilities.CumSum(new int[0]); + CollectionAssert.AreEqual(output, new[] { 0 }); + + output = Utilities.CumSum(new[] { 100 }); + CollectionAssert.AreEqual(output, new[] { 0, 100 }); + + output = Utilities.CumSum(new[] { -1, 10 }); + CollectionAssert.AreEqual(output, new[] { 0, -1, 9 }); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/UtilitiesTests.cs.meta b/com.unity.ml-agents/Tests/Editor/UtilitiesTests.cs.meta new file mode 100644 index 0000000000..a0be9a722d --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/UtilitiesTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 45ab7fc6851444d8ba622b4f63b8290b +timeCreated: 1538775063 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Runtime.meta b/com.unity.ml-agents/Tests/Runtime.meta new file mode 100644 index 0000000000..1260b3a158 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: fec0f6b603d3046f086e09ea2a44ed1f +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs new file mode 100644 index 0000000000..743212a69b --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs @@ -0,0 +1,135 @@ +using System.Collections; +using System.Collections.Generic; +using Unity.MLAgents; +using Unity.MLAgents.Policies; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; +using NUnit.Framework; +using Unity.MLAgents.Actuators; +using UnityEngine; +using UnityEngine.TestTools; + +namespace Tests +{ + public class PublicApiAgent : Agent + { + public int numHeuristicCalls; + + [Observable] + public float ObservableFloat; + + public override void Heuristic(in ActionBuffers actionsOut) + { + numHeuristicCalls++; + base.Heuristic(actionsOut); + } + } + + // Simple SensorComponent that sets up a StackingSensor + public class StackingComponent : SensorComponent + { + public SensorComponent wrappedComponent; + public int numStacks; + + public override ISensor[] CreateSensors() + { + var wrappedSensors = wrappedComponent.CreateSensors(); + var sensorsOut = new ISensor[wrappedSensors.Length]; + for (var i = 0; i < wrappedSensors.Length; i++) + { + sensorsOut[i] = new StackingSensor(wrappedSensors[i], numStacks); + } + + return sensorsOut; + } + } + + public class RuntimeApiTest + { + [SetUp] + public static void Setup() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + Academy.Instance.AutomaticSteppingEnabled = false; + } + + [UnityTest] + public IEnumerator RuntimeApiTestWithEnumeratorPasses() + { + Academy.Instance.InferenceSeed = 1337; + var gameObject = new GameObject(); + + var behaviorParams = gameObject.AddComponent(); + behaviorParams.BrainParameters.VectorObservationSize = 3; + behaviorParams.BrainParameters.NumStackedVectorObservations = 2; + behaviorParams.BrainParameters.VectorActionDescriptions = new[] { "Continuous1", "TestActionA", "TestActionB" }; + behaviorParams.BrainParameters.ActionSpec = new ActionSpec(1, new[] { 2, 2 }); + behaviorParams.BehaviorName = "TestBehavior"; + behaviorParams.TeamId = 42; + behaviorParams.UseChildSensors = true; + behaviorParams.DeterministicInference = false; + behaviorParams.ObservableAttributeHandling = ObservableAttributeOptions.ExamineAll; + + + // Can't actually create an Agent with InferenceOnly and no model, so change back + behaviorParams.BehaviorType = BehaviorType.Default; +#if MLA_UNITY_PHYSICS_MODULE + var sensorComponent = gameObject.AddComponent(); + sensorComponent.SensorName = "ray3d"; + sensorComponent.DetectableTags = new List { "Player", "Respawn" }; + sensorComponent.RaysPerDirection = 3; + + // Make a StackingSensor that wraps the RayPerceptionSensorComponent3D + // This isn't necessarily practical, just to ensure that it can be done + var wrappingSensorComponent = gameObject.AddComponent(); + wrappingSensorComponent.wrappedComponent = sensorComponent; + wrappingSensorComponent.numStacks = 3; + + // ISensor isn't set up yet. + Assert.IsNull(sensorComponent.RaySensor); +#endif + + + // Make sure we can set the behavior type correctly after the agent is initialized + // (this creates a new policy). + behaviorParams.BehaviorType = BehaviorType.HeuristicOnly; + + // Agent needs to be added after everything else is setup. + var agent = gameObject.AddComponent(); + + // DecisionRequester has to be added after Agent. + var decisionRequester = gameObject.AddComponent(); + decisionRequester.DecisionPeriod = 2; + decisionRequester.TakeActionsBetweenDecisions = true; + +#if MLA_UNITY_PHYSICS_MODULE + // Initialization should set up the sensors + Assert.IsNotNull(sensorComponent.RaySensor); +#endif + // Let's change the inference device + var otherDevice = behaviorParams.InferenceDevice == InferenceDevice.CPU ? InferenceDevice.GPU : InferenceDevice.CPU; + agent.SetModel(behaviorParams.BehaviorName, behaviorParams.Model, otherDevice); + + agent.AddReward(1.0f); + + // skip a frame. + yield return null; + + Academy.Instance.EnvironmentStep(); + + var actions = agent.GetStoredActionBuffers().DiscreteActions; + // default Heuristic implementation should return zero actions. + Assert.AreEqual(new ActionSegment(new[] { 0, 0 }), actions); + Assert.AreEqual(1, agent.numHeuristicCalls); + + Academy.Instance.EnvironmentStep(); + Assert.AreEqual(1, agent.numHeuristicCalls); + + Academy.Instance.EnvironmentStep(); + Assert.AreEqual(2, agent.numHeuristicCalls); + } + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs.meta b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs.meta new file mode 100644 index 0000000000..5f7821402b --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 17878576e4ed14b09875e37394e5ad90 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor.meta b/com.unity.ml-agents/Tests/Runtime/Sensor.meta new file mode 100644 index 0000000000..67c65573c6 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 782dd744bcb0744b8a880aa70c8c7421 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs new file mode 100644 index 0000000000..ea30ed63c2 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs @@ -0,0 +1,305 @@ +#if MLA_UNITY_PHYSICS_MODULE +using System.Collections.Generic; +using System.Reflection; +using NUnit.Framework; +using UnityEngine; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Tests +{ + internal class TestBoxOverlapChecker : BoxOverlapChecker + { + public TestBoxOverlapChecker( + Vector3 cellScale, + Vector3Int gridSize, + bool rotateWithAgent, + LayerMask colliderMask, + GameObject centerObject, + GameObject agentGameObject, + string[] detectableTags, + int initialColliderBufferSize, + int maxColliderBufferSize + ) : base( + cellScale, + gridSize, + rotateWithAgent, + colliderMask, + centerObject, + agentGameObject, + detectableTags, + initialColliderBufferSize, + maxColliderBufferSize) + { } + + public Vector3[] CellLocalPositions + { + get + { + return (Vector3[])typeof(BoxOverlapChecker).GetField("m_CellLocalPositions", + BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); + } + } + + public Collider[] ColliderBuffer + { + get + { + return (Collider[])typeof(BoxOverlapChecker).GetField("m_ColliderBuffer", + BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); + } + } + + public static TestBoxOverlapChecker CreateChecker( + float cellScaleX = 1f, + float cellScaleZ = 1f, + int gridSizeX = 10, + int gridSizeZ = 10, + bool rotateWithAgent = true, + GameObject centerObject = null, + GameObject agentGameObject = null, + string[] detectableTags = null, + int initialColliderBufferSize = 4, + int maxColliderBufferSize = 500) + { + return new TestBoxOverlapChecker( + new Vector3(cellScaleX, 0.01f, cellScaleZ), + new Vector3Int(gridSizeX, 1, gridSizeZ), + rotateWithAgent, + LayerMask.GetMask("Default"), + centerObject, + agentGameObject, + detectableTags, + initialColliderBufferSize, + maxColliderBufferSize); + } + } + + public class BoxOverlapCheckerTests + { + [Test] + public void TestCellLocalPosition() + { + var testGo = new GameObject("test"); + testGo.transform.position = Vector3.zero; + var boxOverlapSquare = TestBoxOverlapChecker.CreateChecker(gridSizeX: 10, gridSizeZ: 10, rotateWithAgent: false, agentGameObject: testGo); + + var localPos = boxOverlapSquare.CellLocalPositions; + Assert.AreEqual(new Vector3(-4.5f, 0, -4.5f), localPos[0]); + Assert.AreEqual(new Vector3(-4.5f, 0, 4.5f), localPos[9]); + Assert.AreEqual(new Vector3(4.5f, 0, -4.5f), localPos[90]); + Assert.AreEqual(new Vector3(4.5f, 0, 4.5f), localPos[99]); + Object.DestroyImmediate(testGo); + + var testGo2 = new GameObject("test"); + testGo2.transform.position = new Vector3(3.5f, 8f, 17f); // random, should have no effect on local positions + var boxOverlapRect = TestBoxOverlapChecker.CreateChecker(gridSizeX: 5, gridSizeZ: 15, rotateWithAgent: true, agentGameObject: testGo); + + localPos = boxOverlapRect.CellLocalPositions; + Assert.AreEqual(new Vector3(-2f, 0, -7f), localPos[0]); + Assert.AreEqual(new Vector3(-2f, 0, 7f), localPos[14]); + Assert.AreEqual(new Vector3(2f, 0, -7f), localPos[60]); + Assert.AreEqual(new Vector3(2f, 0, 7f), localPos[74]); + Object.DestroyImmediate(testGo2); + } + + [Test] + public void TestCellGlobalPositionNoRotate() + { + var testGo = new GameObject("test"); + var position = new Vector3(3.5f, 8f, 17f); + testGo.transform.position = position; + var boxOverlap = TestBoxOverlapChecker.CreateChecker(gridSizeX: 10, gridSizeZ: 10, rotateWithAgent: false, agentGameObject: testGo, centerObject: testGo); + + Assert.AreEqual(new Vector3(-4.5f, 0, -4.5f) + position, boxOverlap.GetCellGlobalPosition(0)); + Assert.AreEqual(new Vector3(-4.5f, 0, 4.5f) + position, boxOverlap.GetCellGlobalPosition(9)); + Assert.AreEqual(new Vector3(4.5f, 0, -4.5f) + position, boxOverlap.GetCellGlobalPosition(90)); + Assert.AreEqual(new Vector3(4.5f, 0, 4.5f) + position, boxOverlap.GetCellGlobalPosition(99)); + + testGo.transform.Rotate(0, 90, 0); // should have no effect on positions + Assert.AreEqual(new Vector3(-4.5f, 0, -4.5f) + position, boxOverlap.GetCellGlobalPosition(0)); + Assert.AreEqual(new Vector3(-4.5f, 0, 4.5f) + position, boxOverlap.GetCellGlobalPosition(9)); + Assert.AreEqual(new Vector3(4.5f, 0, -4.5f) + position, boxOverlap.GetCellGlobalPosition(90)); + Assert.AreEqual(new Vector3(4.5f, 0, 4.5f) + position, boxOverlap.GetCellGlobalPosition(99)); + + Object.DestroyImmediate(testGo); + } + + [Test] + public void TestCellGlobalPositionRotate() + { + var testGo = new GameObject("test"); + var position = new Vector3(15f, 6f, 13f); + testGo.transform.position = position; + var boxOverlap = TestBoxOverlapChecker.CreateChecker(gridSizeX: 5, gridSizeZ: 15, rotateWithAgent: true, agentGameObject: testGo, centerObject: testGo); + + Assert.AreEqual(new Vector3(-2f, 0, -7f) + position, boxOverlap.GetCellGlobalPosition(0)); + Assert.AreEqual(new Vector3(-2f, 0, 7f) + position, boxOverlap.GetCellGlobalPosition(14)); + Assert.AreEqual(new Vector3(2f, 0, -7f) + position, boxOverlap.GetCellGlobalPosition(60)); + Assert.AreEqual(new Vector3(2f, 0, 7f) + position, boxOverlap.GetCellGlobalPosition(74)); + + testGo.transform.Rotate(0, 90, 0); + // round to int to ignore numeric errors + Assert.AreEqual(Vector3Int.RoundToInt(new Vector3(-7f, 0, 2f) + position), Vector3Int.RoundToInt(boxOverlap.GetCellGlobalPosition(0))); + Assert.AreEqual(Vector3Int.RoundToInt(new Vector3(7f, 0, 2f) + position), Vector3Int.RoundToInt(boxOverlap.GetCellGlobalPosition(14))); + Assert.AreEqual(Vector3Int.RoundToInt(new Vector3(-7f, 0, -2f) + position), Vector3Int.RoundToInt(boxOverlap.GetCellGlobalPosition(60))); + Assert.AreEqual(Vector3Int.RoundToInt(new Vector3(7f, 0, -2f) + position), Vector3Int.RoundToInt(boxOverlap.GetCellGlobalPosition(74))); + + Object.DestroyImmediate(testGo); + } + + [Test] + public void TestBufferResize() + { + List testObjects = new List(); + var testGo = new GameObject("test"); + testGo.transform.position = Vector3.zero; + testObjects.Add(testGo); + var boxOverlap = TestBoxOverlapChecker.CreateChecker(agentGameObject: testGo, centerObject: testGo, initialColliderBufferSize: 2, maxColliderBufferSize: 5); + boxOverlap.Perceive(); + Assert.AreEqual(2, boxOverlap.ColliderBuffer.Length); + + for (var i = 0; i < 3; i++) + { + var boxGo = new GameObject("test"); + boxGo.transform.position = Vector3.zero; + boxGo.AddComponent(); + testObjects.Add(boxGo); + } + boxOverlap.Perceive(); + Assert.AreEqual(4, boxOverlap.ColliderBuffer.Length); + + for (var i = 0; i < 2; i++) + { + var boxGo = new GameObject("test"); + boxGo.transform.position = Vector3.zero; + boxGo.AddComponent(); + testObjects.Add(boxGo); + } + boxOverlap.Perceive(); + Assert.AreEqual(5, boxOverlap.ColliderBuffer.Length); + + Object.DestroyImmediate(testGo); + foreach (var go in testObjects) + { + Object.DestroyImmediate(go); + } + } + + [Test] + public void TestParseCollidersClosest() + { + var tag1 = "Player"; + List testObjects = new List(); + var testGo = new GameObject("test"); + testGo.transform.position = Vector3.zero; + var boxOverlap = TestBoxOverlapChecker.CreateChecker( + cellScaleX: 10f, + cellScaleZ: 10f, + gridSizeX: 2, + gridSizeZ: 2, + agentGameObject: testGo, + centerObject: testGo, + detectableTags: new [] { tag1 }); + var helper = new VerifyParseCollidersHelper(); + boxOverlap.GridOverlapDetectedClosest += helper.DetectedAction; + + for (var i = 0; i < 3; i++) + { + var boxGo = new GameObject("test"); + boxGo.transform.position = new Vector3(i + 1, 0, 1); + boxGo.AddComponent(); + boxGo.tag = tag1; + testObjects.Add(boxGo); + } + + boxOverlap.Perceive(); + helper.Verify(1, new List { testObjects[0] }); + + Object.DestroyImmediate(testGo); + foreach (var go in testObjects) + { + Object.DestroyImmediate(go); + } + } + + [Test] + public void TestParseCollidersAll() + { + var tag1 = "Player"; + List testObjects = new List(); + var testGo = new GameObject("test"); + testGo.transform.position = Vector3.zero; + var boxOverlap = TestBoxOverlapChecker.CreateChecker( + cellScaleX: 10f, + cellScaleZ: 10f, + gridSizeX: 2, + gridSizeZ: 2, + agentGameObject: testGo, + centerObject: testGo, + detectableTags: new [] { tag1 }); + var helper = new VerifyParseCollidersHelper(); + boxOverlap.GridOverlapDetectedAll += helper.DetectedAction; + + for (var i = 0; i < 3; i++) + { + var boxGo = new GameObject("test"); + boxGo.transform.position = new Vector3(i + 1, 0, 1); + boxGo.AddComponent(); + boxGo.tag = tag1; + testObjects.Add(boxGo); + } + + boxOverlap.Perceive(); + helper.Verify(3, testObjects); + + Object.DestroyImmediate(testGo); + foreach (var go in testObjects) + { + Object.DestroyImmediate(go); + } + } + + public class VerifyParseCollidersHelper + { + int m_NumInvoked; + List m_ParsedObjects = new List(); + + public void DetectedAction(GameObject go, int cellIndex) + { + m_NumInvoked += 1; + m_ParsedObjects.Add(go); + } + + public void Verify(int expectNumInvoke, List expectedObjects) + { + Assert.AreEqual(expectNumInvoke, m_NumInvoked); + Assert.AreEqual(expectedObjects.Count, m_ParsedObjects.Count); + foreach (var obj in expectedObjects) + { + Assert.Contains(obj, m_ParsedObjects); + } + } + } + + [Test] + public void TestOnlyOneChecker() + { + var testGo = new GameObject("test"); + testGo.transform.position = Vector3.zero; + var gridSensorComponent = testGo.AddComponent(); + gridSensorComponent.SetComponentParameters(useGridSensorBase: true, useTestingGridSensor: true); + var sensors = gridSensorComponent.CreateSensors(); + int numChecker = 0; + foreach (var sensor in sensors) + { + var gridsensor = (GridSensorBase)sensor; + if (gridsensor.m_GridPerception != null) + { + numChecker += 1; + } + } + Assert.AreEqual(1, numChecker); + } + } +} +#endif diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs.meta new file mode 100644 index 0000000000..d0b075e343 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 087f04a0f817c45f4a709ed36fe5ba1a +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs new file mode 100644 index 0000000000..a1ec30dd16 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs @@ -0,0 +1,80 @@ +using NUnit.Framework; +using UnityEngine; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class BufferSensorTest + { + [Test] + public void TestBufferSensor() + { + var bufferSensor = new BufferSensor(20, 4, "testName"); + var shape = bufferSensor.GetObservationSpec().Shape; + var dimProp = bufferSensor.GetObservationSpec().DimensionProperties; + Assert.AreEqual(shape[0], 20); + Assert.AreEqual(shape[1], 4); + Assert.AreEqual(shape.Length, 2); + Assert.AreEqual(dimProp[0], DimensionProperty.VariableSize); + Assert.AreEqual(dimProp[1], DimensionProperty.None); + Assert.AreEqual(dimProp.Length, 2); + + bufferSensor.AppendObservation(new float[] { 1, 2, 3, 4 }); + bufferSensor.AppendObservation(new float[] { 5, 6, 7, 8 }); + var obsWriter = new ObservationWriter(); + var obs = bufferSensor.GetObservationProto(obsWriter); + + Assert.AreEqual(shape, InplaceArray.FromList(obs.Shape)); + Assert.AreEqual(obs.DimensionProperties.Count, 2); + Assert.AreEqual((int)dimProp[0], obs.DimensionProperties[0]); + Assert.AreEqual((int)dimProp[1], obs.DimensionProperties[1]); + + for (int i = 0; i < 8; i++) + { + Assert.AreEqual(obs.FloatData.Data[i], i + 1); + } + for (int i = 8; i < 80; i++) + { + Assert.AreEqual(obs.FloatData.Data[i], 0); + } + } + + [Test] + public void TestBufferSensorComponent() + { + var agentGameObj = new GameObject("agent"); + var bufferComponent = agentGameObj.AddComponent(); + bufferComponent.MaxNumObservables = 20; + bufferComponent.ObservableSize = 4; + bufferComponent.SensorName = "TestName"; + + var sensor = bufferComponent.CreateSensors()[0]; + var shape = sensor.GetObservationSpec().Shape; + + Assert.AreEqual(shape[0], 20); + Assert.AreEqual(shape[1], 4); + Assert.AreEqual(shape.Length, 2); + + bufferComponent.AppendObservation(new float[] { 1, 2, 3, 4 }); + bufferComponent.AppendObservation(new float[] { 5, 6, 7, 8 }); + + var obsWriter = new ObservationWriter(); + var obs = sensor.GetObservationProto(obsWriter); + + Assert.AreEqual(shape, InplaceArray.FromList(obs.Shape)); + Assert.AreEqual(obs.DimensionProperties.Count, 2); + + Assert.AreEqual(sensor.GetName(), "TestName"); + + for (int i = 0; i < 8; i++) + { + Assert.AreEqual(obs.FloatData.Data[i], i + 1); + } + for (int i = 8; i < 80; i++) + { + Assert.AreEqual(obs.FloatData.Data[i], 0); + } + } + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs.meta new file mode 100644 index 0000000000..be1021246e --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 5267572aa66d34b49bbc65940674b2a6 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs new file mode 100644 index 0000000000..b71c89e82e --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs @@ -0,0 +1,53 @@ +using System; +using System.Reflection; +using NUnit.Framework; +using UnityEngine; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Tests +{ + + [TestFixture] + public class CameraSensorComponentTest + { + [Test] + public void TestCameraSensorComponent() + { + foreach (var grayscale in new[] { true, false }) + { + foreach (SensorCompressionType compression in Enum.GetValues(typeof(SensorCompressionType))) + { + var width = 24; + var height = 16; + var camera = Camera.main; + + var agentGameObj = new GameObject("agent"); + + var cameraComponent = agentGameObj.AddComponent(); + cameraComponent.Camera = camera; + cameraComponent.Height = height; + cameraComponent.Width = width; + cameraComponent.Grayscale = grayscale; + cameraComponent.CompressionType = compression; + cameraComponent.RuntimeCameraEnable = true; + + var sensor = cameraComponent.CreateSensors()[0]; + var expectedShape = new InplaceArray(height, width, grayscale ? 1 : 3); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); + Assert.AreEqual(typeof(CameraSensor), sensor.GetType()); + + var flags = BindingFlags.Instance | BindingFlags.NonPublic; + var runtimeCameraEnabled = (bool)typeof(CameraSensorComponent).GetField("m_RuntimeCameraEnable", flags).GetValue(cameraComponent); + Assert.True(runtimeCameraEnabled); + + // Make sure cleaning up the component cleans up the sensor too + cameraComponent.Dispose(); + var cameraComponentSensor = (CameraSensor)typeof(CameraSensorComponent).GetField("m_Sensor", flags).GetValue(cameraComponent); + Assert.IsNull(cameraComponentSensor); + var cameraTexture = (Texture2D)typeof(CameraSensor).GetField("m_Texture", flags).GetValue(sensor); + Assert.IsNull(cameraTexture); + } + } + } + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs.meta new file mode 100644 index 0000000000..f66b94c6b7 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 4c0b188faef38407e82223854fc8eaf5 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs new file mode 100644 index 0000000000..3a0bd17746 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs @@ -0,0 +1,57 @@ +using System; +using NUnit.Framework; +using UnityEngine; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Tests +{ + + [TestFixture] + public class CameraSensorTest + { + [Test] + public void TestCameraSensor() + { + foreach (var grayscale in new[] { true, false }) + { + foreach (SensorCompressionType compression in Enum.GetValues(typeof(SensorCompressionType))) + { + var width = 24; + var height = 16; + var camera = Camera.main; + var c = new GameObject(); + if (ReferenceEquals(null, camera)) + { + camera = c.AddComponent(); + } + var sensor = new CameraSensor(camera, width, height, grayscale, "TestCameraSensor", compression); + + var obsWriter = new ObservationWriter(); + var obs = sensor.GetObservationProto(obsWriter); + + Assert.AreEqual((int)compression, (int)obs.CompressionType); + var expectedShape = new[] { height, width, grayscale ? 1 : 3 }; + Assert.AreEqual(expectedShape, obs.Shape); + UnityEngine.Object.DestroyImmediate(c); + } + } + } + + [Test] + public void TestObservationType() + { + var width = 24; + var height = 16; + var camera = Camera.main; + var sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None); + var spec = sensor.GetObservationSpec(); + Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default); + sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.Default); + spec = sensor.GetObservationSpec(); + Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default); + sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.GoalSignal); + spec = sensor.GetObservationSpec(); + Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.GoalSignal); + } + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs.meta new file mode 100644 index 0000000000..6f505b3aa3 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: ccdfc5b4015c9465cb1e811375be971c +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/CompressionSpecTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/CompressionSpecTests.cs new file mode 100644 index 0000000000..95740a5c7c --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/CompressionSpecTests.cs @@ -0,0 +1,30 @@ +using NUnit.Framework; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class CompressionSpecTests + { + [Test] + public void TestIsTrivialMapping() + { + Assert.IsTrue(CompressionSpec.Default().IsTrivialMapping()); + + var spec = new CompressionSpec(SensorCompressionType.PNG, null); + Assert.AreEqual(spec.IsTrivialMapping(), true); + + spec = new CompressionSpec(SensorCompressionType.PNG, new[] { 0, 0, 0 }); + Assert.AreEqual(spec.IsTrivialMapping(), true); + + spec = new CompressionSpec(SensorCompressionType.PNG, new[] { 0, 1, 2, 3, 4 }); + Assert.AreEqual(spec.IsTrivialMapping(), true); + + spec = new CompressionSpec(SensorCompressionType.PNG, new[] { 1, 2, 3, 4, -1, -1 }); + Assert.AreEqual(spec.IsTrivialMapping(), false); + + spec = new CompressionSpec(SensorCompressionType.PNG, new[] { 0, 0, 0, 1, 1, 1 }); + Assert.AreEqual(spec.IsTrivialMapping(), false); + } + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/CompressionSpecTests.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/CompressionSpecTests.cs.meta new file mode 100644 index 0000000000..d9df7ce5b2 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/CompressionSpecTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: cd0990de0eb646b0b0531b91c840c9da +timeCreated: 1616030728 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/FloatVisualSensorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/FloatVisualSensorTests.cs new file mode 100644 index 0000000000..0f4a7ae1de --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/FloatVisualSensorTests.cs @@ -0,0 +1,106 @@ +using NUnit.Framework; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Tests +{ + public class Float2DSensor : ISensor + { + public int Width { get; } + public int Height { get; } + string m_Name; + private ObservationSpec m_ObservationSpec; + public float[,] floatData; + + public Float2DSensor(int width, int height, string name) + { + Width = width; + Height = height; + m_Name = name; + + m_ObservationSpec = ObservationSpec.Visual(height, width, 1); + floatData = new float[Height, Width]; + } + + public Float2DSensor(float[,] floatData, string name) + { + this.floatData = floatData; + Height = floatData.GetLength(0); + Width = floatData.GetLength(1); + m_Name = name; + m_ObservationSpec = ObservationSpec.Visual(Height, Width, 1); + } + + public string GetName() + { + return m_Name; + } + + public ObservationSpec GetObservationSpec() + { + return m_ObservationSpec; + } + + public byte[] GetCompressedObservation() + { + return null; + } + + public int Write(ObservationWriter writer) + { + using (TimerStack.Instance.Scoped("Float2DSensor.Write")) + { + for (var h = 0; h < Height; h++) + { + for (var w = 0; w < Width; w++) + { + writer[h, w, 0] = floatData[h, w]; + } + } + var numWritten = Height * Width; + return numWritten; + } + } + + public void Update() { } + public void Reset() { } + + public CompressionSpec GetCompressionSpec() + { + return CompressionSpec.Default(); + } + } + + public class FloatVisualSensorTests + { + [Test] + public void TestFloat2DSensorWrite() + { + var sensor = new Float2DSensor(3, 4, "floatsensor"); + for (var h = 0; h < 4; h++) + { + for (var w = 0; w < 3; w++) + { + sensor.floatData[h, w] = 3 * h + w; + } + } + + var output = new float[12]; + var writer = new ObservationWriter(); + writer.SetTarget(output, sensor.GetObservationSpec(), 0); + sensor.Write(writer); + for (var i = 0; i < 9; i++) + { + Assert.AreEqual(i, output[i]); + } + } + + [Test] + public void TestFloat2DSensorExternalData() + { + var data = new float[4, 3]; + var sensor = new Float2DSensor(data, "floatsensor"); + Assert.AreEqual(sensor.Height, 4); + Assert.AreEqual(sensor.Width, 3); + } + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/FloatVisualSensorTests.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/FloatVisualSensorTests.cs.meta new file mode 100644 index 0000000000..ea128787e0 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/FloatVisualSensorTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 49b7da14949a486b803e28ed32d91a09 +timeCreated: 1578093005 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTestUtils.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTestUtils.cs new file mode 100644 index 0000000000..548262be96 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTestUtils.cs @@ -0,0 +1,88 @@ +using NUnit.Framework; +using System; +using System.Linq; + +namespace Unity.MLAgents.Tests +{ + public static class GridObsTestUtils + { + /// + /// Utility function to duplicate an array into an array of arrays + /// + /// array to duplicate + /// number of times to duplicate + /// array of duplicated arrays + public static float[][] DuplicateArray(float[] array, int numCopies) + { + float[][] duplicated = new float[numCopies][]; + for (int i = 0; i < numCopies; i++) + { + duplicated[i] = array; + } + return duplicated; + } + + + /// + /// Asserts that the sub-arrays of the total array are equal to specific subarrays at specific subarray indicies and equal to a default everywhere else. + /// + /// Array containing all data of the grid observation. Is a concatenation of N subarrays all of the same length + /// The indicies to verify that differ from the default array + /// The sub arrays values that differ from the default array + /// The default value of a sub array + /// + /// If the total array is data from a 4x4x2 grid observation, total will be an array of size 32 and each sub array will have a size of 2. + /// Let 3 cells at indicies (0, 1), (2, 2), and (3, 0) with values ([.1, .5]), ([.9, .7]), ([0, .2]), respectively. + /// If the default values of cells are ([0, 0]) then the grid observation will be as follows: + /// [ [0, 0], [.1, .5], [ 0, 0 ], [0, 0], + /// [0, 0], [ 0, 0 ], [ 0, 0 ], [0, 0], + /// [0, 0], [ 0, 0 ], [.9, .7], [0, 0], + /// [0, .2], [ 0, 0 ], [ 0, 0 ], [0, 0] ] + /// + /// Which will make the total array will be the flattened array + /// total = [0, 0, .1, .5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, .9, .7, 0, 0, 0, .2, 0, 0, 0, 0, 0] + /// + /// The indicies of the activated cells in the flattened array will be 1, 10, and 12 + /// + /// So to verify that the total array is as expected, AssertSubarraysAtIndex should be called as + /// AssertSubarraysAtIndex( + /// total, + /// indicies = new int[] {1, 10, 12}, + /// expectedArrays = new float[][] { new float[] {.1, .5}, new float[] {.9, .7}, new float[] {0, .2}}, + /// expecedDefaultArray = new float[] {0, 0} + /// ) + /// + public static void AssertSubarraysAtIndex(float[] total, int[] indicies, float[][] expectedArrays, float[] expectedDefaultArray) + { + int totalIndex = 0; + int subIndex = 0; + int subarrayIndex = 0; + int lenOfData = expectedDefaultArray.Length; + int numArrays = total.Length / lenOfData; + for (int i = 0; i < numArrays; i++) + { + totalIndex = i * lenOfData; + + if (indicies.Contains(i)) + { + subarrayIndex = Array.IndexOf(indicies, i); + for (subIndex = 0; subIndex < lenOfData; subIndex++) + { + Assert.AreEqual(expectedArrays[subarrayIndex][subIndex], total[totalIndex], + "Expected " + expectedArrays[subarrayIndex][subIndex] + " at subarray index " + totalIndex + ", index = " + subIndex + " but was " + total[totalIndex]); + totalIndex++; + } + } + else + { + for (subIndex = 0; subIndex < lenOfData; subIndex++) + { + Assert.AreEqual(expectedDefaultArray[subIndex], total[totalIndex], + "Expected default value " + expectedDefaultArray[subIndex] + " at subarray index " + totalIndex + ", index = " + subIndex + " but was " + total[totalIndex]); + totalIndex++; + } + } + } + } + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTestUtils.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTestUtils.cs.meta new file mode 100644 index 0000000000..216e38aa9d --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTestUtils.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 53aa6ab552bca5d49a6d298f1e633717 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs new file mode 100644 index 0000000000..d718133f49 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs @@ -0,0 +1,207 @@ +#if MLA_UNITY_PHYSICS_MODULE +using System.Collections.Generic; +using System.Collections; +using System.Reflection; +using NUnit.Framework; +using UnityEngine; +using UnityEngine.TestTools; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Tests +{ + public class GridSensorTests + { + GameObject testGo; + GameObject boxGo; + SimpleTestGridSensorComponent gridSensorComponent; + + // Use built-in tags + const string k_Tag1 = "Player"; + const string k_Tag2 = "Respawn"; + + [UnitySetUp] + public IEnumerator SetupScene() + { + testGo = new GameObject("test"); + testGo.transform.position = Vector3.zero; + gridSensorComponent = testGo.AddComponent(); + + boxGo = new GameObject("block"); + boxGo.tag = k_Tag1; + boxGo.transform.position = new Vector3(3f, 0f, 3f); + boxGo.AddComponent(); + + TestGridSensorConfig.Reset(); + yield return null; + } + + [TearDown] + public void ClearScene() + { + Object.DestroyImmediate(boxGo); + Object.DestroyImmediate(testGo); + } + + [Test] + public void TestBufferSize() + { + testGo.tag = k_Tag2; + string[] tags = { k_Tag1, k_Tag2 }; + gridSensorComponent.SetComponentParameters(tags, gridSizeX: 3, gridSizeZ: 4, useTestingGridSensor: true); + TestGridSensorConfig.SetParameters(5, true, false); + var gridSensor = (SimpleTestGridSensor)gridSensorComponent.CreateSensors()[0]; + Assert.AreEqual(gridSensor.PerceptionBuffer.Length, 3 * 4 * 5); + } + + [Test] + public void TestInvalidSizeConfiguration() + { + testGo.tag = k_Tag2; + string[] tags = { k_Tag1, k_Tag2 }; + gridSensorComponent.SetComponentParameters(tags, gridSizeY: 10, useTestingGridSensor: true); + gridSensorComponent.CreateSensors(); // expect no exception + + gridSensorComponent.m_GridSize.y = 10; + Assert.Throws(() => + { + gridSensorComponent.CreateSensors(); + }); + } + + [Test] + public void TestInvalidCompressionConfiguration() + { + testGo.tag = k_Tag2; + string[] tags = { k_Tag1, k_Tag2 }; + gridSensorComponent.SetComponentParameters(tags, compression: SensorCompressionType.PNG, useTestingGridSensor: true); + + var gridSensor = (GridSensorBase)gridSensorComponent.CreateSensors()[0]; + LogAssert.Expect(LogType.Warning, $"Compression type {SensorCompressionType.PNG} is only supported with normalized data. " + + "The sensor will not compress the data."); + Assert.AreEqual(gridSensor.CompressionType, SensorCompressionType.None); + } + + [Test] + public void TestCreateSensor() + { + testGo.tag = k_Tag2; + string[] tags = { k_Tag1, k_Tag2 }; + gridSensorComponent.SetComponentParameters(tags, useGridSensorBase: true); + + gridSensorComponent.CreateSensors(); + var componentSensor = (List)typeof(GridSensorComponent).GetField("m_Sensors", + BindingFlags.Instance | BindingFlags.NonPublic).GetValue(gridSensorComponent); + Assert.AreEqual(componentSensor.Count, 1); + } + + [Test] + public void PerceiveNotSelf() + { + testGo.tag = k_Tag2; + + string[] tags = { k_Tag1, k_Tag2 }; + gridSensorComponent.SetComponentParameters(tags, useGridSensorBase: true); + var gridSensor = (GridSensorBase)gridSensorComponent.CreateSensors()[0]; + + gridSensor.Update(); + + int[] subarrayIndicies = new int[] { 77, 78, 87, 88 }; + float[][] expectedSubarrays = GridObsTestUtils.DuplicateArray(new float[] { 1 }, 4); + float[] expectedDefault = new float[] { 0 }; + GridObsTestUtils.AssertSubarraysAtIndex(gridSensor.PerceptionBuffer, subarrayIndicies, expectedSubarrays, expectedDefault); + } + + [Test] + public void TestReset() + { + testGo.tag = k_Tag2; + string[] tags = { k_Tag1, k_Tag2 }; + gridSensorComponent.SetComponentParameters(tags, useGridSensorBase: true); + TestGridSensorConfig.SetParameters(3, false, false); + var gridSensor = (GridSensorBase)gridSensorComponent.CreateSensors()[0]; + + gridSensor.Update(); + + int[] subarrayIndicies = new int[] { 77, 78, 87, 88 }; + float[][] expectedSubarrays = GridObsTestUtils.DuplicateArray(new float[] { 1 }, 4); + float[] expectedDefault = new float[] { 0 }; + GridObsTestUtils.AssertSubarraysAtIndex(gridSensor.PerceptionBuffer, subarrayIndicies, expectedSubarrays, expectedDefault); + Object.DestroyImmediate(boxGo); + + gridSensor.Update(); + + subarrayIndicies = new int[0]; + expectedSubarrays = new float[0][]; + GridObsTestUtils.AssertSubarraysAtIndex(gridSensor.PerceptionBuffer, subarrayIndicies, expectedSubarrays, expectedDefault); + } + + [Test] + public void TestOneHotSensor() + { + testGo.tag = k_Tag2; + string[] tags = { k_Tag1, k_Tag2 }; + gridSensorComponent.SetComponentParameters(tags, useOneHotTag: true); + var gridSensor = (OneHotGridSensor)gridSensorComponent.CreateSensors()[0]; + Assert.AreEqual(gridSensor.PerceptionBuffer.Length, 10 * 10 * 2); + + gridSensor.Update(); + + int[] subarrayIndicies = new int[] { 77, 78, 87, 88 }; + float[][] expectedSubarrays = GridObsTestUtils.DuplicateArray(new float[] { 1, 0 }, 4); + float[] expectedDefault = new float[] { 0, 0 }; + GridObsTestUtils.AssertSubarraysAtIndex(gridSensor.PerceptionBuffer, subarrayIndicies, expectedSubarrays, expectedDefault); + } + + [Test] + public void TestCustomSensorInvalidData() + { + testGo.tag = k_Tag2; + string[] tags = { k_Tag1, k_Tag2 }; + gridSensorComponent.SetComponentParameters(tags, compression: SensorCompressionType.PNG, useTestingGridSensor: true); + TestGridSensorConfig.SetParameters(5, true, false); + var gridSensor = (SimpleTestGridSensor)gridSensorComponent.CreateSensors()[0]; + + gridSensor.DummyData = new float[] { 1, 2, 3, 4, 5 }; + Assert.Throws(() => + { + gridSensor.Update(); + }); + } + + [Test] + public void TestMultipleSensors() + { + testGo.tag = k_Tag2; + string[] tags = { k_Tag1, k_Tag2 }; + gridSensorComponent.SetComponentParameters(tags, useOneHotTag: true, useGridSensorBase: true, useTestingGridSensor: true); + var gridSensors = gridSensorComponent.CreateSensors(); + Assert.IsNotNull(((GridSensorBase)gridSensors[0]).m_GridPerception); + Assert.IsNull(((GridSensorBase)gridSensors[1]).m_GridPerception); + Assert.IsNull(((GridSensorBase)gridSensors[2]).m_GridPerception); + } + + [Test] + public void TestNoSensors() + { + testGo.tag = k_Tag2; + string[] tags = { k_Tag1, k_Tag2 }; + gridSensorComponent.SetComponentParameters(tags); + Assert.Throws(() => + { + gridSensorComponent.CreateSensors(); + }); + } + + [Test] + public void TestStackedSensors() + { + testGo.tag = k_Tag2; + string[] tags = { k_Tag1, k_Tag2 }; + gridSensorComponent.SetComponentParameters(tags, useGridSensorBase: true); + gridSensorComponent.ObservationStacks = 3; + var sensors = gridSensorComponent.CreateSensors(); + Assert.IsInstanceOf(typeof(StackingSensor), sensors[0]); + } + } +} +#endif diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs.meta new file mode 100644 index 0000000000..e477ba8405 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 2a8626afd640a4edd942dac3d8d6bc85 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/ObservableAttributeTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/ObservableAttributeTests.cs new file mode 100644 index 0000000000..76f76ba037 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/ObservableAttributeTests.cs @@ -0,0 +1,361 @@ +using System; +using System.Collections.Generic; +using NUnit.Framework; +using UnityEngine; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class ObservableAttributeTests + { + public enum TestEnum + { + ValueA = -100, + ValueB = 1, + ValueC = 42, + } + + [Flags] + public enum TestFlags + { + FlagA = 1, + FlagB = 2, + FlagC = 4 + } + + class TestClass + { + // Non-observables + int m_NonObservableInt; + float m_NonObservableFloat; + + // + // Int + // + [Observable] + public int m_IntMember; + + int m_IntProperty; + + [Observable] + public int IntProperty + { + get => m_IntProperty; + set => m_IntProperty = value; + } + + // + // Float + // + [Observable("floatMember")] + public float m_FloatMember; + + float m_FloatProperty; + [Observable("floatProperty")] + public float FloatProperty + { + get => m_FloatProperty; + set => m_FloatProperty = value; + } + + // + // Bool + // + [Observable("boolMember")] + public bool m_BoolMember; + + bool m_BoolProperty; + [Observable("boolProperty")] + public bool BoolProperty + { + get => m_BoolProperty; + set => m_BoolProperty = value; + } + + // + // Vector2 + // + + [Observable("vector2Member")] + public Vector2 m_Vector2Member; + + Vector2 m_Vector2Property; + + [Observable("vector2Property")] + public Vector2 Vector2Property + { + get => m_Vector2Property; + set => m_Vector2Property = value; + } + + // + // Vector3 + // + [Observable("vector3Member")] + public Vector3 m_Vector3Member; + + Vector3 m_Vector3Property; + + [Observable("vector3Property")] + public Vector3 Vector3Property + { + get => m_Vector3Property; + set => m_Vector3Property = value; + } + + // + // Vector4 + // + + [Observable("vector4Member")] + public Vector4 m_Vector4Member; + + Vector4 m_Vector4Property; + + [Observable("vector4Property")] + public Vector4 Vector4Property + { + get => m_Vector4Property; + set => m_Vector4Property = value; + } + + // + // Quaternion + // + [Observable("quaternionMember")] + public Quaternion m_QuaternionMember; + + Quaternion m_QuaternionProperty; + + [Observable("quaternionProperty")] + public Quaternion QuaternionProperty + { + get => m_QuaternionProperty; + set => m_QuaternionProperty = value; + } + + // + // Enum + // + + [Observable("enumMember")] + public TestEnum m_EnumMember = TestEnum.ValueA; + + TestEnum m_EnumProperty = TestEnum.ValueC; + + [Observable("enumProperty")] + public TestEnum EnumProperty + { + get => m_EnumProperty; + set => m_EnumProperty = value; + } + + [Observable("badEnumMember")] + public TestEnum m_BadEnumMember = (TestEnum)1337; + + // + // Flags + // + [Observable("flagMember")] + public TestFlags m_FlagMember = TestFlags.FlagA; + + TestFlags m_FlagProperty = TestFlags.FlagB | TestFlags.FlagC; + + [Observable("flagProperty")] + public TestFlags FlagProperty + { + get => m_FlagProperty; + set => m_FlagProperty = value; + } + + } + + [Test] + public void TestGetObservableSensors() + { + var testClass = new TestClass(); + testClass.m_IntMember = 1; + testClass.IntProperty = 2; + + testClass.m_FloatMember = 1.1f; + testClass.FloatProperty = 1.2f; + + testClass.m_BoolMember = true; + testClass.BoolProperty = true; + + testClass.m_Vector2Member = new Vector2(2.0f, 2.1f); + testClass.Vector2Property = new Vector2(2.2f, 2.3f); + + testClass.m_Vector3Member = new Vector3(3.0f, 3.1f, 3.2f); + testClass.Vector3Property = new Vector3(3.3f, 3.4f, 3.5f); + + testClass.m_Vector4Member = new Vector4(4.0f, 4.1f, 4.2f, 4.3f); + testClass.Vector4Property = new Vector4(4.4f, 4.5f, 4.5f, 4.7f); + + testClass.m_Vector4Member = new Vector4(4.0f, 4.1f, 4.2f, 4.3f); + testClass.Vector4Property = new Vector4(4.4f, 4.5f, 4.5f, 4.7f); + + testClass.m_QuaternionMember = new Quaternion(5.0f, 5.1f, 5.2f, 5.3f); + testClass.QuaternionProperty = new Quaternion(5.4f, 5.5f, 5.5f, 5.7f); + + var sensors = ObservableAttribute.CreateObservableSensors(testClass, false); + + var sensorsByName = new Dictionary(); + foreach (var sensor in sensors) + { + sensorsByName[sensor.GetName()] = sensor; + } + + SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.m_IntMember"], new[] { 1.0f }); + SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.IntProperty"], new[] { 2.0f }); + + SensorTestHelper.CompareObservation(sensorsByName["floatMember"], new[] { 1.1f }); + SensorTestHelper.CompareObservation(sensorsByName["floatProperty"], new[] { 1.2f }); + + SensorTestHelper.CompareObservation(sensorsByName["boolMember"], new[] { 1.0f }); + SensorTestHelper.CompareObservation(sensorsByName["boolProperty"], new[] { 1.0f }); + + SensorTestHelper.CompareObservation(sensorsByName["vector2Member"], new[] { 2.0f, 2.1f }); + SensorTestHelper.CompareObservation(sensorsByName["vector2Property"], new[] { 2.2f, 2.3f }); + + SensorTestHelper.CompareObservation(sensorsByName["vector3Member"], new[] { 3.0f, 3.1f, 3.2f }); + SensorTestHelper.CompareObservation(sensorsByName["vector3Property"], new[] { 3.3f, 3.4f, 3.5f }); + + SensorTestHelper.CompareObservation(sensorsByName["vector4Member"], new[] { 4.0f, 4.1f, 4.2f, 4.3f }); + SensorTestHelper.CompareObservation(sensorsByName["vector4Property"], new[] { 4.4f, 4.5f, 4.5f, 4.7f }); + + SensorTestHelper.CompareObservation(sensorsByName["quaternionMember"], new[] { 5.0f, 5.1f, 5.2f, 5.3f }); + SensorTestHelper.CompareObservation(sensorsByName["quaternionProperty"], new[] { 5.4f, 5.5f, 5.5f, 5.7f }); + + // Actual ordering is B, C, A + SensorTestHelper.CompareObservation(sensorsByName["enumMember"], new[] { 0.0f, 0.0f, 1.0f }); + SensorTestHelper.CompareObservation(sensorsByName["enumProperty"], new[] { 0.0f, 1.0f, 0.0f }); + SensorTestHelper.CompareObservation(sensorsByName["badEnumMember"], new[] { 0.0f, 0.0f, 0.0f }); + + SensorTestHelper.CompareObservation(sensorsByName["flagMember"], new[] { 1.0f, 0.0f, 0.0f }); + SensorTestHelper.CompareObservation(sensorsByName["flagProperty"], new[] { 0.0f, 1.0f, 1.0f }); + } + + [Test] + public void TestGetTotalObservationSize() + { + var testClass = new TestClass(); + var errors = new List(); + var expectedObsSize = 2 * ( // two fields each of these + 1 // int + + 1 // float + + 1 // bool + + 2 // vector2 + + 3 // vector3 + + 4 // vector4 + + 4 // quaternion + + 3 // TestEnum - 3 values + + 3 // TestFlags - 3 values + ) + + 3; // TestEnum with bad value + Assert.AreEqual(expectedObsSize, ObservableAttribute.GetTotalObservationSize(testClass, false, errors)); + Assert.AreEqual(0, errors.Count); + } + + class BadClass + { + [Observable] + double m_Double; + + [Observable] + double DoubleProperty + { + get => m_Double; + set => m_Double = value; + } + + float m_WriteOnlyProperty; + + [Observable] + // No get property, so we shouldn't be able to make a sensor out of this. + public float WriteOnlyProperty + { + set => m_WriteOnlyProperty = value; + } + } + + [Test] + public void TestInvalidObservables() + { + var bad = new BadClass(); + bad.WriteOnlyProperty = 1.0f; + var errors = new List(); + Assert.AreEqual(0, ObservableAttribute.GetTotalObservationSize(bad, false, errors)); + Assert.AreEqual(3, errors.Count); + + // Should be able to safely generate sensors (and get nothing back) + var sensors = ObservableAttribute.CreateObservableSensors(bad, false); + Assert.AreEqual(0, sensors.Count); + } + + class StackingClass + { + [Observable(numStackedObservations: 2)] + public float FloatVal; + } + + [Test] + public void TestObservableAttributeStacking() + { + var c = new StackingClass(); + c.FloatVal = 1.0f; + var sensors = ObservableAttribute.CreateObservableSensors(c, false); + var sensor = sensors[0]; + Assert.AreEqual(typeof(StackingSensor), sensor.GetType()); + SensorTestHelper.CompareObservation(sensor, new[] { 0.0f, 1.0f }); + + sensor.Update(); + c.FloatVal = 3.0f; + SensorTestHelper.CompareObservation(sensor, new[] { 1.0f, 3.0f }); + + var errors = new List(); + Assert.AreEqual(2, ObservableAttribute.GetTotalObservationSize(c, false, errors)); + Assert.AreEqual(0, errors.Count); + } + + class BaseClass + { + [Observable("base")] + public float m_BaseField; + + [Observable("private")] + float m_PrivateField; + } + + class DerivedClass : BaseClass + { + [Observable("derived")] + float m_DerivedField; + } + + [Test] + public void TestObservableAttributeExcludeInherited() + { + var d = new DerivedClass(); + d.m_BaseField = 1.0f; + + // excludeInherited=false will get fields in the derived class, plus public and protected inherited fields + var sensorAll = ObservableAttribute.CreateObservableSensors(d, false); + Assert.AreEqual(2, sensorAll.Count); + // Note - actual order doesn't matter here, we can change this to use a HashSet if neeed. + Assert.AreEqual("derived", sensorAll[0].GetName()); + Assert.AreEqual("base", sensorAll[1].GetName()); + + // excludeInherited=true will only get fields in the derived class + var sensorsDerivedOnly = ObservableAttribute.CreateObservableSensors(d, true); + Assert.AreEqual(1, sensorsDerivedOnly.Count); + Assert.AreEqual("derived", sensorsDerivedOnly[0].GetName()); + + var b = new BaseClass(); + var baseSensors = ObservableAttribute.CreateObservableSensors(b, false); + Assert.AreEqual(2, baseSensors.Count); + } + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/ObservableAttributeTests.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/ObservableAttributeTests.cs.meta new file mode 100644 index 0000000000..611fdcfa12 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/ObservableAttributeTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 33d7912e6b3504412bd261b40e46df32 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/ObservationWriterTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/ObservationWriterTests.cs new file mode 100644 index 0000000000..bc805f954e --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/ObservationWriterTests.cs @@ -0,0 +1,99 @@ +using NUnit.Framework; +using Unity.Barracuda; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Inference; + + +namespace Unity.MLAgents.Tests +{ + public class ObservationWriterTests + { + [Test] + public void TestWritesToIList() + { + ObservationWriter writer = new ObservationWriter(); + var buffer = new[] { 0f, 0f, 0f }; + var shape = new InplaceArray(3); + + writer.SetTarget(buffer, shape, 0); + // Elementwise writes + writer[0] = 1f; + writer[2] = 2f; + Assert.AreEqual(new[] { 1f, 0f, 2f }, buffer); + + // Elementwise writes with offset + writer.SetTarget(buffer, shape, 1); + writer[0] = 3f; + Assert.AreEqual(new[] { 1f, 3f, 2f }, buffer); + + // AddList + writer.SetTarget(buffer, shape, 0); + writer.AddList(new[] { 4f, 5f }); + Assert.AreEqual(new[] { 4f, 5f, 2f }, buffer); + + // AddList with offset + writer.SetTarget(buffer, shape, 1); + writer.AddList(new[] { 6f, 7f }); + Assert.AreEqual(new[] { 4f, 6f, 7f }, buffer); + } + + [Test] + public void TestWritesToTensor() + { + ObservationWriter writer = new ObservationWriter(); + var t = new TensorProxy + { + valueType = TensorProxy.TensorType.FloatingPoint, + data = new Tensor(2, 3) + }; + + writer.SetTarget(t, 0, 0); + Assert.AreEqual(0f, t.data[0, 0]); + writer[0] = 1f; + Assert.AreEqual(1f, t.data[0, 0]); + + writer.SetTarget(t, 1, 1); + writer[0] = 2f; + writer[1] = 3f; + // [0, 0] shouldn't change + Assert.AreEqual(1f, t.data[0, 0]); + Assert.AreEqual(2f, t.data[1, 1]); + Assert.AreEqual(3f, t.data[1, 2]); + + // AddList + t = new TensorProxy + { + valueType = TensorProxy.TensorType.FloatingPoint, + data = new Tensor(2, 3) + }; + + writer.SetTarget(t, 1, 1); + writer.AddList(new[] { -1f, -2f }); + Assert.AreEqual(0f, t.data[0, 0]); + Assert.AreEqual(0f, t.data[0, 1]); + Assert.AreEqual(0f, t.data[0, 2]); + Assert.AreEqual(0f, t.data[1, 0]); + Assert.AreEqual(-1f, t.data[1, 1]); + Assert.AreEqual(-2f, t.data[1, 2]); + } + + [Test] + public void TestWritesToTensor3D() + { + ObservationWriter writer = new ObservationWriter(); + var t = new TensorProxy + { + valueType = TensorProxy.TensorType.FloatingPoint, + data = new Tensor(2, 2, 2, 3) + }; + + writer.SetTarget(t, 0, 0); + writer[1, 0, 1] = 1f; + Assert.AreEqual(1f, t.data[0, 1, 0, 1]); + + writer.SetTarget(t, 0, 1); + writer[1, 0, 0] = 2f; + Assert.AreEqual(2f, t.data[0, 1, 0, 1]); + } + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/ObservationWriterTests.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/ObservationWriterTests.cs.meta new file mode 100644 index 0000000000..31e61311e9 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/ObservationWriterTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 3de9cbda816e4d7b907e765577dd54f7 +timeCreated: 1572568337 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/RayPerceptionSensorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/RayPerceptionSensorTests.cs new file mode 100644 index 0000000000..da2ce43f0d --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/RayPerceptionSensorTests.cs @@ -0,0 +1,445 @@ +using System.Collections.Generic; +using NUnit.Framework; +using UnityEngine; +using Unity.MLAgents.Sensors; +using UnityEngine.TestTools; + +namespace Unity.MLAgents.Tests +{ + public class RayPerceptionSensorTests + { + [Test] + public void TestGetRayAngles() + { + var anglesAlternating = RayPerceptionSensorComponentBase.GetRayAnglesAlternating(3, 90f); + var expectedAnglesAlternating = new[] { 90f, 60f, 120f, 30f, 150f, 0f, 180f }; + Assert.AreEqual(expectedAnglesAlternating.Length, anglesAlternating.Length); + for (var i = 0; i < anglesAlternating.Length; i++) + { + Assert.AreEqual(expectedAnglesAlternating[i], anglesAlternating[i], .01); + } + + var angles = RayPerceptionSensorComponentBase.GetRayAngles(3, 90f); + var expectedAngles = new[] { 0f, 30f, 60f, 90f, 120f, 150f, 180f }; + Assert.AreEqual(expectedAngles.Length, angles.Length); + for (var i = 0; i < angles.Length; i++) + { + Assert.AreEqual(expectedAngles[i], angles[i], .01); + } + } + } + + public class RayPerception3DTests + { + [Test] + public void TestDefaultLayersAreNegativeFive() + { +#if MLA_UNITY_PHYSICS_MODULE + Assert.IsTrue(Physics.DefaultRaycastLayers == -5); +#endif +#if MLA_UNITY_PHYSICS2D_MODULE + Assert.IsTrue(Physics2D.DefaultRaycastLayers == -5); +#endif + } + +#if MLA_UNITY_PHYSICS_MODULE + // Use built-in tags + const string k_CubeTag = "Player"; + const string k_SphereTag = "Respawn"; + + [TearDown] + public void RemoveGameObjects() + { + var objects = GameObject.FindObjectsOfType(); + foreach (var o in objects) + { + UnityEngine.Object.DestroyImmediate(o); + } + } + + void SetupScene() + { + /* Creates game objects in the world for testing. + * C is a cube + * S are spheres + * @ is the agent (at the origin) + * Each space or line is 5 world units, +x is right, +z is up + * + * C + * S S + * @ + * + * S + */ + var cube = GameObject.CreatePrimitive(PrimitiveType.Cube); + cube.transform.position = new Vector3(0, 0, 10); + cube.tag = k_CubeTag; + cube.name = "cube"; + + var sphere1 = GameObject.CreatePrimitive(PrimitiveType.Sphere); + sphere1.transform.position = new Vector3(-5, 0, 5); + sphere1.tag = k_SphereTag; + sphere1.name = "sphere1"; + + var sphere2 = GameObject.CreatePrimitive(PrimitiveType.Sphere); + sphere2.transform.position = new Vector3(5, 0, 5); + // No tag for sphere2 + sphere2.name = "sphere2"; + + var sphere3 = GameObject.CreatePrimitive(PrimitiveType.Sphere); + sphere3.transform.position = new Vector3(0, 0, -10); + sphere3.tag = k_SphereTag; + sphere3.name = "sphere3"; + + + Physics.SyncTransforms(); + } + + [Test] + public void TestRaycasts() + { + SetupScene(); + var obj = new GameObject("agent"); + var perception = obj.AddComponent(); + + perception.RaysPerDirection = 1; + perception.MaxRayDegrees = 45; + perception.RayLength = 20; + perception.DetectableTags = new List(); + perception.DetectableTags.Add(k_CubeTag); + perception.DetectableTags.Add(k_SphereTag); + + var radii = new[] { 0f, .5f }; + foreach (var castRadius in radii) + { + perception.SphereCastRadius = castRadius; + var sensor = perception.CreateSensors()[0]; + sensor.Update(); + + var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2); + Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs); + var outputBuffer = new float[expectedObs]; + + ObservationWriter writer = new ObservationWriter(); + writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0); + + var numWritten = sensor.Write(writer); + Assert.AreEqual(numWritten, expectedObs); + + // Expected hits: + // ray 0 should hit the cube at roughly halfway + // ray 1 should hit a sphere but no tag + // ray 2 should hit a sphere with the k_SphereTag tag + // The hit fraction should be the same for rays 1 and + // + Assert.AreEqual(1.0f, outputBuffer[0]); // hit cube + Assert.AreEqual(0.0f, outputBuffer[1]); // missed sphere + Assert.AreEqual(0.0f, outputBuffer[2]); // missed unknown tag + + // Hit is at z=9.0 in world space, ray length is 20 + Assert.That( + outputBuffer[3], Is.EqualTo((9.5f - castRadius) / perception.RayLength).Within(.0005f) + ); + + // Spheres are at 5,0,5 and 5,0,-5, so 5*sqrt(2) units from origin + // Minus 1.0 for the sphere radius to get the length of the hit. + var expectedHitLengthWorldSpace = 5.0f * Mathf.Sqrt(2.0f) - 0.5f - castRadius; + Assert.AreEqual(0.0f, outputBuffer[4]); // missed cube + Assert.AreEqual(0.0f, outputBuffer[5]); // missed sphere + Assert.AreEqual(0.0f, outputBuffer[6]); // hit unknown tag -> all 0 + Assert.That( + outputBuffer[7], Is.EqualTo(expectedHitLengthWorldSpace / perception.RayLength).Within(.0005f) + ); + + Assert.AreEqual(0.0f, outputBuffer[8]); // missed cube + Assert.AreEqual(1.0f, outputBuffer[9]); // hit sphere + Assert.AreEqual(0.0f, outputBuffer[10]); // missed unknown tag + Assert.That( + outputBuffer[11], Is.EqualTo(expectedHitLengthWorldSpace / perception.RayLength).Within(.0005f) + ); + } + } + + [Test] + public void TestRaycastMiss() + { + var obj = new GameObject("agent"); + var perception = obj.AddComponent(); + + perception.RaysPerDirection = 0; + perception.MaxRayDegrees = 45; + perception.RayLength = 20; + perception.DetectableTags = new List(); + perception.DetectableTags.Add(k_CubeTag); + perception.DetectableTags.Add(k_SphereTag); + + var sensor = perception.CreateSensors()[0]; + sensor.Update(); + var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2); + Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs); + var outputBuffer = new float[expectedObs]; + + ObservationWriter writer = new ObservationWriter(); + writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0); + + var numWritten = sensor.Write(writer); + Assert.AreEqual(numWritten, expectedObs); + + // Everything missed + Assert.AreEqual(new float[] { 0, 0, 1, 1 }, outputBuffer); + } + + [Test] + public void TestRayFilter() + { + var cube = GameObject.CreatePrimitive(PrimitiveType.Cube); + cube.transform.position = new Vector3(0, 0, 10); + cube.tag = k_CubeTag; + cube.name = "cubeFar"; + + var cubeFiltered = GameObject.CreatePrimitive(PrimitiveType.Cube); + cubeFiltered.transform.position = new Vector3(0, 0, 5); + cubeFiltered.tag = k_CubeTag; + cubeFiltered.name = "cubeNear"; + cubeFiltered.layer = 7; + + Physics.SyncTransforms(); + + var obj = new GameObject("agent"); + var perception = obj.AddComponent(); + perception.RaysPerDirection = 0; + perception.RayLength = 20; + perception.DetectableTags = new List(); + + var filterCubeLayers = new[] { false, true }; + foreach (var filterCubeLayer in filterCubeLayers) + { + // Set the layer mask to either the default, or one that ignores the close cube's layer + var layerMask = Physics.DefaultRaycastLayers; + if (filterCubeLayer) + { + layerMask &= ~(1 << cubeFiltered.layer); + } + perception.RayLayerMask = layerMask; + + var sensor = perception.CreateSensors()[0]; + sensor.Update(); + var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2); + Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs); + var outputBuffer = new float[expectedObs]; + + ObservationWriter writer = new ObservationWriter(); + writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0); + + var numWritten = sensor.Write(writer); + Assert.AreEqual(numWritten, expectedObs); + + if (filterCubeLayer) + { + // Hit the far cube because close was filtered. + Assert.That(outputBuffer[outputBuffer.Length - 1], + Is.EqualTo((9.5f - perception.SphereCastRadius) / perception.RayLength).Within(.0005f) + ); + } + else + { + // Hit the close cube because not filtered. + Assert.That(outputBuffer[outputBuffer.Length - 1], + Is.EqualTo((4.5f - perception.SphereCastRadius) / perception.RayLength).Within(.0005f) + ); + } + } + } + + [Test] + public void TestRaycastsScaled() + { + SetupScene(); + var obj = new GameObject("agent"); + var perception = obj.AddComponent(); + obj.transform.localScale = new Vector3(2, 2, 2); + + perception.RaysPerDirection = 0; + perception.MaxRayDegrees = 45; + perception.RayLength = 20; + perception.DetectableTags = new List(); + perception.DetectableTags.Add(k_CubeTag); + + var radii = new[] { 0f, .5f }; + foreach (var castRadius in radii) + { + perception.SphereCastRadius = castRadius; + var sensor = perception.CreateSensors()[0]; + sensor.Update(); + + var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2); + Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs); + var outputBuffer = new float[expectedObs]; + + ObservationWriter writer = new ObservationWriter(); + writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0); + + var numWritten = sensor.Write(writer); + Assert.AreEqual(numWritten, expectedObs); + + // Expected hits: + // ray 0 should hit the cube at roughly 1/4 way + // + Assert.AreEqual(1.0f, outputBuffer[0]); // hit cube + Assert.AreEqual(0.0f, outputBuffer[1]); // missed unknown tag + + // Hit is at z=9.0 in world space, ray length was 20 + // But scale increases the cast size and the ray length + var scaledRayLength = 2 * perception.RayLength; + var scaledCastRadius = 2 * castRadius; + Assert.That( + outputBuffer[2], Is.EqualTo((9.5f - scaledCastRadius) / scaledRayLength).Within(.0005f) + ); + } + } + + [Test] + public void TestRayZeroLength() + { + // Place the cube touching the origin + var cube = GameObject.CreatePrimitive(PrimitiveType.Cube); + cube.transform.position = new Vector3(0, 0, .5f); + cube.tag = k_CubeTag; + + Physics.SyncTransforms(); + + var obj = new GameObject("agent"); + var perception = obj.AddComponent(); + perception.RaysPerDirection = 0; + perception.RayLength = 0.0f; + perception.SphereCastRadius = .5f; + perception.DetectableTags = new List(); + perception.DetectableTags.Add(k_CubeTag); + + { + // Set the layer mask to either the default, or one that ignores the close cube's layer + + var sensor = perception.CreateSensors()[0]; + sensor.Update(); + var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2); + Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs); + var outputBuffer = new float[expectedObs]; + + ObservationWriter writer = new ObservationWriter(); + writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0); + + var numWritten = sensor.Write(writer); + Assert.AreEqual(numWritten, expectedObs); + + // hit fraction is arbitrary but should be finite in [0,1] + Assert.GreaterOrEqual(outputBuffer[2], 0.0f); + Assert.LessOrEqual(outputBuffer[2], 1.0f); + } + } + + [Test] + public void TestStaticPerceive() + { + SetupScene(); + var obj = new GameObject("agent"); + var perception = obj.AddComponent(); + + perception.RaysPerDirection = 0; // single ray + perception.MaxRayDegrees = 45; + perception.RayLength = 20; + perception.DetectableTags = new List(); + perception.DetectableTags.Add(k_CubeTag); + perception.DetectableTags.Add(k_SphereTag); + + var radii = new[] { 0f, .5f }; + foreach (var castRadius in radii) + { + perception.SphereCastRadius = castRadius; + var castInput = perception.GetRayPerceptionInput(); + var castOutput = RayPerceptionSensor.Perceive(castInput); + + Assert.AreEqual(1, castOutput.RayOutputs.Length); + + // Expected to hit the cube + Assert.AreEqual("cube", castOutput.RayOutputs[0].HitGameObject.name); + Assert.AreEqual(0, castOutput.RayOutputs[0].HitTagIndex); + } + } + + [Test] + public void TestStaticPerceiveInvalidTags() + { + SetupScene(); + var obj = new GameObject("agent"); + var perception = obj.AddComponent(); + + perception.RaysPerDirection = 0; // single ray + perception.MaxRayDegrees = 45; + perception.RayLength = 20; + perception.DetectableTags = new List(); + perception.DetectableTags.Add("Bad tag"); + perception.DetectableTags.Add(null); + perception.DetectableTags.Add(""); + perception.DetectableTags.Add(k_CubeTag); + + var radii = new[] { 0f, .5f }; + foreach (var castRadius in radii) + { + perception.SphereCastRadius = castRadius; + var castInput = perception.GetRayPerceptionInput(); + + // There's no clean way that I can find to check for a defined tag without + // logging an error. + LogAssert.Expect(LogType.Error, "Tag: Bad tag is not defined."); + var castOutput = RayPerceptionSensor.Perceive(castInput); + + Assert.AreEqual(1, castOutput.RayOutputs.Length); + + // Expected to hit the cube + Assert.AreEqual("cube", castOutput.RayOutputs[0].HitGameObject.name); + Assert.AreEqual(3, castOutput.RayOutputs[0].HitTagIndex); + } + } + + [Test] + public void TestStaticPerceiveNoTags() + { + SetupScene(); + var obj = new GameObject("agent"); + var perception = obj.AddComponent(); + + perception.RaysPerDirection = 0; // single ray + perception.MaxRayDegrees = 45; + perception.RayLength = 20; + perception.DetectableTags = null; + + var radii = new[] { 0f, .5f }; + foreach (var castRadius in radii) + { + perception.SphereCastRadius = castRadius; + var castInput = perception.GetRayPerceptionInput(); + var castOutput = RayPerceptionSensor.Perceive(castInput); + + Assert.AreEqual(1, castOutput.RayOutputs.Length); + + // Expected to hit the cube + Assert.AreEqual("cube", castOutput.RayOutputs[0].HitGameObject.name); + Assert.AreEqual(-1, castOutput.RayOutputs[0].HitTagIndex); + } + } + + [Test] + public void TestCreateDefault() + { + SetupScene(); + var obj = new GameObject("agent"); + var perception = obj.AddComponent(); + + Assert.DoesNotThrow(() => + { + perception.CreateSensors(); + }); + } +#endif + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/RayPerceptionSensorTests.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/RayPerceptionSensorTests.cs.meta new file mode 100644 index 0000000000..ae0be2c197 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/RayPerceptionSensorTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: d2983e2bca9a40398f287727dc0472a5 +timeCreated: 1573242741 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs new file mode 100644 index 0000000000..c4dcc93ef5 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs @@ -0,0 +1,38 @@ +using System; +using NUnit.Framework; +using UnityEngine; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class RenderTextureSensorComponentTest + { + [Test] + public void TestRenderTextureSensorComponent() + { + foreach (var grayscale in new[] { true, false }) + { + foreach (SensorCompressionType compression in Enum.GetValues(typeof(SensorCompressionType))) + { + var width = 24; + var height = 16; + var texture = new RenderTexture(width, height, 0); + + var agentGameObj = new GameObject("agent"); + + var renderTexComponent = agentGameObj.AddComponent(); + renderTexComponent.RenderTexture = texture; + renderTexComponent.Grayscale = grayscale; + renderTexComponent.CompressionType = compression; + + var expectedShape = new InplaceArray(height, width, grayscale ? 1 : 3); + + var sensor = renderTexComponent.CreateSensors()[0]; + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); + Assert.AreEqual(typeof(RenderTextureSensor), sensor.GetType()); + } + } + } + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs.meta new file mode 100644 index 0000000000..0e4c37fa29 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 6be53c3cd01244f179a58c96560c54cf +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorTests.cs new file mode 100644 index 0000000000..767dd50de5 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorTests.cs @@ -0,0 +1,50 @@ +using System; +using NUnit.Framework; +using UnityEngine; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class RenderTextureSensorTests + { + [Test] + public void TestRenderTextureSensor() + { + foreach (var grayscale in new[] { true, false }) + { + foreach (SensorCompressionType compression in Enum.GetValues(typeof(SensorCompressionType))) + { + var width = 24; + var height = 16; + var texture = new RenderTexture(width, height, 0); + var sensor = new RenderTextureSensor(texture, grayscale, "TestCameraSensor", compression); + + var obsWriter = new ObservationWriter(); + var obs = sensor.GetObservationProto(obsWriter); + + Assert.AreEqual((int)compression, (int)obs.CompressionType); + var expectedShape = new[] { height, width, grayscale ? 1 : 3 }; + Assert.AreEqual(expectedShape, obs.Shape); + } + } + } + + [Test] + public void TestObservationType() + { + var width = 24; + var height = 16; + var camera = Camera.main; + var sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None); + var spec = sensor.GetObservationSpec(); + Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default); + sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.Default); + spec = sensor.GetObservationSpec(); + Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default); + sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.GoalSignal); + spec = sensor.GetObservationSpec(); + Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.GoalSignal); + } + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorTests.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorTests.cs.meta new file mode 100644 index 0000000000..e73904f94f --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: be9f7d8ce17d8407e92d46fbee2ab809 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/SensorShapeValidatorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/SensorShapeValidatorTests.cs new file mode 100644 index 0000000000..84b69b6172 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/SensorShapeValidatorTests.cs @@ -0,0 +1,146 @@ +using System.Collections.Generic; +using System.Text.RegularExpressions; +using NUnit.Framework; +using UnityEngine; +using UnityEngine.TestTools; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Tests +{ + public class DummySensor : ISensor + { + string m_Name = "DummySensor"; + ObservationSpec m_ObservationSpec; + + public DummySensor(int dim1) + { + m_ObservationSpec = ObservationSpec.Vector(dim1); + } + + public DummySensor(int dim1, int dim2) + { + m_ObservationSpec = ObservationSpec.VariableLength(dim1, dim2); + } + + public DummySensor(int dim1, int dim2, int dim3) + { + m_ObservationSpec = ObservationSpec.Visual(dim1, dim2, dim3); + } + + public string GetName() + { + return m_Name; + } + + public ObservationSpec GetObservationSpec() + { + return m_ObservationSpec; + } + + public byte[] GetCompressedObservation() + { + return null; + } + + public int Write(ObservationWriter writer) + { + return this.ObservationSize(); + } + + public void Update() { } + public void Reset() { } + + public CompressionSpec GetCompressionSpec() + { + return CompressionSpec.Default(); + } + } + + public class SensorShapeValidatorTests + { + [Test] + public void TestShapesAgree() + { + var validator = new SensorShapeValidator(); + var sensorList1 = new List() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; + validator.ValidateSensors(sensorList1); + + var sensorList2 = new List() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; + validator.ValidateSensors(sensorList2); + } + + [Test] + public void TestNumSensorMismatch() + { + var validator = new SensorShapeValidator(); + var sensorList1 = new List() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; + validator.ValidateSensors(sensorList1); + + var sensorList2 = new List() { new DummySensor(1), new DummySensor(2, 3), }; + LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2"); + validator.ValidateSensors(sensorList2); + + // Add the sensors in the other order + validator = new SensorShapeValidator(); + validator.ValidateSensors(sensorList2); + LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3"); + validator.ValidateSensors(sensorList1); + } + + [Test] + public void TestDimensionMismatch() + { + var validator = new SensorShapeValidator(); + var sensorList1 = new List() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; + validator.ValidateSensors(sensorList1); + + var sensorList2 = new List() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5) }; + LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); + validator.ValidateSensors(sensorList2); + + // Add the sensors in the other order + validator = new SensorShapeValidator(); + validator.ValidateSensors(sensorList2); + LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); + validator.ValidateSensors(sensorList1); + } + + [Test] + public void TestSizeMismatch() + { + var validator = new SensorShapeValidator(); + var sensorList1 = new List() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; + validator.ValidateSensors(sensorList1); + + var sensorList2 = new List() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 7) }; + LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); + validator.ValidateSensors(sensorList2); + + // Add the sensors in the other order + validator = new SensorShapeValidator(); + validator.ValidateSensors(sensorList2); + LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); + validator.ValidateSensors(sensorList1); + } + + [Test] + public void TestEverythingMismatch() + { + var validator = new SensorShapeValidator(); + var sensorList1 = new List() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; + validator.ValidateSensors(sensorList1); + + var sensorList2 = new List() { new DummySensor(1), new DummySensor(9) }; + LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2"); + LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); + validator.ValidateSensors(sensorList2); + + // Add the sensors in the other order + validator = new SensorShapeValidator(); + validator.ValidateSensors(sensorList2); + LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3"); + LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); + validator.ValidateSensors(sensorList1); + } + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/SensorShapeValidatorTests.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/SensorShapeValidatorTests.cs.meta new file mode 100644 index 0000000000..c538a8a743 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/SensorShapeValidatorTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: bbfcd7a9de490454cbc37b8d7d900e7e +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/SensorTestHelper.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/SensorTestHelper.cs new file mode 100644 index 0000000000..208e0ade25 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/SensorTestHelper.cs @@ -0,0 +1,22 @@ +using NUnit.Framework; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Tests +{ + public static class SensorTestHelper + { + public static void CompareObservation(ISensor sensor, float[] expected) + { + string errorMessage; + bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage); + Assert.IsTrue(isOK, errorMessage); + } + + public static void CompareObservation(ISensor sensor, float[,,] expected) + { + string errorMessage; + bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage); + Assert.IsTrue(isOK, errorMessage); + } + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/SensorTestHelper.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/SensorTestHelper.cs.meta new file mode 100644 index 0000000000..487ace557e --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/SensorTestHelper.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e769354f8bd404ca180d7cd7302a5d61 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/SensorUtilTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/SensorUtilTests.cs new file mode 100644 index 0000000000..e186f800c4 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/SensorUtilTests.cs @@ -0,0 +1,56 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using NUnit.Framework; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Utils.Tests; + +namespace Unity.MLAgents.Tests +{ + + [TestFixture] + public class SensorUtilTests + { + internal class TempCulture : IDisposable + { + private CultureInfo m_OriginalCulture; + + internal TempCulture(CultureInfo newCulture) + { + m_OriginalCulture = CultureInfo.CurrentCulture; + CultureInfo.CurrentCulture = newCulture; + } + + public void Dispose() + { + CultureInfo.CurrentCulture = m_OriginalCulture; + } + } + + /// + /// Test that sensors sort by name consistently across culture settings. + /// Example strings and cultures taken from + /// https://docs.microsoft.com/en-us/globalization/locale/sorting-and-string-comparison + /// + /// + [TestCase("da-DK")] + [TestCase("en-US")] + public void TestSortCulture(string culture) + { + List sensors = new List(); + var sensor0 = new TestSensor("Apple"); + var sensor1 = new TestSensor("Æble"); + sensors.Add(sensor0); + sensors.Add(sensor1); + + var originalCulture = CultureInfo.CurrentCulture; + CultureInfo.CurrentCulture = new CultureInfo(culture); + SensorUtils.SortSensors(sensors); + CultureInfo.CurrentCulture = originalCulture; + + Assert.AreEqual(sensor1, sensors[0]); + Assert.AreEqual(sensor0, sensors[1]); + } + + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/SensorUtilTests.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/SensorUtilTests.cs.meta new file mode 100644 index 0000000000..c9e661ce9c --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/SensorUtilTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 929b34a718bc42c8aa75a3e1c8c11103 +timeCreated: 1617049947 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/SimpleTestGridSensor.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/SimpleTestGridSensor.cs new file mode 100644 index 0000000000..eb576d8f78 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/SimpleTestGridSensor.cs @@ -0,0 +1,141 @@ +using System.Collections.Generic; +using UnityEngine; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Tests +{ + public static class TestGridSensorConfig + { + public static int ObservationSize; + public static bool IsNormalized; + public static bool ParseAllColliders; + + public static void SetParameters(int observationSize, bool isNormalized, bool parseAllColliders) + { + ObservationSize = observationSize; + IsNormalized = isNormalized; + ParseAllColliders = parseAllColliders; + } + + public static void Reset() + { + ObservationSize = 0; + IsNormalized = false; + ParseAllColliders = false; + } + } + + public class SimpleTestGridSensor : GridSensorBase + { + public float[] DummyData; + + public SimpleTestGridSensor( + string name, + Vector3 cellScale, + Vector3Int gridSize, + string[] detectableTags, + SensorCompressionType compression + ) : base( + name, + cellScale, + gridSize, + detectableTags, + compression) + { } + + protected override int GetCellObservationSize() + { + return TestGridSensorConfig.ObservationSize; + } + + protected override bool IsDataNormalized() + { + return TestGridSensorConfig.IsNormalized; + } + + protected internal override ProcessCollidersMethod GetProcessCollidersMethod() + { + return TestGridSensorConfig.ParseAllColliders ? ProcessCollidersMethod.ProcessAllColliders : ProcessCollidersMethod.ProcessClosestColliders; + } + + protected override void GetObjectData(GameObject detectedObject, int typeIndex, float[] dataBuffer) + { + for (var i = 0; i < DummyData.Length; i++) + { + dataBuffer[i] = DummyData[i]; + } + } + } + + public class SimpleTestGridSensorComponent : GridSensorComponent + { + bool m_UseOneHotTag; + bool m_UseTestingGridSensor; + bool m_UseGridSensorBase; + + protected override GridSensorBase[] GetGridSensors() + { + List sensorList = new List(); + if (m_UseOneHotTag) + { + var testSensor = new OneHotGridSensor( + SensorName, + CellScale, + GridSize, + DetectableTags, + CompressionType + ); + sensorList.Add(testSensor); + } + if (m_UseGridSensorBase) + { + var testSensor = new GridSensorBase( + SensorName, + CellScale, + GridSize, + DetectableTags, + CompressionType + ); + sensorList.Add(testSensor); + } + if (m_UseTestingGridSensor) + { + var testSensor = new SimpleTestGridSensor( + SensorName, + CellScale, + GridSize, + DetectableTags, + CompressionType + ); + sensorList.Add(testSensor); + } + return sensorList.ToArray(); + } + + public void SetComponentParameters( + string[] detectableTags = null, + float cellScaleX = 1f, + float cellScaleZ = 1f, + int gridSizeX = 10, + int gridSizeY = 1, + int gridSizeZ = 10, + int colliderMaskInt = -1, + SensorCompressionType compression = SensorCompressionType.None, + bool rotateWithAgent = false, + bool useOneHotTag = false, + bool useTestingGridSensor = false, + bool useGridSensorBase = false + ) + { + DetectableTags = detectableTags; + CellScale = new Vector3(cellScaleX, 0.01f, cellScaleZ); + GridSize = new Vector3Int(gridSizeX, gridSizeY, gridSizeZ); + ColliderMask = colliderMaskInt < 0 ? LayerMask.GetMask("Default") : colliderMaskInt; + RotateWithAgent = rotateWithAgent; + CompressionType = compression; + m_UseOneHotTag = useOneHotTag; + m_UseGridSensorBase = useGridSensorBase; + m_UseTestingGridSensor = useTestingGridSensor; + } + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/SimpleTestGridSensor.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/SimpleTestGridSensor.cs.meta new file mode 100644 index 0000000000..233b23c8d5 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/SimpleTestGridSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 72f121528c63749688e53aa926f2cb0a +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs new file mode 100644 index 0000000000..5cd81ac2b0 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs @@ -0,0 +1,294 @@ +using NUnit.Framework; +using System; +using System.Linq; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Policies; +using UnityEngine; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Utils.Tests; + +namespace Unity.MLAgents.Tests +{ + public class StackingSensorTests + { + [SetUp] + public void SetUp() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + + Academy.Instance.AutomaticSteppingEnabled = false; + } + + [TearDown] + public void TearDown() + { + CommunicatorFactory.ClearCreator(); + } + + [Test] + public void TestCtor() + { + ISensor wrapped = new VectorSensor(4); + ISensor sensor = new StackingSensor(wrapped, 4); + Assert.AreEqual("StackingSensor_size4_VectorSensor_size4", sensor.GetName()); + Assert.AreEqual(sensor.GetObservationSpec().Shape, new InplaceArray(16)); + } + + [Test] + public void AssertStackingReset() + { + var agentGo1 = new GameObject("TestAgent"); + var bp1 = agentGo1.AddComponent(); + bp1.BrainParameters.NumStackedVectorObservations = 3; + bp1.BrainParameters.ActionSpec = ActionSpec.MakeContinuous(1); + var aca = Academy.Instance; + var agent1 = agentGo1.AddComponent(); + var policy = new TestPolicy(); + agent1.SetPolicy(policy); + + StackingSensor sensor = null; + foreach (ISensor s in agent1.sensors) + { + if (s is StackingSensor) + { + sensor = s as StackingSensor; + } + } + + Assert.NotNull(sensor); + + for (int i = 0; i < 20; i++) + { + agent1.RequestDecision(); + aca.EnvironmentStep(); + } + SensorTestHelper.CompareObservation(sensor, new[] { 18f, 19f, 20f }); + policy.OnRequestDecision = () => SensorTestHelper.CompareObservation(sensor, new[] { 19f, 20f, 21f }); + agent1.EndEpisode(); + policy.OnRequestDecision = () => { }; + SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f }); + for (int i = 0; i < 20; i++) + { + agent1.RequestDecision(); + aca.EnvironmentStep(); + SensorTestHelper.CompareObservation(sensor, new[] { Math.Max(0, i - 1f), i, i + 1 }); + } + } + + [Test] + public void TestVectorStacking() + { + VectorSensor wrapped = new VectorSensor(2); + StackingSensor sensor = new StackingSensor(wrapped, 3); + + wrapped.AddObservation(new[] { 1f, 2f }); + SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 1f, 2f }); + var data = sensor.GetStackedObservations(); + Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 0f, 0f, 0f, 0f, 1f, 2f })); + + sensor.Update(); + wrapped.AddObservation(new[] { 3f, 4f }); + SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 1f, 2f, 3f, 4f }); + data = sensor.GetStackedObservations(); + Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 0f, 0f, 1f, 2f, 3f, 4f })); + + sensor.Update(); + wrapped.AddObservation(new[] { 5f, 6f }); + SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f, 5f, 6f }); + data = sensor.GetStackedObservations(); + Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 1f, 2f, 3f, 4f, 5f, 6f })); + + sensor.Update(); + wrapped.AddObservation(new[] { 7f, 8f }); + SensorTestHelper.CompareObservation(sensor, new[] { 3f, 4f, 5f, 6f, 7f, 8f }); + data = sensor.GetStackedObservations(); + Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 3f, 4f, 5f, 6f, 7f, 8f })); + + sensor.Update(); + wrapped.AddObservation(new[] { 9f, 10f }); + SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f }); + data = sensor.GetStackedObservations(); + Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 5f, 6f, 7f, 8f, 9f, 10f })); + + // Check that if we don't call Update(), the same observations are produced + SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f }); + data = sensor.GetStackedObservations(); + Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 5f, 6f, 7f, 8f, 9f, 10f })); + } + + [Test] + public void TestVectorStackingReset() + { + VectorSensor wrapped = new VectorSensor(2); + ISensor sensor = new StackingSensor(wrapped, 3); + + wrapped.AddObservation(new[] { 1f, 2f }); + SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 1f, 2f }); + + sensor.Update(); + wrapped.AddObservation(new[] { 3f, 4f }); + SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 1f, 2f, 3f, 4f }); + + sensor.Reset(); + wrapped.AddObservation(new[] { 5f, 6f }); + SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 5f, 6f }); + } + + class Dummy3DSensor : ISensor + { + public SensorCompressionType CompressionType = SensorCompressionType.PNG; + public int[] Mapping; + public ObservationSpec ObservationSpec; + public float[,,] CurrentObservation; + + public ObservationSpec GetObservationSpec() + { + return ObservationSpec; + } + + public int Write(ObservationWriter writer) + { + for (var h = 0; h < ObservationSpec.Shape[0]; h++) + { + for (var w = 0; w < ObservationSpec.Shape[1]; w++) + { + for (var c = 0; c < ObservationSpec.Shape[2]; c++) + { + writer[h, w, c] = CurrentObservation[h, w, c]; + } + } + } + return ObservationSpec.Shape[0] * ObservationSpec.Shape[1] * ObservationSpec.Shape[2]; + } + + public byte[] GetCompressedObservation() + { + var writer = new ObservationWriter(); + var flattenedObservation = new float[ObservationSpec.Shape[0] * ObservationSpec.Shape[1] * ObservationSpec.Shape[2]]; + writer.SetTarget(flattenedObservation, ObservationSpec.Shape, 0); + Write(writer); + byte[] bytes = Array.ConvertAll(flattenedObservation, (z) => (byte)z); + return bytes; + } + + public void Update() { } + + public void Reset() { } + + public CompressionSpec GetCompressionSpec() + { + return new CompressionSpec(CompressionType, Mapping); + } + + public string GetName() + { + return "Dummy"; + } + } + + [Test] + public void TestStackingMapping() + { + // Test grayscale stacked mapping with CameraSensor + var cameraSensor = new CameraSensor(new Camera(), 64, 64, + true, "grayscaleCamera", SensorCompressionType.PNG); + var stackedCameraSensor = new StackingSensor(cameraSensor, 2); + Assert.AreEqual(stackedCameraSensor.GetCompressionSpec().CompressedChannelMapping, new[] { 0, 0, 0, 1, 1, 1 }); + + // Test RGB stacked mapping with RenderTextureSensor + var renderTextureSensor = new RenderTextureSensor(new RenderTexture(24, 16, 0), + false, "renderTexture", SensorCompressionType.PNG); + var stackedRenderTextureSensor = new StackingSensor(renderTextureSensor, 2); + Assert.AreEqual(stackedRenderTextureSensor.GetCompressionSpec().CompressedChannelMapping, new[] { 0, 1, 2, 3, 4, 5 }); + + // Test mapping with number of layers not being multiple of 3 + var dummySensor = new Dummy3DSensor(); + dummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4); + dummySensor.Mapping = new[] { 0, 1, 2, 3 }; + var stackedDummySensor = new StackingSensor(dummySensor, 2); + Assert.AreEqual(stackedDummySensor.GetCompressionSpec().CompressedChannelMapping, new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 }); + + // Test mapping with dummy layers that should be dropped + var paddedDummySensor = new Dummy3DSensor(); + paddedDummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4); + paddedDummySensor.Mapping = new[] { 0, 1, 2, 3, -1, -1 }; + var stackedPaddedDummySensor = new StackingSensor(paddedDummySensor, 2); + Assert.AreEqual(stackedPaddedDummySensor.GetCompressionSpec().CompressedChannelMapping, new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 }); + } + + [Test] + public void Test3DStacking() + { + var wrapped = new Dummy3DSensor(); + wrapped.ObservationSpec = ObservationSpec.Visual(2, 1, 2); + var sensor = new StackingSensor(wrapped, 2); + + // Check the stacking is on the last dimension + wrapped.CurrentObservation = new[, ,] { { { 1f, 2f } }, { { 3f, 4f } } }; + SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 0f, 0f, 1f, 2f } }, { { 0f, 0f, 3f, 4f } } }); + + sensor.Update(); + wrapped.CurrentObservation = new[, ,] { { { 5f, 6f } }, { { 7f, 8f } } }; + SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 1f, 2f, 5f, 6f } }, { { 3f, 4f, 7f, 8f } } }); + + sensor.Update(); + wrapped.CurrentObservation = new[, ,] { { { 9f, 10f } }, { { 11f, 12f } } }; + SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 5f, 6f, 9f, 10f } }, { { 7f, 8f, 11f, 12f } } }); + + // Check that if we don't call Update(), the same observations are produced + SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 5f, 6f, 9f, 10f } }, { { 7f, 8f, 11f, 12f } } }); + + // Test reset + sensor.Reset(); + wrapped.CurrentObservation = new[, ,] { { { 13f, 14f } }, { { 15f, 16f } } }; + SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 0f, 0f, 13f, 14f } }, { { 0f, 0f, 15f, 16f } } }); + } + + [Test] + public void TestStackedGetCompressedObservation() + { + var wrapped = new Dummy3DSensor(); + wrapped.ObservationSpec = ObservationSpec.Visual(1, 1, 3); + var sensor = new StackingSensor(wrapped, 2); + + wrapped.CurrentObservation = new[, ,] { { { 1f, 2f, 3f } } }; + var expected1 = sensor.CreateEmptyPNG(); + expected1 = expected1.Concat(Array.ConvertAll(new[] { 1f, 2f, 3f }, (z) => (byte)z)).ToArray(); + Assert.AreEqual(sensor.GetCompressedObservation(), expected1); + + sensor.Update(); + wrapped.CurrentObservation = new[, ,] { { { 4f, 5f, 6f } } }; + var expected2 = Array.ConvertAll(new[] { 1f, 2f, 3f, 4f, 5f, 6f }, (z) => (byte)z); + Assert.AreEqual(sensor.GetCompressedObservation(), expected2); + + sensor.Update(); + wrapped.CurrentObservation = new[, ,] { { { 7f, 8f, 9f } } }; + var expected3 = Array.ConvertAll(new[] { 4f, 5f, 6f, 7f, 8f, 9f }, (z) => (byte)z); + Assert.AreEqual(sensor.GetCompressedObservation(), expected3); + + // Test reset + sensor.Reset(); + wrapped.CurrentObservation = new[, ,] { { { 10f, 11f, 12f } } }; + var expected4 = sensor.CreateEmptyPNG(); + expected4 = expected4.Concat(Array.ConvertAll(new[] { 10f, 11f, 12f }, (z) => (byte)z)).ToArray(); + Assert.AreEqual(sensor.GetCompressedObservation(), expected4); + } + + [Test] + public void TestStackingSensorBuiltInSensorType() + { + var dummySensor = new Dummy3DSensor(); + dummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4); + dummySensor.Mapping = new[] { 0, 1, 2, 3 }; + var stackedDummySensor = new StackingSensor(dummySensor, 2); + Assert.AreEqual(stackedDummySensor.GetBuiltInSensorType(), BuiltInSensorType.Unknown); + + var vectorSensor = new VectorSensor(4); + var stackedVectorSensor = new StackingSensor(vectorSensor, 4); + Assert.AreEqual(stackedVectorSensor.GetBuiltInSensorType(), BuiltInSensorType.VectorSensor); + } + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs.meta new file mode 100644 index 0000000000..81723dd4cc --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 7b071fdf91474d18a05ea20175c6b3bd +timeCreated: 1572564843 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/Unity.ML-Agents.Runtime.Sensor.Tests.asmdef b/com.unity.ml-agents/Tests/Runtime/Sensor/Unity.ML-Agents.Runtime.Sensor.Tests.asmdef new file mode 100644 index 0000000000..d779451c40 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/Unity.ML-Agents.Runtime.Sensor.Tests.asmdef @@ -0,0 +1,37 @@ +{ + "name": "Unity.ML-Agents.Runtime.Sensor.Tests", + "references": [ + "Unity.ML-Agents", + "Unity.Barracuda", + "Unity.ML-Agents.CommunicatorObjects", + "Unity.ML-Agents.Runtime.Utils.Tests" + ], + "optionalUnityReferences": [ + "TestAssemblies" + ], + "includePlatforms": [], + "excludePlatforms": [], + "allowUnsafeCode": false, + "overrideReferences": true, + "precompiledReferences": [ + "System.IO.Abstractions.dll", + "System.IO.Abstractions.TestingHelpers.dll", + "Google.Protobuf.dll" + ], + "autoReferenced": false, + "defineConstraints": [ + "UNITY_INCLUDE_TESTS" + ], + "versionDefines": [ + { + "name": "com.unity.modules.physics", + "expression": "1.0.0", + "define": "MLA_UNITY_PHYSICS_MODULE" + }, + { + "name": "com.unity.modules.physics2d", + "expression": "1.0.0", + "define": "MLA_UNITY_PHYSICS2D_MODULE" + } + ] +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/Unity.ML-Agents.Runtime.Sensor.Tests.asmdef.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/Unity.ML-Agents.Runtime.Sensor.Tests.asmdef.meta new file mode 100644 index 0000000000..2b92b0bbb2 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/Unity.ML-Agents.Runtime.Sensor.Tests.asmdef.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 917f332e7ad944428f0246683188de8f +timeCreated: 1616137285 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs new file mode 100644 index 0000000000..f58606e99c --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs @@ -0,0 +1,131 @@ +using NUnit.Framework; +using UnityEngine; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Tests +{ + public class VectorSensorTests + { + [Test] + public void TestCtor() + { + ISensor sensor = new VectorSensor(4); + Assert.AreEqual("VectorSensor_size4", sensor.GetName()); + + sensor = new VectorSensor(3, "test_sensor"); + Assert.AreEqual("test_sensor", sensor.GetName()); + } + + [Test] + public void TestWrite() + { + var sensor = new VectorSensor(4); + sensor.AddObservation(1f); + sensor.AddObservation(2f); + sensor.AddObservation(3f); + sensor.AddObservation(4f); + + SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f }); + // Check that if we don't call Update(), the same observations are produced + SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f }); + + // Check that Update() clears the data + sensor.Update(); + SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f }); + } + + [Test] + public void TestAddObservationFloat() + { + var sensor = new VectorSensor(1); + sensor.AddObservation(1.2f); + SensorTestHelper.CompareObservation(sensor, new[] { 1.2f }); + } + + [Test] + public void TestObservationType() + { + var sensor = new VectorSensor(1); + var spec = sensor.GetObservationSpec(); + Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default); + sensor = new VectorSensor(1, observationType: ObservationType.Default); + spec = sensor.GetObservationSpec(); + Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default); + sensor = new VectorSensor(1, observationType: ObservationType.GoalSignal); + spec = sensor.GetObservationSpec(); + Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.GoalSignal); + } + + [Test] + public void TestAddObservationInt() + { + var sensor = new VectorSensor(1); + sensor.AddObservation(42); + SensorTestHelper.CompareObservation(sensor, new[] { 42f }); + } + + [Test] + public void TestAddObservationVec() + { + var sensor = new VectorSensor(3); + sensor.AddObservation(new Vector3(1, 2, 3)); + SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f }); + + sensor = new VectorSensor(2); + sensor.AddObservation(new Vector2(4, 5)); + SensorTestHelper.CompareObservation(sensor, new[] { 4f, 5f }); + } + + [Test] + public void TestAddObservationQuaternion() + { + var sensor = new VectorSensor(4); + sensor.AddObservation(Quaternion.identity); + SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 1f }); + } + + [Test] + public void TestWriteEnumerable() + { + var sensor = new VectorSensor(4); + sensor.AddObservation(new[] { 1f, 2f, 3f, 4f }); + + SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f }); + } + + [Test] + public void TestAddObservationBool() + { + var sensor = new VectorSensor(1); + sensor.AddObservation(true); + SensorTestHelper.CompareObservation(sensor, new[] { 1f }); + } + + [Test] + public void TestAddObservationOneHot() + { + var sensor = new VectorSensor(4); + sensor.AddOneHotObservation(2, 4); + SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 1f, 0f }); + } + + [Test] + public void TestWriteTooMany() + { + var sensor = new VectorSensor(2); + sensor.AddObservation(new[] { 1f, 2f, 3f, 4f }); + + SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f }); + } + + [Test] + public void TestWriteNotEnough() + { + var sensor = new VectorSensor(4); + sensor.AddObservation(new[] { 1f, 2f }); + + // Make sure extra zeros are added + SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 0f, 0f }); + } + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs.meta new file mode 100644 index 0000000000..05c14f9206 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 18c0d390ce4c5464ab48b96db0392eb0 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Unity.ML-Agents.Runtime.Tests.asmdef b/com.unity.ml-agents/Tests/Runtime/Unity.ML-Agents.Runtime.Tests.asmdef new file mode 100644 index 0000000000..b740695c2a --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Unity.ML-Agents.Runtime.Tests.asmdef @@ -0,0 +1,37 @@ +{ + "name": "Unity.ML-Agents.Runtime.Tests", + "references": [ + "Unity.ML-Agents", + "Unity.Barracuda", + "Unity.ML-Agents.CommunicatorObjects", + "Unity.ML-Agents.Editor" + ], + "optionalUnityReferences": [ + "TestAssemblies" + ], + "includePlatforms": [], + "excludePlatforms": [], + "allowUnsafeCode": false, + "overrideReferences": true, + "precompiledReferences": [ + "System.IO.Abstractions.dll", + "System.IO.Abstractions.TestingHelpers.dll", + "Google.Protobuf.dll" + ], + "autoReferenced": false, + "defineConstraints": [ + "UNITY_INCLUDE_TESTS" + ], + "versionDefines": [ + { + "name": "com.unity.modules.physics", + "expression": "1.0.0", + "define": "MLA_UNITY_PHYSICS_MODULE" + }, + { + "name": "com.unity.modules.physics2d", + "expression": "1.0.0", + "define": "MLA_UNITY_PHYSICS2D_MODULE" + } + ] +} diff --git a/com.unity.ml-agents/Tests/Runtime/Unity.ML-Agents.Runtime.Tests.asmdef.meta b/com.unity.ml-agents/Tests/Runtime/Unity.ML-Agents.Runtime.Tests.asmdef.meta new file mode 100644 index 0000000000..4fa9a793f6 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Unity.ML-Agents.Runtime.Tests.asmdef.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: d29014db7ebcd4cf4a14f537fbf02110 +AssemblyDefinitionImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Utils.meta b/com.unity.ml-agents/Tests/Runtime/Utils.meta new file mode 100644 index 0000000000..412c47bbf6 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Utils.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: e4c1564ff83a4d51855e5b5b461f6a59 +timeCreated: 1616137685 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Runtime/Utils/TestClasses.cs b/com.unity.ml-agents/Tests/Runtime/Utils/TestClasses.cs new file mode 100644 index 0000000000..f17b88279e --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Utils/TestClasses.cs @@ -0,0 +1,169 @@ +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Runtime.CompilerServices; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Policies; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; + +[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Runtime.Sensor.Tests")] +[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.EditorTests")] + +namespace Unity.MLAgents.Utils.Tests +{ + internal class TestPolicy : IPolicy + { + public Action OnRequestDecision; + ObservationWriter m_ObsWriter = new ObservationWriter(); + static ActionSpec s_ActionSpec = ActionSpec.MakeContinuous(1); + static ActionBuffers s_EmptyActionBuffers = new ActionBuffers(new float[1], Array.Empty()); + + public void RequestDecision(AgentInfo info, List sensors) + { + foreach (var sensor in sensors) + { + sensor.GetObservationProto(m_ObsWriter); + } + OnRequestDecision?.Invoke(); + } + + public ref readonly ActionBuffers DecideAction() { return ref s_EmptyActionBuffers; } + + public void Dispose() { } + } + + public class TestAgent : Agent + { + internal AgentInfo _Info + { + get + { + return (AgentInfo)typeof(Agent).GetField("m_Info", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); + } + set + { + typeof(Agent).GetField("m_Info", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(this, value); + } + } + + internal void SetPolicy(IPolicy policy) + { + typeof(Agent).GetField("m_Brain", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(this, policy); + } + + internal IPolicy GetPolicy() + { + return (IPolicy)typeof(Agent).GetField("m_Brain", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); + } + + public int initializeAgentCalls; + public int collectObservationsCalls; + public int collectObservationsCallsForEpisode; + public int agentActionCalls; + public int agentActionCallsForEpisode; + public int agentOnEpisodeBeginCalls; + public int heuristicCalls; + public TestSensor sensor1; + public TestSensor sensor2; + + [Observable("observableFloat")] + public float observableFloat; + + public override void Initialize() + { + initializeAgentCalls += 1; + + // Add in some custom Sensors so we can confirm they get sorted as expected. + sensor1 = new TestSensor("testsensor1"); + sensor2 = new TestSensor("testsensor2"); + sensor2.compressionType = SensorCompressionType.PNG; + + sensors.Add(sensor2); + sensors.Add(sensor1); + } + + public override void CollectObservations(VectorSensor sensor) + { + collectObservationsCalls += 1; + collectObservationsCallsForEpisode += 1; + sensor.AddObservation(collectObservationsCallsForEpisode); + } + + public override void OnActionReceived(ActionBuffers buffers) + { + agentActionCalls += 1; + agentActionCallsForEpisode += 1; + AddReward(0.1f); + } + + public override void OnEpisodeBegin() + { + agentOnEpisodeBeginCalls += 1; + collectObservationsCallsForEpisode = 0; + agentActionCallsForEpisode = 0; + } + + public override void Heuristic(in ActionBuffers actionsOut) + { + var obs = GetObservations(); + var continuousActions = actionsOut.ContinuousActions; + continuousActions[0] = (int)obs[0]; + heuristicCalls++; + } + } + + public class TestSensor : ISensor + { + public string sensorName; + public int numWriteCalls; + public int numCompressedCalls; + public int numResetCalls; + public SensorCompressionType compressionType = SensorCompressionType.None; + + public TestSensor(string n) + { + sensorName = n; + } + + public ObservationSpec GetObservationSpec() + { + return ObservationSpec.Vector(0); + } + + public int Write(ObservationWriter writer) + { + numWriteCalls++; + // No-op + return 0; + } + + public byte[] GetCompressedObservation() + { + numCompressedCalls++; + return new byte[] { 0 }; + } + + public CompressionSpec GetCompressionSpec() + { + return new CompressionSpec(compressionType); + } + + public string GetName() + { + return sensorName; + } + + public void Update() { } + + public void Reset() + { + numResetCalls++; + } + } + + public class TestClasses + { + } +} diff --git a/com.unity.ml-agents/Tests/Runtime/Utils/TestClasses.cs.meta b/com.unity.ml-agents/Tests/Runtime/Utils/TestClasses.cs.meta new file mode 100644 index 0000000000..a7cb7f2276 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Utils/TestClasses.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 329ca71a721948a9a64a7b3b48604058 +timeCreated: 1616137640 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Runtime/Utils/Unity.ML-Agents.Runtime.Utils.Tests.asmdef b/com.unity.ml-agents/Tests/Runtime/Utils/Unity.ML-Agents.Runtime.Utils.Tests.asmdef new file mode 100644 index 0000000000..0827203bcc --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Utils/Unity.ML-Agents.Runtime.Utils.Tests.asmdef @@ -0,0 +1,24 @@ +{ + "name": "Unity.ML-Agents.Runtime.Utils.Tests", + "references": [ + "Unity.ML-Agents", + "Unity.Barracuda", + "Unity.ML-Agents.CommunicatorObjects" + ], + "optionalUnityReferences": [ + "TestAssemblies" + ], + "includePlatforms": [], + "excludePlatforms": [], + "allowUnsafeCode": false, + "overrideReferences": true, + "precompiledReferences": [ + "System.IO.Abstractions.dll", + "System.IO.Abstractions.TestingHelpers.dll", + "Google.Protobuf.dll" + ], + "autoReferenced": false, + "defineConstraints": [ + "UNITY_INCLUDE_TESTS" + ] +} diff --git a/com.unity.ml-agents/Tests/Runtime/Utils/Unity.ML-Agents.Runtime.Utils.Tests.asmdef.meta b/com.unity.ml-agents/Tests/Runtime/Utils/Unity.ML-Agents.Runtime.Utils.Tests.asmdef.meta new file mode 100644 index 0000000000..cf380bec78 --- /dev/null +++ b/com.unity.ml-agents/Tests/Runtime/Utils/Unity.ML-Agents.Runtime.Utils.Tests.asmdef.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: daaa01096d9848a38185ee08fa7321c0 +timeCreated: 1616137694 \ No newline at end of file diff --git a/com.unity.ml-agents/Third Party Notices.md b/com.unity.ml-agents/Third Party Notices.md new file mode 100644 index 0000000000..fc124fc468 --- /dev/null +++ b/com.unity.ml-agents/Third Party Notices.md @@ -0,0 +1,381 @@ +This package contains third-party software components governed by the license(s) indicated below: + --------- + + Component Name: System.Buffers.dll + + License Type: MIT + + The MIT License (MIT) + + Copyright (c) .NET Foundation and Contributors + + All rights reserved. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + + --------- + +Component Name: System.Numerics.Vectors.dll + +License Type: MIT + +The MIT License (MIT) + +Copyright (c) .NET Foundation and Contributors + +All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + --------- + +Component Name: System.Runtime.CompilerServices.Unsafe + +License Type: MIT + +The MIT License (MIT) + +Copyright (c) .NET Foundation and Contributors + +All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + --------- + +Component Name: System.Memory.dll + +License Type: MIT + +The MIT License (MIT) + +Copyright (c) .NET Foundation and Contributors + +All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + --------- + +Component Name: System.IO.Abstractions + +License Type: MIT + +The MIT License (MIT) + +Copyright (c) Tatham Oddie and Contributors + +All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + --------- + +Component Name: System.Interactive.Async.dll + +License Type: Apache-2.0 + +Copyright (c) .NET Foundation and Contributors +All Rights Reserved + +Licensed under the Apache License, Version 2.0 (the "License"); you +may not use this file except in compliance with the License. You may +obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions +and limitations under the License. + + --------- + +Component Name: Grpc + +License Type: Apache-2.0 + + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/com.unity.ml-agents/Third Party Notices.md.meta b/com.unity.ml-agents/Third Party Notices.md.meta new file mode 100644 index 0000000000..00901a0666 --- /dev/null +++ b/com.unity.ml-agents/Third Party Notices.md.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: 3fb7f1407083340b8921a0520b2d8870 +TextScriptImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/catalog-info.yaml b/com.unity.ml-agents/catalog-info.yaml new file mode 100644 index 0000000000..67d98bb424 --- /dev/null +++ b/com.unity.ml-agents/catalog-info.yaml @@ -0,0 +1,21 @@ +# For more information about the available options please visit: http://go/backstage (VPN required) +apiVersion: backstage.io/v1alpha1 +kind: Component +metadata: + annotations: + github.com/project-slug: unity/com.unity.ml-agents + name: com.unity.ml-agents + description: "Unity ML-Agents Package" + labels: + costcenter: "5160" + tags: + - planned-public + - enterprise + links: + - url: https://unity.slack.com/messages/C8FECS6L9/ + title: "#devs-ml-agents" + icon: chat +spec: + type: unity-package + lifecycle: production + owner: unity/behavior-authoring diff --git a/com.unity.ml-agents/catalog-info.yaml.meta b/com.unity.ml-agents/catalog-info.yaml.meta new file mode 100644 index 0000000000..15b7d888e5 --- /dev/null +++ b/com.unity.ml-agents/catalog-info.yaml.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: 7d543dec1acb6455fb97a799ca89315c +TextScriptImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/com.unity.ml-agents.sln.DotSettings b/com.unity.ml-agents/com.unity.ml-agents.sln.DotSettings new file mode 100644 index 0000000000..c1c2b34e0f --- /dev/null +++ b/com.unity.ml-agents/com.unity.ml-agents.sln.DotSettings @@ -0,0 +1,21 @@ + + BLAS + CPU + GPU + NN + PNG + RL + True + True + True + + + True + True + True + True + True + True + True + True + True \ No newline at end of file diff --git a/com.unity.ml-agents/com.unity.ml-agents.sln.DotSettings.meta b/com.unity.ml-agents/com.unity.ml-agents.sln.DotSettings.meta new file mode 100644 index 0000000000..ee052ef017 --- /dev/null +++ b/com.unity.ml-agents/com.unity.ml-agents.sln.DotSettings.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: c8f6f0111d3fb4d71af892263a7a614d +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/package.json b/com.unity.ml-agents/package.json new file mode 100755 index 0000000000..b82a76d6f3 --- /dev/null +++ b/com.unity.ml-agents/package.json @@ -0,0 +1,12 @@ +{ + "name": "com.unity.ml-agents", + "displayName": "ML Agents", + "version": "2.3.0-exp.3", + "unity": "2021.3", + "description": "Use state-of-the-art machine learning to create intelligent character behaviors in any Unity environment (games, robotics, film, etc.).", + "dependencies": { + "com.unity.barracuda": "3.0.0", + "com.unity.modules.imageconversion": "1.0.0", + "com.unity.modules.jsonserialize": "1.0.0" + } +} diff --git a/com.unity.ml-agents/package.json.meta b/com.unity.ml-agents/package.json.meta new file mode 100644 index 0000000000..d76c84a5fa --- /dev/null +++ b/com.unity.ml-agents/package.json.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: e13b73fbfc8e74e4d87af46bf55d7df6 +TextScriptImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/conftest.py b/conftest.py index 6f5b1ff4ad..846ff868c7 100644 --- a/conftest.py +++ b/conftest.py @@ -14,6 +14,7 @@ from filelock import FileLock # TODO: Use this in all ml-agents tests so they can all run in parallel. +import mlagents.plugins.trainer_type _BASE_PORT = 6005 @@ -76,3 +77,8 @@ def test_something(base_port: int) -> None: :return: The base port number. """ return PortAllocator().reserve_n_ports(n_ports) + + +@pytest.fixture(scope="session", autouse=True) +def setup_plugin_trainers(): + _, _ = mlagents.plugins.trainer_type.register_trainer_plugins() diff --git a/docs/Background-Machine-Learning.md b/docs/Background-Machine-Learning.md index 3e47c57684..58c8b2f838 100644 --- a/docs/Background-Machine-Learning.md +++ b/docs/Background-Machine-Learning.md @@ -111,9 +111,7 @@ every step, but only when a robot arrives at a success or failure situation), is a defining characteristic of reinforcement learning and precisely why learning good policies can be difficult (and/or time-consuming) for complex environments. -

- The reinforcement learning cycle. -

+
The reinforcement learning lifecycle.
[Learning a policy](https://blogs.unity3d.com/2017/08/22/unity-ai-reinforcement-learning-with-q-learning/) usually requires many trials and iterative policy updates. More specifically, diff --git a/docs/CODE_OF_CONDUCT.md b/docs/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..14a6a4a839 --- /dev/null +++ b/docs/CODE_OF_CONDUCT.md @@ -0,0 +1 @@ +{!../CODE_OF_CONDUCT.md!} diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md new file mode 100644 index 0000000000..657c00df4b --- /dev/null +++ b/docs/CONTRIBUTING.md @@ -0,0 +1 @@ +{!../com.unity.ml-agents/CONTRIBUTING.md!} diff --git a/docs/Installation-Anaconda-Windows.md b/docs/Installation-Anaconda-Windows.md index 4bc5a1342a..d5af6befce 100644 --- a/docs/Installation-Anaconda-Windows.md +++ b/docs/Installation-Anaconda-Windows.md @@ -18,8 +18,8 @@ versions and features. [Download](https://www.anaconda.com/download/#windows) and install Anaconda for Windows. By using Anaconda, you can manage separate environments for different -distributions of Python. Python 3.7.2 or higher is required as we no longer -support Python 2. In this guide, we are using Python version 3.7 and Anaconda +distributions of Python. Python 3.8.13 or higher is required as we no longer +support Python 2. In this guide, we are using Python version 3.8 and Anaconda version 5.1 ([64-bit](https://repo.continuum.io/archive/Anaconda3-5.1.0-Windows-x86_64.exe) or [32-bit](https://repo.continuum.io/archive/Anaconda3-5.1.0-Windows-x86.exe) @@ -80,12 +80,12 @@ To create a new Conda environment, open a new Anaconda Prompt (_Anaconda Prompt_ in the search bar) and type in the following command: ```sh -conda create -n ml-agents python=3.7 +conda create -n ml-agents python=3.8 ``` You may be asked to install new packages. Type `y` and press enter _(make sure you are connected to the Internet)_. You must install these required packages. -The new Conda environment is called ml-agents and uses Python version 3.7. +The new Conda environment is called ml-agents and uses Python version 3.8.

Anaconda Install @@ -151,7 +151,7 @@ config files in this directory when running `mlagents-learn`. Make sure you are connected to the Internet and then type in the Anaconda Prompt: ```console -python -m pip install mlagents==0.28.0 +python -m pip install mlagents==0.29.0 ``` This will complete the installation of all the required Python packages to run @@ -162,7 +162,7 @@ pip will get stuck when trying to read the cache of the package. If you see this, you can try: ```console -python -m pip install mlagents==0.28.0 --no-cache-dir +python -m pip install mlagents==0.29.0 --no-cache-dir ``` This `--no-cache-dir` tells the pip to disable the cache. diff --git a/docs/Installation.md b/docs/Installation.md index 58c97c06a3..b2fe1778df 100644 --- a/docs/Installation.md +++ b/docs/Installation.md @@ -10,11 +10,11 @@ The ML-Agents Toolkit contains several components: contains experimental C#/Unity components that are not yet ready to be part of the base `com.unity.ml-agents` package. `com.unity.ml-agents.extensions` has a direct dependency on `com.unity.ml-agents`. -- Three Python packages: +- Two Python packages: - [`mlagents`](../ml-agents/) contains the machine learning algorithms that enables you to train behaviors in your Unity scene. Most users of ML-Agents will only need to directly install `mlagents`. - - [`mlagents_envs`](../ml-agents-envs/) contains a Python API to interact with + - [`mlagents_envs`](../ml-agents-envs/) contains a set of Python APIs to interact with a Unity scene. It is a foundational layer that facilitates data messaging between Unity scene and the Python machine learning algorithms. Consequently, `mlagents` depends on `mlagents_envs`. @@ -24,8 +24,8 @@ The ML-Agents Toolkit contains several components: Consequently, to install and use the ML-Agents Toolkit you will need to: -- Install Unity (2020.3 or later) -- Install Python (3.7.2 or higher) +- Install Unity (2021.3 or later) +- Install Python (3.8.13 or higher) - Clone this repository (Optional) - __Note:__ If you do not clone the repository, then you will not be able to access the example environments and training configurations or the @@ -36,15 +36,15 @@ Consequently, to install and use the ML-Agents Toolkit you will need to: - Install the `com.unity.ml-agents.extensions` Unity package (Optional) - Install the `mlagents` Python package -### Install **Unity 2020.3** or Later +### Install **Unity 2021.3** or Later [Download](https://unity3d.com/get-unity/download) and install Unity. We strongly recommend that you install Unity through the Unity Hub as it will enable you to manage multiple Unity versions. -### Install **Python 3.7.2** or Higher +### Install **Python 3.8.13** or Higher -We recommend [installing](https://www.python.org/downloads/) Python 3.7. +We recommend [installing](https://www.python.org/downloads/) Python 3.8. If you are using Windows, please install the x86-64 version and not x86. If your Python environment doesn't include `pip3`, see these [instructions](https://packaging.python.org/guides/installing-using-linux-tools/#installing-pip-setuptools-wheel-with-linux-package-managers) @@ -74,7 +74,7 @@ You will need to clone the repository if you plan to modify or extend the ML-Agents Toolkit for your purposes. If you plan to contribute those changes back, make sure to clone the `main` branch (by omitting `--branch release_19` from the command above). See our -[Contributions Guidelines](../com.unity.ml-agents/CONTRIBUTING.md) for more +[Contributions Guidelines](CONTRIBUTING.md) for more information on contributing to the ML-Agents Toolkit. ### Install the `com.unity.ml-agents` Unity package @@ -153,7 +153,7 @@ To install the `mlagents` Python package, activate your virtual environment and run from the command line: ```sh -python -m pip install mlagents==0.28.0 +python -m pip install mlagents==0.29.0 ``` Note that this will install `mlagents` from PyPi, _not_ from the cloned diff --git a/docs/LICENSE.md b/docs/LICENSE.md new file mode 100644 index 0000000000..b5bebb551d --- /dev/null +++ b/docs/LICENSE.md @@ -0,0 +1 @@ +{!../LICENSE.md!} diff --git a/docs/Learning-Environment-Design-Agents.md b/docs/Learning-Environment-Design-Agents.md index d6bd892938..15c52f9671 100644 --- a/docs/Learning-Environment-Design-Agents.md +++ b/docs/Learning-Environment-Design-Agents.md @@ -454,7 +454,7 @@ Agent. - To collect visual observations, attach `CameraSensor` or `RenderTextureSensor` components to the agent GameObject. -- Visual observations should generally be used unless vector observations are +- Visual observations should generally only be used when vector observations are not sufficient. - Image size should be kept as small as possible, without the loss of needed details for decision making. @@ -494,6 +494,12 @@ Both sensor components have several settings: `Behavior Parameters`. - _Start Vertical Offset_ (3D only) The vertical offset of the ray start point. - _End Vertical Offset_ (3D only) The vertical offset of the ray end point. +- _Alternating Ray Order_ Alternating is the default, it gives an order of (0, + -delta, delta, -2*delta, 2*delta, ..., -n*delta, n*delta). If alternating is + disabled the order is left to right (-n*delta, -(n-1)*delta, ..., -delta, 0, + delta, ..., (n-1)*delta, n*delta). For general usage there is no difference + but if using custom models the left-to-right layout that matches the spatial + structuring can be preferred (e.g. for processing with conv nets). In the example image above, the Agent has two `RayPerceptionSensorComponent3D`s. Both use 3 Rays Per Direction and 90 Max Ray Degrees. One of the components had diff --git a/docs/Learning-Environment-Examples.md b/docs/Learning-Environment-Examples.md index 3ff3cb466f..9f3914250c 100644 --- a/docs/Learning-Environment-Examples.md +++ b/docs/Learning-Environment-Examples.md @@ -18,7 +18,7 @@ This page only overviews the example environments we provide. To learn more on how to design and build your own environments see our [Making a New Learning Environment](Learning-Environment-Create-New.md) page. If you would like to contribute environments, please see our -[contribution guidelines](../com.unity.ml-agents/CONTRIBUTING.md) page. +[contribution guidelines](CONTRIBUTING.md) page. ## Basic diff --git a/docs/Learning-Environment-Executable.md b/docs/Learning-Environment-Executable.md index f47d0f4c02..f5ea6936cc 100644 --- a/docs/Learning-Environment-Executable.md +++ b/docs/Learning-Environment-Executable.md @@ -194,5 +194,5 @@ graphics display in the Unity executable. There are two ways to achieve this: If you want to train with graphics (for example, using camera and visual observations), you'll need to set up display rendering support (e.g. xvfb) on you server machine. In our -[Colab Notebook Tutorials](Readme.md#python-tutorial-with-google-colab), the Setup section has +[Colab Notebook Tutorials](ML-Agents-Toolkit-Documentation.md#python-tutorial-with-google-colab), the Setup section has examples of setting up xvfb on servers. diff --git a/docs/Limitations.md b/docs/Limitations.md index 5a126e9fce..a6ee31cb27 100644 --- a/docs/Limitations.md +++ b/docs/Limitations.md @@ -2,6 +2,6 @@ See the package-specific Limitations pages: -- [`com.unity.mlagents` Unity package](../com.unity.ml-agents/Documentation~/com.unity.ml-agents.md#known-limitations) -- [`mlagents` Python package](../ml-agents/README.md#limitations) +- [`com.unity.mlagents` Unity package](https://docs.unity3d.com/Packages/com.unity.ml-agents@2.1/manual/index.html#known-limitations) +- [`mlagents` Python package](ML-Agents-README.md) - [`mlagents_envs` Python package](../ml-agents-envs/README.md#limitations) diff --git a/docs/ML-Agents-Envs-README.md b/docs/ML-Agents-Envs-README.md new file mode 100644 index 0000000000..9e382027ea --- /dev/null +++ b/docs/ML-Agents-Envs-README.md @@ -0,0 +1 @@ +{!../ml-agents-envs/README.md!} diff --git a/docs/ML-Agents-Overview.md b/docs/ML-Agents-Overview.md index 0798430d72..30ea5cfc46 100644 --- a/docs/ML-Agents-Overview.md +++ b/docs/ML-Agents-Overview.md @@ -27,7 +27,7 @@ - [Model Types](#model-types) - [Learning from Vector Observations](#learning-from-vector-observations) - [Learning from Cameras using Convolutional Neural Networks](#learning-from-cameras-using-convolutional-neural-networks) - - [Learning from Variable Length Observations using Attention](#learning-from-ariable-length-observations-using-attention) + - [Learning from Variable Length Observations using Attention](#learning-from-variable-length-observations-using-attention) - [Memory-enhanced Agents using Recurrent Neural Networks](#memory-enhanced-agents-using-recurrent-neural-networks) - [Additional Features](#additional-features) - [Summary and Next Steps](#summary-and-next-steps) @@ -179,9 +179,8 @@ The ML-Agents Toolkit contains five high-level components: - **Gym Wrapper** (not pictured). A common way in which machine learning researchers interact with simulation environments is via a wrapper provided by OpenAI called [gym](https://github.com/openai/gym). We provide a gym wrapper - in the `ml-agents-envs` package and - [instructions](Python-Gym-API.md) for using it with existing machine - learning algorithms which utilize gym. + in the `ml-agents-envs` package and [instructions](Python-Gym-API.md) for using + it with existing machine learning algorithms which utilize gym. - **PettingZoo Wrapper** (not pictured) PettingZoo is python API for interacting with multi-agent simulation environments that provides a gym-like interface. We provide a PettingZoo wrapper for Unity ML-Agents @@ -190,7 +189,7 @@ The ML-Agents Toolkit contains five high-level components: algorithms.

- Simplified ML-Agents Scene Block Diagram @@ -225,7 +224,7 @@ can have the same Behavior. This does not mean that at each instance they will have identical observation and action _values_.

- Example ML-Agents Scene Block Diagram @@ -247,7 +246,7 @@ Channels_ is to exchange data with Python about _Environment Parameters_. The following diagram illustrates the above.

- More Complete Example ML-Agents Scene Block Diagram

@@ -467,7 +466,7 @@ episodes of demonstrations can reduce training steps by more than 4 times. See Behavioral Cloning + GAIL + Curiosity + RL below.

- Using Demonstrations with Reinforcement Learning

@@ -622,6 +621,8 @@ MA-POCA can also be combined with self-play to train teams of agents to play aga To learn more about enabling cooperative behaviors for agents in an ML-Agents environment, check out [this page](Learning-Environment-Design-Agents.md#groups-for-cooperative-scenarios). +To learn more about MA-POCA, please see our paper +[On the Use and Misuse of Absorbing States in Multi-Agent Reinforcement Learning](https://arxiv.org/pdf/2111.05992.pdf). For further reading, MA-POCA builds on previous work in multi-agent cooperative learning ([Lowe et al.](https://arxiv.org/abs/1706.02275), [Foerster et al.](https://arxiv.org/pdf/1705.08926.pdf), among others) to enable the above use-cases. diff --git a/docs/ML-Agents-README.md b/docs/ML-Agents-README.md new file mode 100644 index 0000000000..5e2c7107a8 --- /dev/null +++ b/docs/ML-Agents-README.md @@ -0,0 +1 @@ +{!../ml-agents/README.md!} diff --git a/docs/ML-Agents-Toolkit-Documentation.md b/docs/ML-Agents-Toolkit-Documentation.md new file mode 100644 index 0000000000..4a9d14f509 --- /dev/null +++ b/docs/ML-Agents-Toolkit-Documentation.md @@ -0,0 +1,80 @@ +# Unity ML-Agents Toolkit Documentation + +## Installation & Set-up + +- [Installation](Installation.md) + - [Using Virtual Environment](Using-Virtual-Environment.md) + +## Getting Started + +- [Getting Started Guide](Getting-Started.md) +- [ML-Agents Toolkit Overview](ML-Agents-Overview.md) + - [Background: Unity](Background-Unity.md) + - [Background: Machine Learning](Background-Machine-Learning.md) + - [Background: PyTorch](Background-PyTorch.md) +- [Example Environments](Learning-Environment-Examples.md) + +## Creating Learning Environments + +- [Making a New Learning Environment](Learning-Environment-Create-New.md) +- [Designing a Learning Environment](Learning-Environment-Design.md) + - [Designing Agents](Learning-Environment-Design-Agents.md) +- [Using an Executable Environment](Learning-Environment-Executable.md) +- [ML-Agents Package Settings](Package-Settings.md) + +## Training & Inference + +- [Training ML-Agents](Training-ML-Agents.md) + - [Training Configuration File](Training-Configuration-File.md) + - [Using TensorBoard to Observe Training](Using-Tensorboard.md) + - [Profiling Trainers](Profiling-Python.md) +- [Unity Inference Engine](Unity-Inference-Engine.md) + +## Extending ML-Agents + +- [Creating Custom Side Channels](Custom-SideChannels.md) +- [Creating Custom Samplers for Environment Parameter Randomization](Training-ML-Agents.md#defining-a-new-sampler-type) + +## Python Tutorial with Google Colab + +- [Using a UnityEnvironment](https://colab.research.google.com/github/Unity-Technologies/ml-agents/blob/release_19_docs/colab/Colab_UnityEnvironment_1_Run.ipynb) +- [Q-Learning with a UnityEnvironment](https://colab.research.google.com/github/Unity-Technologies/ml-agents/blob/release_19_docs/colab/Colab_UnityEnvironment_2_Train.ipynb) +- [Using Side Channels on a UnityEnvironment](https://colab.research.google.com/github/Unity-Technologies/ml-agents/blob/release_19_docs/colab/Colab_UnityEnvironment_3_SideChannel.ipynb) + +## Help + +- [Migrating from earlier versions of ML-Agents](Migrating.md) +- [Frequently Asked Questions](FAQ.md) +- [ML-Agents Glossary](Glossary.md) +- [Limitations](Limitations.md) + +## API Docs + +- [API Reference](API-Reference.md) +- [Python API Documentation](Python-LLAPI-Documentation.md) +- [How to use the Python API](Python-LLAPI.md) +- [How to use the Unity Environment Registry](Unity-Environment-Registry.md) +- [Wrapping Learning Environment as a Gym (+Baselines/Dopamine Integration)](Python-Gym-API.md) + +## Translations + +To make the Unity ML-Agents Toolkit accessible to the global research and Unity +developer communities, we're attempting to create and maintain translations of +our documentation. We've started with translating a subset of the documentation +to one language (Chinese), but we hope to continue translating more pages and to +other languages. Consequently, we welcome any enhancements and improvements from +the community. + +- [Chinese](../localized_docs/zh-CN/) +- [Korean](../localized_docs/KR/) + +## Deprecated Docs + +We no longer use them ourselves and so they may not be up-to-date. We've decided +to keep them up just in case they are helpful to you. + +- [Windows Anaconda Installation](Installation-Anaconda-Windows.md) +- [Using Docker](Using-Docker.md) +- [Training on the Cloud with Amazon Web Services](Training-on-Amazon-Web-Service.md) +- [Training on the Cloud with Microsoft Azure](Training-on-Microsoft-Azure.md) +- [Using the Video Recorder](https://github.com/Unity-Technologies/video-recorder) diff --git a/docs/Migrating.md b/docs/Migrating.md index 73a12a9165..4d0e8100ea 100644 --- a/docs/Migrating.md +++ b/docs/Migrating.md @@ -4,10 +4,9 @@ -## Migrating to the ml-agents-envs 0.29.0.dev0 package -- Python 3.7 is now the minimum version of python supported due to [python3.6 EOL](https://endoflife.date/python). - Please update your python installation to 3.7.2 or higher. Note: Due to an issue with the typing system, the maximum - version of python supported is python 3.9.9. +## Migrating to the ml-agents-envs 0.29.0 package +- Python 3.8 is now the minimum version of python supported due to [python3.6 EOL](https://endoflife.date/python). + Please update your python installation to 3.8.13 or higher. - The `gym-unity` package has been refactored into the `ml-agents-envs` package. Please update your imports accordingly. - Example: - Before @@ -21,7 +20,7 @@ from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper ## Migrating the package to version 2.0 -- The official version of Unity ML-Agents supports is now 2020.3 LTS. If you run +- The official version of Unity ML-Agents supports is now 2021.3 LTS. If you run into issues, please consider deleting your project's Library folder and reponening your project. - If you used any of the APIs that were deprecated before version 2.0, you need to use their replacement. These diff --git a/docs/Python-Custom-Trainer-Plugin.md b/docs/Python-Custom-Trainer-Plugin.md new file mode 100644 index 0000000000..3432e9523b --- /dev/null +++ b/docs/Python-Custom-Trainer-Plugin.md @@ -0,0 +1,51 @@ +# Unity Ml-Agents Custom trainers Plugin + +As an attempt to bring a wider variety of reinforcement learning algorithms to our users, we have added custom trainers +capabilities. we introduce an extensible plugin system to define new trainers based on the High level trainer API +in `Ml-agents` Package. This will allow rerouting `mlagents-learn` CLI to custom trainers and extending the config files +with hyper-parameters specific to your new trainers. We will expose a high-level extensible trainer (both on-policy, +and off-policy trainers) optimizer and hyperparameter classes with documentation for the use of this plugin. For more +infromation on how python plugin system works see [Plugin interfaces](Training-Plugins.md). +## Overview +Model-free RL algorithms generally fall into two broad categories: on-policy and off-policy. On-policy algorithms perform updates based on data gathered from the current policy. Off-policy algorithms learn a Q function from a buffer of previous data, then use this Q function to make decisions. Off-policy algorithms have three key benefits in the context of ML-Agents: They tend to use fewer samples than on-policy as they can pull and re-use data from the buffer many times. They allow player demonstrations to be inserted in-line with RL data into the buffer, enabling new ways of doing imitation learning by streaming player data. + +To add new custom trainers to ML-agents, you would need to create a new python package. +To give you an idea of how to structure your package, we have created a [mlagents_trainer_plugin](../ml-agents-trainer-plugin) package ourselves as an +example, with implementation of `A2c` and `DQN` algorithms. You would need a `setup.py` file to list extra requirements and +register the new RL algorithm in ml-agents ecosystem and be able to call `mlagents-learn` CLI with your customized +configuration. + + +```shell +├── mlagents_trainer_plugin +│ ├── __init__.py +│ ├── a2c +│ │ ├── __init__.py +│ │ ├── a2c_3DBall.yaml +│ │ ├── a2c_optimizer.py +│ │ └── a2c_trainer.py +│ └── dqn +│ ├── __init__.py +│ ├── dqn_basic.yaml +│ ├── dqn_optimizer.py +│ └── dqn_trainer.py +└── setup.py +``` +## Installation and Execution +If you haven't already, follow the [installation instructions](Installation.md). Once you have the `ml-agents-env` and `ml-agents` packages you can install the plugin package. From the repository's root directory install `ml-agents-trainer-plugin` (or replace with the name of your plugin folder). + +```sh +pip3 install -e <./ml-agents-trainer-plugin> +``` + +Following the previous installations your package is added as an entrypoint and you can use a config file with new +trainers: +```sh +mlagents-learn ml-agents-trainer-plugin/mlagents_trainer_plugin/a2c/a2c_3DBall.yaml --run-id +--env +``` + +## Tutorial +Here’s a step-by-step [tutorial](Tutorial-Custom-Trainer-Plugin.md) on how to write a setup file and extend ml-agents trainers, optimizers, and +hyperparameter settings.To extend ML-agents classes see references on +[trainers](Python-On-Off-Policy-Trainer-Documentation.md) and [Optimizer](Python-Optimizer-Documentation.md). \ No newline at end of file diff --git a/docs/Python-Gym-API.md b/docs/Python-Gym-API.md index 816c167f62..50051195ed 100644 --- a/docs/Python-Gym-API.md +++ b/docs/Python-Gym-API.md @@ -12,7 +12,7 @@ Unity environment via Python. ## Installation The gym wrapper is part of the `mlgents_envs` package. Please refer to the -[mlagents_envs installation instructions](../ml-agents-envs/README.md). +[mlagents_envs installation instructions](ML-Agents-Envs-README.md). ## Using the Gym Wrapper diff --git a/docs/Python-On-Off-Policy-Trainer-Documentation.md b/docs/Python-On-Off-Policy-Trainer-Documentation.md new file mode 100644 index 0000000000..4f13bdf72e --- /dev/null +++ b/docs/Python-On-Off-Policy-Trainer-Documentation.md @@ -0,0 +1,787 @@ +# Table of Contents + +* [mlagents.trainers.trainer.on\_policy\_trainer](#mlagents.trainers.trainer.on_policy_trainer) + * [OnPolicyTrainer](#mlagents.trainers.trainer.on_policy_trainer.OnPolicyTrainer) + * [\_\_init\_\_](#mlagents.trainers.trainer.on_policy_trainer.OnPolicyTrainer.__init__) + * [add\_policy](#mlagents.trainers.trainer.on_policy_trainer.OnPolicyTrainer.add_policy) +* [mlagents.trainers.trainer.off\_policy\_trainer](#mlagents.trainers.trainer.off_policy_trainer) + * [OffPolicyTrainer](#mlagents.trainers.trainer.off_policy_trainer.OffPolicyTrainer) + * [\_\_init\_\_](#mlagents.trainers.trainer.off_policy_trainer.OffPolicyTrainer.__init__) + * [save\_model](#mlagents.trainers.trainer.off_policy_trainer.OffPolicyTrainer.save_model) + * [save\_replay\_buffer](#mlagents.trainers.trainer.off_policy_trainer.OffPolicyTrainer.save_replay_buffer) + * [load\_replay\_buffer](#mlagents.trainers.trainer.off_policy_trainer.OffPolicyTrainer.load_replay_buffer) + * [add\_policy](#mlagents.trainers.trainer.off_policy_trainer.OffPolicyTrainer.add_policy) +* [mlagents.trainers.trainer.rl\_trainer](#mlagents.trainers.trainer.rl_trainer) + * [RLTrainer](#mlagents.trainers.trainer.rl_trainer.RLTrainer) + * [end\_episode](#mlagents.trainers.trainer.rl_trainer.RLTrainer.end_episode) + * [create\_optimizer](#mlagents.trainers.trainer.rl_trainer.RLTrainer.create_optimizer) + * [save\_model](#mlagents.trainers.trainer.rl_trainer.RLTrainer.save_model) + * [advance](#mlagents.trainers.trainer.rl_trainer.RLTrainer.advance) +* [mlagents.trainers.trainer.trainer](#mlagents.trainers.trainer.trainer) + * [Trainer](#mlagents.trainers.trainer.trainer.Trainer) + * [\_\_init\_\_](#mlagents.trainers.trainer.trainer.Trainer.__init__) + * [stats\_reporter](#mlagents.trainers.trainer.trainer.Trainer.stats_reporter) + * [parameters](#mlagents.trainers.trainer.trainer.Trainer.parameters) + * [get\_max\_steps](#mlagents.trainers.trainer.trainer.Trainer.get_max_steps) + * [get\_step](#mlagents.trainers.trainer.trainer.Trainer.get_step) + * [threaded](#mlagents.trainers.trainer.trainer.Trainer.threaded) + * [should\_still\_train](#mlagents.trainers.trainer.trainer.Trainer.should_still_train) + * [reward\_buffer](#mlagents.trainers.trainer.trainer.Trainer.reward_buffer) + * [save\_model](#mlagents.trainers.trainer.trainer.Trainer.save_model) + * [end\_episode](#mlagents.trainers.trainer.trainer.Trainer.end_episode) + * [create\_policy](#mlagents.trainers.trainer.trainer.Trainer.create_policy) + * [add\_policy](#mlagents.trainers.trainer.trainer.Trainer.add_policy) + * [get\_policy](#mlagents.trainers.trainer.trainer.Trainer.get_policy) + * [advance](#mlagents.trainers.trainer.trainer.Trainer.advance) + * [publish\_policy\_queue](#mlagents.trainers.trainer.trainer.Trainer.publish_policy_queue) + * [subscribe\_trajectory\_queue](#mlagents.trainers.trainer.trainer.Trainer.subscribe_trajectory_queue) +* [mlagents.trainers.settings](#mlagents.trainers.settings) + * [deep\_update\_dict](#mlagents.trainers.settings.deep_update_dict) + * [RewardSignalSettings](#mlagents.trainers.settings.RewardSignalSettings) + * [structure](#mlagents.trainers.settings.RewardSignalSettings.structure) + * [ParameterRandomizationSettings](#mlagents.trainers.settings.ParameterRandomizationSettings) + * [\_\_str\_\_](#mlagents.trainers.settings.ParameterRandomizationSettings.__str__) + * [structure](#mlagents.trainers.settings.ParameterRandomizationSettings.structure) + * [unstructure](#mlagents.trainers.settings.ParameterRandomizationSettings.unstructure) + * [apply](#mlagents.trainers.settings.ParameterRandomizationSettings.apply) + * [ConstantSettings](#mlagents.trainers.settings.ConstantSettings) + * [\_\_str\_\_](#mlagents.trainers.settings.ConstantSettings.__str__) + * [apply](#mlagents.trainers.settings.ConstantSettings.apply) + * [UniformSettings](#mlagents.trainers.settings.UniformSettings) + * [\_\_str\_\_](#mlagents.trainers.settings.UniformSettings.__str__) + * [apply](#mlagents.trainers.settings.UniformSettings.apply) + * [GaussianSettings](#mlagents.trainers.settings.GaussianSettings) + * [\_\_str\_\_](#mlagents.trainers.settings.GaussianSettings.__str__) + * [apply](#mlagents.trainers.settings.GaussianSettings.apply) + * [MultiRangeUniformSettings](#mlagents.trainers.settings.MultiRangeUniformSettings) + * [\_\_str\_\_](#mlagents.trainers.settings.MultiRangeUniformSettings.__str__) + * [apply](#mlagents.trainers.settings.MultiRangeUniformSettings.apply) + * [CompletionCriteriaSettings](#mlagents.trainers.settings.CompletionCriteriaSettings) + * [need\_increment](#mlagents.trainers.settings.CompletionCriteriaSettings.need_increment) + * [Lesson](#mlagents.trainers.settings.Lesson) + * [EnvironmentParameterSettings](#mlagents.trainers.settings.EnvironmentParameterSettings) + * [structure](#mlagents.trainers.settings.EnvironmentParameterSettings.structure) + * [TrainerSettings](#mlagents.trainers.settings.TrainerSettings) + * [structure](#mlagents.trainers.settings.TrainerSettings.structure) + * [CheckpointSettings](#mlagents.trainers.settings.CheckpointSettings) + * [prioritize\_resume\_init](#mlagents.trainers.settings.CheckpointSettings.prioritize_resume_init) + * [RunOptions](#mlagents.trainers.settings.RunOptions) + * [from\_argparse](#mlagents.trainers.settings.RunOptions.from_argparse) + + +# mlagents.trainers.trainer.on\_policy\_trainer + + +## OnPolicyTrainer Objects + +```python +class OnPolicyTrainer(RLTrainer) +``` + +The PPOTrainer is an implementation of the PPO algorithm. + + +#### \_\_init\_\_ + +```python + | __init__(behavior_name: str, reward_buff_cap: int, trainer_settings: TrainerSettings, training: bool, load: bool, seed: int, artifact_path: str) +``` + +Responsible for collecting experiences and training an on-policy model. + +**Arguments**: + +- `behavior_name`: The name of the behavior associated with trainer config +- `reward_buff_cap`: Max reward history to track in the reward buffer +- `trainer_settings`: The parameters for the trainer. +- `training`: Whether the trainer is set for training. +- `load`: Whether the model should be loaded. +- `seed`: The seed the model will be initialized with +- `artifact_path`: The directory within which to store artifacts from this trainer. + + +#### add\_policy + +```python + | add_policy(parsed_behavior_id: BehaviorIdentifiers, policy: Policy) -> None +``` + +Adds policy to trainer. + +**Arguments**: + +- `parsed_behavior_id`: Behavior identifiers that the policy should belong to. +- `policy`: Policy to associate with name_behavior_id. + + +# mlagents.trainers.trainer.off\_policy\_trainer + + +## OffPolicyTrainer Objects + +```python +class OffPolicyTrainer(RLTrainer) +``` + +The SACTrainer is an implementation of the SAC algorithm, with support +for discrete actions and recurrent networks. + + +#### \_\_init\_\_ + +```python + | __init__(behavior_name: str, reward_buff_cap: int, trainer_settings: TrainerSettings, training: bool, load: bool, seed: int, artifact_path: str) +``` + +Responsible for collecting experiences and training an off-policy model. + +**Arguments**: + +- `behavior_name`: The name of the behavior associated with trainer config +- `reward_buff_cap`: Max reward history to track in the reward buffer +- `trainer_settings`: The parameters for the trainer. +- `training`: Whether the trainer is set for training. +- `load`: Whether the model should be loaded. +- `seed`: The seed the model will be initialized with +- `artifact_path`: The directory within which to store artifacts from this trainer. + + +#### save\_model + +```python + | save_model() -> None +``` + +Saves the final training model to memory +Overrides the default to save the replay buffer. + + +#### save\_replay\_buffer + +```python + | save_replay_buffer() -> None +``` + +Save the training buffer's update buffer to a pickle file. + + +#### load\_replay\_buffer + +```python + | load_replay_buffer() -> None +``` + +Loads the last saved replay buffer from a file. + + +#### add\_policy + +```python + | add_policy(parsed_behavior_id: BehaviorIdentifiers, policy: Policy) -> None +``` + +Adds policy to trainer. + + +# mlagents.trainers.trainer.rl\_trainer + + +## RLTrainer Objects + +```python +class RLTrainer(Trainer) +``` + +This class is the base class for trainers that use Reward Signals. + + +#### end\_episode + +```python + | end_episode() -> None +``` + +A signal that the Episode has ended. The buffer must be reset. +Get only called when the academy resets. + + +#### create\_optimizer + +```python + | @abc.abstractmethod + | create_optimizer() -> TorchOptimizer +``` + +Creates an Optimizer object + + +#### save\_model + +```python + | save_model() -> None +``` + +Saves the policy associated with this trainer. + + +#### advance + +```python + | advance() -> None +``` + +Steps the trainer, taking in trajectories and updates if ready. +Will block and wait briefly if there are no trajectories. + + +# mlagents.trainers.trainer.trainer + + +## Trainer Objects + +```python +class Trainer(abc.ABC) +``` + +This class is the base class for the mlagents_envs.trainers + + +#### \_\_init\_\_ + +```python + | __init__(brain_name: str, trainer_settings: TrainerSettings, training: bool, load: bool, artifact_path: str, reward_buff_cap: int = 1) +``` + +Responsible for collecting experiences and training a neural network model. + +**Arguments**: + +- `brain_name`: Brain name of brain to be trained. +- `trainer_settings`: The parameters for the trainer (dictionary). +- `training`: Whether the trainer is set for training. +- `artifact_path`: The directory within which to store artifacts from this trainer +- `reward_buff_cap`: + + +#### stats\_reporter + +```python + | @property + | stats_reporter() +``` + +Returns the stats reporter associated with this Trainer. + + +#### parameters + +```python + | @property + | parameters() -> TrainerSettings +``` + +Returns the trainer parameters of the trainer. + + +#### get\_max\_steps + +```python + | @property + | get_max_steps() -> int +``` + +Returns the maximum number of steps. Is used to know when the trainer should be stopped. + +**Returns**: + +The maximum number of steps of the trainer + + +#### get\_step + +```python + | @property + | get_step() -> int +``` + +Returns the number of steps the trainer has performed + +**Returns**: + +the step count of the trainer + + +#### threaded + +```python + | @property + | threaded() -> bool +``` + +Whether or not to run the trainer in a thread. True allows the trainer to +update the policy while the environment is taking steps. Set to False to +enforce strict on-policy updates (i.e. don't update the policy when taking steps.) + + +#### should\_still\_train + +```python + | @property + | should_still_train() -> bool +``` + +Returns whether or not the trainer should train. A Trainer could +stop training if it wasn't training to begin with, or if max_steps +is reached. + + +#### reward\_buffer + +```python + | @property + | reward_buffer() -> Deque[float] +``` + +Returns the reward buffer. The reward buffer contains the cumulative +rewards of the most recent episodes completed by agents using this +trainer. + +**Returns**: + +the reward buffer. + + +#### save\_model + +```python + | @abc.abstractmethod + | save_model() -> None +``` + +Saves model file(s) for the policy or policies associated with this trainer. + + +#### end\_episode + +```python + | @abc.abstractmethod + | end_episode() +``` + +A signal that the Episode has ended. The buffer must be reset. +Get only called when the academy resets. + + +#### create\_policy + +```python + | @abc.abstractmethod + | create_policy(parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec) -> Policy +``` + +Creates a Policy object + + +#### add\_policy + +```python + | @abc.abstractmethod + | add_policy(parsed_behavior_id: BehaviorIdentifiers, policy: Policy) -> None +``` + +Adds policy to trainer. + + +#### get\_policy + +```python + | get_policy(name_behavior_id: str) -> Policy +``` + +Gets policy associated with name_behavior_id + +**Arguments**: + +- `name_behavior_id`: Fully qualified behavior name + +**Returns**: + +Policy associated with name_behavior_id + + +#### advance + +```python + | @abc.abstractmethod + | advance() -> None +``` + +Advances the trainer. Typically, this means grabbing trajectories +from all subscribed trajectory queues (self.trajectory_queues), and updating +a policy using the steps in them, and if needed pushing a new policy onto the right +policy queues (self.policy_queues). + + +#### publish\_policy\_queue + +```python + | publish_policy_queue(policy_queue: AgentManagerQueue[Policy]) -> None +``` + +Adds a policy queue to the list of queues to publish to when this Trainer +makes a policy update + +**Arguments**: + +- `policy_queue`: Policy queue to publish to. + + +#### subscribe\_trajectory\_queue + +```python + | subscribe_trajectory_queue(trajectory_queue: AgentManagerQueue[Trajectory]) -> None +``` + +Adds a trajectory queue to the list of queues for the trainer to ingest Trajectories from. + +**Arguments**: + +- `trajectory_queue`: Trajectory queue to read from. + + +# mlagents.trainers.settings + + +#### deep\_update\_dict + +```python +deep_update_dict(d: Dict, update_d: Mapping) -> None +``` + +Similar to dict.update(), but works for nested dicts of dicts as well. + + +## RewardSignalSettings Objects + +```python +@attr.s(auto_attribs=True) +class RewardSignalSettings() +``` + + +#### structure + +```python + | @staticmethod + | structure(d: Mapping, t: type) -> Any +``` + +Helper method to structure a Dict of RewardSignalSettings class. Meant to be registered with +cattr.register_structure_hook() and called with cattr.structure(). This is needed to handle +the special Enum selection of RewardSignalSettings classes. + + +## ParameterRandomizationSettings Objects + +```python +@attr.s(auto_attribs=True) +class ParameterRandomizationSettings(abc.ABC) +``` + + +#### \_\_str\_\_ + +```python + | __str__() -> str +``` + +Helper method to output sampler stats to console. + + +#### structure + +```python + | @staticmethod + | structure(d: Union[Mapping, float], t: type) -> "ParameterRandomizationSettings" +``` + +Helper method to a ParameterRandomizationSettings class. Meant to be registered with +cattr.register_structure_hook() and called with cattr.structure(). This is needed to handle +the special Enum selection of ParameterRandomizationSettings classes. + + +#### unstructure + +```python + | @staticmethod + | unstructure(d: "ParameterRandomizationSettings") -> Mapping +``` + +Helper method to a ParameterRandomizationSettings class. Meant to be registered with +cattr.register_unstructure_hook() and called with cattr.unstructure(). + + +#### apply + +```python + | @abc.abstractmethod + | apply(key: str, env_channel: EnvironmentParametersChannel) -> None +``` + +Helper method to send sampler settings over EnvironmentParametersChannel +Calls the appropriate sampler type set method. + +**Arguments**: + +- `key`: environment parameter to be sampled +- `env_channel`: The EnvironmentParametersChannel to communicate sampler settings to environment + + +## ConstantSettings Objects + +```python +@attr.s(auto_attribs=True) +class ConstantSettings(ParameterRandomizationSettings) +``` + + +#### \_\_str\_\_ + +```python + | __str__() -> str +``` + +Helper method to output sampler stats to console. + + +#### apply + +```python + | apply(key: str, env_channel: EnvironmentParametersChannel) -> None +``` + +Helper method to send sampler settings over EnvironmentParametersChannel +Calls the constant sampler type set method. + +**Arguments**: + +- `key`: environment parameter to be sampled +- `env_channel`: The EnvironmentParametersChannel to communicate sampler settings to environment + + +## UniformSettings Objects + +```python +@attr.s(auto_attribs=True) +class UniformSettings(ParameterRandomizationSettings) +``` + + +#### \_\_str\_\_ + +```python + | __str__() -> str +``` + +Helper method to output sampler stats to console. + + +#### apply + +```python + | apply(key: str, env_channel: EnvironmentParametersChannel) -> None +``` + +Helper method to send sampler settings over EnvironmentParametersChannel +Calls the uniform sampler type set method. + +**Arguments**: + +- `key`: environment parameter to be sampled +- `env_channel`: The EnvironmentParametersChannel to communicate sampler settings to environment + + +## GaussianSettings Objects + +```python +@attr.s(auto_attribs=True) +class GaussianSettings(ParameterRandomizationSettings) +``` + + +#### \_\_str\_\_ + +```python + | __str__() -> str +``` + +Helper method to output sampler stats to console. + + +#### apply + +```python + | apply(key: str, env_channel: EnvironmentParametersChannel) -> None +``` + +Helper method to send sampler settings over EnvironmentParametersChannel +Calls the gaussian sampler type set method. + +**Arguments**: + +- `key`: environment parameter to be sampled +- `env_channel`: The EnvironmentParametersChannel to communicate sampler settings to environment + + +## MultiRangeUniformSettings Objects + +```python +@attr.s(auto_attribs=True) +class MultiRangeUniformSettings(ParameterRandomizationSettings) +``` + + +#### \_\_str\_\_ + +```python + | __str__() -> str +``` + +Helper method to output sampler stats to console. + + +#### apply + +```python + | apply(key: str, env_channel: EnvironmentParametersChannel) -> None +``` + +Helper method to send sampler settings over EnvironmentParametersChannel +Calls the multirangeuniform sampler type set method. + +**Arguments**: + +- `key`: environment parameter to be sampled +- `env_channel`: The EnvironmentParametersChannel to communicate sampler settings to environment + + +## CompletionCriteriaSettings Objects + +```python +@attr.s(auto_attribs=True) +class CompletionCriteriaSettings() +``` + +CompletionCriteriaSettings contains the information needed to figure out if the next +lesson must start. + + +#### need\_increment + +```python + | need_increment(progress: float, reward_buffer: List[float], smoothing: float) -> Tuple[bool, float] +``` + +Given measures, this method returns a boolean indicating if the lesson +needs to change now, and a float corresponding to the new smoothed value. + + +## Lesson Objects + +```python +@attr.s(auto_attribs=True) +class Lesson() +``` + +Gathers the data of one lesson for one environment parameter including its name, +the condition that must be fullfiled for the lesson to be completed and a sampler +for the environment parameter. If the completion_criteria is None, then this is +the last lesson in the curriculum. + + +## EnvironmentParameterSettings Objects + +```python +@attr.s(auto_attribs=True) +class EnvironmentParameterSettings() +``` + +EnvironmentParameterSettings is an ordered list of lessons for one environment +parameter. + + +#### structure + +```python + | @staticmethod + | structure(d: Mapping, t: type) -> Dict[str, "EnvironmentParameterSettings"] +``` + +Helper method to structure a Dict of EnvironmentParameterSettings class. Meant +to be registered with cattr.register_structure_hook() and called with +cattr.structure(). + + +## TrainerSettings Objects + +```python +@attr.s(auto_attribs=True) +class TrainerSettings(ExportableSettings) +``` + + +#### structure + +```python + | @staticmethod + | structure(d: Mapping, t: type) -> Any +``` + +Helper method to structure a TrainerSettings class. Meant to be registered with +cattr.register_structure_hook() and called with cattr.structure(). + + +## CheckpointSettings Objects + +```python +@attr.s(auto_attribs=True) +class CheckpointSettings() +``` + + +#### prioritize\_resume\_init + +```python + | prioritize_resume_init() -> None +``` + +Prioritize explicit command line resume/init over conflicting yaml options. +if both resume/init are set at one place use resume + + +## RunOptions Objects + +```python +@attr.s(auto_attribs=True) +class RunOptions(ExportableSettings) +``` + + +#### from\_argparse + +```python + | @staticmethod + | from_argparse(args: argparse.Namespace) -> "RunOptions" +``` + +Takes an argparse.Namespace as specified in `parse_command_line`, loads input configuration files +from file paths, and converts to a RunOptions instance. + +**Arguments**: + +- `args`: collection of command-line parameters passed to mlagents-learn + +**Returns**: + +RunOptions representing the passed in arguments, with trainer config, curriculum and sampler +configs loaded from files. diff --git a/docs/Python-Optimizer-Documentation.md b/docs/Python-Optimizer-Documentation.md new file mode 100644 index 0000000000..9b7e1b993c --- /dev/null +++ b/docs/Python-Optimizer-Documentation.md @@ -0,0 +1,87 @@ +# Table of Contents + +* [mlagents.trainers.optimizer.torch\_optimizer](#mlagents.trainers.optimizer.torch_optimizer) + * [TorchOptimizer](#mlagents.trainers.optimizer.torch_optimizer.TorchOptimizer) + * [create\_reward\_signals](#mlagents.trainers.optimizer.torch_optimizer.TorchOptimizer.create_reward_signals) + * [get\_trajectory\_value\_estimates](#mlagents.trainers.optimizer.torch_optimizer.TorchOptimizer.get_trajectory_value_estimates) +* [mlagents.trainers.optimizer.optimizer](#mlagents.trainers.optimizer.optimizer) + * [Optimizer](#mlagents.trainers.optimizer.optimizer.Optimizer) + * [update](#mlagents.trainers.optimizer.optimizer.Optimizer.update) + + +# mlagents.trainers.optimizer.torch\_optimizer + + +## TorchOptimizer Objects + +```python +class TorchOptimizer(Optimizer) +``` + + +#### create\_reward\_signals + +```python + | create_reward_signals(reward_signal_configs: Dict[RewardSignalType, RewardSignalSettings]) -> None +``` + +Create reward signals + +**Arguments**: + +- `reward_signal_configs`: Reward signal config. + + +#### get\_trajectory\_value\_estimates + +```python + | get_trajectory_value_estimates(batch: AgentBuffer, next_obs: List[np.ndarray], done: bool, agent_id: str = "") -> Tuple[Dict[str, np.ndarray], Dict[str, float], Optional[AgentBufferField]] +``` + +Get value estimates and memories for a trajectory, in batch form. + +**Arguments**: + +- `batch`: An AgentBuffer that consists of a trajectory. +- `next_obs`: the next observation (after the trajectory). Used for boostrapping + if this is not a termiinal trajectory. +- `done`: Set true if this is a terminal trajectory. +- `agent_id`: Agent ID of the agent that this trajectory belongs to. + +**Returns**: + +A Tuple of the Value Estimates as a Dict of [name, np.ndarray(trajectory_len)], + the final value estimate as a Dict of [name, float], and optionally (if using memories) + an AgentBufferField of initial critic memories to be used during update. + + +# mlagents.trainers.optimizer.optimizer + + +## Optimizer Objects + +```python +class Optimizer(abc.ABC) +``` + +Creates loss functions and auxillary networks (e.g. Q or Value) needed for training. +Provides methods to update the Policy. + + +#### update + +```python + | @abc.abstractmethod + | update(batch: AgentBuffer, num_sequences: int) -> Dict[str, float] +``` + +Update the Policy based on the batch that was passed in. + +**Arguments**: + +- `batch`: AgentBuffer that contains the minibatch of data used for this update. +- `num_sequences`: Number of recurrent sequences found in the minibatch. + +**Returns**: + +A Dict containing statistics (name, value) from the update (e.g. loss) diff --git a/docs/Python-PettingZoo-API.md b/docs/Python-PettingZoo-API.md index 78e9439113..9af94d2dbb 100644 --- a/docs/Python-PettingZoo-API.md +++ b/docs/Python-PettingZoo-API.md @@ -7,6 +7,9 @@ interfacing with a Unity environment via Python. ## Installation and Examples +The PettingZoo wrapper is part of the `mlgents_envs` package. Please refer to the +[mlagents_envs installation instructions](ML-Agents-Envs-README.md). + [[Colab] PettingZoo Wrapper Example](https://colab.research.google.com/github/Unity-Technologies/ml-agents/blob/develop-python-api-ga/ml-agents-envs/colabs/Colab_PettingZoo.ipynb) This colab notebook demonstrates the example usage of the wrapper, including installation, diff --git a/docs/Readme.md b/docs/Readme.md index ecac1cd763..c3351187d4 100644 --- a/docs/Readme.md +++ b/docs/Readme.md @@ -1,80 +1,198 @@ -# Unity ML-Agents Toolkit Documentation - -## Installation & Set-up - -- [Installation](Installation.md) - - [Using Virtual Environment](Using-Virtual-Environment.md) - -## Getting Started - -- [Getting Started Guide](Getting-Started.md) -- [ML-Agents Toolkit Overview](ML-Agents-Overview.md) - - [Background: Unity](Background-Unity.md) - - [Background: Machine Learning](Background-Machine-Learning.md) - - [Background: PyTorch](Background-PyTorch.md) -- [Example Environments](Learning-Environment-Examples.md) - -## Creating Learning Environments - -- [Making a New Learning Environment](Learning-Environment-Create-New.md) -- [Designing a Learning Environment](Learning-Environment-Design.md) - - [Designing Agents](Learning-Environment-Design-Agents.md) -- [Using an Executable Environment](Learning-Environment-Executable.md) -- [ML-Agents Package Settings](Package-Settings.md) - -## Training & Inference - -- [Training ML-Agents](Training-ML-Agents.md) - - [Training Configuration File](Training-Configuration-File.md) - - [Using TensorBoard to Observe Training](Using-Tensorboard.md) - - [Profiling Trainers](Profiling-Python.md) -- [Unity Inference Engine](Unity-Inference-Engine.md) - -## Extending ML-Agents - -- [Creating Custom Side Channels](Custom-SideChannels.md) -- [Creating Custom Samplers for Environment Parameter Randomization](Training-ML-Agents.md#defining-a-new-sampler-type) - -## Python Tutorial with Google Colab - -- [Using a UnityEnvironment](https://colab.research.google.com/github/Unity-Technologies/ml-agents/blob/release_19_docs/colab/Colab_UnityEnvironment_1_Run.ipynb) -- [Q-Learning with a UnityEnvironment](https://colab.research.google.com/github/Unity-Technologies/ml-agents/blob/release_19_docs/colab/Colab_UnityEnvironment_2_Train.ipynb) -- [Using Side Channels on a UnityEnvironment](https://colab.research.google.com/github/Unity-Technologies/ml-agents/blob/release_19_docs/colab/Colab_UnityEnvironment_3_SideChannel.ipynb) - -## Help - -- [Migrating from earlier versions of ML-Agents](Migrating.md) -- [Frequently Asked Questions](FAQ.md) -- [ML-Agents Glossary](Glossary.md) -- [Limitations](Limitations.md) - -## API Docs - -- [API Reference](API-Reference.md) -- [Python API Documentation](Python-LLAPI-Documentation.md) -- [How to use the Python API](Python-LLAPI.md) -- [How to use the Unity Environment Registry](Unity-Environment-Registry.md) -- [Wrapping Learning Environment as a Gym (+Baselines/Dopamine Integration)](Python-Gym-API.md) - -## Translations - -To make the Unity ML-Agents Toolkit accessible to the global research and Unity -developer communities, we're attempting to create and maintain translations of -our documentation. We've started with translating a subset of the documentation -to one language (Chinese), but we hope to continue translating more pages and to -other languages. Consequently, we welcome any enhancements and improvements from -the community. - -- [Chinese](localized/zh-CN/) -- [Korean](localized/KR/) - -## Deprecated Docs - -We no longer use them ourselves and so they may not be up-to-date. We've decided -to keep them up just in case they are helpful to you. - -- [Windows Anaconda Installation](Installation-Anaconda-Windows.md) -- [Using Docker](Using-Docker.md) -- [Training on the Cloud with Amazon Web Services](Training-on-Amazon-Web-Service.md) -- [Training on the Cloud with Microsoft Azure](Training-on-Microsoft-Azure.md) -- [Using the Video Recorder](https://github.com/Unity-Technologies/video-recorder) +# Unity ML-Agents Toolkit + +[![docs badge](https://img.shields.io/badge/docs-reference-blue.svg)](https://github.com/Unity-Technologies/ml-agents/tree/release_19_docs/docs/) + +[![license badge](https://img.shields.io/badge/license-Apache--2.0-green.svg)](../LICENSE.md) + +([latest release](https://github.com/Unity-Technologies/ml-agents/releases/tag/latest_release)) +([all releases](https://github.com/Unity-Technologies/ml-agents/releases)) + +**The Unity Machine Learning Agents Toolkit** (ML-Agents) is an open-source +project that enables games and simulations to serve as environments for +training intelligent agents. We provide implementations (based on PyTorch) +of state-of-the-art algorithms to enable game developers and hobbyists to easily +train intelligent agents for 2D, 3D and VR/AR games. Researchers can also use the +provided simple-to-use Python API to train Agents using reinforcement learning, +imitation learning, neuroevolution, or any other methods. These trained agents can be +used for multiple purposes, including controlling NPC behavior (in a variety of +settings such as multi-agent and adversarial), automated testing of game builds +and evaluating different game design decisions pre-release. The ML-Agents +Toolkit is mutually beneficial for both game developers and AI researchers as it +provides a central platform where advances in AI can be evaluated on Unity’s +rich environments and then made accessible to the wider research and game +developer communities. + +## Features +- 17+ [example Unity environments](Learning-Environment-Examples.md) +- Support for multiple environment configurations and training scenarios +- Flexible Unity SDK that can be integrated into your game or custom Unity scene +- Support for training single-agent, multi-agent cooperative, and multi-agent + competitive scenarios via several Deep Reinforcement Learning algorithms (PPO, SAC, MA-POCA, self-play). +- Support for learning from demonstrations through two Imitation Learning algorithms (BC and GAIL). +- Easily definable Curriculum Learning scenarios for complex tasks +- Train robust agents using environment randomization +- Flexible agent control with On Demand Decision Making +- Train using multiple concurrent Unity environment instances +- Utilizes the [Unity Inference Engine](Unity-Inference-Engine.md) to + provide native cross-platform support +- Unity environment [control from Python](Python-LLAPI.md) +- Wrap Unity learning environments as a [gym](Python-Gym-API.md) environment +- Wrap Unity learning environments as a [PettingZoo](Python-PettingZoo-API.md) environment + +See our [ML-Agents Overview](ML-Agents-Overview.md) page for detailed +descriptions of all these features. +## Releases & Documentation + +**Our latest, stable release is `Release 19`. Click +[here](https://github.com/Unity-Technologies/ml-agents/tree/release_19_docs/docs/Readme.md) +to get started with the latest release of ML-Agents.** + +The table below lists all our releases, including our `main` branch which is +under active development and may be unstable. A few helpful guidelines: +- The [Versioning page](Versioning.md) overviews how we manage our GitHub + releases and the versioning process for each of the ML-Agents components. +- The [Releases page](https://github.com/Unity-Technologies/ml-agents/releases) + contains details of the changes between releases. +- The [Migration page](Migrating.md) contains details on how to upgrade + from earlier releases of the ML-Agents Toolkit. +- The **Documentation** links in the table below include installation and usage + instructions specific to each release. Remember to always use the + documentation that corresponds to the release version you're using. +- The `com.unity.ml-agents` package is [verified](https://docs.unity3d.com/2020.1/Documentation/Manual/pack-safe.html) + for Unity 2020.1 and later. Verified packages releases are numbered 1.0.x. + +| **Version** | **Release Date** | **Source** | **Documentation** | **Download** | **Python Package** | **Unity Package** | +|:-------:|:------:|:-------------:|:-------:|:------------:|:------------:|:------------:| +| **main (unstable)** | -- | [source](https://github.com/Unity-Technologies/ml-agents/tree/main) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/main/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/main.zip) | -- | -- | +| **Release 19** | **January 14, 2022** | **[source](https://github.com/Unity-Technologies/ml-agents/tree/release_19)** | **[docs](https://github.com/Unity-Technologies/ml-agents/tree/release_19_docs/docs/Readme.md)** | **[download](https://github.com/Unity-Technologies/ml-agents/archive/release_19.zip)** | **[0.28.0](https://pypi.org/project/mlagents/0.28.0/)** | **[2.2.1](https://docs.unity3d.com/Packages/com.unity.ml-agents@2.2/manual/index.html)** | +| **Verified Package 1.0.8** | **May 26, 2021** | **[source](https://github.com/Unity-Technologies/ml-agents/tree/com.unity.ml-agents_1.0.8)** | **[docs](https://github.com/Unity-Technologies/ml-agents/blob/release_2_verified_docs/docs/Readme.md)** | **[download](https://github.com/Unity-Technologies/ml-agents/archive/com.unity.ml-agents_1.0.8.zip)** | **[0.16.1](https://pypi.org/project/mlagents/0.16.1/)** | **[1.0.8](https://docs.unity3d.com/Packages/com.unity.ml-agents@1.0/manual/index.html)** | + +If you are a researcher interested in a discussion of Unity as an AI platform, +see a pre-print of our +[reference paper on Unity and the ML-Agents Toolkit](https://arxiv.org/abs/1809.02627). + +If you use Unity or the ML-Agents Toolkit to conduct research, we ask that you +cite the following paper as a reference: + +``` +@article{juliani2020, + title={Unity: A general platform for intelligent agents}, + author={Juliani, Arthur and Berges, Vincent-Pierre and Teng, Ervin and Cohen, Andrew and Harper, Jonathan and Elion, Chris and Goy, Chris and Gao, Yuan and Henry, Hunter and Mattar, Marwan and Lange, Danny}, + journal={arXiv preprint arXiv:1809.02627}, + year={2020} +} +``` + +Additionally, if you use the MA-POCA trainer in your research, we ask that you +cite the following paper as a reference: + +``` +@article{cohen2022, + title={On the Use and Misuse of Abosrbing States in Multi-agent Reinforcement Learning}, + author={Cohen, Andrew and Teng, Ervin and Berges, Vincent-Pierre and Dong, Ruo-Ping and Henry, Hunter and Mattar, Marwan and Zook, Alexander and Ganguly, Sujoy}, + journal={RL in Games Workshop AAAI 2022}, + year={2022} +} +``` + + + +## Additional Resources + +We have a Unity Learn course, +[ML-Agents: Hummingbirds](https://learn.unity.com/course/ml-agents-hummingbirds), +that provides a gentle introduction to Unity and the ML-Agents Toolkit. + +We've also partnered with +[CodeMonkeyUnity](https://www.youtube.com/c/CodeMonkeyUnity) to create a +[series of tutorial videos](https://www.youtube.com/playlist?list=PLzDRvYVwl53vehwiN_odYJkPBzcqFw110) +on how to implement and use the ML-Agents Toolkit. + +We have also published a series of blog posts that are relevant for ML-Agents: + +- (July 12, 2021) + [ML-Agents plays Dodgeball](https://blog.unity.com/technology/ml-agents-plays-dodgeball) +- (May 5, 2021) + [ML-Agents v2.0 release: Now supports training complex cooperative behaviors](https://blogs.unity3d.com/2021/05/05/ml-agents-v2-0-release-now-supports-training-complex-cooperative-behaviors/) +- (December 28, 2020) + [Happy holidays from the Unity ML-Agents team!](https://blogs.unity3d.com/2020/12/28/happy-holidays-from-the-unity-ml-agents-team/) +- (November 20, 2020) + [How Eidos-Montréal created Grid Sensors to improve observations for training agents](https://blogs.unity3d.com/2020/11/20/how-eidos-montreal-created-grid-sensors-to-improve-observations-for-training-agents/) +- (November 11, 2020) + [2020 AI@Unity interns shoutout](https://blogs.unity3d.com/2020/11/11/2020-aiunity-interns-shoutout/) +- (May 12, 2020) + [Announcing ML-Agents Unity Package v1.0!](https://blogs.unity3d.com/2020/05/12/announcing-ml-agents-unity-package-v1-0/) +- (February 28, 2020) + [Training intelligent adversaries using self-play with ML-Agents](https://blogs.unity3d.com/2020/02/28/training-intelligent-adversaries-using-self-play-with-ml-agents/) +- (November 11, 2019) + [Training your agents 7 times faster with ML-Agents](https://blogs.unity3d.com/2019/11/11/training-your-agents-7-times-faster-with-ml-agents/) +- (October 21, 2019) + [The AI@Unity interns help shape the world](https://blogs.unity3d.com/2019/10/21/the-aiunity-interns-help-shape-the-world/) +- (April 15, 2019) + [Unity ML-Agents Toolkit v0.8: Faster training on real games](https://blogs.unity3d.com/2019/04/15/unity-ml-agents-toolkit-v0-8-faster-training-on-real-games/) +- (March 1, 2019) + [Unity ML-Agents Toolkit v0.7: A leap towards cross-platform inference](https://blogs.unity3d.com/2019/03/01/unity-ml-agents-toolkit-v0-7-a-leap-towards-cross-platform-inference/) +- (December 17, 2018) + [ML-Agents Toolkit v0.6: Improved usability of Brains and Imitation Learning](https://blogs.unity3d.com/2018/12/17/ml-agents-toolkit-v0-6-improved-usability-of-brains-and-imitation-learning/) +- (October 2, 2018) + [Puppo, The Corgi: Cuteness Overload with the Unity ML-Agents Toolkit](https://blogs.unity3d.com/2018/10/02/puppo-the-corgi-cuteness-overload-with-the-unity-ml-agents-toolkit/) +- (September 11, 2018) + [ML-Agents Toolkit v0.5, new resources for AI researchers available now](https://blogs.unity3d.com/2018/09/11/ml-agents-toolkit-v0-5-new-resources-for-ai-researchers-available-now/) +- (June 26, 2018) + [Solving sparse-reward tasks with Curiosity](https://blogs.unity3d.com/2018/06/26/solving-sparse-reward-tasks-with-curiosity/) +- (June 19, 2018) + [Unity ML-Agents Toolkit v0.4 and Udacity Deep Reinforcement Learning Nanodegree](https://blogs.unity3d.com/2018/06/19/unity-ml-agents-toolkit-v0-4-and-udacity-deep-reinforcement-learning-nanodegree/) +- (May 24, 2018) + [Imitation Learning in Unity: The Workflow](https://blogs.unity3d.com/2018/05/24/imitation-learning-in-unity-the-workflow/) +- (March 15, 2018) + [ML-Agents Toolkit v0.3 Beta released: Imitation Learning, feedback-driven features, and more](https://blogs.unity3d.com/2018/03/15/ml-agents-v0-3-beta-released-imitation-learning-feedback-driven-features-and-more/) +- (December 11, 2017) + [Using Machine Learning Agents in a real game: a beginner’s guide](https://blogs.unity3d.com/2017/12/11/using-machine-learning-agents-in-a-real-game-a-beginners-guide/) +- (December 8, 2017) + [Introducing ML-Agents Toolkit v0.2: Curriculum Learning, new environments, and more](https://blogs.unity3d.com/2017/12/08/introducing-ml-agents-v0-2-curriculum-learning-new-environments-and-more/) +- (September 19, 2017) + [Introducing: Unity Machine Learning Agents Toolkit](https://blogs.unity3d.com/2017/09/19/introducing-unity-machine-learning-agents/) +- Overviewing reinforcement learning concepts + ([multi-armed bandit](https://blogs.unity3d.com/2017/06/26/unity-ai-themed-blog-entries/) + and + [Q-learning](https://blogs.unity3d.com/2017/08/22/unity-ai-reinforcement-learning-with-q-learning/)) + +### More from Unity + +- [Unity Simulation Pro](https://unity.com/products/unity-simulation-pro) +- [Unity Robotics](https://github.com/Unity-Technologies/Unity-Robotics-Hub) +- [Unity Computer Vision](https://unity.com/computer-vision) + +## Community and Feedback + +The ML-Agents Toolkit is an open-source project and we encourage and welcome +contributions. If you wish to contribute, be sure to review our +[contribution guidelines](CONTRIBUTING.md) and +[code of conduct](../CODE_OF_CONDUCT.md). + +For problems with the installation and setup of the ML-Agents Toolkit, or +discussions about how to best setup or train your agents, please create a new +thread on the +[Unity ML-Agents forum](https://forum.unity.com/forums/ml-agents.453/) and make +sure to include as much detail as possible. If you run into any other problems +using the ML-Agents Toolkit or have a specific feature request, please +[submit a GitHub issue](https://github.com/Unity-Technologies/ml-agents/issues). + +Please tell us which samples you would like to see shipped with the ML-Agents Unity +package by replying to +[this forum thread](https://forum.unity.com/threads/feedback-wanted-shipping-sample-s-with-the-ml-agents-package.1073468/). + + +Your opinion matters a great deal to us. Only by hearing your thoughts on the +Unity ML-Agents Toolkit can we continue to improve and grow. Please take a few +minutes to +[let us know about it](https://unitysoftware.co1.qualtrics.com/jfe/form/SV_55pQKCZ578t0kbc). + +For any other questions or feedback, connect directly with the ML-Agents team at +ml-agents@unity3d.com. + +## Privacy + +In order to improve the developer experience for Unity ML-Agents Toolkit, we have added in-editor analytics. +Please refer to "Information that is passively collected by Unity" in the +[Unity Privacy Policy](https://unity3d.com/legal/privacy-policy). diff --git a/docs/Training-Configuration-File.md b/docs/Training-Configuration-File.md index 537bea2f3e..1f4cce5f4d 100644 --- a/docs/Training-Configuration-File.md +++ b/docs/Training-Configuration-File.md @@ -63,6 +63,7 @@ the `trainer` setting above). | `hyperparameters -> epsilon_schedule` | (default = `learning_rate_schedule `) Determines how epsilon changes over time (PPO only).

`linear` decays epsilon linearly, reaching 0 at max_steps, while `constant` keeps the epsilon constant for the entire training run. If not explicitly set, the default epsilon schedule will be set to `hyperparameters -> learning_rate_schedule`. | `hyperparameters -> lambd` | (default = `0.95`) Regularization parameter (lambda) used when calculating the Generalized Advantage Estimate ([GAE](https://arxiv.org/abs/1506.02438)). This can be thought of as how much the agent relies on its current value estimate when calculating an updated value estimate. Low values correspond to relying more on the current value estimate (which can be high bias), and high values correspond to relying more on the actual rewards received in the environment (which can be high variance). The parameter provides a trade-off between the two, and the right value can lead to a more stable training process.

Typical range: `0.9` - `0.95` | | `hyperparameters -> num_epoch` | (default = `3`) Number of passes to make through the experience buffer when performing gradient descent optimization.The larger the batch_size, the larger it is acceptable to make this. Decreasing this will ensure more stable updates, at the cost of slower learning.

Typical range: `3` - `10` | +| `hyperparameters -> shared_critic` | (default = `False`) Whether or not the policy and value function networks share a backbone. It may be useful to use a shared backbone when learning from image observations. ### SAC-specific Configurations @@ -145,7 +146,7 @@ To enable RND, provide these settings: To enable Behavioral Cloning as a pre-training option (assuming you have recorded demonstrations), provide the following configurations under the -`behavior_cloning` section: +`behavioral_cloning` section: | **Setting** | **Description** | | :------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | diff --git a/docs/Training-ML-Agents.md b/docs/Training-ML-Agents.md index b40277fa70..ed25392057 100644 --- a/docs/Training-ML-Agents.md +++ b/docs/Training-ML-Agents.md @@ -284,6 +284,7 @@ behaviors: epsilon_schedule: linear lambd: 0.95 num_epoch: 3 + shared_critic: False # Configuration of the neural network (common to PPO/SAC) network_settings: @@ -509,7 +510,7 @@ Below is a list of the `sampler_type` values supported by the toolkit. - **parameters** - `intervals` The implementation of the samplers can be found in the -[Samplers.cs file](../com.unity.ml-agents/Runtime/Sampler.cs). +[Samplers.cs file](https://github.com/Unity-Technologies/ml-agents/blob/main/com.unity.ml-agents/Runtime/Sampler.cs). ##### Training with Environment Parameter Randomization diff --git a/docs/Training-on-Microsoft-Azure.md b/docs/Training-on-Microsoft-Azure.md index 768797f6f7..f3f647f9c9 100644 --- a/docs/Training-on-Microsoft-Azure.md +++ b/docs/Training-on-Microsoft-Azure.md @@ -33,7 +33,7 @@ view the documentation for doing so [here](#custom-instances). instance, and set it as the working directory. 2. Install the required packages: Torch: `pip3 install torch==1.7.0 -f https://download.pytorch.org/whl/torch_stable.html` and - MLAgents: `python -m pip install mlagents==0.28.0` + MLAgents: `python -m pip install mlagents==0.29.0` ## Testing diff --git a/docs/Tutorial-Custom-Trainer-Plugin.md b/docs/Tutorial-Custom-Trainer-Plugin.md new file mode 100644 index 0000000000..986030611f --- /dev/null +++ b/docs/Tutorial-Custom-Trainer-Plugin.md @@ -0,0 +1,300 @@ +### Step 1: Write your custom trainer class +Before you start writing your code, make sure to use your favorite environment management tool(e.g. `venv` or `conda`) to create and activate a Python virtual environment. The following command uses `conda`, but other tools work similarly: +```shell +conda create -n trainer-env python=3.8.13 +conda activate trainer-env +``` + +Users of the plug-in system are responsible for implementing the trainer class subject to the API standard. Let us follow an example by implementing a custom trainer named "YourCustomTrainer". You can either extend `OnPolicyTrainer` or `OffPolicyTrainer` classes depending on the training strategies you choose. + +Please refer to the internal [PPO implementation](../ml-agents/mlagents/trainers/ppo/trainer.py) for a complete code example. We will not provide a workable code in the document. The purpose of the tutorial is to introduce you to the core components and interfaces of our plugin framework. We use code snippets and patterns to demonstrate the control and data flow. + +Your custom trainers are responsible for collecting experiences and training the models. Your custom trainer class acts like a co-ordinator to the policy and optimizer. To start implementing methods in the class, create a policy class objects from method `create_policy`: + + +```python +def create_policy( + self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec +) -> TorchPolicy: + + actor_cls: Union[Type[SimpleActor], Type[SharedActorCritic]] = SimpleActor + actor_kwargs: Dict[str, Any] = { + "conditional_sigma": False, + "tanh_squash": False, + } + if self.shared_critic: + reward_signal_configs = self.trainer_settings.reward_signals + reward_signal_names = [ + key.value for key, _ in reward_signal_configs.items() + ] + actor_cls = SharedActorCritic + actor_kwargs.update({"stream_names": reward_signal_names}) + + policy = TorchPolicy( + self.seed, + behavior_spec, + self.trainer_settings.network_settings, + actor_cls, + actor_kwargs, + ) + return policy + +``` + +Depending on whether you use shared or separate network architecture for your policy, we provide `SimpleActor` and `SharedActorCritic` from `mlagents.trainers.torch_entities.networks` that you can choose from. In our example above, we use a `SimpleActor`. + +Next, create an optimizer class object from `create_optimizer` method and connect it to the policy object you created above: + + +```python +def create_optimizer(self) -> TorchOptimizer: + return TorchPPOOptimizer( # type: ignore + cast(TorchPolicy, self.policy), self.trainer_settings # type: ignore + ) # type: ignore + +``` + +There are a couple of abstract methods(`_process_trajectory` and `_update_policy`) inherited from `RLTrainer` that you need to implement in your custom trainer class. `_process_trajectory` takes a trajectory and processes it, putting it into the update buffer. Processing involves calculating value and advantage targets for the model updating step. Given input `trajectory: Trajectory`, users are responsible for processing the data in the trajectory and append `agent_buffer_trajectory` to the back of the update buffer by calling `self._append_to_update_buffer(agent_buffer_trajectory)`, whose output will be used in updating the model in `optimizer` class. + +A typical `_process_trajectory` function(incomplete) will convert a trajectory object to an agent buffer then get all value estimates from the trajectory by calling `self.optimizer.get_trajectory_value_estimates`. From the returned dictionary of value estimates we extract reward signals keyed by their names: + +```python +def _process_trajectory(self, trajectory: Trajectory) -> None: + super()._process_trajectory(trajectory) + agent_id = trajectory.agent_id # All the agents should have the same ID + + agent_buffer_trajectory = trajectory.to_agentbuffer() + + # Get all value estimates + ( + value_estimates, + value_next, + value_memories, + ) = self.optimizer.get_trajectory_value_estimates( + agent_buffer_trajectory, + trajectory.next_obs, + trajectory.done_reached and not trajectory.interrupted, + ) + + for name, v in value_estimates.items(): + agent_buffer_trajectory[RewardSignalUtil.value_estimates_key(name)].extend( + v + ) + self._stats_reporter.add_stat( + f"Policy/{self.optimizer.reward_signals[name].name.capitalize()} Value Estimate", + np.mean(v), + ) + + # Evaluate all reward functions + self.collected_rewards["environment"][agent_id] += np.sum( + agent_buffer_trajectory[BufferKey.ENVIRONMENT_REWARDS] + ) + for name, reward_signal in self.optimizer.reward_signals.items(): + evaluate_result = ( + reward_signal.evaluate(agent_buffer_trajectory) * reward_signal.strength + ) + agent_buffer_trajectory[RewardSignalUtil.rewards_key(name)].extend( + evaluate_result + ) + # Report the reward signals + self.collected_rewards[name][agent_id] += np.sum(evaluate_result) + + self._append_to_update_buffer(agent_buffer_trajectory) + +``` + +A trajectory will be a list of dictionaries of strings mapped to `Anything`. When calling `forward` on a policy, the argument will include an “experience” dictionary from the last step. The `forward` method will generate an action and the next “experience” dictionary. Examples of fields in the “experience” dictionary include observation, action, reward, done status, group_reward, LSTM memory state, etc. + + + +### Step 2: implement your custom optimizer for the trainer. +We will show you an example we implemented - `class TorchPPOOptimizer(TorchOptimizer)`, which takes a Policy and a Dict of trainer parameters and creates an Optimizer that connects to the policy. Your optimizer should include a value estimator and a loss function in the `update` method. + +Before writing your optimizer class, first define setting class `class PPOSettings(OnPolicyHyperparamSettings)` for your custom optimizer: + + + +```python +class PPOSettings(OnPolicyHyperparamSettings): + beta: float = 5.0e-3 + epsilon: float = 0.2 + lambd: float = 0.95 + num_epoch: int = 3 + shared_critic: bool = False + learning_rate_schedule: ScheduleType = ScheduleType.LINEAR + beta_schedule: ScheduleType = ScheduleType.LINEAR + epsilon_schedule: ScheduleType = ScheduleType.LINEAR + +``` + +You should implement `update` function following interface: + + +```python +def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: + +``` + +In which losses and other metrics are calculated from an `AgentBuffer` that is generated from your trainer class, depending on which model you choose to implement the loss functions will be different. In our case we calculate value loss from critic and trust region policy loss. A typical pattern(incomplete) of the calculations will look like the following: + + +```python +run_out = self.policy.actor.get_stats( + current_obs, + actions, + masks=act_masks, + memories=memories, + sequence_length=self.policy.sequence_length, +) + +log_probs = run_out["log_probs"] +entropy = run_out["entropy"] + +values, _ = self.critic.critic_pass( + current_obs, + memories=value_memories, + sequence_length=self.policy.sequence_length, +) +policy_loss = ModelUtils.trust_region_policy_loss( + ModelUtils.list_to_tensor(batch[BufferKey.ADVANTAGES]), + log_probs, + old_log_probs, + loss_masks, + decay_eps, +) +loss = ( + policy_loss + + 0.5 * value_loss + - decay_bet * ModelUtils.masked_mean(entropy, loss_masks) +) + +``` + +Finally update the model and return the a dictionary including calculated losses and updated decay learning rate: + + +```python +ModelUtils.update_learning_rate(self.optimizer, decay_lr) +self.optimizer.zero_grad() +loss.backward() + +self.optimizer.step() +update_stats = { + "Losses/Policy Loss": torch.abs(policy_loss).item(), + "Losses/Value Loss": value_loss.item(), + "Policy/Learning Rate": decay_lr, + "Policy/Epsilon": decay_eps, + "Policy/Beta": decay_bet, +} + +``` + +### Step 3: Integrate your custom trainer into the plugin system + +By integrating a custom trainer into the plugin system, a user can use their published packages which have their implementations. To do that, you need to add a setup.py file. In the call to setup(), you'll need to add to the entry_points dictionary for each plugin interface that you implement. The form of this is {entry point name}={plugin module}:{plugin function}. For example: + + + +```python +entry_points={ + ML_AGENTS_TRAINER_TYPE: [ + "your_trainer_type=your_package.your_custom_trainer:get_type_and_setting" + ] + }, +``` + +Some key elements in the code: + +``` +ML_AGENTS_TRAINER_TYPE: a string constant for trainer type +your_trainer_type: name your trainer type, used in configuration file +your_package: your pip installable package containing custom trainer implementation +``` + +Also define `get_type_and_setting` method in `YourCustomTrainer` class: + + +```python +def get_type_and_setting(): + return {YourCustomTrainer.get_trainer_name(): YourCustomTrainer}, { + YourCustomTrainer.get_trainer_name(): YourCustomSetting + } + +``` + +Finally, specify trainer type in the config file: + + +```python +behaviors: + 3DBall: + trainer_type: your_trainer_type +... +``` + +### Step 4: Install your custom trainer and run training: +Before installing your custom trainer package, make sure you have `ml-agents-env` and `ml-agents` installed + +```shell +pip3 install -e ./ml-agents-envs && pip3 install -e ./ml-agents +``` + +Install your cutom trainer package(if your package is pip installable): +```shell +pip3 install your_custom_package +``` +Or follow our internal implementations: +```shell +pip3 install -e ./ml-agents-trainer-plugin +``` + +Following the previous installations your package is added as an entrypoint and you can use a config file with new +trainers: +```shell +mlagents-learn ml-agents-trainer-plugin/mlagents_trainer_plugin/a2c/a2c_3DBall.yaml --run-id +--env +``` + +### Validate your implementations: +Create a clean Python environment with Python 3.8+ and activate it before you start, if you haven't done so already: +```shell +conda create -n trainer-env python=3.8.13 +conda activate trainer-env +``` + +Make sure you follow previous steps and install all required packages. We are testing internal implementations in this tutorial, but ML-Agents users can run similar validations once they have their own implementations installed: +```shell +pip3 install -e ./ml-agents-envs && pip3 install -e ./ml-agents +pip3 install -e ./ml-agents-trainer-plugin +``` +Once your package is added as an `entrypoint`, you can add to the config file the new trainer type. Check if trainer type is specified in the config file `a2c_3DBall.yaml`: +``` +trainer_type: a2c +``` + +Test if custom trainer package is installed by running: +```shell +mlagents-learn ml-agents-trainer-plugin/mlagents_trainer_plugin/a2c/a2c_3DBall.yaml --run-id test-trainer +``` + +You can also list all trainers installed in the registry. Type `python` in your shell to open a REPL session. Run the python code below, you should be able to see all trainer types currently installed: +```python +>>> import pkg_resources +>>> for entry in pkg_resources.iter_entry_points('mlagents.trainer_type'): +... print(entry) +... +default = mlagents.plugins.trainer_type:get_default_trainer_types +a2c = mlagents_trainer_plugin.a2c.a2c_trainer:get_type_and_setting +dqn = mlagents_trainer_plugin.dqn.dqn_trainer:get_type_and_setting +``` + +If it is properly installed, you will see Unity logo and message indicating training will start: +``` +[INFO] Listening on port 5004. Start training by pressing the Play button in the Unity Editor. +``` + +If you see the following error message, it could be due to trainer type is wrong or the trainer type specified is not installed: +```shell +mlagents.trainers.exception.TrainerConfigError: Invalid trainer type a2c was found +``` + diff --git a/docs/Using-Virtual-Environment.md b/docs/Using-Virtual-Environment.md index 97460a2f06..5ec0592a54 100644 --- a/docs/Using-Virtual-Environment.md +++ b/docs/Using-Virtual-Environment.md @@ -18,7 +18,7 @@ from dependencies of other projects. This has a few advantages: with the different version. ## Python Version Requirement (Required) -This guide has been tested with Python 3.7.2 through Python 3.9.9. Newer versions might not +This guide has been tested with Python 3.8.13 through Python 3.10.x. Newer versions might not have support for the dependent libraries, so are not recommended. ## Installing Pip (Required) @@ -63,7 +63,7 @@ then python3-distutils needs to be installed. Install python3-distutils using environment using the same `activate` command listed above) Note: -- Verify that you are using a Python version between 3.7.2 and 3.9.9. Launch a +- Verify that you are using a Python version between 3.8.13 and 3.10.x. Launch a command prompt using `cmd` and execute `python --version` to verify the version. - Python3 installation may require admin privileges on Windows. - This guide is for Windows 10 using a 64-bit architecture only. diff --git a/docs/com.unity.ml-agents.md b/docs/com.unity.ml-agents.md new file mode 100644 index 0000000000..3d3d806687 --- /dev/null +++ b/docs/com.unity.ml-agents.md @@ -0,0 +1 @@ +{!../com.unity.ml-agents/Documentation~/com.unity.ml-agents.md!} diff --git a/docs/doxygen/unity.css b/docs/doxygen/unity.css index eb46e20638..3c3b9d18da 100644 --- a/docs/doxygen/unity.css +++ b/docs/doxygen/unity.css @@ -169,7 +169,7 @@ input.blue-btn, input.gray-btn { padding: 0 20px 4px 20px; } .gray-btn:hover { color: #fff; background-color: #222c37; } .bbtn { height: 50px; line-height: 50px; padding: 0 40px !important; font-size: 1.0em; } .sbtn { height: 24px; line-height: 24px; padding: 0 10px !important; font-size: 0.75em; } -.dbtn, .dbtn:hover, .dbtn:active { cursor: default; background-color: #ccc; color: #f0f0f0; background-color: #ccc; } +.dbtn, .dbtn:hover, .dbtn:active { cursor: default; background-color: #ccc; color: #f0f0f0; } .centerbtn { float: none; display: inline-block; margin: 0; } /**************************************** @@ -400,7 +400,7 @@ div.content-wrap { width: 480px; float: left; margin: 0; } div.content-block { margin: 0; } } -@media only screen and (-moz-min-device-pixel-ratio: 2), only screen and (-o-min-device-pixel-ratio: 2/1), only screen and (-webkit-min-device-pixel-ratio: 2), only screen and (min-device-pixel-ratio: 2) { +@media only screen and (-moz-min-device-pixel-ratio: 2), only screen and (-o-min-device-pixel-ratio: 2/1), only screen and (-webkit-min-device-pixel-ratio: 2) { div.header .menu .logo a, div.header .more ul li a, div.toolbar div.script-lang div.dialog div.close, div.lang-switcher div.current div.arrow, div.sidebar-menu ul li div.arrow { background-image: url(../images/sprites@2x.png); -webkit-background-size: 500px 250px; -moz-background-size: 500px 250px; -o-background-size: 500px 250px; background-size: 500px 250px; } input[type="text"].error, input[type="tel"].error, input[type="email"].error, input[type="password"].error, textarea.error { background-image: url(../images/error-red.png); -webkit-background-size: 24px 12px; -moz-background-size: 24px 12px; background-size: 24px 12px; } diff --git a/docs/extra.css b/docs/extra.css new file mode 100644 index 0000000000..3c5c35ed4b --- /dev/null +++ b/docs/extra.css @@ -0,0 +1,3 @@ +.wy-nav-top, .wy-side-nav-search { + background: #439b47; +} diff --git a/docs/images/U_MachineLearningAgents_Logo_Black_RGB.png b/docs/images/U_MachineLearningAgents_Logo_Black_RGB.png new file mode 100644 index 0000000000..88e2173ac4 Binary files /dev/null and b/docs/images/U_MachineLearningAgents_Logo_Black_RGB.png differ diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000000..7eb39d98ff --- /dev/null +++ b/docs/index.md @@ -0,0 +1,2 @@ + +{!README.md!} diff --git a/docs/localized/KR/README.md b/localized_docs/KR/README.md similarity index 100% rename from docs/localized/KR/README.md rename to localized_docs/KR/README.md diff --git a/docs/localized/KR/docs/Installation-Anaconda-Windows.md b/localized_docs/KR/docs/Installation-Anaconda-Windows.md similarity index 98% rename from docs/localized/KR/docs/Installation-Anaconda-Windows.md rename to localized_docs/KR/docs/Installation-Anaconda-Windows.md index 1ebaf8897e..71dc8e4870 100644 --- a/docs/localized/KR/docs/Installation-Anaconda-Windows.md +++ b/localized_docs/KR/docs/Installation-Anaconda-Windows.md @@ -12,8 +12,8 @@ ML-Agents toolkit�� ����ϱ� ����, �Ʒ��� ��� Windows ������ Anaconda�� [�ٿ�ε�](https://www.anaconda.com/download/#windows)�ϰ� ��ġ�Ͻʽÿ�. Anaconda�� ��������ν�, �ٸ� ���� ������ Python�� �и��� ȯ�濡�� ������ �� �ֽ��ϴ�. -Python 2�� ���̻� �������� �ʱ� ������ Python 3.7�� �ʿ��մϴ�. �� ���̵忡�� �츮�� -Python 3.7 ������ Anaconda 5.1 ������ ����� ���Դϴ�. +Python 2�� ���̻� �������� �ʱ� ������ Python 3.8�� �ʿ��մϴ�. �� ���̵忡�� �츮�� +Python 3.8 ������ Anaconda 5.1 ������ ����� ���Դϴ�. ([64-bit](https://repo.continuum.io/archive/Anaconda3-5.1.0-Windows-x86_64.exe) �Ǵ� [32-bit](https://repo.continuum.io/archive/Anaconda3-5.1.0-Windows-x86.exe) ��ũ). @@ -65,11 +65,11 @@ ML-Agents toolkit�� �Բ� ����� ���ο� [Conda ȯ��] ��ɾ Ÿ���� �Ͻʽÿ�: ```sh -conda create -n ml-agents python=3.7 +conda create -n ml-agents python=3.8 ``` �� ��Ű���� ��ġ�ϱ� ���� �޼����� ���� ��� `y`�� Ÿ�����ϰ� ���͸� �����ʽÿ� _(���ͳ��� ����Ǿ��ִ��� Ȯ���Ͻʽÿ�)_. -�� �䱸�Ǵ� ��Ű������ �ݵ�� ��ġ�ؾ� �մϴ�. ���ο� Conda ȯ�濡�� Python 3.7 ������ ���Ǹ� ml-agents�� ȣ��˴ϴ�. +�� �䱸�Ǵ� ��Ű������ �ݵ�� ��ġ�ؾ� �մϴ�. ���ο� Conda ȯ�濡�� Python 3.8 ������ ���Ǹ� ml-agents�� ȣ��˴ϴ�.

Anaconda Install diff --git a/docs/localized/KR/docs/Installation.md b/localized_docs/KR/docs/Installation.md similarity index 96% rename from docs/localized/KR/docs/Installation.md rename to localized_docs/KR/docs/Installation.md index 5bf25a8389..633c2ea671 100644 --- a/docs/localized/KR/docs/Installation.md +++ b/localized_docs/KR/docs/Installation.md @@ -40,13 +40,13 @@ git clone https://github.com/Unity-Technologies/ml-agents.git ### 파이썬과 mlagents 패키지 설치 -ML-Agents toolkit을 사용하기 위해 [setup.py file](../ml-agents/setup.py)에 나열된 종속성과 함께 파이썬 3.7이 필요합니다. +ML-Agents toolkit을 사용하기 위해 [setup.py file](../ml-agents/setup.py)에 나열된 종속성과 함께 파이썬 3.8이 필요합니다. 주요 종속성의 일부는 다음을 포함합니다: - [TensorFlow](Background-TensorFlow.md) (Requires a CPU w/ AVX support) - [Jupyter](Background-Jupyter.md) -Python 3.7이 만약 설치되어 있지 않다면, [다운로드](https://www.python.org/downloads/)하고 설치하십시오. +Python 3.8이 만약 설치되어 있지 않다면, [다운로드](https://www.python.org/downloads/)하고 설치하십시오. 만약 당신의 파이썬 환경이 `pip3`을 포함하지 않는다면, 다음 [지시사항](https://packaging.python.org/guides/installing-using-linux-tools/#installing-pip-setuptools-wheel-with-linux-package-managers) @@ -64,7 +64,7 @@ pip3 install mlagents **주의:** -- 현재 Python 3.7 또는 Python 3.5을 지원하지 않습니다. +- 현재 Python 3.8 또는 Python 3.5을 지원하지 않습니다. - 만약 Anaconda를 사용하고 TensorFlow에 문제가 있다면, 다음 [링크](https://www.tensorflow.org/install/pip)에서 Anaconda 환경에서 어떻게 TensorFlow를 설치하는지 확인하십시오. ### 개발을 위한 설치방법 diff --git a/docs/localized/KR/docs/Training-Imitation-Learning.md b/localized_docs/KR/docs/Training-Imitation-Learning.md similarity index 100% rename from docs/localized/KR/docs/Training-Imitation-Learning.md rename to localized_docs/KR/docs/Training-Imitation-Learning.md diff --git a/docs/localized/KR/docs/Training-PPO.md b/localized_docs/KR/docs/Training-PPO.md similarity index 100% rename from docs/localized/KR/docs/Training-PPO.md rename to localized_docs/KR/docs/Training-PPO.md diff --git a/docs/localized/KR/docs/Using-Docker.md b/localized_docs/KR/docs/Using-Docker.md similarity index 100% rename from docs/localized/KR/docs/Using-Docker.md rename to localized_docs/KR/docs/Using-Docker.md diff --git a/docs/localized/KR/docs/images/3dball_big.png b/localized_docs/KR/docs/images/3dball_big.png similarity index 100% rename from docs/localized/KR/docs/images/3dball_big.png rename to localized_docs/KR/docs/images/3dball_big.png diff --git a/docs/localized/KR/docs/images/3dball_learning_brain.png b/localized_docs/KR/docs/images/3dball_learning_brain.png similarity index 100% rename from docs/localized/KR/docs/images/3dball_learning_brain.png rename to localized_docs/KR/docs/images/3dball_learning_brain.png diff --git a/docs/localized/KR/docs/images/3dball_small.png b/localized_docs/KR/docs/images/3dball_small.png similarity index 100% rename from docs/localized/KR/docs/images/3dball_small.png rename to localized_docs/KR/docs/images/3dball_small.png diff --git a/docs/localized/KR/docs/images/TensorBoard-download.png b/localized_docs/KR/docs/images/TensorBoard-download.png similarity index 100% rename from docs/localized/KR/docs/images/TensorBoard-download.png rename to localized_docs/KR/docs/images/TensorBoard-download.png diff --git a/docs/localized/KR/docs/images/anaconda_default.PNG b/localized_docs/KR/docs/images/anaconda_default.PNG similarity index 100% rename from docs/localized/KR/docs/images/anaconda_default.PNG rename to localized_docs/KR/docs/images/anaconda_default.PNG diff --git a/docs/localized/KR/docs/images/anaconda_install.PNG b/localized_docs/KR/docs/images/anaconda_install.PNG similarity index 100% rename from docs/localized/KR/docs/images/anaconda_install.PNG rename to localized_docs/KR/docs/images/anaconda_install.PNG diff --git a/docs/localized/KR/docs/images/balance.png b/localized_docs/KR/docs/images/balance.png similarity index 100% rename from docs/localized/KR/docs/images/balance.png rename to localized_docs/KR/docs/images/balance.png diff --git a/docs/localized/KR/docs/images/banner.png b/localized_docs/KR/docs/images/banner.png similarity index 100% rename from docs/localized/KR/docs/images/banner.png rename to localized_docs/KR/docs/images/banner.png diff --git a/docs/localized/KR/docs/images/basic.png b/localized_docs/KR/docs/images/basic.png similarity index 100% rename from docs/localized/KR/docs/images/basic.png rename to localized_docs/KR/docs/images/basic.png diff --git a/docs/localized/KR/docs/images/bouncer.png b/localized_docs/KR/docs/images/bouncer.png similarity index 100% rename from docs/localized/KR/docs/images/bouncer.png rename to localized_docs/KR/docs/images/bouncer.png diff --git a/docs/localized/KR/docs/images/conda_new.PNG b/localized_docs/KR/docs/images/conda_new.PNG similarity index 100% rename from docs/localized/KR/docs/images/conda_new.PNG rename to localized_docs/KR/docs/images/conda_new.PNG diff --git a/docs/localized/KR/docs/images/crawler.png b/localized_docs/KR/docs/images/crawler.png similarity index 100% rename from docs/localized/KR/docs/images/crawler.png rename to localized_docs/KR/docs/images/crawler.png diff --git a/docs/localized/KR/docs/images/cuDNN_membership_required.png b/localized_docs/KR/docs/images/cuDNN_membership_required.png similarity index 100% rename from docs/localized/KR/docs/images/cuDNN_membership_required.png rename to localized_docs/KR/docs/images/cuDNN_membership_required.png diff --git a/docs/localized/KR/docs/images/cuda_toolkit_directory.PNG b/localized_docs/KR/docs/images/cuda_toolkit_directory.PNG similarity index 100% rename from docs/localized/KR/docs/images/cuda_toolkit_directory.PNG rename to localized_docs/KR/docs/images/cuda_toolkit_directory.PNG diff --git a/docs/localized/KR/docs/images/cudnn_zip_files.PNG b/localized_docs/KR/docs/images/cudnn_zip_files.PNG similarity index 100% rename from docs/localized/KR/docs/images/cudnn_zip_files.PNG rename to localized_docs/KR/docs/images/cudnn_zip_files.PNG diff --git a/docs/localized/KR/docs/images/curriculum.png b/localized_docs/KR/docs/images/curriculum.png similarity index 100% rename from docs/localized/KR/docs/images/curriculum.png rename to localized_docs/KR/docs/images/curriculum.png diff --git a/docs/localized/KR/docs/images/demo_component.png b/localized_docs/KR/docs/images/demo_component.png similarity index 100% rename from docs/localized/KR/docs/images/demo_component.png rename to localized_docs/KR/docs/images/demo_component.png diff --git a/docs/localized/KR/docs/images/demo_inspector.png b/localized_docs/KR/docs/images/demo_inspector.png similarity index 100% rename from docs/localized/KR/docs/images/demo_inspector.png rename to localized_docs/KR/docs/images/demo_inspector.png diff --git a/docs/localized/KR/docs/images/docker_build_settings.png b/localized_docs/KR/docs/images/docker_build_settings.png similarity index 100% rename from docs/localized/KR/docs/images/docker_build_settings.png rename to localized_docs/KR/docs/images/docker_build_settings.png diff --git a/docs/localized/KR/docs/images/edit_env_var.png b/localized_docs/KR/docs/images/edit_env_var.png similarity index 100% rename from docs/localized/KR/docs/images/edit_env_var.png rename to localized_docs/KR/docs/images/edit_env_var.png diff --git a/docs/localized/KR/docs/images/example-envs.png b/localized_docs/KR/docs/images/example-envs.png similarity index 100% rename from docs/localized/KR/docs/images/example-envs.png rename to localized_docs/KR/docs/images/example-envs.png diff --git a/docs/localized/KR/docs/images/foodCollector.png b/localized_docs/KR/docs/images/foodCollector.png similarity index 100% rename from docs/localized/KR/docs/images/foodCollector.png rename to localized_docs/KR/docs/images/foodCollector.png diff --git a/docs/localized/KR/docs/images/gridworld.png b/localized_docs/KR/docs/images/gridworld.png similarity index 100% rename from docs/localized/KR/docs/images/gridworld.png rename to localized_docs/KR/docs/images/gridworld.png diff --git a/docs/localized/KR/docs/images/hallway.png b/localized_docs/KR/docs/images/hallway.png similarity index 100% rename from docs/localized/KR/docs/images/hallway.png rename to localized_docs/KR/docs/images/hallway.png diff --git a/docs/localized/KR/docs/images/image-banner.png b/localized_docs/KR/docs/images/image-banner.png similarity index 100% rename from docs/localized/KR/docs/images/image-banner.png rename to localized_docs/KR/docs/images/image-banner.png diff --git a/docs/localized/KR/docs/images/learning_environment_basic.png b/localized_docs/KR/docs/images/learning_environment_basic.png similarity index 100% rename from docs/localized/KR/docs/images/learning_environment_basic.png rename to localized_docs/KR/docs/images/learning_environment_basic.png diff --git a/docs/localized/KR/docs/images/learning_environment_example.png b/localized_docs/KR/docs/images/learning_environment_example.png similarity index 100% rename from docs/localized/KR/docs/images/learning_environment_example.png rename to localized_docs/KR/docs/images/learning_environment_example.png diff --git a/docs/localized/KR/docs/images/learning_environment_full.png b/localized_docs/KR/docs/images/learning_environment_full.png similarity index 100% rename from docs/localized/KR/docs/images/learning_environment_full.png rename to localized_docs/KR/docs/images/learning_environment_full.png diff --git a/docs/localized/KR/docs/images/match3.png b/localized_docs/KR/docs/images/match3.png similarity index 100% rename from docs/localized/KR/docs/images/match3.png rename to localized_docs/KR/docs/images/match3.png diff --git a/docs/localized/KR/docs/images/math.png b/localized_docs/KR/docs/images/math.png similarity index 100% rename from docs/localized/KR/docs/images/math.png rename to localized_docs/KR/docs/images/math.png diff --git a/docs/localized/KR/docs/images/ml-agents-LSTM.png b/localized_docs/KR/docs/images/ml-agents-LSTM.png similarity index 100% rename from docs/localized/KR/docs/images/ml-agents-LSTM.png rename to localized_docs/KR/docs/images/ml-agents-LSTM.png diff --git a/docs/localized/KR/docs/images/mlagents-3DBallHierarchy.png b/localized_docs/KR/docs/images/mlagents-3DBallHierarchy.png similarity index 100% rename from docs/localized/KR/docs/images/mlagents-3DBallHierarchy.png rename to localized_docs/KR/docs/images/mlagents-3DBallHierarchy.png diff --git a/docs/localized/KR/docs/images/mlagents-BuildWindow.png b/localized_docs/KR/docs/images/mlagents-BuildWindow.png similarity index 100% rename from docs/localized/KR/docs/images/mlagents-BuildWindow.png rename to localized_docs/KR/docs/images/mlagents-BuildWindow.png diff --git a/docs/localized/KR/docs/images/mlagents-ImitationAndRL.png b/localized_docs/KR/docs/images/mlagents-ImitationAndRL.png similarity index 100% rename from docs/localized/KR/docs/images/mlagents-ImitationAndRL.png rename to localized_docs/KR/docs/images/mlagents-ImitationAndRL.png diff --git a/docs/localized/KR/docs/images/mlagents-NewTutSplash.png b/localized_docs/KR/docs/images/mlagents-NewTutSplash.png similarity index 100% rename from docs/localized/KR/docs/images/mlagents-NewTutSplash.png rename to localized_docs/KR/docs/images/mlagents-NewTutSplash.png diff --git a/docs/localized/KR/docs/images/mlagents-Open3DBall.png b/localized_docs/KR/docs/images/mlagents-Open3DBall.png similarity index 100% rename from docs/localized/KR/docs/images/mlagents-Open3DBall.png rename to localized_docs/KR/docs/images/mlagents-Open3DBall.png diff --git a/docs/localized/KR/docs/images/mlagents-RollerAgentStats.png b/localized_docs/KR/docs/images/mlagents-RollerAgentStats.png similarity index 100% rename from docs/localized/KR/docs/images/mlagents-RollerAgentStats.png rename to localized_docs/KR/docs/images/mlagents-RollerAgentStats.png diff --git a/docs/localized/KR/docs/images/mlagents-TensorBoard.png b/localized_docs/KR/docs/images/mlagents-TensorBoard.png similarity index 100% rename from docs/localized/KR/docs/images/mlagents-TensorBoard.png rename to localized_docs/KR/docs/images/mlagents-TensorBoard.png diff --git a/docs/localized/KR/docs/images/new_system_variable.PNG b/localized_docs/KR/docs/images/new_system_variable.PNG similarity index 100% rename from docs/localized/KR/docs/images/new_system_variable.PNG rename to localized_docs/KR/docs/images/new_system_variable.PNG diff --git a/docs/localized/KR/docs/images/path_variables.PNG b/localized_docs/KR/docs/images/path_variables.PNG similarity index 100% rename from docs/localized/KR/docs/images/path_variables.PNG rename to localized_docs/KR/docs/images/path_variables.PNG diff --git a/docs/localized/KR/docs/images/platform_prefab.png b/localized_docs/KR/docs/images/platform_prefab.png similarity index 100% rename from docs/localized/KR/docs/images/platform_prefab.png rename to localized_docs/KR/docs/images/platform_prefab.png diff --git a/docs/localized/KR/docs/images/push.png b/localized_docs/KR/docs/images/push.png similarity index 100% rename from docs/localized/KR/docs/images/push.png rename to localized_docs/KR/docs/images/push.png diff --git a/docs/localized/KR/docs/images/pyramids.png b/localized_docs/KR/docs/images/pyramids.png similarity index 100% rename from docs/localized/KR/docs/images/pyramids.png rename to localized_docs/KR/docs/images/pyramids.png diff --git a/docs/localized/KR/docs/images/ray_perception.png b/localized_docs/KR/docs/images/ray_perception.png similarity index 100% rename from docs/localized/KR/docs/images/ray_perception.png rename to localized_docs/KR/docs/images/ray_perception.png diff --git a/docs/localized/KR/docs/images/reacher.png b/localized_docs/KR/docs/images/reacher.png similarity index 100% rename from docs/localized/KR/docs/images/reacher.png rename to localized_docs/KR/docs/images/reacher.png diff --git a/docs/localized/KR/docs/images/rl_cycle.png b/localized_docs/KR/docs/images/rl_cycle.png similarity index 100% rename from docs/localized/KR/docs/images/rl_cycle.png rename to localized_docs/KR/docs/images/rl_cycle.png diff --git a/docs/localized/KR/docs/images/roller-ball-agent.png b/localized_docs/KR/docs/images/roller-ball-agent.png similarity index 100% rename from docs/localized/KR/docs/images/roller-ball-agent.png rename to localized_docs/KR/docs/images/roller-ball-agent.png diff --git a/docs/localized/KR/docs/images/roller-ball-floor.png b/localized_docs/KR/docs/images/roller-ball-floor.png similarity index 100% rename from docs/localized/KR/docs/images/roller-ball-floor.png rename to localized_docs/KR/docs/images/roller-ball-floor.png diff --git a/docs/localized/KR/docs/images/roller-ball-hierarchy.png b/localized_docs/KR/docs/images/roller-ball-hierarchy.png similarity index 100% rename from docs/localized/KR/docs/images/roller-ball-hierarchy.png rename to localized_docs/KR/docs/images/roller-ball-hierarchy.png diff --git a/docs/localized/KR/docs/images/roller-ball-projects.png b/localized_docs/KR/docs/images/roller-ball-projects.png similarity index 100% rename from docs/localized/KR/docs/images/roller-ball-projects.png rename to localized_docs/KR/docs/images/roller-ball-projects.png diff --git a/docs/localized/KR/docs/images/roller-ball-target.png b/localized_docs/KR/docs/images/roller-ball-target.png similarity index 100% rename from docs/localized/KR/docs/images/roller-ball-target.png rename to localized_docs/KR/docs/images/roller-ball-target.png diff --git a/docs/localized/KR/docs/images/soccer.png b/localized_docs/KR/docs/images/soccer.png similarity index 100% rename from docs/localized/KR/docs/images/soccer.png rename to localized_docs/KR/docs/images/soccer.png diff --git a/docs/localized/KR/docs/images/strikersvsgoalie.png b/localized_docs/KR/docs/images/strikersvsgoalie.png similarity index 100% rename from docs/localized/KR/docs/images/strikersvsgoalie.png rename to localized_docs/KR/docs/images/strikersvsgoalie.png diff --git a/docs/localized/KR/docs/images/system_variable_name_value.PNG b/localized_docs/KR/docs/images/system_variable_name_value.PNG similarity index 100% rename from docs/localized/KR/docs/images/system_variable_name_value.PNG rename to localized_docs/KR/docs/images/system_variable_name_value.PNG diff --git a/docs/localized/KR/docs/images/team_id.png b/localized_docs/KR/docs/images/team_id.png similarity index 100% rename from docs/localized/KR/docs/images/team_id.png rename to localized_docs/KR/docs/images/team_id.png diff --git a/docs/localized/KR/docs/images/tennis.png b/localized_docs/KR/docs/images/tennis.png similarity index 100% rename from docs/localized/KR/docs/images/tennis.png rename to localized_docs/KR/docs/images/tennis.png diff --git a/docs/localized/KR/docs/images/unity-wide.png b/localized_docs/KR/docs/images/unity-wide.png similarity index 100% rename from docs/localized/KR/docs/images/unity-wide.png rename to localized_docs/KR/docs/images/unity-wide.png diff --git a/docs/localized/KR/docs/images/unity_linux_build_support.png b/localized_docs/KR/docs/images/unity_linux_build_support.png similarity index 100% rename from docs/localized/KR/docs/images/unity_linux_build_support.png rename to localized_docs/KR/docs/images/unity_linux_build_support.png diff --git a/docs/localized/KR/docs/images/unity_package_json.png b/localized_docs/KR/docs/images/unity_package_json.png similarity index 100% rename from docs/localized/KR/docs/images/unity_package_json.png rename to localized_docs/KR/docs/images/unity_package_json.png diff --git a/docs/localized/KR/docs/images/unity_package_manager_git_url.png b/localized_docs/KR/docs/images/unity_package_manager_git_url.png similarity index 100% rename from docs/localized/KR/docs/images/unity_package_manager_git_url.png rename to localized_docs/KR/docs/images/unity_package_manager_git_url.png diff --git a/docs/localized/KR/docs/images/unity_package_manager_window.png b/localized_docs/KR/docs/images/unity_package_manager_window.png similarity index 100% rename from docs/localized/KR/docs/images/unity_package_manager_window.png rename to localized_docs/KR/docs/images/unity_package_manager_window.png diff --git a/docs/localized/KR/docs/images/visual-observation-rawimage.png b/localized_docs/KR/docs/images/visual-observation-rawimage.png similarity index 100% rename from docs/localized/KR/docs/images/visual-observation-rawimage.png rename to localized_docs/KR/docs/images/visual-observation-rawimage.png diff --git a/docs/localized/KR/docs/images/visual-observation-rendertexture.png b/localized_docs/KR/docs/images/visual-observation-rendertexture.png similarity index 100% rename from docs/localized/KR/docs/images/visual-observation-rendertexture.png rename to localized_docs/KR/docs/images/visual-observation-rendertexture.png diff --git a/docs/localized/KR/docs/images/visual-observation.png b/localized_docs/KR/docs/images/visual-observation.png similarity index 100% rename from docs/localized/KR/docs/images/visual-observation.png rename to localized_docs/KR/docs/images/visual-observation.png diff --git a/docs/localized/KR/docs/images/walker.png b/localized_docs/KR/docs/images/walker.png similarity index 100% rename from docs/localized/KR/docs/images/walker.png rename to localized_docs/KR/docs/images/walker.png diff --git a/docs/localized/KR/docs/images/wall.png b/localized_docs/KR/docs/images/wall.png similarity index 100% rename from docs/localized/KR/docs/images/wall.png rename to localized_docs/KR/docs/images/wall.png diff --git a/docs/localized/KR/docs/images/worm.png b/localized_docs/KR/docs/images/worm.png similarity index 100% rename from docs/localized/KR/docs/images/worm.png rename to localized_docs/KR/docs/images/worm.png diff --git a/docs/localized/RU/README.md b/localized_docs/RU/README.md similarity index 100% rename from docs/localized/RU/README.md rename to localized_docs/RU/README.md diff --git "a/docs/localized/RU/docs/\320\235\320\260\321\207\320\260\320\273\320\276 \321\200\320\260\320\261\320\276\321\202\321\213.md" "b/localized_docs/RU/docs/\320\235\320\260\321\207\320\260\320\273\320\276 \321\200\320\260\320\261\320\276\321\202\321\213.md" similarity index 100% rename from "docs/localized/RU/docs/\320\235\320\260\321\207\320\260\320\273\320\276 \321\200\320\260\320\261\320\276\321\202\321\213.md" rename to "localized_docs/RU/docs/\320\235\320\260\321\207\320\260\320\273\320\276 \321\200\320\260\320\261\320\276\321\202\321\213.md" diff --git "a/docs/localized/RU/docs/\320\243\321\201\321\202\320\260\320\275\320\276\320\262\320\272\320\260.md" "b/localized_docs/RU/docs/\320\243\321\201\321\202\320\260\320\275\320\276\320\262\320\272\320\260.md" similarity index 97% rename from "docs/localized/RU/docs/\320\243\321\201\321\202\320\260\320\275\320\276\320\262\320\272\320\260.md" rename to "localized_docs/RU/docs/\320\243\321\201\321\202\320\260\320\275\320\276\320\262\320\272\320\260.md" index d50b0c9c2c..6a81fbff20 100644 --- "a/docs/localized/RU/docs/\320\243\321\201\321\202\320\260\320\275\320\276\320\262\320\272\320\260.md" +++ "b/localized_docs/RU/docs/\320\243\321\201\321\202\320\260\320\275\320\276\320\262\320\272\320\260.md" @@ -18,22 +18,22 @@ ML-Agents Toolkit состоит из нескольких компоненто где реализованы различные возможности ML-Agents для наглядности. Итого, чтобы установить и использовать ML-Agents, вам нужно: -- Установить Unity (2020.3 или более позднюю версию) -- Установить Python (3.7.2 или более позднюю версию) +- Установить Unity (2021.3 или более позднюю версию) +- Установить Python (3.8.13 или более позднюю версию) - Клонировать этот репозиторий (Опционально) - __Примечание:__ если вы не склонируете репозиторий, тогда у вас не будет доступа к примерам и конфигурациям обучения. Также, раздел "Начало работы" подразумевает, что вы сделали клон репозитория. - Установить пакет `com.unity.ml-agents` Unity - Установить пакет `mlagents` Python -### Установка **Unity 2020.3** или более поздней версии +### Установка **Unity 2021.3** или более поздней версии [Загрузите](https://unity3d.com/get-unity/download) и установите движок Unity. Мы очень рекомендуем установить Unity через Unity Hub, так как последний позволяет управлять разными версиями движка. -### Установка **Python 3.7.2** или более поздней версии +### Установка **Python 3.8.13** или более поздней версии -Мы рекомендуем [установить](https://www.python.org/downloads/) Python 3.7. +Мы рекомендуем [установить](https://www.python.org/downloads/) Python 3.8. Если вы используете Windows, пожалуйста, установите x86-64 версию, а не x86. Если у вас нет системы управления пакетами `pip3` для Python, пожалуйста, воспользуйтесь [этими инструкциями](https://packaging.python.org/guides/installing-using-linux-tools/#installing-pip-setuptools-wheel-with-linux-package-managers) diff --git a/docs/localized/TR/README.md b/localized_docs/TR/README.md similarity index 100% rename from docs/localized/TR/README.md rename to localized_docs/TR/README.md diff --git a/docs/localized/TR/docs/Getting-Started.md b/localized_docs/TR/docs/Getting-Started.md similarity index 100% rename from docs/localized/TR/docs/Getting-Started.md rename to localized_docs/TR/docs/Getting-Started.md diff --git a/docs/localized/TR/docs/Installation.md b/localized_docs/TR/docs/Installation.md similarity index 96% rename from docs/localized/TR/docs/Installation.md rename to localized_docs/TR/docs/Installation.md index df4ba85242..b57209f258 100644 --- a/docs/localized/TR/docs/Installation.md +++ b/localized_docs/TR/docs/Installation.md @@ -14,20 +14,20 @@ ML-Agents Araç Seti birkaç bileşen içermektedir: ML-Agents Toolkit 'i kurmanız için gerekenler: -- Unity yükleyin (2020.3 veya daha sonraki bir sürüm) -- Python yükleyin (3.7.2 veya daha yüksek bir sürüm) +- Unity yükleyin (2021.3 veya daha sonraki bir sürüm) +- Python yükleyin (3.8.13 veya daha yüksek bir sürüm) - Bu depoyu klonlayın (İsteğe bağlı) - __Not:__ Depoyu klonlamazsanız, örnek ortamlara ve eğitim yapılandırmalarına erişemezsiniz. Ek olarak, [Başlangıç Rehberi](Getting-Started.md) depoyu klonladığınızı varsayar. - `com.unity.ml-agents` ML-Agents Unity paketini yükleyin. - `mlagents` Python paketini yüklemek. -### **Unity 2020.3** veya Sonraki Bir Sürüm Yükleyin +### **Unity 2021.3** veya Sonraki Bir Sürüm Yükleyin [İndir](https://unity3d.com/get-unity/download) ve Unity'i yükle. Şiddetli bir şekilde Unity Hub üzerinden kurmanızı ve bu şekilde birden fazla Unity sürümünü yönetmenizi öneriyoruz. -### **Python 3.7.2** veya Daha Yüksek Bir Sürüm Yükleyin +### **Python 3.8.13** veya Daha Yüksek Bir Sürüm Yükleyin -Python 3.7 veya daha yüksek bir sürümü [yüklemenizi](https://www.python.org/downloads/) öneriyoruz. Eğer, Windows kullanıyorsanız, lütfen x86-64 versiyonunu kurun ve asla sadece x86 isimli versiyonu kurmayın. Python ortamınız `pip3` içermiyorsa, [talimatları](https://packaging.python.org/guides/installing-using-linux-tools/#installing-pip-setuptools-wheel-with-linux-package-managers) takip ederek yükleyebilirsiniz. +Python 3.8 veya daha yüksek bir sürümü [yüklemenizi](https://www.python.org/downloads/) öneriyoruz. Eğer, Windows kullanıyorsanız, lütfen x86-64 versiyonunu kurun ve asla sadece x86 isimli versiyonu kurmayın. Python ortamınız `pip3` içermiyorsa, [talimatları](https://packaging.python.org/guides/installing-using-linux-tools/#installing-pip-setuptools-wheel-with-linux-package-managers) takip ederek yükleyebilirsiniz. Windows'ta Anaconda kurulumu için destek sağlamıyor olsak da, önceki [Windows için Anaconda Yüklemesi (Kullanımdan Kaldırılan) Rehberine](Installation-Anaconda-Windows.md) bakabilirsiniz. diff --git a/docs/localized/TR/docs/Readme.md b/localized_docs/TR/docs/Readme.md similarity index 100% rename from docs/localized/TR/docs/Readme.md rename to localized_docs/TR/docs/Readme.md diff --git a/docs/localized/TR/docs/images/3dball_learning_brain.png b/localized_docs/TR/docs/images/3dball_learning_brain.png similarity index 100% rename from docs/localized/TR/docs/images/3dball_learning_brain.png rename to localized_docs/TR/docs/images/3dball_learning_brain.png diff --git a/docs/localized/TR/docs/images/balance.png b/localized_docs/TR/docs/images/balance.png similarity index 100% rename from docs/localized/TR/docs/images/balance.png rename to localized_docs/TR/docs/images/balance.png diff --git a/docs/localized/TR/docs/images/image-banner.png b/localized_docs/TR/docs/images/image-banner.png similarity index 100% rename from docs/localized/TR/docs/images/image-banner.png rename to localized_docs/TR/docs/images/image-banner.png diff --git a/docs/localized/TR/docs/images/mlagents-3DBallHierarchy.png b/localized_docs/TR/docs/images/mlagents-3DBallHierarchy.png similarity index 100% rename from docs/localized/TR/docs/images/mlagents-3DBallHierarchy.png rename to localized_docs/TR/docs/images/mlagents-3DBallHierarchy.png diff --git a/docs/localized/TR/docs/images/mlagents-TensorBoard.png b/localized_docs/TR/docs/images/mlagents-TensorBoard.png similarity index 100% rename from docs/localized/TR/docs/images/mlagents-TensorBoard.png rename to localized_docs/TR/docs/images/mlagents-TensorBoard.png diff --git a/docs/localized/TR/docs/images/platform_prefab.png b/localized_docs/TR/docs/images/platform_prefab.png similarity index 100% rename from docs/localized/TR/docs/images/platform_prefab.png rename to localized_docs/TR/docs/images/platform_prefab.png diff --git a/docs/localized/TR/docs/images/unity_package_json.png b/localized_docs/TR/docs/images/unity_package_json.png similarity index 100% rename from docs/localized/TR/docs/images/unity_package_json.png rename to localized_docs/TR/docs/images/unity_package_json.png diff --git a/docs/localized/TR/docs/images/unity_package_manager_window.png b/localized_docs/TR/docs/images/unity_package_manager_window.png similarity index 100% rename from docs/localized/TR/docs/images/unity_package_manager_window.png rename to localized_docs/TR/docs/images/unity_package_manager_window.png diff --git a/docs/localized/zh-CN/README.md b/localized_docs/zh-CN/README.md similarity index 100% rename from docs/localized/zh-CN/README.md rename to localized_docs/zh-CN/README.md diff --git a/docs/localized/zh-CN/docs/Getting-Started-with-Balance-Ball.md b/localized_docs/zh-CN/docs/Getting-Started-with-Balance-Ball.md similarity index 100% rename from docs/localized/zh-CN/docs/Getting-Started-with-Balance-Ball.md rename to localized_docs/zh-CN/docs/Getting-Started-with-Balance-Ball.md diff --git a/docs/localized/zh-CN/docs/Installation.md b/localized_docs/zh-CN/docs/Installation.md similarity index 100% rename from docs/localized/zh-CN/docs/Installation.md rename to localized_docs/zh-CN/docs/Installation.md diff --git a/docs/localized/zh-CN/docs/Learning-Environment-Create-New.md b/localized_docs/zh-CN/docs/Learning-Environment-Create-New.md similarity index 100% rename from docs/localized/zh-CN/docs/Learning-Environment-Create-New.md rename to localized_docs/zh-CN/docs/Learning-Environment-Create-New.md diff --git a/docs/localized/zh-CN/docs/Learning-Environment-Design.md b/localized_docs/zh-CN/docs/Learning-Environment-Design.md similarity index 100% rename from docs/localized/zh-CN/docs/Learning-Environment-Design.md rename to localized_docs/zh-CN/docs/Learning-Environment-Design.md diff --git a/docs/localized/zh-CN/docs/Learning-Environment-Examples.md b/localized_docs/zh-CN/docs/Learning-Environment-Examples.md similarity index 100% rename from docs/localized/zh-CN/docs/Learning-Environment-Examples.md rename to localized_docs/zh-CN/docs/Learning-Environment-Examples.md diff --git a/docs/localized/zh-CN/docs/ML-Agents-Overview.md b/localized_docs/zh-CN/docs/ML-Agents-Overview.md similarity index 100% rename from docs/localized/zh-CN/docs/ML-Agents-Overview.md rename to localized_docs/zh-CN/docs/ML-Agents-Overview.md diff --git a/docs/localized/zh-CN/docs/Readme.md b/localized_docs/zh-CN/docs/Readme.md similarity index 100% rename from docs/localized/zh-CN/docs/Readme.md rename to localized_docs/zh-CN/docs/Readme.md diff --git a/docs/localized/zh-CN/docs/images/academy.png b/localized_docs/zh-CN/docs/images/academy.png similarity index 100% rename from docs/localized/zh-CN/docs/images/academy.png rename to localized_docs/zh-CN/docs/images/academy.png diff --git a/docs/localized/zh-CN/docs/images/agent.png b/localized_docs/zh-CN/docs/images/agent.png similarity index 100% rename from docs/localized/zh-CN/docs/images/agent.png rename to localized_docs/zh-CN/docs/images/agent.png diff --git a/docs/localized/zh-CN/docs/images/anaconda_default.PNG b/localized_docs/zh-CN/docs/images/anaconda_default.PNG similarity index 100% rename from docs/localized/zh-CN/docs/images/anaconda_default.PNG rename to localized_docs/zh-CN/docs/images/anaconda_default.PNG diff --git a/docs/localized/zh-CN/docs/images/anaconda_install.PNG b/localized_docs/zh-CN/docs/images/anaconda_install.PNG similarity index 100% rename from docs/localized/zh-CN/docs/images/anaconda_install.PNG rename to localized_docs/zh-CN/docs/images/anaconda_install.PNG diff --git a/docs/localized/zh-CN/docs/images/balance.png b/localized_docs/zh-CN/docs/images/balance.png similarity index 100% rename from docs/localized/zh-CN/docs/images/balance.png rename to localized_docs/zh-CN/docs/images/balance.png diff --git a/docs/localized/zh-CN/docs/images/banana.png b/localized_docs/zh-CN/docs/images/banana.png similarity index 100% rename from docs/localized/zh-CN/docs/images/banana.png rename to localized_docs/zh-CN/docs/images/banana.png diff --git a/docs/localized/zh-CN/docs/images/banner.png b/localized_docs/zh-CN/docs/images/banner.png similarity index 100% rename from docs/localized/zh-CN/docs/images/banner.png rename to localized_docs/zh-CN/docs/images/banner.png diff --git a/docs/localized/zh-CN/docs/images/basic.png b/localized_docs/zh-CN/docs/images/basic.png similarity index 100% rename from docs/localized/zh-CN/docs/images/basic.png rename to localized_docs/zh-CN/docs/images/basic.png diff --git a/docs/localized/zh-CN/docs/images/bc_teacher_helper.png b/localized_docs/zh-CN/docs/images/bc_teacher_helper.png similarity index 100% rename from docs/localized/zh-CN/docs/images/bc_teacher_helper.png rename to localized_docs/zh-CN/docs/images/bc_teacher_helper.png diff --git a/docs/localized/zh-CN/docs/images/bouncer.png b/localized_docs/zh-CN/docs/images/bouncer.png similarity index 100% rename from docs/localized/zh-CN/docs/images/bouncer.png rename to localized_docs/zh-CN/docs/images/bouncer.png diff --git a/docs/localized/zh-CN/docs/images/brain.png b/localized_docs/zh-CN/docs/images/brain.png similarity index 100% rename from docs/localized/zh-CN/docs/images/brain.png rename to localized_docs/zh-CN/docs/images/brain.png diff --git a/docs/localized/zh-CN/docs/images/broadcast.png b/localized_docs/zh-CN/docs/images/broadcast.png similarity index 100% rename from docs/localized/zh-CN/docs/images/broadcast.png rename to localized_docs/zh-CN/docs/images/broadcast.png diff --git a/docs/localized/zh-CN/docs/images/conda_new.PNG b/localized_docs/zh-CN/docs/images/conda_new.PNG similarity index 100% rename from docs/localized/zh-CN/docs/images/conda_new.PNG rename to localized_docs/zh-CN/docs/images/conda_new.PNG diff --git a/docs/localized/zh-CN/docs/images/crawler.png b/localized_docs/zh-CN/docs/images/crawler.png similarity index 100% rename from docs/localized/zh-CN/docs/images/crawler.png rename to localized_docs/zh-CN/docs/images/crawler.png diff --git a/docs/localized/zh-CN/docs/images/cuDNN_membership_required.png b/localized_docs/zh-CN/docs/images/cuDNN_membership_required.png similarity index 100% rename from docs/localized/zh-CN/docs/images/cuDNN_membership_required.png rename to localized_docs/zh-CN/docs/images/cuDNN_membership_required.png diff --git a/docs/localized/zh-CN/docs/images/cuda_toolkit_directory.PNG b/localized_docs/zh-CN/docs/images/cuda_toolkit_directory.PNG similarity index 100% rename from docs/localized/zh-CN/docs/images/cuda_toolkit_directory.PNG rename to localized_docs/zh-CN/docs/images/cuda_toolkit_directory.PNG diff --git a/docs/localized/zh-CN/docs/images/cudnn_zip_files.PNG b/localized_docs/zh-CN/docs/images/cudnn_zip_files.PNG similarity index 100% rename from docs/localized/zh-CN/docs/images/cudnn_zip_files.PNG rename to localized_docs/zh-CN/docs/images/cudnn_zip_files.PNG diff --git a/docs/localized/zh-CN/docs/images/curriculum.png b/localized_docs/zh-CN/docs/images/curriculum.png similarity index 100% rename from docs/localized/zh-CN/docs/images/curriculum.png rename to localized_docs/zh-CN/docs/images/curriculum.png diff --git a/docs/localized/zh-CN/docs/images/curriculum_progress.png b/localized_docs/zh-CN/docs/images/curriculum_progress.png similarity index 100% rename from docs/localized/zh-CN/docs/images/curriculum_progress.png rename to localized_docs/zh-CN/docs/images/curriculum_progress.png diff --git a/docs/localized/zh-CN/docs/images/docker_build_settings.png b/localized_docs/zh-CN/docs/images/docker_build_settings.png similarity index 100% rename from docs/localized/zh-CN/docs/images/docker_build_settings.png rename to localized_docs/zh-CN/docs/images/docker_build_settings.png diff --git a/docs/localized/zh-CN/docs/images/edit_env_var.png b/localized_docs/zh-CN/docs/images/edit_env_var.png similarity index 100% rename from docs/localized/zh-CN/docs/images/edit_env_var.png rename to localized_docs/zh-CN/docs/images/edit_env_var.png diff --git a/docs/localized/zh-CN/docs/images/gridworld.png b/localized_docs/zh-CN/docs/images/gridworld.png similarity index 100% rename from docs/localized/zh-CN/docs/images/gridworld.png rename to localized_docs/zh-CN/docs/images/gridworld.png diff --git a/docs/localized/zh-CN/docs/images/hallway.png b/localized_docs/zh-CN/docs/images/hallway.png similarity index 100% rename from docs/localized/zh-CN/docs/images/hallway.png rename to localized_docs/zh-CN/docs/images/hallway.png diff --git a/docs/localized/zh-CN/docs/images/internal_brain.png b/localized_docs/zh-CN/docs/images/internal_brain.png similarity index 100% rename from docs/localized/zh-CN/docs/images/internal_brain.png rename to localized_docs/zh-CN/docs/images/internal_brain.png diff --git a/docs/localized/zh-CN/docs/images/learning_environment.png b/localized_docs/zh-CN/docs/images/learning_environment.png similarity index 100% rename from docs/localized/zh-CN/docs/images/learning_environment.png rename to localized_docs/zh-CN/docs/images/learning_environment.png diff --git a/docs/localized/zh-CN/docs/images/learning_environment_basic.png b/localized_docs/zh-CN/docs/images/learning_environment_basic.png similarity index 100% rename from docs/localized/zh-CN/docs/images/learning_environment_basic.png rename to localized_docs/zh-CN/docs/images/learning_environment_basic.png diff --git a/docs/localized/zh-CN/docs/images/learning_environment_example.png b/localized_docs/zh-CN/docs/images/learning_environment_example.png similarity index 100% rename from docs/localized/zh-CN/docs/images/learning_environment_example.png rename to localized_docs/zh-CN/docs/images/learning_environment_example.png diff --git a/docs/localized/zh-CN/docs/images/math.png b/localized_docs/zh-CN/docs/images/math.png similarity index 100% rename from docs/localized/zh-CN/docs/images/math.png rename to localized_docs/zh-CN/docs/images/math.png diff --git a/docs/localized/zh-CN/docs/images/ml-agents-LSTM.png b/localized_docs/zh-CN/docs/images/ml-agents-LSTM.png similarity index 100% rename from docs/localized/zh-CN/docs/images/ml-agents-LSTM.png rename to localized_docs/zh-CN/docs/images/ml-agents-LSTM.png diff --git a/docs/localized/zh-CN/docs/images/mlagents-3DBallHierarchy.png b/localized_docs/zh-CN/docs/images/mlagents-3DBallHierarchy.png similarity index 100% rename from docs/localized/zh-CN/docs/images/mlagents-3DBallHierarchy.png rename to localized_docs/zh-CN/docs/images/mlagents-3DBallHierarchy.png diff --git a/docs/localized/zh-CN/docs/images/mlagents-BuildWindow.png b/localized_docs/zh-CN/docs/images/mlagents-BuildWindow.png similarity index 100% rename from docs/localized/zh-CN/docs/images/mlagents-BuildWindow.png rename to localized_docs/zh-CN/docs/images/mlagents-BuildWindow.png diff --git a/docs/localized/zh-CN/docs/images/mlagents-NewProject.png b/localized_docs/zh-CN/docs/images/mlagents-NewProject.png similarity index 100% rename from docs/localized/zh-CN/docs/images/mlagents-NewProject.png rename to localized_docs/zh-CN/docs/images/mlagents-NewProject.png diff --git a/docs/localized/zh-CN/docs/images/mlagents-NewTutAcademy.png b/localized_docs/zh-CN/docs/images/mlagents-NewTutAcademy.png similarity index 100% rename from docs/localized/zh-CN/docs/images/mlagents-NewTutAcademy.png rename to localized_docs/zh-CN/docs/images/mlagents-NewTutAcademy.png diff --git a/docs/localized/zh-CN/docs/images/mlagents-NewTutAssignBrain.png b/localized_docs/zh-CN/docs/images/mlagents-NewTutAssignBrain.png similarity index 100% rename from docs/localized/zh-CN/docs/images/mlagents-NewTutAssignBrain.png rename to localized_docs/zh-CN/docs/images/mlagents-NewTutAssignBrain.png diff --git a/docs/localized/zh-CN/docs/images/mlagents-NewTutBlock.png b/localized_docs/zh-CN/docs/images/mlagents-NewTutBlock.png similarity index 100% rename from docs/localized/zh-CN/docs/images/mlagents-NewTutBlock.png rename to localized_docs/zh-CN/docs/images/mlagents-NewTutBlock.png diff --git a/docs/localized/zh-CN/docs/images/mlagents-NewTutBrain.png b/localized_docs/zh-CN/docs/images/mlagents-NewTutBrain.png similarity index 100% rename from docs/localized/zh-CN/docs/images/mlagents-NewTutBrain.png rename to localized_docs/zh-CN/docs/images/mlagents-NewTutBrain.png diff --git a/docs/localized/zh-CN/docs/images/mlagents-NewTutFloor.png b/localized_docs/zh-CN/docs/images/mlagents-NewTutFloor.png similarity index 100% rename from docs/localized/zh-CN/docs/images/mlagents-NewTutFloor.png rename to localized_docs/zh-CN/docs/images/mlagents-NewTutFloor.png diff --git a/docs/localized/zh-CN/docs/images/mlagents-NewTutHierarchy.png b/localized_docs/zh-CN/docs/images/mlagents-NewTutHierarchy.png similarity index 100% rename from docs/localized/zh-CN/docs/images/mlagents-NewTutHierarchy.png rename to localized_docs/zh-CN/docs/images/mlagents-NewTutHierarchy.png diff --git a/docs/localized/zh-CN/docs/images/mlagents-NewTutSphere.png b/localized_docs/zh-CN/docs/images/mlagents-NewTutSphere.png similarity index 100% rename from docs/localized/zh-CN/docs/images/mlagents-NewTutSphere.png rename to localized_docs/zh-CN/docs/images/mlagents-NewTutSphere.png diff --git a/docs/localized/zh-CN/docs/images/mlagents-NewTutSplash.png b/localized_docs/zh-CN/docs/images/mlagents-NewTutSplash.png similarity index 100% rename from docs/localized/zh-CN/docs/images/mlagents-NewTutSplash.png rename to localized_docs/zh-CN/docs/images/mlagents-NewTutSplash.png diff --git a/docs/localized/zh-CN/docs/images/mlagents-Open3DBall.png b/localized_docs/zh-CN/docs/images/mlagents-Open3DBall.png similarity index 100% rename from docs/localized/zh-CN/docs/images/mlagents-Open3DBall.png rename to localized_docs/zh-CN/docs/images/mlagents-Open3DBall.png diff --git a/docs/localized/zh-CN/docs/images/mlagents-SetExternalBrain.png b/localized_docs/zh-CN/docs/images/mlagents-SetExternalBrain.png similarity index 100% rename from docs/localized/zh-CN/docs/images/mlagents-SetExternalBrain.png rename to localized_docs/zh-CN/docs/images/mlagents-SetExternalBrain.png diff --git a/docs/localized/zh-CN/docs/images/mlagents-TensorBoard.png b/localized_docs/zh-CN/docs/images/mlagents-TensorBoard.png similarity index 100% rename from docs/localized/zh-CN/docs/images/mlagents-TensorBoard.png rename to localized_docs/zh-CN/docs/images/mlagents-TensorBoard.png diff --git a/docs/localized/zh-CN/docs/images/monitor.png b/localized_docs/zh-CN/docs/images/monitor.png similarity index 100% rename from docs/localized/zh-CN/docs/images/monitor.png rename to localized_docs/zh-CN/docs/images/monitor.png diff --git a/docs/localized/zh-CN/docs/images/new_system_variable.PNG b/localized_docs/zh-CN/docs/images/new_system_variable.PNG similarity index 100% rename from docs/localized/zh-CN/docs/images/new_system_variable.PNG rename to localized_docs/zh-CN/docs/images/new_system_variable.PNG diff --git a/docs/localized/zh-CN/docs/images/normalization.png b/localized_docs/zh-CN/docs/images/normalization.png similarity index 100% rename from docs/localized/zh-CN/docs/images/normalization.png rename to localized_docs/zh-CN/docs/images/normalization.png diff --git a/docs/localized/zh-CN/docs/images/path_variables.PNG b/localized_docs/zh-CN/docs/images/path_variables.PNG similarity index 100% rename from docs/localized/zh-CN/docs/images/path_variables.PNG rename to localized_docs/zh-CN/docs/images/path_variables.PNG diff --git a/docs/localized/zh-CN/docs/images/player_brain.png b/localized_docs/zh-CN/docs/images/player_brain.png similarity index 100% rename from docs/localized/zh-CN/docs/images/player_brain.png rename to localized_docs/zh-CN/docs/images/player_brain.png diff --git a/docs/localized/zh-CN/docs/images/push.png b/localized_docs/zh-CN/docs/images/push.png similarity index 100% rename from docs/localized/zh-CN/docs/images/push.png rename to localized_docs/zh-CN/docs/images/push.png diff --git a/docs/localized/zh-CN/docs/images/reacher.png b/localized_docs/zh-CN/docs/images/reacher.png similarity index 100% rename from docs/localized/zh-CN/docs/images/reacher.png rename to localized_docs/zh-CN/docs/images/reacher.png diff --git a/docs/localized/zh-CN/docs/images/rl_cycle.png b/localized_docs/zh-CN/docs/images/rl_cycle.png similarity index 100% rename from docs/localized/zh-CN/docs/images/rl_cycle.png rename to localized_docs/zh-CN/docs/images/rl_cycle.png diff --git a/docs/localized/zh-CN/docs/images/scene-hierarchy.png b/localized_docs/zh-CN/docs/images/scene-hierarchy.png similarity index 100% rename from docs/localized/zh-CN/docs/images/scene-hierarchy.png rename to localized_docs/zh-CN/docs/images/scene-hierarchy.png diff --git a/docs/localized/zh-CN/docs/images/soccer.png b/localized_docs/zh-CN/docs/images/soccer.png similarity index 100% rename from docs/localized/zh-CN/docs/images/soccer.png rename to localized_docs/zh-CN/docs/images/soccer.png diff --git a/docs/localized/zh-CN/docs/images/splitbar.png b/localized_docs/zh-CN/docs/images/splitbar.png similarity index 100% rename from docs/localized/zh-CN/docs/images/splitbar.png rename to localized_docs/zh-CN/docs/images/splitbar.png diff --git a/docs/localized/zh-CN/docs/images/system_variable_name_value.PNG b/localized_docs/zh-CN/docs/images/system_variable_name_value.PNG similarity index 100% rename from docs/localized/zh-CN/docs/images/system_variable_name_value.PNG rename to localized_docs/zh-CN/docs/images/system_variable_name_value.PNG diff --git a/docs/localized/zh-CN/docs/images/tennis.png b/localized_docs/zh-CN/docs/images/tennis.png similarity index 100% rename from docs/localized/zh-CN/docs/images/tennis.png rename to localized_docs/zh-CN/docs/images/tennis.png diff --git a/docs/localized/zh-CN/docs/images/unity-logo-rgb.png b/localized_docs/zh-CN/docs/images/unity-logo-rgb.png similarity index 100% rename from docs/localized/zh-CN/docs/images/unity-logo-rgb.png rename to localized_docs/zh-CN/docs/images/unity-logo-rgb.png diff --git a/docs/localized/zh-CN/docs/images/unity-wide.png b/localized_docs/zh-CN/docs/images/unity-wide.png similarity index 100% rename from docs/localized/zh-CN/docs/images/unity-wide.png rename to localized_docs/zh-CN/docs/images/unity-wide.png diff --git a/docs/localized/zh-CN/docs/images/unity_linux_build_support.png b/localized_docs/zh-CN/docs/images/unity_linux_build_support.png similarity index 100% rename from docs/localized/zh-CN/docs/images/unity_linux_build_support.png rename to localized_docs/zh-CN/docs/images/unity_linux_build_support.png diff --git a/docs/localized/zh-CN/docs/images/visual-observation.png b/localized_docs/zh-CN/docs/images/visual-observation.png similarity index 100% rename from docs/localized/zh-CN/docs/images/visual-observation.png rename to localized_docs/zh-CN/docs/images/visual-observation.png diff --git a/docs/localized/zh-CN/docs/images/wall.png b/localized_docs/zh-CN/docs/images/wall.png similarity index 100% rename from docs/localized/zh-CN/docs/images/wall.png rename to localized_docs/zh-CN/docs/images/wall.png diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000000..9de0b224cb --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,33 @@ +site_name: Unity ML-Agents Toolkit +site_url: https://unity-technologies.github.io/ml-agents/ +repo_url: https://github.com/Unity-Technologies/ml-agents +edit_uri: edit/main/docs/ +site_description: The Unity Machine Learning Agents Toolkit (ML-Agents) is an open-source project that enables games and simulations to serve as environments for training intelligent agents. +site_author: Unity Technologies +copyright: com.unity.ml-agents copyright © 2017 Unity Technologies +pages: +- Home: index.md +- ML-Agents Overview: ML-Agents-Overview.md +- Toolkit Documentation: ML-Agents-Toolkit-Documentation.md +- Background: + - Machine Learning: Background-Machine-Learning.md + - PyTorch: Background-PyTorch.md + - Unity: Background-Unity.md +- Interfacing with Unity Builds: + - Getting started with the Gym API: Python-Gym-API.md + - Getting started with the PettingZoo API: Python-PettingZoo-API.md + - Getting started with the LLAPI: Python-LLAPI.md +- Python API Docs: + - Gym API Documentation: Python-Gym-API-Documentation.md + - Petting Zoo Documentation: Python-PettingZoo-API-Documentation.md + - LLAPI Documentation: Python-LLAPI-Documentation.md +- About: + - FAQs: FAQ.md + - Limitations: Limitations.md + - Migrating: Migrating.md + - Versioning: Versioning.md +theme: readthedocs +extra_css: [extra.css] +markdown_extensions: + - markdown_include.include: + base_path: docs diff --git a/ml-agents-envs/mlagents_envs/__init__.py b/ml-agents-envs/mlagents_envs/__init__.py index 5d158f10d1..8650fdbc2a 100644 --- a/ml-agents-envs/mlagents_envs/__init__.py +++ b/ml-agents-envs/mlagents_envs/__init__.py @@ -1,5 +1,5 @@ # Version of the library that will be used to upload to pypi -__version__ = "0.29.0.dev0" +__version__ = "0.30.0" # Git tag that will be checked to determine whether to trigger upload to pypi __release_tag__ = None diff --git a/ml-agents-envs/mlagents_envs/base_env.py b/ml-agents-envs/mlagents_envs/base_env.py index 02edc4eab2..a993d8a7c5 100644 --- a/ml-agents-envs/mlagents_envs/base_env.py +++ b/ml-agents-envs/mlagents_envs/base_env.py @@ -257,7 +257,7 @@ def empty(spec: "BehaviorSpec") -> "TerminalSteps": return TerminalSteps( obs=obs, reward=np.zeros(0, dtype=np.float32), - interrupted=np.zeros(0, dtype=np.bool), + interrupted=np.zeros(0, dtype=bool), agent_id=np.zeros(0, dtype=np.int32), group_id=np.zeros(0, dtype=np.int32), group_reward=np.zeros(0, dtype=np.float32), diff --git a/ml-agents-envs/mlagents_envs/envs/unity_vec_env.py b/ml-agents-envs/mlagents_envs/envs/unity_vec_env.py deleted file mode 100644 index b09407259d..0000000000 --- a/ml-agents-envs/mlagents_envs/envs/unity_vec_env.py +++ /dev/null @@ -1,122 +0,0 @@ -from dataclasses import dataclass -from pathlib import Path -from typing import Callable, Any - -import gym -from gym import Env -from stable_baselines3.common.vec_env import VecEnv, SubprocVecEnv -from supersuit import observation_lambda_v0 - -from mlagents_envs.environment import UnityEnvironment -from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper -from mlagents_envs.registry import UnityEnvRegistry, default_registry -from mlagents_envs.side_channel.engine_configuration_channel import ( - EngineConfig, - EngineConfigurationChannel, -) - -# Default values from CLI (See cli_utils.py) -DEFAULT_ENGINE_CONFIG = EngineConfig( - width=84, - height=84, - quality_level=4, - time_scale=20, - target_frame_rate=-1, - capture_frame_rate=60, -) - - -# Some config subset of an actual config.yaml file for MLA. -@dataclass -class LimitedConfig: - # The local path to a Unity executable or the name of an entry in the registry. - env_path_or_name: str - base_port: int - base_seed: int = 0 - num_env: int = 1 - engine_config: EngineConfig = DEFAULT_ENGINE_CONFIG - visual_obs: bool = False - # TODO: Decide if we should just tell users to always use MultiInputPolicy so we can simplify the user workflow. - # WARNING: Make sure to use MultiInputPolicy if you turn this on. - allow_multiple_obs: bool = False - env_registry: UnityEnvRegistry = default_registry - - -def _unity_env_from_path_or_registry( - env: str, registry: UnityEnvRegistry, **kwargs: Any -) -> UnityEnvironment: - if Path(env).expanduser().absolute().exists(): - return UnityEnvironment(file_name=env, **kwargs) - elif env in registry: - return registry.get(env).make(**kwargs) - else: - raise ValueError(f"Environment '{env}' wasn't a local path or registry entry") - - -def make_mla_sb3_env(config: LimitedConfig, **kwargs: Any) -> VecEnv: - """ - Create a VecEnv (Stable Baselines 3) of multiple UnityEnvironments. - :param config: Specifics around initializing the UnityEnvironment. - :param kwargs: Any other args that need to be passed to the UnityEnvironment constructor that aren't supported - through the config. - :return: A VecEnv backed by Unity environments as specified in the conifg. - - Example Usage: - # See ml-agents-envs/tests/test_unity_vec_env.py or colab/Colab_UnityEnvironment_4_SB3VectorEnv.ipynb - sb3_vec_env = make_mla_sb3_env( - config=LimitedConfig( - env_path_or_name=BASIC_ID, - base_port=6000, - num_env=2, - ), - # Other args to UnityEnvironment - no_graphics=True, - num_areas=1, - ) - """ - - def handle_obs(obs, space): - if isinstance(space, gym.spaces.Tuple): - if len(space) == 1: - return obs[0] - # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple). - return {str(i): v for i, v in enumerate(obs)} - return obs - - def handle_obs_space(space): - if isinstance(space, gym.spaces.Tuple): - if len(space) == 1: - return space[0] - # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple). - return gym.spaces.Dict({str(i): v for i, v in enumerate(space)}) - return space - - def create_env(env: str, worker_id: int) -> Callable[[], Env]: - def _f() -> Env: - engine_configuration_channel = EngineConfigurationChannel() - engine_configuration_channel.set_configuration(config.engine_config) - kwargs["side_channels"] = kwargs.get("side_channels", []) + [ - engine_configuration_channel - ] - unity_env = _unity_env_from_path_or_registry( - env=env, - registry=config.env_registry, - worker_id=worker_id, - base_port=config.base_port, - seed=config.base_seed + worker_id, - **kwargs, - ) - new_env = UnityToGymWrapper( - unity_env=unity_env, - uint8_visual=config.visual_obs, - allow_multiple_obs=config.allow_multiple_obs, - ) - new_env = observation_lambda_v0(new_env, handle_obs, handle_obs_space) - return new_env - - return _f - - env_facts = [ - create_env(config.env_path_or_name, worker_id=x) for x in range(config.num_env) - ] - return SubprocVecEnv(env_facts) diff --git a/ml-agents-envs/mlagents_envs/rpc_utils.py b/ml-agents-envs/mlagents_envs/rpc_utils.py index 5ea4db76cd..f2e3d1d468 100644 --- a/ml-agents-envs/mlagents_envs/rpc_utils.py +++ b/ml-agents-envs/mlagents_envs/rpc_utils.py @@ -373,7 +373,7 @@ def steps_from_proto( max_step = np.array( [agent_info.max_step_reached for agent_info in terminal_agent_info_list], - dtype=np.bool, + dtype=bool, ) decision_agent_id = np.array( [agent_info.id for agent_info in decision_agent_info_list], dtype=np.int32 @@ -389,7 +389,7 @@ def steps_from_proto( ): n_agents = len(decision_agent_info_list) a_size = np.sum(behavior_spec.action_spec.discrete_branches) - mask_matrix = np.ones((n_agents, a_size), dtype=np.bool) + mask_matrix = np.ones((n_agents, a_size), dtype=bool) for agent_index, agent_info in enumerate(decision_agent_info_list): if agent_info.action_mask is not None: if len(agent_info.action_mask) == a_size: @@ -397,7 +397,7 @@ def steps_from_proto( False if agent_info.action_mask[k] else True for k in range(a_size) ] - action_mask = (1 - mask_matrix).astype(np.bool) + action_mask = (1 - mask_matrix).astype(bool) indices = _generate_split_indices( behavior_spec.action_spec.discrete_branches ) diff --git a/ml-agents-envs/pydoc-config.yaml b/ml-agents-envs/pydoc-config.yaml index fc7c29d43c..9fec4a8951 100644 --- a/ml-agents-envs/pydoc-config.yaml +++ b/ml-agents-envs/pydoc-config.yaml @@ -1,4 +1,4 @@ -# config to specify which modules will be used to render api docs +# config to specify which modules will be used to render api docs docs from ml-agents-env package folder: docs modules: - name: mlagents_envs diff --git a/ml-agents-envs/setup.py b/ml-agents-envs/setup.py index 8665457aaa..a050f4a408 100644 --- a/ml-agents-envs/setup.py +++ b/ml-agents-envs/setup.py @@ -40,8 +40,9 @@ def run(self): "Intended Audience :: Developers", "Topic :: Scientific/Engineering :: Artificial Intelligence", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", ], packages=find_packages( exclude=["*.tests", "*.tests.*", "tests.*", "tests", "colabs", "*.ipynb"] @@ -55,13 +56,11 @@ def run(self): "protobuf>=3.6", "pyyaml>=3.1.0", "gym>=0.21.0", - "pettingzoo>=1.15.0", + "pettingzoo==1.15.0", "numpy==1.21.2", "filelock>=3.4.0", - "stable_baselines3[extra]", - "supersuit>=3.3.3", ], - python_requires=">=3.7.2,<3.10.0", + python_requires=">=3.8.13,<=3.10.8", # TODO: Remove this once mypy stops having spurious setuptools issues. cmdclass={"verify": VerifyVersionCommand}, # type: ignore ) diff --git a/ml-agents-envs/tests/dummy_config.py b/ml-agents-envs/tests/dummy_config.py index a7f74afa26..40109e7b39 100644 --- a/ml-agents-envs/tests/dummy_config.py +++ b/ml-agents-envs/tests/dummy_config.py @@ -4,24 +4,26 @@ import copy import os from mlagents.trainers.settings import ( - POCASettings, TrainerSettings, - PPOSettings, - SACSettings, GAILSettings, CuriositySettings, RewardSignalSettings, NetworkSettings, - TrainerType, RewardSignalType, ScheduleType, ) +from mlagents.trainers.ppo.trainer import PPOSettings, TRAINER_NAME as PPO_TRAINER_NAME +from mlagents.trainers.sac.trainer import SACSettings, TRAINER_NAME as SAC_TRAINER_NAME +from mlagents.trainers.poca.trainer import ( + POCASettings, + TRAINER_NAME as POCA_TRAINER_NAME, +) CONTINUOUS_DEMO_PATH = os.path.dirname(os.path.abspath(__file__)) + "/test.demo" DISCRETE_DEMO_PATH = os.path.dirname(os.path.abspath(__file__)) + "/testdcvis.demo" _PPO_CONFIG = TrainerSettings( - trainer_type=TrainerType.PPO, + trainer_type=PPO_TRAINER_NAME, hyperparameters=PPOSettings( learning_rate=5.0e-3, learning_rate_schedule=ScheduleType.CONSTANT, @@ -35,7 +37,7 @@ ) _SAC_CONFIG = TrainerSettings( - trainer_type=TrainerType.SAC, + trainer_type=SAC_TRAINER_NAME, hyperparameters=SACSettings( learning_rate=5.0e-3, learning_rate_schedule=ScheduleType.CONSTANT, @@ -52,7 +54,7 @@ ) _POCA_CONFIG = TrainerSettings( - trainer_type=TrainerType.POCA, + trainer_type=POCA_TRAINER_NAME, hyperparameters=POCASettings( learning_rate=5.0e-3, learning_rate_schedule=ScheduleType.CONSTANT, diff --git a/ml-agents-envs/tests/simple_test_envs.py b/ml-agents-envs/tests/simple_test_envs.py index 976d4449a5..64b720fad3 100644 --- a/ml-agents-envs/tests/simple_test_envs.py +++ b/ml-agents-envs/tests/simple_test_envs.py @@ -179,9 +179,7 @@ def _generate_mask(self): action_mask = None if self.action_spec.discrete_size > 0: # LL-Python API will return an empty dim if there is only 1 agent. - ndmask = np.array( - 2 * self.action_spec.discrete_size * [False], dtype=np.bool - ) + ndmask = np.array(2 * self.action_spec.discrete_size * [False], dtype=bool) ndmask = np.expand_dims(ndmask, axis=0) action_mask = [ndmask] return action_mask @@ -253,7 +251,7 @@ def _construct_reset_step( self, name: str ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: new_reward = np.array([0.0], dtype=np.float32) - new_done = np.array([False], dtype=np.bool) + new_done = np.array([False], dtype=bool) new_agent_id = np.array([self.agent_id[name]], dtype=np.int32) new_action_mask = self._generate_mask() new_group_id = np.array([0], dtype=np.int32) @@ -270,7 +268,6 @@ def _construct_reset_step( def step(self) -> None: assert all(action is not None for action in self.action.values()) for name in self.names: - done = self._take_action(name) reward = self._compute_reward(name, done) self.rewards[name] += reward diff --git a/ml-agents-envs/tests/test_rpc_utils.py b/ml-agents-envs/tests/test_rpc_utils.py index 57148e2b6f..9959aa4a99 100644 --- a/ml-agents-envs/tests/test_rpc_utils.py +++ b/ml-agents-envs/tests/test_rpc_utils.py @@ -144,7 +144,7 @@ def proto_from_steps( agent_mask = np.concatenate( (agent_mask, _branch[agent_id_index, :]), axis=0 ) - agent_mask = agent_mask.astype(np.bool).tolist() + agent_mask = agent_mask.astype(bool).tolist() observations: List[ObservationProto] = [] for all_observations_of_type in decision_steps.obs: observation = all_observations_of_type[agent_id_index] diff --git a/ml-agents-envs/tests/test_steps.py b/ml-agents-envs/tests/test_steps.py index a3c1e6601a..01e39f3ea7 100644 --- a/ml-agents-envs/tests/test_steps.py +++ b/ml-agents-envs/tests/test_steps.py @@ -15,7 +15,7 @@ def test_decision_steps(): obs=[np.array(range(12), dtype=np.float32).reshape(3, 4)], reward=np.array(range(3), dtype=np.float32), agent_id=np.array(range(10, 13), dtype=np.int32), - action_mask=[np.zeros((3, 4), dtype=np.bool)], + action_mask=[np.zeros((3, 4), dtype=bool)], group_id=np.array(range(3), dtype=np.int32), group_reward=np.array(range(3), dtype=np.float32), ) @@ -30,7 +30,7 @@ def test_decision_steps(): mask_agent = ds[10].action_mask assert isinstance(mask_agent, list) assert len(mask_agent) == 1 - assert np.array_equal(mask_agent[0], np.zeros((4), dtype=np.bool)) + assert np.array_equal(mask_agent[0], np.zeros((4), dtype=bool)) for agent_id in ds: assert ds.agent_id_to_index[agent_id] in range(3) @@ -52,7 +52,7 @@ def test_terminal_steps(): obs=[np.array(range(12), dtype=np.float32).reshape(3, 4)], reward=np.array(range(3), dtype=np.float32), agent_id=np.array(range(10, 13), dtype=np.int32), - interrupted=np.array([1, 0, 1], dtype=np.bool), + interrupted=np.array([1, 0, 1], dtype=bool), group_id=np.array(range(3), dtype=np.int32), group_reward=np.array(range(3), dtype=np.float32), ) diff --git a/ml-agents-envs/tests/test_unity_vec_env.py b/ml-agents-envs/tests/test_unity_vec_env.py deleted file mode 100644 index 3bf65a5cfd..0000000000 --- a/ml-agents-envs/tests/test_unity_vec_env.py +++ /dev/null @@ -1,66 +0,0 @@ -import pytest - -from stable_baselines3 import PPO - -from mlagents_envs.envs.unity_vec_env import LimitedConfig, make_mla_sb3_env -from mlagents_envs.registry import default_registry - -BASIC_ID = "Basic" - - -@pytest.mark.parametrize("n_ports", [2]) -def test_vec_env_basic(base_port: int) -> None: - num_envs = 2 - sb3_vec_env = make_mla_sb3_env( - config=LimitedConfig( - env_path_or_name=BASIC_ID, - base_port=base_port, - num_env=num_envs, - visual_obs=False, - allow_multiple_obs=True, - env_registry=default_registry, - ), - # Args to UnityEnvironment - no_graphics=True, - num_areas=1, - ) - assert sb3_vec_env.num_envs == num_envs - sb3_vec_env.reset() - observation, reward, done, info = sb3_vec_env.step( - [sb3_vec_env.action_space.sample()] * 2 - ) - assert len(observation) == num_envs - assert len(reward) == num_envs - assert len(done) == num_envs - assert len(info) == num_envs - sb3_vec_env.close() - - -@pytest.mark.slow -@pytest.mark.parametrize("n_ports", [4]) -def test_vec_env_trains(base_port: int) -> None: - sb3_vec_env = make_mla_sb3_env( - config=LimitedConfig( - env_path_or_name=BASIC_ID, - base_port=base_port, - num_env=4, - visual_obs=False, - allow_multiple_obs=True, - env_registry=default_registry, - ), - # Args to UnityEnvironment - no_graphics=True, - num_areas=1, - ) - - model = PPO( - "MlpPolicy", - sb3_vec_env, - verbose=1, - learning_rate=lambda progress: 0.0003 * (1.0 - progress), - ) - model.learn(total_timesteps=6000) - sb3_vec_env.close() - - -# TODO(https://jira.unity3d.com/browse/MLA-2404): Add longer running nightly tests to make sure this trains. diff --git a/ml-agents/mlagents/trainers/tests/torch/__init__.py b/ml-agents-trainer-plugin/mlagents_trainer_plugin/__init__.py similarity index 100% rename from ml-agents/mlagents/trainers/tests/torch/__init__.py rename to ml-agents-trainer-plugin/mlagents_trainer_plugin/__init__.py diff --git a/ml-agents/mlagents/trainers/torch/__init__.py b/ml-agents-trainer-plugin/mlagents_trainer_plugin/a2c/__init__.py similarity index 100% rename from ml-agents/mlagents/trainers/torch/__init__.py rename to ml-agents-trainer-plugin/mlagents_trainer_plugin/a2c/__init__.py diff --git a/ml-agents-trainer-plugin/mlagents_trainer_plugin/a2c/a2c_3DBall.yaml b/ml-agents-trainer-plugin/mlagents_trainer_plugin/a2c/a2c_3DBall.yaml new file mode 100644 index 0000000000..2b0af46070 --- /dev/null +++ b/ml-agents-trainer-plugin/mlagents_trainer_plugin/a2c/a2c_3DBall.yaml @@ -0,0 +1,24 @@ +behaviors: + 3DBall: + trainer_type: a2c + hyperparameters: + batch_size: 1000 + buffer_size: 1000 + learning_rate: 0.0003 + beta: 0.001 + lambd: 0.99 + num_epoch: 1 + learning_rate_schedule: linear + network_settings: + normalize: true + hidden_units: 128 + num_layers: 2 + vis_encode_type: simple + reward_signals: + extrinsic: + gamma: 0.99 + strength: 1.0 + keep_checkpoints: 5 + max_steps: 500000 + time_horizon: 1000 + summary_freq: 1000 diff --git a/ml-agents-trainer-plugin/mlagents_trainer_plugin/a2c/a2c_optimizer.py b/ml-agents-trainer-plugin/mlagents_trainer_plugin/a2c/a2c_optimizer.py new file mode 100644 index 0000000000..8af95ee41e --- /dev/null +++ b/ml-agents-trainer-plugin/mlagents_trainer_plugin/a2c/a2c_optimizer.py @@ -0,0 +1,192 @@ +from typing import Dict, cast +import attr + +from mlagents.torch_utils import torch, default_device + +from mlagents.trainers.buffer import AgentBuffer, BufferKey, RewardSignalUtil + +from mlagents_envs.timers import timed +from mlagents.trainers.policy.torch_policy import TorchPolicy +from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer +from mlagents.trainers.settings import ( + TrainerSettings, + OnPolicyHyperparamSettings, + ScheduleType, +) +from mlagents.trainers.torch_entities.networks import ValueNetwork +from mlagents.trainers.torch_entities.agent_action import AgentAction +from mlagents.trainers.torch_entities.utils import ModelUtils +from mlagents.trainers.trajectory import ObsUtil + +from mlagents.trainers.exception import TrainerConfigError + + +@attr.s(auto_attribs=True) +class A2CSettings(OnPolicyHyperparamSettings): + beta: float = 5.0e-3 + lambd: float = 0.95 + num_epoch: int = attr.ib(default=1) # A2C does just one pass + shared_critic: bool = False + + @num_epoch.validator + def _check_num_epoch_one(self, attribute, value): + if value != 1: + raise TrainerConfigError("A2C requires num_epoch = 1") + + learning_rate_schedule: ScheduleType = ScheduleType.LINEAR + beta_schedule: ScheduleType = ScheduleType.LINEAR + + +class A2COptimizer(TorchOptimizer): + def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): + """ + Takes a Policy and a Dict of trainer parameters and creates an Optimizer around the policy. + The A2C optimizer has a value estimator and a loss function. + :param policy: A TorchPolicy object that will be updated by this A2C Optimizer. + :param trainer_params: Trainer parameters dictionary that specifies the + properties of the trainer. + """ + # Create the graph here to give more granular control of the TF graph to the Optimizer. + + super().__init__(policy, trainer_settings) + self.hyperparameters: A2CSettings = cast( + A2CSettings, trainer_settings.hyperparameters + ) + + params = list(self.policy.actor.parameters()) + if self.hyperparameters.shared_critic: + self._critic = policy.actor + else: + + self._critic = ValueNetwork( + list(self.reward_signals.keys()), + policy.behavior_spec.observation_specs, + network_settings=trainer_settings.network_settings, + ) + self._critic.to(default_device()) + params += list(self._critic.parameters()) + + self.decay_learning_rate = ModelUtils.DecayedValue( + self.hyperparameters.learning_rate_schedule, + self.hyperparameters.learning_rate, + 1e-10, + self.trainer_settings.max_steps, + ) + + self.decay_beta = ModelUtils.DecayedValue( + self.hyperparameters.beta_schedule, + self.hyperparameters.beta, + 1e-10, + self.trainer_settings.max_steps, + ) + + self.optimizer = torch.optim.Adam( + params, lr=self.trainer_settings.hyperparameters.learning_rate + ) + self.stats_name_to_update_name = { + "Losses/Value Loss": "value_loss", + "Losses/Policy Loss": "policy_loss", + } + + self.stream_names = list(self.reward_signals.keys()) + + @property + def critic(self): + return self._critic + + @timed + def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: + """ + Performs update on model. + :param batch: Batch of experiences. + :param num_sequences: Number of sequences to process. + :return: Results of update. + """ + # Get decayed parameters + decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) + decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) + returns = {} + for name in self.reward_signals: + returns[name] = ModelUtils.list_to_tensor( + batch[RewardSignalUtil.returns_key(name)] + ) + + n_obs = len(self.policy.behavior_spec.observation_specs) + current_obs = ObsUtil.from_buffer(batch, n_obs) + # Convert to tensors + current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] + + act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK]) + actions = AgentAction.from_buffer(batch) + + memories = [ + ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i]) + for i in range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length) + ] + if len(memories) > 0: + memories = torch.stack(memories).unsqueeze(0) + + # Get value memories + value_memories = [ + ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i]) + for i in range( + 0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length + ) + ] + if len(value_memories) > 0: + value_memories = torch.stack(value_memories).unsqueeze(0) + + run_out = self.policy.actor.get_stats( + current_obs, + masks=act_masks, + actions=actions, + memories=memories, + sequence_length=self.policy.sequence_length, + ) + + log_probs = run_out["log_probs"] + entropy = run_out["entropy"] + + values, _ = self.critic.critic_pass( + current_obs, + memories=value_memories, + sequence_length=self.policy.sequence_length, + ) + log_probs = log_probs.flatten() + + value_loss_per_head = [] + for name, head in values.items(): + returns_tensor = returns[name] + be = (returns_tensor - head) ** 2 + value_loss_per_head.append(be) + value_loss = torch.mean(torch.stack(value_loss_per_head)) + + advantages = ModelUtils.list_to_tensor(batch[BufferKey.ADVANTAGES]) + policy_loss = -1 * torch.mean(torch.sum(log_probs, dim=1) * advantages) + + loss = policy_loss + 0.5 * value_loss - decay_bet * torch.mean(entropy) + + # Set optimizer learning rate + ModelUtils.update_learning_rate(self.optimizer, decay_lr) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + update_stats = { + # NOTE: abs() is not technically correct, but matches the behavior in TensorFlow. + # TODO: After PyTorch is default, change to something more correct. + "Losses/Policy Loss": torch.abs(policy_loss).item(), + "Losses/Value Loss": value_loss.item(), + "Policy/Learning Rate": decay_lr, + "Policy/Beta": decay_bet, + } + + return update_stats + + def get_modules(self): + modules = { + "Optimizer:value_optimizer": self.optimizer, + "Optimizer:critic": self._critic, + } + for reward_provider in self.reward_signals.values(): + modules.update(reward_provider.get_modules()) + return modules diff --git a/ml-agents-trainer-plugin/mlagents_trainer_plugin/a2c/a2c_trainer.py b/ml-agents-trainer-plugin/mlagents_trainer_plugin/a2c/a2c_trainer.py new file mode 100644 index 0000000000..36cf762fb8 --- /dev/null +++ b/ml-agents-trainer-plugin/mlagents_trainer_plugin/a2c/a2c_trainer.py @@ -0,0 +1,213 @@ +# # Unity ML-Agents Toolkit +# ## ML-Agent Learning (A2C) +# Contains an implementation of A2C as described in: https://arxiv.org/abs/1707.06347 + +from typing import cast + +import numpy as np + +from mlagents_envs.base_env import BehaviorSpec +from mlagents_envs.logging_util import get_logger +from mlagents.trainers.buffer import BufferKey, RewardSignalUtil +from mlagents.trainers.trainer.on_policy_trainer import OnPolicyTrainer +from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer +from mlagents.trainers.trainer.trainer_utils import get_gae +from mlagents.trainers.policy.torch_policy import TorchPolicy +from .a2c_optimizer import A2COptimizer, A2CSettings +from mlagents.trainers.trajectory import Trajectory +from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers +from mlagents.trainers.settings import TrainerSettings + +from mlagents.trainers.torch_entities.networks import SimpleActor, SharedActorCritic + +logger = get_logger(__name__) + +TRAINER_NAME = "a2c" + + +class A2CTrainer(OnPolicyTrainer): + """The A2CTrainer is an implementation of the A2C algorithm.""" + + def __init__( + self, + behavior_name: str, + reward_buff_cap: int, + trainer_settings: TrainerSettings, + training: bool, + load: bool, + seed: int, + artifact_path: str, + ): + """ + Responsible for collecting experiences and training A2C model. + :param behavior_name: The name of the behavior associated with trainer config + :param reward_buff_cap: Max reward history to track in the reward buffer + :param trainer_settings: The parameters for the trainer. + :param training: Whether the trainer is set for training. + :param load: Whether the model should be loaded. + :param seed: The seed the model will be initialized with + :param artifact_path: The directory within which to store artifacts from this trainer. + """ + super().__init__( + behavior_name, + reward_buff_cap, + trainer_settings, + training, + load, + seed, + artifact_path, + ) + self.hyperparameters: A2CSettings = cast( + A2CSettings, self.trainer_settings.hyperparameters + ) + self.shared_critic = self.hyperparameters.shared_critic + self.policy: TorchPolicy = None # type: ignore + + def _process_trajectory(self, trajectory: Trajectory) -> None: + """ + Takes a trajectory and processes it, putting it into the update buffer. + Processing involves calculating value and advantage targets for model updating step. + :param trajectory: The Trajectory tuple containing the steps to be processed. + """ + super()._process_trajectory(trajectory) + agent_id = trajectory.agent_id # All the agents should have the same ID + + agent_buffer_trajectory = trajectory.to_agentbuffer() + # Check if we used group rewards, warn if so. + self._warn_if_group_reward(agent_buffer_trajectory) + + # Update the normalization + if self.is_training: + self.policy.actor.update_normalization(agent_buffer_trajectory) + self.optimizer.critic.update_normalization(agent_buffer_trajectory) + + # Get all value estimates + ( + value_estimates, + value_next, + value_memories, + ) = self.optimizer.get_trajectory_value_estimates( + agent_buffer_trajectory, + trajectory.next_obs, + trajectory.done_reached and not trajectory.interrupted, + ) + if value_memories is not None: + agent_buffer_trajectory[BufferKey.CRITIC_MEMORY].set(value_memories) + + for name, v in value_estimates.items(): + agent_buffer_trajectory[RewardSignalUtil.value_estimates_key(name)].extend( + v + ) + self._stats_reporter.add_stat( + f"Policy/{self.optimizer.reward_signals[name].name.capitalize()} Value Estimate", + np.mean(v), + ) + + # Evaluate all reward functions + self.collected_rewards["environment"][agent_id] += np.sum( + agent_buffer_trajectory[BufferKey.ENVIRONMENT_REWARDS] + ) + for name, reward_signal in self.optimizer.reward_signals.items(): + evaluate_result = ( + reward_signal.evaluate(agent_buffer_trajectory) * reward_signal.strength + ) + agent_buffer_trajectory[RewardSignalUtil.rewards_key(name)].extend( + evaluate_result + ) + # Report the reward signals + self.collected_rewards[name][agent_id] += np.sum(evaluate_result) + + # Compute GAE and returns + tmp_advantages = [] + tmp_returns = [] + for name in self.optimizer.reward_signals: + bootstrap_value = value_next[name] + + local_rewards = agent_buffer_trajectory[ + RewardSignalUtil.rewards_key(name) + ].get_batch() + local_value_estimates = agent_buffer_trajectory[ + RewardSignalUtil.value_estimates_key(name) + ].get_batch() + + local_advantage = get_gae( + rewards=local_rewards, + value_estimates=local_value_estimates, + value_next=bootstrap_value, + gamma=self.optimizer.reward_signals[name].gamma, + lambd=self.hyperparameters.lambd, + ) + local_return = local_advantage + local_value_estimates + # This is later use as target for the different value estimates + agent_buffer_trajectory[RewardSignalUtil.returns_key(name)].set( + local_return + ) + agent_buffer_trajectory[RewardSignalUtil.advantage_key(name)].set( + local_advantage + ) + tmp_advantages.append(local_advantage) + tmp_returns.append(local_return) + + # Get global advantages + global_advantages = list( + np.mean(np.array(tmp_advantages, dtype=np.float32), axis=0) + ) + global_returns = list(np.mean(np.array(tmp_returns, dtype=np.float32), axis=0)) + agent_buffer_trajectory[BufferKey.ADVANTAGES].set(global_advantages) + agent_buffer_trajectory[BufferKey.DISCOUNTED_RETURNS].set(global_returns) + + self._append_to_update_buffer(agent_buffer_trajectory) + + # If this was a terminal trajectory, append stats and reset reward collection + if trajectory.done_reached: + self._update_end_episode_stats(agent_id, self.optimizer) + + def create_optimizer(self) -> TorchOptimizer: + """ + Creates an Optimizer object + """ + return A2COptimizer( # type: ignore + cast(TorchPolicy, self.policy), self.trainer_settings # type: ignore + ) # type: ignore + + def create_policy( + self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec + ) -> TorchPolicy: + """ + Creates a policy with a PyTorch backend and PPO hyperparameters + :param parsed_behavior_id: + :param behavior_spec: specifications for policy construction + :return policy + """ + actor_cls = SimpleActor + actor_kwargs = {"conditional_sigma": False, "tanh_squash": False} + if self.shared_critic: + reward_signal_configs = self.trainer_settings.reward_signals + reward_signal_names = [ + key.value for key, _ in reward_signal_configs.items() + ] + actor_cls = SharedActorCritic + actor_kwargs.update({"stream_names": reward_signal_names}) + + policy = TorchPolicy( + self.seed, + behavior_spec, + self.trainer_settings.network_settings, + actor_cls, + actor_kwargs, + ) + return policy + + @staticmethod + def get_settings_type(): + return A2CSettings + + @staticmethod + def get_trainer_name() -> str: + return TRAINER_NAME + + +def get_type_and_setting(): + return {A2CTrainer.get_trainer_name(): A2CTrainer}, { + A2CTrainer.get_trainer_name(): A2CSettings + } diff --git a/ml-agents/mlagents/trainers/torch/components/__init__.py b/ml-agents-trainer-plugin/mlagents_trainer_plugin/dqn/__init__.py similarity index 100% rename from ml-agents/mlagents/trainers/torch/components/__init__.py rename to ml-agents-trainer-plugin/mlagents_trainer_plugin/dqn/__init__.py diff --git a/ml-agents-trainer-plugin/mlagents_trainer_plugin/dqn/dqn_basic.yaml b/ml-agents-trainer-plugin/mlagents_trainer_plugin/dqn/dqn_basic.yaml new file mode 100644 index 0000000000..07237dac4e --- /dev/null +++ b/ml-agents-trainer-plugin/mlagents_trainer_plugin/dqn/dqn_basic.yaml @@ -0,0 +1,27 @@ +behaviors: + Basic: + trainer_type: dqn + hyperparameters: + learning_rate: 0.0003 + learning_rate_schedule: constant + batch_size: 64 + buffer_size: 50000 + tau: 0.005 + steps_per_update: 10.0 + save_replay_buffer: false + exploration_schedule: linear + exploration_initial_eps: 0.8 + exploration_final_eps: 0.05 + network_settings: + normalize: false + hidden_units: 20 + num_layers: 2 + vis_encode_type: simple + reward_signals: + extrinsic: + gamma: 0.99 + strength: 1.0 + keep_checkpoints: 5 + max_steps: 500000 + time_horizon: 10 + summary_freq: 1000 diff --git a/ml-agents-trainer-plugin/mlagents_trainer_plugin/dqn/dqn_optimizer.py b/ml-agents-trainer-plugin/mlagents_trainer_plugin/dqn/dqn_optimizer.py new file mode 100644 index 0000000000..e01b6b5020 --- /dev/null +++ b/ml-agents-trainer-plugin/mlagents_trainer_plugin/dqn/dqn_optimizer.py @@ -0,0 +1,290 @@ +from typing import cast +from mlagents.torch_utils import torch, nn, default_device +from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer +from mlagents.trainers.policy.torch_policy import TorchPolicy +from mlagents.trainers.buffer import AgentBuffer, BufferKey, RewardSignalUtil +from mlagents_envs.timers import timed +from typing import List, Dict, Tuple, Optional, Union, Any +from mlagents.trainers.torch_entities.networks import ValueNetwork, Actor +from mlagents_envs.base_env import ActionSpec, ObservationSpec +from mlagents.trainers.torch_entities.agent_action import AgentAction +from mlagents.trainers.torch_entities.utils import ModelUtils +from mlagents.trainers.trajectory import ObsUtil +from mlagents.trainers.settings import TrainerSettings, OffPolicyHyperparamSettings +from mlagents.trainers.settings import ScheduleType, NetworkSettings + +from mlagents.trainers.torch_entities.networks import Critic +import numpy as np +import attr + + +# TODO: fix saving to onnx + + +@attr.s(auto_attribs=True) +class DQNSettings(OffPolicyHyperparamSettings): + gamma: float = 0.99 + exploration_schedule: ScheduleType = ScheduleType.LINEAR + exploration_initial_eps: float = 0.1 + exploration_final_eps: float = 0.05 + target_update_interval: int = 10000 + tau: float = 0.005 + steps_per_update: float = 1 + save_replay_buffer: bool = False + reward_signal_steps_per_update: float = attr.ib() + + @reward_signal_steps_per_update.default + def _reward_signal_steps_per_update_default(self): + return self.steps_per_update + + +class DQNOptimizer(TorchOptimizer): + def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): + super().__init__(policy, trainer_settings) + + # initialize hyper parameters + params = list(self.policy.actor.parameters()) + self.optimizer = torch.optim.Adam( + params, lr=self.trainer_settings.hyperparameters.learning_rate + ) + self.stream_names = list(self.reward_signals.keys()) + self.gammas = [_val.gamma for _val in trainer_settings.reward_signals.values()] + self.use_dones_in_backup = { + name: int(not self.reward_signals[name].ignore_done) + for name in self.stream_names + } + + self.hyperparameters: DQNSettings = cast( + DQNSettings, trainer_settings.hyperparameters + ) + self.tau = self.hyperparameters.tau + self.decay_learning_rate = ModelUtils.DecayedValue( + self.hyperparameters.learning_rate_schedule, + self.hyperparameters.learning_rate, + 1e-10, + self.trainer_settings.max_steps, + ) + + self.decay_exploration_rate = ModelUtils.DecayedValue( + self.hyperparameters.exploration_schedule, + self.hyperparameters.exploration_initial_eps, + self.hyperparameters.exploration_final_eps, + 20000, + ) + + # initialize Target Q_network + self.q_net_target = QNetwork( + stream_names=self.reward_signals.keys(), + observation_specs=policy.behavior_spec.observation_specs, + network_settings=policy.network_settings, + action_spec=policy.behavior_spec.action_spec, + ) + ModelUtils.soft_update(self.policy.actor, self.q_net_target, 1.0) + + self.q_net_target.to(default_device()) + + @property + def critic(self): + return self.q_net_target + + @timed + def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: + """ + Performs update on model. + :param batch: Batch of experiences. + :param num_sequences: Number of sequences to process. + :return: Results of update. + """ + # Get decayed parameters + decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) + exp_rate = self.decay_exploration_rate.get_value(self.policy.get_current_step()) + self.policy.actor.exploration_rate = exp_rate + rewards = {} + for name in self.reward_signals: + rewards[name] = ModelUtils.list_to_tensor( + batch[RewardSignalUtil.rewards_key(name)] + ) + + n_obs = len(self.policy.behavior_spec.observation_specs) + current_obs = ObsUtil.from_buffer(batch, n_obs) + # Convert to tensors + current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] + + next_obs = ObsUtil.from_buffer_next(batch, n_obs) + # Convert to tensors + next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs] + + actions = AgentAction.from_buffer(batch) + + dones = ModelUtils.list_to_tensor(batch[BufferKey.DONE]) + + current_q_values, _ = self.policy.actor.critic_pass( + current_obs, sequence_length=self.policy.sequence_length + ) + + qloss = [] + with torch.no_grad(): + greedy_actions = self.policy.actor.get_greedy_action(current_q_values) + next_q_values_list, _ = self.q_net_target.critic_pass( + next_obs, sequence_length=self.policy.sequence_length + ) + for name_i, name in enumerate(rewards.keys()): + with torch.no_grad(): + next_q_values = torch.gather( + next_q_values_list[name], dim=1, index=greedy_actions + ).squeeze() + target_q_values = rewards[name] + ( + (1.0 - self.use_dones_in_backup[name] * dones) + * self.gammas[name_i] + * next_q_values + ) + target_q_values = target_q_values.reshape(-1, 1) + curr_q = torch.gather( + current_q_values[name], dim=1, index=actions.discrete_tensor + ) + qloss.append(torch.nn.functional.smooth_l1_loss(curr_q, target_q_values)) + + loss = torch.mean(torch.stack(qloss)) + ModelUtils.update_learning_rate(self.optimizer, decay_lr) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + ModelUtils.soft_update(self.policy.actor, self.q_net_target, self.tau) + update_stats = { + "Losses/Value Loss": loss.item(), + "Policy/Learning Rate": decay_lr, + "Policy/epsilon": exp_rate, + } + + for reward_provider in self.reward_signals.values(): + update_stats.update(reward_provider.update(batch)) + return update_stats + + def get_modules(self): + modules = { + "Optimizer:value_optimizer": self.optimizer, + "Optimizer:critic": self.critic, + } + for reward_provider in self.reward_signals.values(): + modules.update(reward_provider.get_modules()) + return modules + + +class QNetwork(nn.Module, Actor, Critic): + MODEL_EXPORT_VERSION = 3 + + def __init__( + self, + stream_names: List[str], + observation_specs: List[ObservationSpec], + network_settings: NetworkSettings, + action_spec: ActionSpec, + exploration_initial_eps: float = 1.0, + ): + self.exploration_rate = exploration_initial_eps + nn.Module.__init__(self) + output_act_size = max(sum(action_spec.discrete_branches), 1) + self.network_body = ValueNetwork( + stream_names, + observation_specs, + network_settings, + outputs_per_stream=output_act_size, + ) + + # extra tensors for exporting to ONNX + self.action_spec = action_spec + self.version_number = torch.nn.Parameter( + torch.Tensor([self.MODEL_EXPORT_VERSION]), requires_grad=False + ) + self.is_continuous_int_deprecated = torch.nn.Parameter( + torch.Tensor([int(self.action_spec.is_continuous())]), requires_grad=False + ) + self.continuous_act_size_vector = torch.nn.Parameter( + torch.Tensor([int(self.action_spec.continuous_size)]), requires_grad=False + ) + self.discrete_act_size_vector = torch.nn.Parameter( + torch.Tensor([self.action_spec.discrete_branches]), requires_grad=False + ) + self.act_size_vector_deprecated = torch.nn.Parameter( + torch.Tensor( + [ + self.action_spec.continuous_size + + sum(self.action_spec.discrete_branches) + ] + ), + requires_grad=False, + ) + self.memory_size_vector = torch.nn.Parameter( + torch.Tensor([int(self.network_body.memory_size)]), requires_grad=False + ) + + def update_normalization(self, buffer: AgentBuffer) -> None: + self.network_body.update_normalization(buffer) + + def critic_pass( + self, + inputs: List[torch.Tensor], + memories: Optional[torch.Tensor] = None, + sequence_length: int = 1, + ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: + value_outputs, critic_mem_out = self.network_body( + inputs, memories=memories, sequence_length=sequence_length + ) + return value_outputs, critic_mem_out + + @property + def memory_size(self) -> int: + return self.network_body.memory_size + + def forward( + self, + inputs: List[torch.Tensor], + masks: Optional[torch.Tensor] = None, + memories: Optional[torch.Tensor] = None, + sequence_length: int = 1, + ) -> Tuple[Union[int, torch.Tensor], ...]: + out_vals, memories = self.critic_pass(inputs, memories, sequence_length) + + # fixme random action tensor + export_out = [self.version_number, self.memory_size_vector] + + disc_action_out = self.get_greedy_action(out_vals) + deterministic_disc_action_out = self.get_random_action(out_vals) + export_out += [ + disc_action_out, + self.discrete_act_size_vector, + deterministic_disc_action_out, + ] + return tuple(export_out) + + def get_random_action(self, inputs) -> torch.Tensor: + action_out = torch.randint( + 0, self.action_spec.discrete_branches[0], (len(inputs), 1) + ) + return action_out + + @staticmethod + def get_greedy_action(q_values) -> torch.Tensor: + all_q = torch.cat([val.unsqueeze(0) for val in q_values.values()]) + return torch.argmax(all_q.sum(dim=0), dim=1, keepdim=True) + + def get_action_and_stats( + self, + inputs: List[torch.Tensor], + masks: Optional[torch.Tensor] = None, + memories: Optional[torch.Tensor] = None, + sequence_length: int = 1, + deterministic=False, + ) -> Tuple[AgentAction, Dict[str, Any], torch.Tensor]: + run_out = {} + if not deterministic and np.random.rand() < self.exploration_rate: + action_out = self.get_random_action(inputs) + action_out = AgentAction(None, [action_out]) + run_out["env_action"] = action_out.to_action_tuple() + else: + out_vals, _ = self.critic_pass(inputs, memories, sequence_length) + action_out = self.get_greedy_action(out_vals) + action_out = AgentAction(None, [action_out]) + run_out["env_action"] = action_out.to_action_tuple() + return action_out, run_out, torch.Tensor([]) diff --git a/ml-agents-trainer-plugin/mlagents_trainer_plugin/dqn/dqn_trainer.py b/ml-agents-trainer-plugin/mlagents_trainer_plugin/dqn/dqn_trainer.py new file mode 100644 index 0000000000..f2d8011b47 --- /dev/null +++ b/ml-agents-trainer-plugin/mlagents_trainer_plugin/dqn/dqn_trainer.py @@ -0,0 +1,162 @@ +from typing import cast + +import numpy as np +from mlagents_envs.logging_util import get_logger +from mlagents.trainers.buffer import BufferKey +from mlagents.trainers.policy.torch_policy import TorchPolicy +from mlagents.trainers.trainer.off_policy_trainer import OffPolicyTrainer +from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer +from mlagents.trainers.trajectory import Trajectory, ObsUtil +from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers +from mlagents_envs.base_env import BehaviorSpec +from mlagents.trainers.settings import TrainerSettings +from .dqn_optimizer import DQNOptimizer, DQNSettings, QNetwork + +logger = get_logger(__name__) +TRAINER_NAME = "dqn" + + +class DQNTrainer(OffPolicyTrainer): + """The DQNTrainer is an implementation of""" + + def __init__( + self, + behavior_name: str, + reward_buff_cap: int, + trainer_settings: TrainerSettings, + training: bool, + load: bool, + seed: int, + artifact_path: str, + ): + """ + Responsible for collecting experiences and training SAC model. + :param behavior_name: The name of the behavior associated with trainer config + :param reward_buff_cap: Max reward history to track in the reward buffer + :param trainer_settings: The parameters for the trainer. + :param training: Whether the trainer is set for training. + :param load: Whether the model should be loaded. + :param seed: The seed the model will be initialized with + :param artifact_path: The directory within which to store artifacts from this trainer. + """ + super().__init__( + behavior_name, + reward_buff_cap, + trainer_settings, + training, + load, + seed, + artifact_path, + ) + self.policy: TorchPolicy = None # type: ignore + self.optimizer: DQNOptimizer = None # type: ignore + + def _process_trajectory(self, trajectory: Trajectory) -> None: + """ + Takes a trajectory and processes it, putting it into the replay buffer. + """ + super()._process_trajectory(trajectory) + last_step = trajectory.steps[-1] + agent_id = trajectory.agent_id # All the agents should have the same ID + + agent_buffer_trajectory = trajectory.to_agentbuffer() + # Check if we used group rewards, warn if so. + self._warn_if_group_reward(agent_buffer_trajectory) + + # Update the normalization + if self.is_training: + self.policy.actor.update_normalization(agent_buffer_trajectory) + self.optimizer.critic.update_normalization(agent_buffer_trajectory) + + # Evaluate all reward functions for reporting purposes + self.collected_rewards["environment"][agent_id] += np.sum( + agent_buffer_trajectory[BufferKey.ENVIRONMENT_REWARDS] + ) + for name, reward_signal in self.optimizer.reward_signals.items(): + evaluate_result = ( + reward_signal.evaluate(agent_buffer_trajectory) * reward_signal.strength + ) + + # Report the reward signals + self.collected_rewards[name][agent_id] += np.sum(evaluate_result) + + # Get all value estimates for reporting purposes + ( + value_estimates, + _, + value_memories, + ) = self.optimizer.get_trajectory_value_estimates( + agent_buffer_trajectory, trajectory.next_obs, trajectory.done_reached + ) + if value_memories is not None: + agent_buffer_trajectory[BufferKey.CRITIC_MEMORY].set(value_memories) + + for name, v in value_estimates.items(): + self._stats_reporter.add_stat( + f"Policy/{self.optimizer.reward_signals[name].name.capitalize()} Value", + np.mean(v), + ) + + # Bootstrap using the last step rather than the bootstrap step if max step is reached. + # Set last element to duplicate obs and remove dones. + if last_step.interrupted: + last_step_obs = last_step.obs + for i, obs in enumerate(last_step_obs): + agent_buffer_trajectory[ObsUtil.get_name_at_next(i)][-1] = obs + agent_buffer_trajectory[BufferKey.DONE][-1] = False + + self._append_to_update_buffer(agent_buffer_trajectory) + + if trajectory.done_reached: + self._update_end_episode_stats(agent_id, self.optimizer) + + def create_optimizer(self) -> TorchOptimizer: + """ + Creates an Optimizer object + """ + return DQNOptimizer( # type: ignore + cast(TorchPolicy, self.policy), self.trainer_settings # type: ignore + ) # type: ignore + + def create_policy( + self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec + ) -> TorchPolicy: + """ + Creates a policy with a PyTorch backend and give DQN hyperparameters + :param parsed_behavior_id: + :param behavior_spec: specifications for policy construction + :return policy + """ + # initialize online Q-network which works as actor + exploration_initial_eps = cast( + DQNSettings, self.trainer_settings.hyperparameters + ).exploration_initial_eps + actor_kwargs = { + "exploration_initial_eps": exploration_initial_eps, + "stream_names": [ + signal.value for signal in self.trainer_settings.reward_signals.keys() + ], + } + policy = TorchPolicy( + self.seed, + behavior_spec, + self.trainer_settings.network_settings, + actor_cls=QNetwork, + actor_kwargs=actor_kwargs, + ) + self.maybe_load_replay_buffer() + return policy + + @staticmethod + def get_settings_type(): + return DQNSettings + + @staticmethod + def get_trainer_name() -> str: + return TRAINER_NAME + + +def get_type_and_setting(): + return {DQNTrainer.get_trainer_name(): DQNTrainer}, { + DQNTrainer.get_trainer_name(): DQNTrainer.get_settings_type() + } diff --git a/ml-agents-trainer-plugin/setup.py b/ml-agents-trainer-plugin/setup.py new file mode 100644 index 0000000000..01f8c66aab --- /dev/null +++ b/ml-agents-trainer-plugin/setup.py @@ -0,0 +1,18 @@ +from setuptools import setup +from mlagents.plugins import ML_AGENTS_TRAINER_TYPE + +setup( + name="mlagents_trainer_plugin", + version="0.0.1", + # Example of how to add your own registration functions that will be called + # by mlagents-learn. + # + # Here, the get_example_stats_writer() function in mlagents_plugin_examples/example_stats_writer.py + # will get registered with the ML_AGENTS_STATS_WRITER plugin interface. + entry_points={ + ML_AGENTS_TRAINER_TYPE: [ + "a2c=mlagents_trainer_plugin.a2c.a2c_trainer:get_type_and_setting", + "dqn=mlagents_trainer_plugin.dqn.dqn_trainer:get_type_and_setting", + ] + }, +) diff --git a/ml-agents/README.md b/ml-agents/README.md index ad9d53e900..8d347344fc 100644 --- a/ml-agents/README.md +++ b/ml-agents/README.md @@ -16,7 +16,7 @@ package. Install the `mlagents` package with: ```sh -python -m pip install mlagents==0.28.0 +python -m pip install mlagents==0.29.0 ``` ## Usage & More Information @@ -27,7 +27,5 @@ scene with the ML-Agents SDK, check out the main ## Limitations -- `mlagents` does not yet explicitly support multi-agent scenarios so training - cooperative behavior among different agents is not stable. - Resuming self-play from a checkpoint resets the reported ELO to the default value. diff --git a/ml-agents/mlagents/plugins/__init__.py b/ml-agents/mlagents/plugins/__init__.py index a5a5353a15..b63a39732d 100644 --- a/ml-agents/mlagents/plugins/__init__.py +++ b/ml-agents/mlagents/plugins/__init__.py @@ -1 +1,8 @@ +from typing import Dict, Any + ML_AGENTS_STATS_WRITER = "mlagents.stats_writer" +ML_AGENTS_TRAINER_TYPE = "mlagents.trainer_type" + +# TODO: the real type is Dict[str, HyperparamSettings] +all_trainer_types: Dict[str, Any] = {} +all_trainer_settings: Dict[str, Any] = {} diff --git a/ml-agents/mlagents/plugins/trainer_type.py b/ml-agents/mlagents/plugins/trainer_type.py new file mode 100644 index 0000000000..2766368863 --- /dev/null +++ b/ml-agents/mlagents/plugins/trainer_type.py @@ -0,0 +1,80 @@ +import sys +from typing import Dict, Tuple, Any + +# importlib.metadata is new in python3.8 +# We use the backport for older python versions. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata # pylint: disable=E0611 + + +from mlagents_envs import logging_util +from mlagents.plugins import ML_AGENTS_TRAINER_TYPE +from mlagents.trainers.ppo.trainer import PPOTrainer +from mlagents.trainers.sac.trainer import SACTrainer +from mlagents.trainers.poca.trainer import POCATrainer +from mlagents.trainers.ppo.optimizer_torch import PPOSettings +from mlagents.trainers.sac.optimizer_torch import SACSettings +from mlagents.trainers.poca.optimizer_torch import POCASettings +from mlagents import plugins as mla_plugins + +logger = logging_util.get_logger(__name__) + + +def get_default_trainer_types() -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + The Trainers that mlagents-learn always uses: + """ + + mla_plugins.all_trainer_types.update( + { + PPOTrainer.get_trainer_name(): PPOTrainer, + SACTrainer.get_trainer_name(): SACTrainer, + POCATrainer.get_trainer_name(): POCATrainer, + } + ) + # global all_trainer_settings + mla_plugins.all_trainer_settings.update( + { + PPOTrainer.get_trainer_name(): PPOSettings, + SACTrainer.get_trainer_name(): SACSettings, + POCATrainer.get_trainer_name(): POCASettings, + } + ) + + return mla_plugins.all_trainer_types, mla_plugins.all_trainer_settings + + +def register_trainer_plugins() -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Registers all Trainer plugins (including the default one), + and evaluates them, and returns the list of all the Trainer implementations. + """ + if ML_AGENTS_TRAINER_TYPE not in importlib_metadata.entry_points(): + logger.warning( + f"Unable to find any entry points for {ML_AGENTS_TRAINER_TYPE}, even the default ones. " + "Uninstalling and reinstalling ml-agents via pip should resolve. " + "Using default plugins for now." + ) + return get_default_trainer_types() + + entry_points = importlib_metadata.entry_points()[ML_AGENTS_TRAINER_TYPE] + + for entry_point in entry_points: + + try: + logger.debug(f"Initializing Trainer plugins: {entry_point.name}") + plugin_func = entry_point.load() + plugin_trainer_types, plugin_trainer_settings = plugin_func() + logger.debug( + f"Found {len(plugin_trainer_types)} Trainers for plugin {entry_point.name}" + ) + mla_plugins.all_trainer_types.update(plugin_trainer_types) + mla_plugins.all_trainer_settings.update(plugin_trainer_settings) + except BaseException: + # Catch all exceptions from setting up the plugin, so that bad user code doesn't break things. + logger.exception( + f"Error initializing Trainer plugins for {entry_point.name}. This plugin will not be used." + ) + return mla_plugins.all_trainer_types, mla_plugins.all_trainer_settings diff --git a/ml-agents/mlagents/trainers/__init__.py b/ml-agents/mlagents/trainers/__init__.py index 5d158f10d1..8650fdbc2a 100644 --- a/ml-agents/mlagents/trainers/__init__.py +++ b/ml-agents/mlagents/trainers/__init__.py @@ -1,5 +1,5 @@ # Version of the library that will be used to upload to pypi -__version__ = "0.29.0.dev0" +__version__ = "0.30.0" # Git tag that will be checked to determine whether to trigger upload to pypi __release_tag__ = None diff --git a/ml-agents/mlagents/trainers/agent_processor.py b/ml-agents/mlagents/trainers/agent_processor.py index 3702e44ada..720f3d14bd 100644 --- a/ml-agents/mlagents/trainers/agent_processor.py +++ b/ml-agents/mlagents/trainers/agent_processor.py @@ -3,6 +3,7 @@ from typing import List, Dict, TypeVar, Generic, Tuple, Any, Union from collections import defaultdict, Counter import queue +from mlagents.torch_utils import torch from mlagents_envs.base_env import ( ActionTuple, @@ -19,7 +20,6 @@ from mlagents.trainers.trajectory import AgentStatus, Trajectory, AgentExperience from mlagents.trainers.policy import Policy from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs -from mlagents.trainers.torch.action_log_probs import LogProbsTuple from mlagents.trainers.stats import StatsReporter from mlagents.trainers.behavior_id_utils import ( get_global_agent_id, @@ -27,6 +27,8 @@ GlobalAgentId, GlobalGroupId, ) +from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple +from mlagents.trainers.torch_entities.utils import ModelUtils T = TypeVar("T") @@ -100,8 +102,13 @@ def add_experiences( """ take_action_outputs = previous_action.outputs if take_action_outputs: - for _entropy in take_action_outputs["entropy"]: - self._stats_reporter.add_stat("Policy/Entropy", _entropy) + try: + for _entropy in take_action_outputs["entropy"]: + if isinstance(_entropy, torch.Tensor): + _entropy = ModelUtils.to_numpy(_entropy) + self._stats_reporter.add_stat("Policy/Entropy", _entropy) + except KeyError: + pass # Make unique agent_ids that are global across workers action_global_agent_ids = [ @@ -233,11 +240,17 @@ def _process_step( continuous=stored_actions.continuous[idx], discrete=stored_actions.discrete[idx], ) - stored_action_probs = stored_take_action_outputs["log_probs"] - log_probs_tuple = LogProbsTuple( - continuous=stored_action_probs.continuous[idx], - discrete=stored_action_probs.discrete[idx], - ) + try: + stored_action_probs = stored_take_action_outputs["log_probs"] + if not isinstance(stored_action_probs, LogProbsTuple): + stored_action_probs = stored_action_probs.to_log_probs_tuple() + log_probs_tuple = LogProbsTuple( + continuous=stored_action_probs.continuous[idx], + discrete=stored_action_probs.discrete[idx], + ) + except KeyError: + log_probs_tuple = LogProbsTuple.empty_log_probs() + action_mask = stored_decision_step.action_mask prev_action = self.policy.retrieve_previous_action([global_agent_id])[0, :] diff --git a/ml-agents/mlagents/trainers/buffer.py b/ml-agents/mlagents/trainers/buffer.py index f1c93a41ee..ea6a2d5111 100644 --- a/ml-agents/mlagents/trainers/buffer.py +++ b/ml-agents/mlagents/trainers/buffer.py @@ -137,7 +137,6 @@ def set(self, data: List[BufferEntry]) -> None: Sets the list of BufferEntry to the input data :param data: The BufferEntry list to be set. """ - self[:] = [] self[:] = data def get_batch( @@ -245,6 +244,12 @@ def padded_to_batch( ) return new_list + def to_ndarray(self): + """ + Returns the AgentBufferField which is a list of numpy ndarrays (or List[np.ndarray]) as an ndarray. + """ + return np.array(self) + class AgentBuffer(MutableMapping): """ diff --git a/ml-agents/mlagents/trainers/ghost/trainer.py b/ml-agents/mlagents/trainers/ghost/trainer.py index faa04f1312..f49a643574 100644 --- a/ml-agents/mlagents/trainers/ghost/trainer.py +++ b/ml-agents/mlagents/trainers/ghost/trainer.py @@ -11,6 +11,7 @@ from mlagents.trainers.policy import Policy from mlagents.trainers.trainer import Trainer +from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer from mlagents.trainers.trajectory import Trajectory from mlagents.trainers.agent_processor import AgentManagerQueue from mlagents.trainers.stats import StatsPropertyType @@ -336,7 +337,6 @@ def create_policy( self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec, - create_graph: bool = False, ) -> Policy: """ Creates policy with the wrapped trainer's create_policy function @@ -345,9 +345,7 @@ def create_policy( team are grouped. All policies associated with this team are added to the wrapped trainer to be trained. """ - policy = self.trainer.create_policy( - parsed_behavior_id, behavior_spec, create_graph=True - ) + policy = self.trainer.create_policy(parsed_behavior_id, behavior_spec) team_id = parsed_behavior_id.team_id self.controller.subscribe_team_id(team_id, self) @@ -372,6 +370,9 @@ def create_policy( ) return policy + def create_optimizer(self) -> TorchOptimizer: + pass + def add_policy( self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy ) -> None: @@ -384,14 +385,6 @@ def add_policy( self._name_to_parsed_behavior_id[name_behavior_id] = parsed_behavior_id self.policies[name_behavior_id] = policy - def get_policy(self, name_behavior_id: str) -> Policy: - """ - Gets policy associated with name_behavior_id - :param name_behavior_id: Fully qualified behavior name - :return: Policy associated with name_behavior_id - """ - return self.policies[name_behavior_id] - def _save_snapshot(self) -> None: """ Saves a snapshot of the current weights of the policy and maintains the policy_snapshots diff --git a/ml-agents/mlagents/trainers/learn.py b/ml-agents/mlagents/trainers/learn.py index b79a4d85bc..69320920a5 100644 --- a/ml-agents/mlagents/trainers/learn.py +++ b/ml-agents/mlagents/trainers/learn.py @@ -33,6 +33,7 @@ ) from mlagents_envs import logging_util from mlagents.plugins.stats_writer import register_stats_writer_plugins +from mlagents.plugins.trainer_type import register_trainer_plugins logger = logging_util.get_logger(__name__) @@ -47,7 +48,10 @@ def get_version_string() -> str: PyTorch: {torch_utils.torch.__version__}""" -def parse_command_line(argv: Optional[List[str]] = None) -> RunOptions: +def parse_command_line( + argv: Optional[List[str]] = None, +) -> RunOptions: + _, _ = register_trainer_plugins() args = parser.parse_args(argv) return RunOptions.from_argparse(args) diff --git a/ml-agents/mlagents/trainers/model_saver/torch_model_saver.py b/ml-agents/mlagents/trainers/model_saver/torch_model_saver.py index e75e50d4cd..70c3f19e43 100644 --- a/ml-agents/mlagents/trainers/model_saver/torch_model_saver.py +++ b/ml-agents/mlagents/trainers/model_saver/torch_model_saver.py @@ -8,7 +8,7 @@ from mlagents.trainers.settings import TrainerSettings, SerializationSettings from mlagents.trainers.policy.torch_policy import TorchPolicy from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer -from mlagents.trainers.torch.model_serialization import ModelSerializer +from mlagents.trainers.torch_entities.model_serialization import ModelSerializer logger = get_logger(__name__) diff --git a/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py b/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py index 7f1637164b..8cb0a6ee8c 100644 --- a/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py +++ b/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py @@ -5,8 +5,10 @@ from mlagents.trainers.buffer import AgentBuffer, AgentBufferField from mlagents.trainers.trajectory import ObsUtil -from mlagents.trainers.torch.components.bc.module import BCModule -from mlagents.trainers.torch.components.reward_providers import create_reward_provider +from mlagents.trainers.torch_entities.components.bc.module import BCModule +from mlagents.trainers.torch_entities.components.reward_providers import ( + create_reward_provider, +) from mlagents.trainers.policy.torch_policy import TorchPolicy from mlagents.trainers.optimizer import Optimizer @@ -15,7 +17,7 @@ RewardSignalSettings, RewardSignalType, ) -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.torch_entities.utils import ModelUtils class TorchOptimizer(Optimizer): @@ -130,6 +132,12 @@ def _evaluate_by_sequence( next_mem = _mem return all_value_tensors, all_next_memories, next_mem + def update_reward_signals(self, batch: AgentBuffer) -> Dict[str, float]: + update_stats: Dict[str, float] = {} + for reward_provider in self.reward_signals.values(): + update_stats.update(reward_provider.update(batch)) + return update_stats + def get_trajectory_value_estimates( self, batch: AgentBuffer, diff --git a/ml-agents/mlagents/trainers/poca/optimizer_torch.py b/ml-agents/mlagents/trainers/poca/optimizer_torch.py index cf672b7679..4f77de4ebb 100644 --- a/ml-agents/mlagents/trainers/poca/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/poca/optimizer_torch.py @@ -1,6 +1,8 @@ from typing import Dict, cast, List, Tuple, Optional from collections import defaultdict -from mlagents.trainers.torch.components.reward_providers.extrinsic_reward_provider import ( +import attr + +from mlagents.trainers.torch_entities.components.reward_providers.extrinsic_reward_provider import ( ExtrinsicRewardProvider, ) import numpy as np @@ -21,21 +23,33 @@ RewardSignalSettings, RewardSignalType, TrainerSettings, - POCASettings, + NetworkSettings, + OnPolicyHyperparamSettings, + ScheduleType, ) -from mlagents.trainers.torch.networks import Critic, MultiAgentNetworkBody -from mlagents.trainers.torch.decoders import ValueHeads -from mlagents.trainers.torch.agent_action import AgentAction -from mlagents.trainers.torch.action_log_probs import ActionLogProbs -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.torch_entities.networks import Critic, MultiAgentNetworkBody +from mlagents.trainers.torch_entities.decoders import ValueHeads +from mlagents.trainers.torch_entities.agent_action import AgentAction +from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs +from mlagents.trainers.torch_entities.utils import ModelUtils from mlagents.trainers.trajectory import ObsUtil, GroupObsUtil -from mlagents.trainers.settings import NetworkSettings from mlagents_envs.logging_util import get_logger logger = get_logger(__name__) +@attr.s(auto_attribs=True) +class POCASettings(OnPolicyHyperparamSettings): + beta: float = 5.0e-3 + epsilon: float = 0.2 + lambd: float = 0.95 + num_epoch: int = 3 + learning_rate_schedule: ScheduleType = ScheduleType.LINEAR + beta_schedule: ScheduleType = ScheduleType.LINEAR + epsilon_schedule: ScheduleType = ScheduleType.LINEAR + + class TorchPOCAOptimizer(TorchOptimizer): class POCAValueNetwork(torch.nn.Module, Critic): """ @@ -162,9 +176,11 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): self._critic.to(default_device()) params = list(self.policy.actor.parameters()) + list(self.critic.parameters()) + self.hyperparameters: POCASettings = cast( POCASettings, trainer_settings.hyperparameters ) + self.decay_learning_rate = ModelUtils.DecayedValue( self.hyperparameters.learning_rate_schedule, self.hyperparameters.learning_rate, @@ -285,13 +301,17 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: value_memories = torch.stack(value_memories).unsqueeze(0) baseline_memories = torch.stack(baseline_memories).unsqueeze(0) - log_probs, entropy = self.policy.evaluate_actions( + run_out = self.policy.actor.get_stats( current_obs, + actions, masks=act_masks, - actions=actions, memories=memories, - seq_len=self.policy.sequence_length, + sequence_length=self.policy.sequence_length, ) + + log_probs = run_out["log_probs"] + entropy = run_out["entropy"] + all_obs = [current_obs] + groupmate_obs values, _ = self.critic.critic_pass( all_obs, @@ -346,9 +366,6 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: "Policy/Beta": decay_bet, } - for reward_provider in self.reward_signals.values(): - update_stats.update(reward_provider.update(batch)) - return update_stats def get_modules(self): diff --git a/ml-agents/mlagents/trainers/poca/trainer.py b/ml-agents/mlagents/trainers/poca/trainer.py index 45f561321f..266a149321 100644 --- a/ml-agents/mlagents/trainers/poca/trainer.py +++ b/ml-agents/mlagents/trainers/poca/trainer.py @@ -3,7 +3,7 @@ # Contains an implementation of MA-POCA. from collections import defaultdict -from typing import cast, Dict +from typing import cast, Dict, Union, Any, Type import numpy as np @@ -11,18 +11,23 @@ from mlagents_envs.logging_util import get_logger from mlagents_envs.base_env import BehaviorSpec from mlagents.trainers.buffer import BufferKey, RewardSignalUtil -from mlagents.trainers.trainer.rl_trainer import RLTrainer +from mlagents.trainers.trainer.on_policy_trainer import OnPolicyTrainer +from mlagents.trainers.trainer.trainer_utils import lambda_return from mlagents.trainers.policy import Policy from mlagents.trainers.policy.torch_policy import TorchPolicy -from mlagents.trainers.poca.optimizer_torch import TorchPOCAOptimizer +from mlagents.trainers.poca.optimizer_torch import TorchPOCAOptimizer, POCASettings from mlagents.trainers.trajectory import Trajectory from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers -from mlagents.trainers.settings import TrainerSettings, POCASettings +from mlagents.trainers.settings import TrainerSettings + +from mlagents.trainers.torch_entities.networks import SimpleActor, SharedActorCritic logger = get_logger(__name__) +TRAINER_NAME = "poca" + -class POCATrainer(RLTrainer): +class POCATrainer(OnPolicyTrainer): """The POCATrainer is an implementation of the MA-POCA algorithm.""" def __init__( @@ -47,17 +52,19 @@ def __init__( """ super().__init__( behavior_name, + reward_buff_cap, trainer_settings, training, load, + seed, artifact_path, - reward_buff_cap, ) self.hyperparameters: POCASettings = cast( POCASettings, self.trainer_settings.hyperparameters ) self.seed = seed self.policy: TorchPolicy = None # type: ignore + self.optimizer: TorchPOCAOptimizer = None # type: ignore self.collected_group_rewards: Dict[str, int] = defaultdict(lambda: 0) def _process_trajectory(self, trajectory: Trajectory) -> None: @@ -72,7 +79,7 @@ def _process_trajectory(self, trajectory: Trajectory) -> None: agent_buffer_trajectory = trajectory.to_agentbuffer() # Update the normalization if self.is_training: - self.policy.update_normalization(agent_buffer_trajectory) + self.policy.actor.update_normalization(agent_buffer_trajectory) self.optimizer.critic.update_normalization(agent_buffer_trajectory) # Get all value estimates @@ -193,56 +200,6 @@ def _is_ready_update(self): size_of_buffer = self.update_buffer.num_experiences return size_of_buffer > self.hyperparameters.buffer_size - def _update_policy(self): - """ - Uses demonstration_buffer to update the policy. - The reward signal generators must be updated in this method at their own pace. - """ - buffer_length = self.update_buffer.num_experiences - self.cumulative_returns_since_policy_update.clear() - - # Make sure batch_size is a multiple of sequence length. During training, we - # will need to reshape the data into a batch_size x sequence_length tensor. - batch_size = ( - self.hyperparameters.batch_size - - self.hyperparameters.batch_size % self.policy.sequence_length - ) - # Make sure there is at least one sequence - batch_size = max(batch_size, self.policy.sequence_length) - - n_sequences = max( - int(self.hyperparameters.batch_size / self.policy.sequence_length), 1 - ) - - advantages = np.array( - self.update_buffer[BufferKey.ADVANTAGES].get_batch(), dtype=np.float32 - ) - self.update_buffer[BufferKey.ADVANTAGES].set( - (advantages - advantages.mean()) / (advantages.std() + 1e-10) - ) - num_epoch = self.hyperparameters.num_epoch - batch_update_stats = defaultdict(list) - for _ in range(num_epoch): - self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) - buffer = self.update_buffer - max_num_batch = buffer_length // batch_size - for i in range(0, max_num_batch * batch_size, batch_size): - update_stats = self.optimizer.update( - buffer.make_mini_batch(i, i + batch_size), n_sequences - ) - for stat_name, value in update_stats.items(): - batch_update_stats[stat_name].append(value) - - for stat, stat_list in batch_update_stats.items(): - self._stats_reporter.add_stat(stat, np.mean(stat_list)) - - if self.optimizer.bc_module: - update_stats = self.optimizer.bc_module.update() - for stat, val in update_stats.items(): - self._stats_reporter.add_stat(stat, val) - self._clear_update_buffer() - return True - def end_episode(self) -> None: """ A signal that the Episode has ended. The buffer must be reset. @@ -252,7 +209,7 @@ def end_episode(self) -> None: super().end_episode() self.collected_group_rewards.clear() - def create_torch_policy( + def create_policy( self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec ) -> TorchPolicy: """ @@ -261,41 +218,24 @@ def create_torch_policy( :param behavior_spec: specifications for policy construction :return policy """ + actor_cls: Union[Type[SimpleActor], Type[SharedActorCritic]] = SimpleActor + actor_kwargs: Dict[str, Any] = { + "conditional_sigma": False, + "tanh_squash": False, + } + policy = TorchPolicy( self.seed, behavior_spec, - self.trainer_settings, - condition_sigma_on_obs=False, # Faster training for POCA - separate_critic=True, # Match network architecture with TF + self.trainer_settings.network_settings, + actor_cls, + actor_kwargs, ) return policy - def create_poca_optimizer(self) -> TorchPOCAOptimizer: + def create_optimizer(self) -> TorchPOCAOptimizer: return TorchPOCAOptimizer(self.policy, self.trainer_settings) - def add_policy( - self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy - ) -> None: - """ - Adds policy to trainer. - :param parsed_behavior_id: Behavior identifiers that the policy should belong to. - :param policy: Policy to associate with name_behavior_id. - """ - if not isinstance(policy, TorchPolicy): - raise RuntimeError(f"policy {policy} must be an instance of TorchPolicy.") - self.policy = policy - self.policies[parsed_behavior_id.behavior_id] = policy - self.optimizer = self.create_poca_optimizer() - for _reward_signal in self.optimizer.reward_signals.keys(): - self.collected_rewards[_reward_signal] = defaultdict(lambda: 0) - - self.model_saver.register(self.policy) - self.model_saver.register(self.optimizer) - self.model_saver.initialize_or_load() - - # Needed to resume loads properly - self._step = policy.get_current_step() - def get_policy(self, name_behavior_id: str) -> Policy: """ Gets policy from trainer associated with name_behavior_id @@ -304,14 +244,6 @@ def get_policy(self, name_behavior_id: str) -> Policy: return self.policy - -def lambda_return(r, value_estimates, gamma=0.99, lambd=0.8, value_next=0.0): - returns = np.zeros_like(r) - returns[-1] = r[-1] + gamma * value_next - for t in reversed(range(0, r.size - 1)): - returns[t] = ( - gamma * lambd * returns[t + 1] - + r[t] - + (1 - lambd) * gamma * value_estimates[t + 1] - ) - return returns + @staticmethod + def get_trainer_name() -> str: + return TRAINER_NAME diff --git a/ml-agents/mlagents/trainers/policy/policy.py b/ml-agents/mlagents/trainers/policy/policy.py index 5208a39cf8..0c5e9f7247 100644 --- a/ml-agents/mlagents/trainers/policy/policy.py +++ b/ml-agents/mlagents/trainers/policy/policy.py @@ -6,8 +6,7 @@ from mlagents_envs.exception import UnityException from mlagents.trainers.action_info import ActionInfo -from mlagents.trainers.settings import TrainerSettings, NetworkSettings -from mlagents.trainers.buffer import AgentBuffer +from mlagents.trainers.settings import NetworkSettings from mlagents.trainers.behavior_id_utils import GlobalAgentId @@ -24,40 +23,22 @@ def __init__( self, seed: int, behavior_spec: BehaviorSpec, - trainer_settings: TrainerSettings, - tanh_squash: bool = False, - condition_sigma_on_obs: bool = True, + network_settings: NetworkSettings, ): self.behavior_spec = behavior_spec - self.trainer_settings = trainer_settings - self.network_settings: NetworkSettings = trainer_settings.network_settings + self.network_settings: NetworkSettings = network_settings self.seed = seed self.previous_action_dict: Dict[str, np.ndarray] = {} self.previous_memory_dict: Dict[str, np.ndarray] = {} self.memory_dict: Dict[str, np.ndarray] = {} - self.normalize = trainer_settings.network_settings.normalize + self.normalize = network_settings.normalize self.use_recurrent = self.network_settings.memory is not None - self.h_size = self.network_settings.hidden_units - num_layers = self.network_settings.num_layers - if num_layers < 1: - num_layers = 1 - self.num_layers = num_layers - - self.vis_encode_type = self.network_settings.vis_encode_type - self.tanh_squash = tanh_squash - self.condition_sigma_on_obs = condition_sigma_on_obs - self.m_size = 0 self.sequence_length = 1 - if self.network_settings.memory is not None: + if self.use_recurrent: self.m_size = self.network_settings.memory.memory_size self.sequence_length = self.network_settings.memory.sequence_length - # Non-exposed parameters; these aren't exposed because they don't have a - # good explanation and usually shouldn't be touched. - self.log_std_min = -20 - self.log_std_max = 2 - def make_empty_memory(self, num_agents): """ Creates empty memory for use with RNNs @@ -144,10 +125,6 @@ def check_nan_action(action: Optional[ActionTuple]) -> None: if has_nan: raise RuntimeError("Continuous NaN action detected.") - @abstractmethod - def update_normalization(self, buffer: AgentBuffer) -> None: - pass - @abstractmethod def increment_step(self, n_steps): pass diff --git a/ml-agents/mlagents/trainers/policy/torch_policy.py b/ml-agents/mlagents/trainers/policy/torch_policy.py index 3ac769f8a9..fceacda6e9 100644 --- a/ml-agents/mlagents/trainers/policy/torch_policy.py +++ b/ml-agents/mlagents/trainers/policy/torch_policy.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, List import numpy as np from mlagents.torch_utils import torch, default_device import copy @@ -9,13 +9,10 @@ from mlagents_envs.base_env import DecisionSteps, BehaviorSpec from mlagents_envs.timers import timed -from mlagents.trainers.settings import TrainerSettings -from mlagents.trainers.torch.networks import SimpleActor, SharedActorCritic, GlobalSteps +from mlagents.trainers.settings import NetworkSettings +from mlagents.trainers.torch_entities.networks import GlobalSteps -from mlagents.trainers.torch.utils import ModelUtils -from mlagents.trainers.buffer import AgentBuffer -from mlagents.trainers.torch.agent_action import AgentAction -from mlagents.trainers.torch.action_log_probs import ActionLogProbs +from mlagents.trainers.torch_entities.utils import ModelUtils EPSILON = 1e-7 # Small value to avoid divide by zero @@ -25,10 +22,9 @@ def __init__( self, seed: int, behavior_spec: BehaviorSpec, - trainer_settings: TrainerSettings, - tanh_squash: bool = False, - separate_critic: bool = True, - condition_sigma_on_obs: bool = True, + network_settings: NetworkSettings, + actor_cls: type, + actor_kwargs: Dict[str, Any], ): """ Policy that uses a multilayer perceptron to map the observations to actions. Could @@ -36,46 +32,26 @@ def __init__( continuous actions, as well as recurrent networks. :param seed: Random seed. :param behavior_spec: Assigned BehaviorSpec object. - :param trainer_settings: Defined training parameters. - :param load: Whether a pre-trained model will be loaded or a new one created. - :param tanh_squash: Whether to use a tanh function on the continuous output, - or a clipped output. + :param network_settings: Defined network parameters. + :param actor_cls: The type of Actor + :param actor_kwargs: Keyword args for the Actor class """ - super().__init__( - seed, behavior_spec, trainer_settings, tanh_squash, condition_sigma_on_obs - ) + super().__init__(seed, behavior_spec, network_settings) self.global_step = ( GlobalSteps() ) # could be much simpler if TorchPolicy is nn.Module - self.grads = None self.stats_name_to_update_name = { "Losses/Value Loss": "value_loss", "Losses/Policy Loss": "policy_loss", } - if separate_critic: - self.actor = SimpleActor( - observation_specs=self.behavior_spec.observation_specs, - network_settings=trainer_settings.network_settings, - action_spec=behavior_spec.action_spec, - conditional_sigma=self.condition_sigma_on_obs, - tanh_squash=tanh_squash, - ) - self.shared_critic = False - else: - reward_signal_configs = trainer_settings.reward_signals - reward_signal_names = [ - key.value for key, _ in reward_signal_configs.items() - ] - self.actor = SharedActorCritic( - observation_specs=self.behavior_spec.observation_specs, - network_settings=trainer_settings.network_settings, - action_spec=behavior_spec.action_spec, - stream_names=reward_signal_names, - conditional_sigma=self.condition_sigma_on_obs, - tanh_squash=tanh_squash, - ) - self.shared_critic = True + + self.actor = actor_cls( + observation_specs=self.behavior_spec.observation_specs, + network_settings=network_settings, + action_spec=behavior_spec.action_spec, + **actor_kwargs, + ) # Save the m_size needed for export self._export_m_size = self.m_size @@ -83,7 +59,6 @@ def __init__( self.m_size = self.actor.memory_size self.actor.to(default_device()) - self._clip_action = not tanh_squash @property def export_memory_size(self) -> int: @@ -104,49 +79,6 @@ def _extract_masks(self, decision_requests: DecisionSteps) -> np.ndarray: ) return mask - def update_normalization(self, buffer: AgentBuffer) -> None: - """ - If this policy normalizes vector observations, this will update the norm values in the graph. - :param buffer: The buffer with the observations to add to the running estimate - of the distribution. - """ - - if self.normalize: - self.actor.update_normalization(buffer) - - @timed - def sample_actions( - self, - obs: List[torch.Tensor], - masks: Optional[torch.Tensor] = None, - memories: Optional[torch.Tensor] = None, - seq_len: int = 1, - ) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]: - """ - :param obs: List of observations. - :param masks: Loss masks for RNN, else None. - :param memories: Input memories when using RNN, else None. - :param seq_len: Sequence length when using RNN. - :return: Tuple of AgentAction, ActionLogProbs, entropies, and output memories. - """ - actions, log_probs, entropies, memories = self.actor.get_action_and_stats( - obs, masks, memories, seq_len - ) - return (actions, log_probs, entropies, memories) - - def evaluate_actions( - self, - obs: List[torch.Tensor], - actions: AgentAction, - masks: Optional[torch.Tensor] = None, - memories: Optional[torch.Tensor] = None, - seq_len: int = 1, - ) -> Tuple[ActionLogProbs, torch.Tensor]: - log_probs, entropies = self.actor.get_stats( - obs, actions, masks, memories, seq_len - ) - return log_probs, entropies - @timed def evaluate( self, decision_requests: DecisionSteps, global_agent_ids: List[str] @@ -164,21 +96,15 @@ def evaluate( memories = torch.as_tensor(self.retrieve_memories(global_agent_ids)).unsqueeze( 0 ) - - run_out = {} with torch.no_grad(): - action, log_probs, entropy, memories = self.sample_actions( + action, run_out, memories = self.actor.get_action_and_stats( tensor_obs, masks=masks, memories=memories ) - action_tuple = action.to_action_tuple() - run_out["action"] = action_tuple - # This is the clipped action which is not saved to the buffer - # but is exclusively sent to the environment. - env_action_tuple = action.to_action_tuple(clip=self._clip_action) - run_out["env_action"] = env_action_tuple - run_out["log_probs"] = log_probs.to_log_probs_tuple() - run_out["entropy"] = ModelUtils.to_numpy(entropy) - run_out["learning_rate"] = 0.0 + run_out["action"] = action.to_action_tuple() + if "log_probs" in run_out: + run_out["log_probs"] = run_out["log_probs"].to_log_probs_tuple() + if "entropy" in run_out: + run_out["entropy"] = ModelUtils.to_numpy(run_out["entropy"]) if self.use_recurrent: run_out["memory_out"] = ModelUtils.to_numpy(memories).squeeze(0) return run_out diff --git a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py index 196d6c42b4..41a452c65c 100644 --- a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py @@ -1,4 +1,6 @@ from typing import Dict, cast +import attr + from mlagents.torch_utils import torch, default_device from mlagents.trainers.buffer import AgentBuffer, BufferKey, RewardSignalUtil @@ -6,14 +8,30 @@ from mlagents_envs.timers import timed from mlagents.trainers.policy.torch_policy import TorchPolicy from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer -from mlagents.trainers.settings import TrainerSettings, PPOSettings -from mlagents.trainers.torch.networks import ValueNetwork -from mlagents.trainers.torch.agent_action import AgentAction -from mlagents.trainers.torch.action_log_probs import ActionLogProbs -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.settings import ( + TrainerSettings, + OnPolicyHyperparamSettings, + ScheduleType, +) +from mlagents.trainers.torch_entities.networks import ValueNetwork +from mlagents.trainers.torch_entities.agent_action import AgentAction +from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs +from mlagents.trainers.torch_entities.utils import ModelUtils from mlagents.trainers.trajectory import ObsUtil +@attr.s(auto_attribs=True) +class PPOSettings(OnPolicyHyperparamSettings): + beta: float = 5.0e-3 + epsilon: float = 0.2 + lambd: float = 0.95 + num_epoch: int = 3 + shared_critic: bool = False + learning_rate_schedule: ScheduleType = ScheduleType.LINEAR + beta_schedule: ScheduleType = ScheduleType.LINEAR + epsilon_schedule: ScheduleType = ScheduleType.LINEAR + + class TorchPPOOptimizer(TorchOptimizer): def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): """ @@ -29,7 +47,12 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): reward_signal_configs = trainer_settings.reward_signals reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] - if policy.shared_critic: + self.hyperparameters: PPOSettings = cast( + PPOSettings, trainer_settings.hyperparameters + ) + + params = list(self.policy.actor.parameters()) + if self.hyperparameters.shared_critic: self._critic = policy.actor else: self._critic = ValueNetwork( @@ -38,11 +61,8 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): network_settings=trainer_settings.network_settings, ) self._critic.to(default_device()) + params += list(self._critic.parameters()) - params = list(self.policy.actor.parameters()) + list(self._critic.parameters()) - self.hyperparameters: PPOSettings = cast( - PPOSettings, trainer_settings.hyperparameters - ) self.decay_learning_rate = ModelUtils.DecayedValue( self.hyperparameters.learning_rate_schedule, self.hyperparameters.learning_rate, @@ -123,13 +143,17 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: if len(value_memories) > 0: value_memories = torch.stack(value_memories).unsqueeze(0) - log_probs, entropy = self.policy.evaluate_actions( + run_out = self.policy.actor.get_stats( current_obs, + actions, masks=act_masks, - actions=actions, memories=memories, - seq_len=self.policy.sequence_length, + sequence_length=self.policy.sequence_length, ) + + log_probs = run_out["log_probs"] + entropy = run_out["entropy"] + values, _ = self.critic.critic_pass( current_obs, memories=value_memories, @@ -170,11 +194,9 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: "Policy/Beta": decay_bet, } - for reward_provider in self.reward_signals.values(): - update_stats.update(reward_provider.update(batch)) - return update_stats + # TODO move module update into TorchOptimizer for reward_provider def get_modules(self): modules = { "Optimizer:value_optimizer": self.optimizer, diff --git a/ml-agents/mlagents/trainers/ppo/trainer.py b/ml-agents/mlagents/trainers/ppo/trainer.py index 845e9dbed2..e7421f0da1 100644 --- a/ml-agents/mlagents/trainers/ppo/trainer.py +++ b/ml-agents/mlagents/trainers/ppo/trainer.py @@ -2,26 +2,31 @@ # ## ML-Agent Learning (PPO) # Contains an implementation of PPO as described in: https://arxiv.org/abs/1707.06347 -from collections import defaultdict -from typing import cast +from typing import cast, Type, Union, Dict, Any import numpy as np -from mlagents_envs.logging_util import get_logger from mlagents_envs.base_env import BehaviorSpec +from mlagents_envs.logging_util import get_logger from mlagents.trainers.buffer import BufferKey, RewardSignalUtil -from mlagents.trainers.trainer.rl_trainer import RLTrainer -from mlagents.trainers.policy import Policy +from mlagents.trainers.trainer.on_policy_trainer import OnPolicyTrainer +from mlagents.trainers.policy.policy import Policy +from mlagents.trainers.trainer.trainer_utils import get_gae +from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer from mlagents.trainers.policy.torch_policy import TorchPolicy -from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer +from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer, PPOSettings from mlagents.trainers.trajectory import Trajectory from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers -from mlagents.trainers.settings import TrainerSettings, PPOSettings +from mlagents.trainers.settings import TrainerSettings + +from mlagents.trainers.torch_entities.networks import SimpleActor, SharedActorCritic logger = get_logger(__name__) +TRAINER_NAME = "ppo" -class PPOTrainer(RLTrainer): + +class PPOTrainer(OnPolicyTrainer): """The PPOTrainer is an implementation of the PPO algorithm.""" def __init__( @@ -46,17 +51,19 @@ def __init__( """ super().__init__( behavior_name, + reward_buff_cap, trainer_settings, training, load, + seed, artifact_path, - reward_buff_cap, ) self.hyperparameters: PPOSettings = cast( PPOSettings, self.trainer_settings.hyperparameters ) self.seed = seed - self.policy: Policy = None # type: ignore + self.shared_critic = self.hyperparameters.shared_critic + self.policy: TorchPolicy = None # type: ignore def _process_trajectory(self, trajectory: Trajectory) -> None: """ @@ -73,7 +80,7 @@ def _process_trajectory(self, trajectory: Trajectory) -> None: # Update the normalization if self.is_training: - self.policy.update_normalization(agent_buffer_trajectory) + self.policy.actor.update_normalization(agent_buffer_trajectory) self.optimizer.critic.update_normalization(agent_buffer_trajectory) # Get all value estimates @@ -157,65 +164,12 @@ def _process_trajectory(self, trajectory: Trajectory) -> None: if trajectory.done_reached: self._update_end_episode_stats(agent_id, self.optimizer) - def _is_ready_update(self): - """ - Returns whether or not the trainer has enough elements to run update model - :return: A boolean corresponding to whether or not update_model() can be run - """ - size_of_buffer = self.update_buffer.num_experiences - return size_of_buffer > self.hyperparameters.buffer_size - - def _update_policy(self): - """ - Uses demonstration_buffer to update the policy. - The reward signal generators must be updated in this method at their own pace. - """ - buffer_length = self.update_buffer.num_experiences - self.cumulative_returns_since_policy_update.clear() - - # Make sure batch_size is a multiple of sequence length. During training, we - # will need to reshape the data into a batch_size x sequence_length tensor. - batch_size = ( - self.hyperparameters.batch_size - - self.hyperparameters.batch_size % self.policy.sequence_length - ) - # Make sure there is at least one sequence - batch_size = max(batch_size, self.policy.sequence_length) - - n_sequences = max( - int(self.hyperparameters.batch_size / self.policy.sequence_length), 1 - ) - - advantages = np.array( - self.update_buffer[BufferKey.ADVANTAGES].get_batch(), dtype=np.float32 - ) - self.update_buffer[BufferKey.ADVANTAGES].set( - (advantages - advantages.mean()) / (advantages.std() + 1e-10) - ) - num_epoch = self.hyperparameters.num_epoch - batch_update_stats = defaultdict(list) - for _ in range(num_epoch): - self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) - buffer = self.update_buffer - max_num_batch = buffer_length // batch_size - for i in range(0, max_num_batch * batch_size, batch_size): - update_stats = self.optimizer.update( - buffer.make_mini_batch(i, i + batch_size), n_sequences - ) - for stat_name, value in update_stats.items(): - batch_update_stats[stat_name].append(value) - - for stat, stat_list in batch_update_stats.items(): - self._stats_reporter.add_stat(stat, np.mean(stat_list)) - - if self.optimizer.bc_module: - update_stats = self.optimizer.bc_module.update() - for stat, val in update_stats.items(): - self._stats_reporter.add_stat(stat, val) - self._clear_update_buffer() - return True + def create_optimizer(self) -> TorchOptimizer: + return TorchPPOOptimizer( # type: ignore + cast(TorchPolicy, self.policy), self.trainer_settings # type: ignore + ) # type: ignore - def create_torch_policy( + def create_policy( self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec ) -> TorchPolicy: """ @@ -224,49 +178,28 @@ def create_torch_policy( :param behavior_spec: specifications for policy construction :return policy """ + actor_cls: Union[Type[SimpleActor], Type[SharedActorCritic]] = SimpleActor + actor_kwargs: Dict[str, Any] = { + "conditional_sigma": False, + "tanh_squash": False, + } + if self.shared_critic: + reward_signal_configs = self.trainer_settings.reward_signals + reward_signal_names = [ + key.value for key, _ in reward_signal_configs.items() + ] + actor_cls = SharedActorCritic + actor_kwargs.update({"stream_names": reward_signal_names}) + policy = TorchPolicy( self.seed, behavior_spec, - self.trainer_settings, - condition_sigma_on_obs=False, # Faster training for PPO - separate_critic=True, # Match network architecture with TF + self.trainer_settings.network_settings, + actor_cls, + actor_kwargs, ) return policy - def create_ppo_optimizer(self) -> TorchPPOOptimizer: - return TorchPPOOptimizer( # type: ignore - cast(TorchPolicy, self.policy), self.trainer_settings # type: ignore - ) # type: ignore - - def add_policy( - self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy - ) -> None: - """ - Adds policy to trainer. - :param parsed_behavior_id: Behavior identifiers that the policy should belong to. - :param policy: Policy to associate with name_behavior_id. - """ - if self.policy: - logger.warning( - "Your environment contains multiple teams, but {} doesn't support adversarial games. Enable self-play to \ - train adversarial games.".format( - self.__class__.__name__ - ) - ) - self.policy = policy - self.policies[parsed_behavior_id.behavior_id] = policy - - self.optimizer = self.create_ppo_optimizer() - for _reward_signal in self.optimizer.reward_signals.keys(): - self.collected_rewards[_reward_signal] = defaultdict(lambda: 0) - - self.model_saver.register(self.policy) - self.model_saver.register(self.optimizer) - self.model_saver.initialize_or_load() - - # Needed to resume loads properly - self._step = policy.get_current_step() - def get_policy(self, name_behavior_id: str) -> Policy: """ Gets policy from trainer associated with name_behavior_id @@ -275,34 +208,6 @@ def get_policy(self, name_behavior_id: str) -> Policy: return self.policy - -def discount_rewards(r, gamma=0.99, value_next=0.0): - """ - Computes discounted sum of future rewards for use in updating value estimate. - :param r: List of rewards. - :param gamma: Discount factor. - :param value_next: T+1 value estimate for returns calculation. - :return: discounted sum of future rewards as list. - """ - discounted_r = np.zeros_like(r) - running_add = value_next - for t in reversed(range(0, r.size)): - running_add = running_add * gamma + r[t] - discounted_r[t] = running_add - return discounted_r - - -def get_gae(rewards, value_estimates, value_next=0.0, gamma=0.99, lambd=0.95): - """ - Computes generalized advantage estimate for use in updating policy. - :param rewards: list of rewards for time-steps t to T. - :param value_next: Value estimate for time-step T+1. - :param value_estimates: list of value estimates for time-steps t to T. - :param gamma: Discount factor. - :param lambd: GAE weighing factor. - :return: list of advantage estimates for time-steps t to T. - """ - value_estimates = np.append(value_estimates, value_next) - delta_t = rewards + gamma * value_estimates[1:] - value_estimates[:-1] - advantage = discount_rewards(r=delta_t, gamma=gamma * lambd) - return advantage + @staticmethod + def get_trainer_name() -> str: + return TRAINER_NAME diff --git a/ml-agents/mlagents/trainers/run_experiment.py b/ml-agents/mlagents/trainers/run_experiment.py index a372fca3b7..8544b673bc 100644 --- a/ml-agents/mlagents/trainers/run_experiment.py +++ b/ml-agents/mlagents/trainers/run_experiment.py @@ -4,6 +4,8 @@ from mlagents.trainers.settings import RunOptions from mlagents.trainers.cli_utils import load_config +from mlagents.plugins.trainer_type import register_trainer_plugins + def parse_command_line(argv: Optional[List[str]] = None) -> argparse.Namespace: parser = argparse.ArgumentParser( @@ -21,6 +23,7 @@ def main(): """ args = parse_command_line() expt_config = load_config(args.experiment_config_path) + _, _ = register_trainer_plugins() run_cli(RunOptions.from_dict(expt_config)) diff --git a/ml-agents/mlagents/trainers/sac/optimizer_torch.py b/ml-agents/mlagents/trainers/sac/optimizer_torch.py index 9cdc1df5a5..a7d566859d 100644 --- a/ml-agents/mlagents/trainers/sac/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/sac/optimizer_torch.py @@ -1,20 +1,22 @@ import numpy as np -from typing import Dict, List, Mapping, NamedTuple, cast, Tuple, Optional +from typing import Dict, List, NamedTuple, cast, Tuple, Optional +import attr + from mlagents.torch_utils import torch, nn, default_device from mlagents_envs.logging_util import get_logger from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer from mlagents.trainers.policy.torch_policy import TorchPolicy from mlagents.trainers.settings import NetworkSettings -from mlagents.trainers.torch.networks import ValueNetwork -from mlagents.trainers.torch.agent_action import AgentAction -from mlagents.trainers.torch.action_log_probs import ActionLogProbs -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.torch_entities.networks import ValueNetwork, SharedActorCritic +from mlagents.trainers.torch_entities.agent_action import AgentAction +from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs +from mlagents.trainers.torch_entities.utils import ModelUtils from mlagents.trainers.buffer import AgentBuffer, BufferKey, RewardSignalUtil from mlagents_envs.timers import timed from mlagents_envs.base_env import ActionSpec, ObservationSpec from mlagents.trainers.exception import UnityTrainerException -from mlagents.trainers.settings import TrainerSettings, SACSettings +from mlagents.trainers.settings import TrainerSettings, OffPolicyHyperparamSettings from contextlib import ExitStack from mlagents.trainers.trajectory import ObsUtil @@ -23,6 +25,22 @@ logger = get_logger(__name__) +@attr.s(auto_attribs=True) +class SACSettings(OffPolicyHyperparamSettings): + batch_size: int = 128 + buffer_size: int = 50000 + buffer_init_steps: int = 0 + tau: float = 0.005 + steps_per_update: float = 1 + save_replay_buffer: bool = False + init_entcoef: float = 1.0 + reward_signal_steps_per_update: float = attr.ib() + + @reward_signal_steps_per_update.default + def _reward_signal_steps_per_update_default(self): + return self.steps_per_update + + class TorchSACOptimizer(TorchOptimizer): class PolicyValueNetwork(nn.Module): def __init__( @@ -105,19 +123,21 @@ def __init__(self, discrete, continuous): self.discrete = discrete self.continuous = continuous - def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): - super().__init__(policy, trainer_params) - reward_signal_configs = trainer_params.reward_signals + def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): + super().__init__(policy, trainer_settings) + reward_signal_configs = trainer_settings.reward_signals reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] - if policy.shared_critic: + if isinstance(policy.actor, SharedActorCritic): raise UnityTrainerException("SAC does not support SharedActorCritic") self._critic = ValueNetwork( reward_signal_names, policy.behavior_spec.observation_specs, policy.network_settings, ) + hyperparameters: SACSettings = cast( + SACSettings, trainer_settings.hyperparameters + ) - hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters) self.tau = hyperparameters.tau self.init_entcoef = hyperparameters.init_entcoef @@ -133,7 +153,7 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): self.stream_names = list(self.reward_signals.keys()) # Use to reduce "survivor bonus" when using Curiosity or GAIL. - self.gammas = [_val.gamma for _val in trainer_params.reward_signals.values()] + self.gammas = [_val.gamma for _val in trainer_settings.reward_signals.values()] self.use_dones_in_backup = { name: int(not self.reward_signals[name].ignore_done) for name in self.stream_names @@ -521,12 +541,13 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: self.policy.actor.network_body ) self._critic.network_body.copy_normalization(self.policy.actor.network_body) - sampled_actions, log_probs, _, _, = self.policy.actor.get_action_and_stats( + sampled_actions, run_out, _, = self.policy.actor.get_action_and_stats( current_obs, masks=act_masks, memories=memories, sequence_length=self.policy.sequence_length, ) + log_probs = run_out["log_probs"] value_estimates, _ = self._critic.critic_pass( current_obs, value_memories, sequence_length=self.policy.sequence_length ) @@ -584,11 +605,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks) entropy_loss = self.sac_entropy_loss(log_probs, masks) - total_value_loss = q1_loss + q2_loss - if self.policy.shared_critic: - policy_loss += value_loss - else: - total_value_loss += value_loss + total_value_loss = q1_loss + q2_loss + value_loss decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) @@ -624,14 +641,6 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: return update_stats - def update_reward_signals( - self, reward_signal_minibatches: Mapping[str, AgentBuffer], num_sequences: int - ) -> Dict[str, float]: - update_stats: Dict[str, float] = {} - for name, update_buffer in reward_signal_minibatches.items(): - update_stats.update(self.reward_signals[name].update(update_buffer)) - return update_stats - def get_modules(self): modules = { "Optimizer:q_network": self.q_network, diff --git a/ml-agents/mlagents/trainers/sac/trainer.py b/ml-agents/mlagents/trainers/sac/trainer.py index 9842697e15..56860c7381 100644 --- a/ml-agents/mlagents/trainers/sac/trainer.py +++ b/ml-agents/mlagents/trainers/sac/trainer.py @@ -2,31 +2,32 @@ # Contains an implementation of SAC as described in https://arxiv.org/abs/1801.01290 # and implemented in https://github.com/hill-a/stable-baselines -from collections import defaultdict -from typing import Dict, cast -import os +from typing import cast import numpy as np -from mlagents.trainers.policy.checkpoint_manager import ModelCheckpoint from mlagents_envs.logging_util import get_logger -from mlagents_envs.timers import timed from mlagents_envs.base_env import BehaviorSpec -from mlagents.trainers.buffer import BufferKey, RewardSignalUtil -from mlagents.trainers.policy import Policy -from mlagents.trainers.trainer.rl_trainer import RLTrainer +from mlagents.trainers.buffer import BufferKey +from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer +from mlagents.trainers.trainer.off_policy_trainer import OffPolicyTrainer from mlagents.trainers.policy.torch_policy import TorchPolicy -from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer +from mlagents.trainers.policy.policy import Policy +from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer, SACSettings from mlagents.trainers.trajectory import Trajectory, ObsUtil from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers -from mlagents.trainers.settings import TrainerSettings, SACSettings +from mlagents.trainers.settings import TrainerSettings + +from mlagents.trainers.torch_entities.networks import SimpleActor logger = get_logger(__name__) BUFFER_TRUNCATE_PERCENT = 0.8 +TRAINER_NAME = "sac" + -class SACTrainer(RLTrainer): +class SACTrainer(OffPolicyTrainer): """ The SACTrainer is an implementation of the SAC algorithm, with support for discrete actions and recurrent networks. @@ -54,15 +55,16 @@ def __init__( """ super().__init__( behavior_name, + reward_buff_cap, trainer_settings, training, load, + seed, artifact_path, - reward_buff_cap, ) self.seed = seed - self.policy: Policy = None # type: ignore + self.policy: TorchPolicy = None # type: ignore self.optimizer: TorchSACOptimizer = None # type: ignore self.hyperparameters: SACSettings = cast( SACSettings, trainer_settings.hyperparameters @@ -80,51 +82,6 @@ def __init__( self.checkpoint_replay_buffer = self.hyperparameters.save_replay_buffer - def _checkpoint(self) -> ModelCheckpoint: - """ - Writes a checkpoint model to memory - Overrides the default to save the replay buffer. - """ - ckpt = super()._checkpoint() - if self.checkpoint_replay_buffer: - self.save_replay_buffer() - return ckpt - - def save_model(self) -> None: - """ - Saves the final training model to memory - Overrides the default to save the replay buffer. - """ - super().save_model() - if self.checkpoint_replay_buffer: - self.save_replay_buffer() - - def save_replay_buffer(self) -> None: - """ - Save the training buffer's update buffer to a pickle file. - """ - filename = os.path.join(self.artifact_path, "last_replay_buffer.hdf5") - logger.info(f"Saving Experience Replay Buffer to {filename}...") - with open(filename, "wb") as file_object: - self.update_buffer.save_to_file(file_object) - logger.info( - f"Saved Experience Replay Buffer ({os.path.getsize(filename)} bytes)." - ) - - def load_replay_buffer(self) -> None: - """ - Loads the last saved replay buffer from a file. - """ - filename = os.path.join(self.artifact_path, "last_replay_buffer.hdf5") - logger.info(f"Loading Experience Replay Buffer from {filename}...") - with open(filename, "rb+") as file_object: - self.update_buffer.load_from_file(file_object) - logger.debug( - "Experience replay buffer has {} experiences.".format( - self.update_buffer.num_experiences - ) - ) - def _process_trajectory(self, trajectory: Trajectory) -> None: """ Takes a trajectory and processes it, putting it into the replay buffer. @@ -139,7 +96,7 @@ def _process_trajectory(self, trajectory: Trajectory) -> None: # Update the normalization if self.is_training: - self.policy.update_normalization(agent_buffer_trajectory) + self.policy.actor.update_normalization(agent_buffer_trajectory) self.optimizer.critic.update_normalization(agent_buffer_trajectory) # Evaluate all reward functions for reporting purposes @@ -184,44 +141,12 @@ def _process_trajectory(self, trajectory: Trajectory) -> None: if trajectory.done_reached: self._update_end_episode_stats(agent_id, self.optimizer) - def _is_ready_update(self) -> bool: - """ - Returns whether or not the trainer has enough elements to run update model - :return: A boolean corresponding to whether or not _update_policy() can be run - """ - return ( - self.update_buffer.num_experiences >= self.hyperparameters.batch_size - and self._step >= self.hyperparameters.buffer_init_steps - ) - - @timed - def _update_policy(self) -> bool: - """ - Update the SAC policy and reward signals. The reward signal generators are updated using different mini batches. - By default we imitate http://arxiv.org/abs/1809.02925 and similar papers, where the policy is updated - N times, then the reward signals are updated N times. - :return: Whether or not the policy was updated. - """ - policy_was_updated = self._update_sac_policy() - self._update_reward_signals() - return policy_was_updated - - def maybe_load_replay_buffer(self): - # Load the replay buffer if load - if self.load and self.checkpoint_replay_buffer: - try: - self.load_replay_buffer() - except (AttributeError, FileNotFoundError): - logger.warning( - "Replay buffer was unable to load, starting from scratch." - ) - logger.debug( - "Loaded update buffer with {} sequences".format( - self.update_buffer.num_experiences - ) - ) + def create_optimizer(self) -> TorchOptimizer: + return TorchSACOptimizer( # type: ignore + cast(TorchPolicy, self.policy), self.trainer_settings # type: ignore + ) # type: ignore - def create_torch_policy( + def create_policy( self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec ) -> TorchPolicy: """ @@ -230,141 +155,19 @@ def create_torch_policy( :param behavior_spec: specifications for policy construction :return policy """ + actor_cls = SimpleActor + actor_kwargs = {"conditional_sigma": True, "tanh_squash": True} + policy = TorchPolicy( self.seed, behavior_spec, - self.trainer_settings, - condition_sigma_on_obs=True, - tanh_squash=True, - separate_critic=True, + self.trainer_settings.network_settings, + actor_cls, + actor_kwargs, ) self.maybe_load_replay_buffer() return policy - def _update_sac_policy(self) -> bool: - """ - Uses update_buffer to update the policy. We sample the update_buffer and update - until the steps_per_update ratio is met. - """ - has_updated = False - self.cumulative_returns_since_policy_update.clear() - n_sequences = max( - int(self.hyperparameters.batch_size / self.policy.sequence_length), 1 - ) - - batch_update_stats: Dict[str, list] = defaultdict(list) - while ( - self._step - self.hyperparameters.buffer_init_steps - ) / self.update_steps > self.steps_per_update: - logger.debug(f"Updating SAC policy at step {self._step}") - buffer = self.update_buffer - if self.update_buffer.num_experiences >= self.hyperparameters.batch_size: - sampled_minibatch = buffer.sample_mini_batch( - self.hyperparameters.batch_size, - sequence_length=self.policy.sequence_length, - ) - # Get rewards for each reward - for name, signal in self.optimizer.reward_signals.items(): - sampled_minibatch[RewardSignalUtil.rewards_key(name)] = ( - signal.evaluate(sampled_minibatch) * signal.strength - ) - - update_stats = self.optimizer.update(sampled_minibatch, n_sequences) - for stat_name, value in update_stats.items(): - batch_update_stats[stat_name].append(value) - - self.update_steps += 1 - - for stat, stat_list in batch_update_stats.items(): - self._stats_reporter.add_stat(stat, np.mean(stat_list)) - has_updated = True - - if self.optimizer.bc_module: - update_stats = self.optimizer.bc_module.update() - for stat, val in update_stats.items(): - self._stats_reporter.add_stat(stat, val) - - # Truncate update buffer if neccessary. Truncate more than we need to to avoid truncating - # a large buffer at each update. - if self.update_buffer.num_experiences > self.hyperparameters.buffer_size: - self.update_buffer.truncate( - int(self.hyperparameters.buffer_size * BUFFER_TRUNCATE_PERCENT) - ) - return has_updated - - def _update_reward_signals(self) -> None: - """ - Iterate through the reward signals and update them. Unlike in PPO, - do it separate from the policy so that it can be done at a different - interval. - This function should only be used to simulate - http://arxiv.org/abs/1809.02925 and similar papers, where the policy is updated - N times, then the reward signals are updated N times. Normally, the reward signal - and policy are updated in parallel. - """ - buffer = self.update_buffer - n_sequences = max( - int(self.hyperparameters.batch_size / self.policy.sequence_length), 1 - ) - batch_update_stats: Dict[str, list] = defaultdict(list) - while ( - self._step - self.hyperparameters.buffer_init_steps - ) / self.reward_signal_update_steps > self.reward_signal_steps_per_update: - # Get minibatches for reward signal update if needed - reward_signal_minibatches = {} - for name in self.optimizer.reward_signals.keys(): - logger.debug(f"Updating {name} at step {self._step}") - if name != "extrinsic": - reward_signal_minibatches[name] = buffer.sample_mini_batch( - self.hyperparameters.batch_size, - sequence_length=self.policy.sequence_length, - ) - update_stats = self.optimizer.update_reward_signals( - reward_signal_minibatches, n_sequences - ) - for stat_name, value in update_stats.items(): - batch_update_stats[stat_name].append(value) - self.reward_signal_update_steps += 1 - - for stat, stat_list in batch_update_stats.items(): - self._stats_reporter.add_stat(stat, np.mean(stat_list)) - - def create_sac_optimizer(self) -> TorchSACOptimizer: - return TorchSACOptimizer( # type: ignore - cast(TorchPolicy, self.policy), self.trainer_settings # type: ignore - ) # type: ignore - - def add_policy( - self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy - ) -> None: - """ - Adds policy to trainer. - """ - if self.policy: - logger.warning( - "Your environment contains multiple teams, but {} doesn't support adversarial games. Enable self-play to \ - train adversarial games.".format( - self.__class__.__name__ - ) - ) - self.policy = policy - self.policies[parsed_behavior_id.behavior_id] = policy - self.optimizer = self.create_sac_optimizer() - for _reward_signal in self.optimizer.reward_signals.keys(): - self.collected_rewards[_reward_signal] = defaultdict(lambda: 0) - - self.model_saver.register(self.policy) - self.model_saver.register(self.optimizer) - self.model_saver.initialize_or_load() - - # Needed to resume loads properly - self._step = policy.get_current_step() - # Assume steps were updated at the correct ratio before - self.update_steps = int(max(1, self._step / self.steps_per_update)) - self.reward_signal_update_steps = int( - max(1, self._step / self.reward_signal_steps_per_update) - ) - def get_policy(self, name_behavior_id: str) -> Policy: """ Gets policy from trainer associated with name_behavior_id @@ -372,3 +175,7 @@ def get_policy(self, name_behavior_id: str) -> Policy: """ return self.policy + + @staticmethod + def get_trainer_name() -> str: + return TRAINER_NAME diff --git a/ml-agents/mlagents/trainers/settings.py b/ml-agents/mlagents/trainers/settings.py index 12200d4b89..7cff991ba2 100644 --- a/ml-agents/mlagents/trainers/settings.py +++ b/ml-agents/mlagents/trainers/settings.py @@ -30,6 +30,7 @@ from mlagents_envs.side_channel.environment_parameters_channel import ( EnvironmentParametersChannel, ) +from mlagents.plugins import all_trainer_settings, all_trainer_types logger = logging_util.get_logger(__name__) @@ -44,23 +45,9 @@ def check_and_structure(key: str, value: Any, class_type: type) -> Any: return cattr.structure(value, attr_fields_dict[key].type) -class TrainerType(Enum): - PPO: str = "ppo" - SAC: str = "sac" - POCA: str = "poca" - - def to_settings(self) -> type: - _mapping = { - TrainerType.PPO: PPOSettings, - TrainerType.SAC: SACSettings, - TrainerType.POCA: POCASettings, - } - return _mapping[self] - - -def check_hyperparam_schedules(val: Dict, trainer_type: TrainerType) -> Dict: +def check_hyperparam_schedules(val: Dict, trainer_type: str) -> Dict: # Check if beta and epsilon are set. If not, set to match learning rate schedule. - if trainer_type is TrainerType.PPO or trainer_type is TrainerType.POCA: + if trainer_type == "ppo" or trainer_type == "poca": if "beta_schedule" not in val.keys() and "learning_rate_schedule" in val.keys(): val["beta_schedule"] = val["learning_rate_schedule"] if ( @@ -175,34 +162,18 @@ class HyperparamSettings: @attr.s(auto_attribs=True) -class PPOSettings(HyperparamSettings): - beta: float = 5.0e-3 - epsilon: float = 0.2 - lambd: float = 0.95 +class OnPolicyHyperparamSettings(HyperparamSettings): num_epoch: int = 3 - learning_rate_schedule: ScheduleType = ScheduleType.LINEAR - beta_schedule: ScheduleType = ScheduleType.LINEAR - epsilon_schedule: ScheduleType = ScheduleType.LINEAR @attr.s(auto_attribs=True) -class SACSettings(HyperparamSettings): +class OffPolicyHyperparamSettings(HyperparamSettings): batch_size: int = 128 buffer_size: int = 50000 buffer_init_steps: int = 0 - tau: float = 0.005 steps_per_update: float = 1 save_replay_buffer: bool = False - init_entcoef: float = 1.0 - reward_signal_steps_per_update: float = attr.ib() - - @reward_signal_steps_per_update.default - def _reward_signal_steps_per_update_default(self): - return self.steps_per_update - - -# POCA uses the same hyperparameters as PPO -POCASettings = PPOSettings + reward_signal_steps_per_update: float = 4 # INTRINSIC REWARD SIGNALS ############################################################# @@ -643,12 +614,12 @@ def _team_change_default(self): @attr.s(auto_attribs=True) class TrainerSettings(ExportableSettings): default_override: ClassVar[Optional["TrainerSettings"]] = None - trainer_type: TrainerType = TrainerType.PPO + trainer_type: str = "ppo" hyperparameters: HyperparamSettings = attr.ib() @hyperparameters.default def _set_default_hyperparameters(self): - return self.trainer_type.to_settings()() + return all_trainer_settings[self.trainer_type]() network_settings: NetworkSettings = attr.ib(factory=NetworkSettings) reward_signals: Dict[RewardSignalType, RewardSignalSettings] = attr.ib( @@ -722,12 +693,20 @@ def structure(d: Mapping, t: type) -> Any: d_copy[key] = check_hyperparam_schedules( val, d_copy["trainer_type"] ) - d_copy[key] = strict_to_cls( - d_copy[key], TrainerType(d_copy["trainer_type"]).to_settings() - ) + try: + d_copy[key] = strict_to_cls( + d_copy[key], all_trainer_settings[d_copy["trainer_type"]] + ) + except KeyError: + raise TrainerConfigError( + f"Settings for trainer type {d_copy['trainer_type']} were not found" + ) elif key == "max_steps": d_copy[key] = int(float(val)) # In some legacy configs, max steps was specified as a float + elif key == "trainer_type": + if val not in all_trainer_types.keys(): + raise TrainerConfigError(f"Invalid trainer type {val} was found") else: d_copy[key] = check_and_structure(key, val, t) return t(**d_copy) @@ -968,7 +947,9 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions": return final_runoptions @staticmethod - def from_dict(options_dict: Dict[str, Any]) -> "RunOptions": + def from_dict( + options_dict: Dict[str, Any], + ) -> "RunOptions": # If a default settings was specified, set the TrainerSettings class override if ( "default_settings" in options_dict.keys() diff --git a/ml-agents/mlagents/trainers/tests/dummy_config.py b/ml-agents/mlagents/trainers/tests/dummy_config.py index a7f74afa26..e4a71b0ee1 100644 --- a/ml-agents/mlagents/trainers/tests/dummy_config.py +++ b/ml-agents/mlagents/trainers/tests/dummy_config.py @@ -4,24 +4,24 @@ import copy import os from mlagents.trainers.settings import ( - POCASettings, TrainerSettings, - PPOSettings, - SACSettings, GAILSettings, CuriositySettings, RewardSignalSettings, NetworkSettings, - TrainerType, RewardSignalType, ScheduleType, ) +from mlagents.trainers.ppo.optimizer_torch import PPOSettings +from mlagents.trainers.sac.optimizer_torch import SACSettings +from mlagents.trainers.poca.optimizer_torch import POCASettings + CONTINUOUS_DEMO_PATH = os.path.dirname(os.path.abspath(__file__)) + "/test.demo" DISCRETE_DEMO_PATH = os.path.dirname(os.path.abspath(__file__)) + "/testdcvis.demo" _PPO_CONFIG = TrainerSettings( - trainer_type=TrainerType.PPO, + trainer_type="ppo", hyperparameters=PPOSettings( learning_rate=5.0e-3, learning_rate_schedule=ScheduleType.CONSTANT, @@ -35,7 +35,7 @@ ) _SAC_CONFIG = TrainerSettings( - trainer_type=TrainerType.SAC, + trainer_type="sac", hyperparameters=SACSettings( learning_rate=5.0e-3, learning_rate_schedule=ScheduleType.CONSTANT, @@ -52,7 +52,7 @@ ) _POCA_CONFIG = TrainerSettings( - trainer_type=TrainerType.POCA, + trainer_type="poca", hyperparameters=POCASettings( learning_rate=5.0e-3, learning_rate_schedule=ScheduleType.CONSTANT, diff --git a/ml-agents/mlagents/trainers/tests/mock_brain.py b/ml-agents/mlagents/trainers/tests/mock_brain.py index 6658e6bb43..c3857429de 100644 --- a/ml-agents/mlagents/trainers/tests/mock_brain.py +++ b/ml-agents/mlagents/trainers/tests/mock_brain.py @@ -2,7 +2,7 @@ import numpy as np from mlagents.trainers.buffer import AgentBuffer, AgentBufferKey -from mlagents.trainers.torch.action_log_probs import LogProbsTuple +from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple from mlagents.trainers.trajectory import AgentStatus, Trajectory, AgentExperience from mlagents_envs.base_env import ( DecisionSteps, @@ -87,6 +87,7 @@ def make_fake_trajectory( num_other_agents_in_group: int = 0, group_reward: float = 0.0, is_terminal: bool = True, + team_id: int = 0, ) -> Trajectory: """ Makes a fake trajectory of length length. If max_step_complete, @@ -125,7 +126,7 @@ def make_fake_trajectory( max_step = False memory = np.ones(memory_size, dtype=np.float32) agent_id = "test_agent" - behavior_id = "test_brain" + behavior_id = "test_brain?team=" + str(team_id) group_status = [] for _ in range(num_other_agents_in_group): group_status.append(AgentStatus(obs, reward, action, done)) diff --git a/ml-agents/mlagents/trainers/tests/results/ppo/run_logs/training_status.json b/ml-agents/mlagents/trainers/tests/results/ppo/run_logs/training_status.json index 16291b488f..04eaa876ae 100644 --- a/ml-agents/mlagents/trainers/tests/results/ppo/run_logs/training_status.json +++ b/ml-agents/mlagents/trainers/tests/results/ppo/run_logs/training_status.json @@ -10,7 +10,7 @@ }, "metadata": { "stats_format_version": "0.3.0", - "mlagents_version": "0.29.0.dev0", + "mlagents_version": "0.29.0", "torch_version": "1.8.1" } -} \ No newline at end of file +} diff --git a/ml-agents/mlagents/trainers/tests/test_agent_processor.py b/ml-agents/mlagents/trainers/tests/test_agent_processor.py index 4d38b42f32..43e1008f02 100644 --- a/ml-agents/mlagents/trainers/tests/test_agent_processor.py +++ b/ml-agents/mlagents/trainers/tests/test_agent_processor.py @@ -9,7 +9,7 @@ AgentManagerQueue, ) from mlagents.trainers.action_info import ActionInfo -from mlagents.trainers.torch.action_log_probs import LogProbsTuple +from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple from mlagents.trainers.trajectory import Trajectory from mlagents.trainers.stats import StatsReporter, StatsSummary from mlagents.trainers.behavior_id_utils import get_global_agent_id diff --git a/ml-agents/mlagents/trainers/tests/test_config_conversion.py b/ml-agents/mlagents/trainers/tests/test_config_conversion.py index 557da132a5..b9a49f6cb2 100644 --- a/ml-agents/mlagents/trainers/tests/test_config_conversion.py +++ b/ml-agents/mlagents/trainers/tests/test_config_conversion.py @@ -2,12 +2,10 @@ import pytest from mlagents.trainers.upgrade_config import convert_behaviors, remove_nones, convert -from mlagents.trainers.settings import ( - TrainerType, - PPOSettings, - SACSettings, - RewardSignalType, -) +from mlagents.trainers.settings import RewardSignalType +from mlagents.trainers.ppo.trainer import PPOSettings, TRAINER_NAME as PPO_TRAINER_NAME +from mlagents.trainers.sac.trainer import SACSettings, TRAINER_NAME as SAC_TRAINER_NAME + BRAIN_NAME = "testbehavior" @@ -162,9 +160,9 @@ @pytest.mark.parametrize("use_recurrent", [True, False]) -@pytest.mark.parametrize("trainer_type", [TrainerType.PPO, TrainerType.SAC]) +@pytest.mark.parametrize("trainer_type", [PPO_TRAINER_NAME, SAC_TRAINER_NAME]) def test_convert_behaviors(trainer_type, use_recurrent): - if trainer_type == TrainerType.PPO: + if trainer_type == PPO_TRAINER_NAME: trainer_config = PPO_CONFIG trainer_settings_type = PPOSettings else: diff --git a/ml-agents/mlagents/trainers/tests/test_rl_trainer.py b/ml-agents/mlagents/trainers/tests/test_rl_trainer.py index fb346245f1..89b5ae9810 100644 --- a/ml-agents/mlagents/trainers/tests/test_rl_trainer.py +++ b/ml-agents/mlagents/trainers/tests/test_rl_trainer.py @@ -8,6 +8,10 @@ from mlagents.trainers.tests.test_buffer import construct_fake_buffer from mlagents.trainers.agent_processor import AgentManagerQueue from mlagents.trainers.settings import TrainerSettings +from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer +from mlagents_envs.base_env import BehaviorSpec +from mlagents.trainers.policy import Policy +from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes from mlagents_envs.base_env import ActionSpec import os.path @@ -43,7 +47,14 @@ def checkpoint_path(brain_name, step): mock_model_saver.save_checkpoint.side_effect = checkpoint_path self.model_saver = mock_model_saver - def create_torch_policy(self, parsed_behavior_id, behavior_spec): + def create_optimizer(self) -> TorchOptimizer: + return mock.Mock() + + def create_policy( + self, + parsed_behavior_id: BehaviorIdentifiers, + behavior_spec: BehaviorSpec, + ) -> Policy: return mock.Mock() def _process_trajectory(self, trajectory): diff --git a/ml-agents/mlagents/trainers/tests/test_settings.py b/ml-agents/mlagents/trainers/tests/test_settings.py index 9d14fc5cc7..b700147879 100644 --- a/ml-agents/mlagents/trainers/tests/test_settings.py +++ b/ml-agents/mlagents/trainers/tests/test_settings.py @@ -10,8 +10,6 @@ RunOptions, TrainerSettings, NetworkSettings, - PPOSettings, - SACSettings, RewardSignalType, RewardSignalSettings, CuriositySettings, @@ -21,19 +19,25 @@ UniformSettings, GaussianSettings, MultiRangeUniformSettings, - TrainerType, deep_update_dict, strict_to_cls, ScheduleType, ) +from mlagents.trainers.ppo.trainer import PPOSettings, TRAINER_NAME as PPO_TRAINER_NAME +from mlagents.trainers.sac.trainer import SACSettings, TRAINER_NAME as SAC_TRAINER_NAME + from mlagents.trainers.exception import TrainerConfigError +TRAINER_SETTING_TYPES = {"ppo": PPOSettings, "sac": SACSettings} + def check_if_different(testobj1: object, testobj2: object) -> None: assert testobj1 is not testobj2 if attr.has(testobj1.__class__) and attr.has(testobj2.__class__): for key, val in attr.asdict(testobj1, recurse=False).items(): - if isinstance(val, dict) or isinstance(val, list) or attr.has(val): + if ( + isinstance(val, dict) or isinstance(val, list) or attr.has(val) + ) and val != {}: # Note: this check doesn't check the contents of mutables. check_if_different(val, attr.asdict(testobj2, recurse=False)[key]) @@ -121,19 +125,20 @@ def test_trainersettings_structure(): Test structuring method for TrainerSettings """ trainersettings_dict = { - "trainer_type": "sac", + "trainer_type": SAC_TRAINER_NAME, "hyperparameters": {"batch_size": 1024}, "max_steps": 1.0, "reward_signals": {"curiosity": {"encoding_size": 64}}, } trainer_settings = TrainerSettings.structure(trainersettings_dict, TrainerSettings) + # check_trainer_setting_types([trainer_settings], TRAINER_SETTING_TYPES) assert isinstance(trainer_settings.hyperparameters, SACSettings) - assert trainer_settings.trainer_type == TrainerType.SAC + assert trainer_settings.trainer_type == SAC_TRAINER_NAME assert isinstance(trainer_settings.max_steps, int) assert RewardSignalType.CURIOSITY in trainer_settings.reward_signals # Check invalid trainer type - with pytest.raises(ValueError): + with pytest.raises(TrainerConfigError): trainersettings_dict = { "trainer_type": "puppo", "hyperparameters": {"batch_size": 1024}, @@ -144,7 +149,7 @@ def test_trainersettings_structure(): # Check invalid hyperparameter with pytest.raises(TrainerConfigError): trainersettings_dict = { - "trainer_type": "ppo", + "trainer_type": PPO_TRAINER_NAME, "hyperparameters": {"notahyperparam": 1024}, "max_steps": 1.0, } @@ -166,7 +171,7 @@ def test_trainersettingsschedules_structure(): Test structuring method for Trainer Settings Schedule """ trainersettings_dict = { - "trainer_type": "ppo", + "trainer_type": PPO_TRAINER_NAME, "hyperparameters": { "learning_rate_schedule": "linear", "beta_schedule": "constant", @@ -550,7 +555,9 @@ def test_default_settings(): # Change the overridden fields back, and check if the rest are equal. test1_settings.max_steps = 1 - test1_settings.network_settings.hidden_units == default_settings_cls.network_settings.hidden_units + test1_settings.network_settings.hidden_units = ( + default_settings_cls.network_settings.hidden_units + ) check_if_different(test1_settings, default_settings_cls) diff --git a/ml-agents/mlagents/trainers/tests/test_trainer_controller.py b/ml-agents/mlagents/trainers/tests/test_trainer_controller.py index 0db44e6ba1..f0d5d79331 100644 --- a/ml-agents/mlagents/trainers/tests/test_trainer_controller.py +++ b/ml-agents/mlagents/trainers/tests/test_trainer_controller.py @@ -136,4 +136,4 @@ def test_advance_adds_experiences_to_trainer_and_trains( env_mock.get_steps.assert_called_once() env_mock.process_steps.assert_called_once() # May have been called many times due to thread - trainer_mock.advance.call_count > 0 + # assert trainer_mock.advance.call_count > 0 diff --git a/ml-agents/mlagents/trainers/tests/test_trainer_util.py b/ml-agents/mlagents/trainers/tests/test_trainer_util.py index eeaaaa7b94..3d90e46fb0 100644 --- a/ml-agents/mlagents/trainers/tests/test_trainer_util.py +++ b/ml-agents/mlagents/trainers/tests/test_trainer_util.py @@ -42,7 +42,7 @@ def mock_constructor( trainer_settings, training, load, - seed, + p_seed, artifact_path, ): assert brain == brain_name @@ -50,7 +50,7 @@ def mock_constructor( assert reward_buff_cap == expected_reward_buff_cap assert training == train_model assert load == load_model - assert seed == seed + assert p_seed == seed assert artifact_path == os.path.join(output_path, brain_name) with patch.object(PPOTrainer, "__init__", mock_constructor): diff --git a/ml-agents/mlagents/trainers/tests/test_trainers.py b/ml-agents/mlagents/trainers/tests/test_trainers.py index 55694211e1..b64b0973b2 100644 --- a/ml-agents/mlagents/trainers/tests/test_trainers.py +++ b/ml-agents/mlagents/trainers/tests/test_trainers.py @@ -67,9 +67,9 @@ def test_ppo_trainer_update_normalization(ppo_config): trajectory_queue0.put(trajectory) # mocking out update_normalization in both the policy and critic with patch( - "mlagents.trainers.torch.networks.ValueNetwork.update_normalization" + "mlagents.trainers.torch_entities.networks.ValueNetwork.update_normalization" ) as optimizer_update_normalization_mock, patch( - "mlagents.trainers.policy.torch_policy.TorchPolicy.update_normalization" + "mlagents.trainers.torch_entities.networks.SimpleActor.update_normalization" ) as policy_update_normalization_mock: ppo_trainer.advance() optimizer_update_normalization_mock.assert_called_once() @@ -111,9 +111,9 @@ def test_sac_trainer_update_normalization(sac_config): trajectory_queue0.put(trajectory) # mocking out update_normalization in both the policy and critic with patch( - "mlagents.trainers.torch.networks.ValueNetwork.update_normalization" + "mlagents.trainers.torch_entities.networks.ValueNetwork.update_normalization" ) as optimizer_update_normalization_mock, patch( - "mlagents.trainers.policy.torch_policy.TorchPolicy.update_normalization" + "mlagents.trainers.torch_entities.networks.SimpleActor.update_normalization" ) as policy_update_normalization_mock: sac_trainer.advance() optimizer_update_normalization_mock.assert_called_once() @@ -157,7 +157,7 @@ def test_poca_trainer_update_normalization(poca_config): with patch( "mlagents.trainers.poca.optimizer_torch.TorchPOCAOptimizer.POCAValueNetwork.update_normalization" ) as optimizer_update_normalization_mock, patch( - "mlagents.trainers.policy.torch_policy.TorchPolicy.update_normalization" + "mlagents.trainers.torch_entities.networks.SimpleActor.update_normalization" ) as policy_update_normalization_mock: poca_trainer.advance() optimizer_update_normalization_mock.assert_called_once() diff --git a/ml-agents/mlagents/trainers/tests/torch/test_policy.py b/ml-agents/mlagents/trainers/tests/torch/test_policy.py deleted file mode 100644 index 995209e2c9..0000000000 --- a/ml-agents/mlagents/trainers/tests/torch/test_policy.py +++ /dev/null @@ -1,144 +0,0 @@ -import pytest - -from mlagents.torch_utils import torch -from mlagents.trainers.policy.torch_policy import TorchPolicy -from mlagents.trainers.tests import mock_brain as mb -from mlagents.trainers.settings import TrainerSettings, NetworkSettings -from mlagents.trainers.torch.utils import ModelUtils -from mlagents.trainers.trajectory import ObsUtil -from mlagents.trainers.torch.agent_action import AgentAction -from mlagents.trainers.buffer import BufferKey - -VECTOR_ACTION_SPACE = 2 -VECTOR_OBS_SPACE = 8 -DISCRETE_ACTION_SPACE = [3, 3, 3, 2] -BUFFER_INIT_SAMPLES = 32 -NUM_AGENTS = 12 -EPSILON = 1e-7 - - -def create_policy_mock( - dummy_config: TrainerSettings, - use_rnn: bool = False, - use_discrete: bool = True, - use_visual: bool = False, - seed: int = 0, -) -> TorchPolicy: - mock_spec = mb.setup_test_behavior_specs( - use_discrete, - use_visual, - vector_action_space=DISCRETE_ACTION_SPACE - if use_discrete - else VECTOR_ACTION_SPACE, - vector_obs_space=VECTOR_OBS_SPACE, - ) - - trainer_settings = dummy_config - trainer_settings.keep_checkpoints = 3 - trainer_settings.network_settings.memory = ( - NetworkSettings.MemorySettings() if use_rnn else None - ) - policy = TorchPolicy(seed, mock_spec, trainer_settings) - return policy - - -@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) -@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) -@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) -def test_policy_evaluate(rnn, visual, discrete): - # Test evaluate - policy = create_policy_mock( - TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual - ) - decision_step, terminal_step = mb.create_steps_from_behavior_spec( - policy.behavior_spec, num_agents=NUM_AGENTS - ) - - run_out = policy.evaluate(decision_step, list(decision_step.agent_id)) - if discrete: - run_out["action"].discrete.shape == (NUM_AGENTS, len(DISCRETE_ACTION_SPACE)) - else: - assert run_out["action"].continuous.shape == (NUM_AGENTS, VECTOR_ACTION_SPACE) - - -@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) -@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) -@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) -def test_evaluate_actions(rnn, visual, discrete): - policy = create_policy_mock( - TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual - ) - buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size) - act_masks = ModelUtils.list_to_tensor(buffer[BufferKey.ACTION_MASK]) - agent_action = AgentAction.from_buffer(buffer) - np_obs = ObsUtil.from_buffer(buffer, len(policy.behavior_spec.observation_specs)) - tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs] - - memories = [ - ModelUtils.list_to_tensor(buffer[BufferKey.MEMORY][i]) - for i in range(0, len(buffer[BufferKey.MEMORY]), policy.sequence_length) - ] - if len(memories) > 0: - memories = torch.stack(memories).unsqueeze(0) - - log_probs, entropy = policy.evaluate_actions( - tensor_obs, - masks=act_masks, - actions=agent_action, - memories=memories, - seq_len=policy.sequence_length, - ) - if discrete: - _size = policy.behavior_spec.action_spec.discrete_size - else: - _size = policy.behavior_spec.action_spec.continuous_size - - assert log_probs.flatten().shape == (64, _size) - assert entropy.shape == (64,) - - -@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) -@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) -@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) -def test_sample_actions(rnn, visual, discrete): - policy = create_policy_mock( - TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual - ) - buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size) - act_masks = ModelUtils.list_to_tensor(buffer[BufferKey.ACTION_MASK]) - - np_obs = ObsUtil.from_buffer(buffer, len(policy.behavior_spec.observation_specs)) - tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs] - - memories = [ - ModelUtils.list_to_tensor(buffer[BufferKey.MEMORY][i]) - for i in range(0, len(buffer[BufferKey.MEMORY]), policy.sequence_length) - ] - if len(memories) > 0: - memories = torch.stack(memories).unsqueeze(0) - - (sampled_actions, log_probs, entropies, memories) = policy.sample_actions( - tensor_obs, masks=act_masks, memories=memories, seq_len=policy.sequence_length - ) - if discrete: - assert log_probs.all_discrete_tensor.shape == ( - 64, - sum(policy.behavior_spec.action_spec.discrete_branches), - ) - else: - assert log_probs.continuous_tensor.shape == ( - 64, - policy.behavior_spec.action_spec.continuous_size, - ) - assert entropies.shape == (64,) - - if rnn: - assert memories.shape == (1, 1, policy.m_size) - - -def test_step_overflow(): - policy = create_policy_mock(TrainerSettings()) - policy.set_step(2**31 - 1) - assert policy.get_current_step() == 2**31 - 1 # step = 2147483647 - policy.increment_step(3) - assert policy.get_current_step() == 2**31 + 2 # step = 2147483650 diff --git a/ml-agents/mlagents/trainers/torch/components/bc/__init__.py b/ml-agents/mlagents/trainers/tests/torch_entities/__init__.py similarity index 100% rename from ml-agents/mlagents/trainers/torch/components/bc/__init__.py rename to ml-agents/mlagents/trainers/tests/torch_entities/__init__.py diff --git a/ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py b/ml-agents/mlagents/trainers/tests/torch_entities/saver/test_saver.py similarity index 87% rename from ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py rename to ml-agents/mlagents/trainers/tests/torch_entities/saver/test_saver.py index ac89c56ca7..ef7eebe56d 100644 --- a/ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/saver/test_saver.py @@ -6,9 +6,9 @@ import numpy as np from mlagents.torch_utils import torch, default_device from mlagents.trainers.policy.torch_policy import TorchPolicy -from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer -from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer -from mlagents.trainers.poca.optimizer_torch import TorchPOCAOptimizer +from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer, PPOSettings +from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer, SACSettings +from mlagents.trainers.poca.optimizer_torch import TorchPOCAOptimizer, POCASettings from mlagents.trainers.model_saver.torch_model_saver import ( TorchModelSaver, DEFAULT_CHECKPOINT_NAME, @@ -17,13 +17,10 @@ TrainerSettings, NetworkSettings, EncoderType, - PPOSettings, - SACSettings, - POCASettings, ) from mlagents.trainers.tests import mock_brain as mb -from mlagents.trainers.tests.torch.test_policy import create_policy_mock -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.tests.torch_entities.test_policy import create_policy_mock +from mlagents.trainers.torch_entities.utils import ModelUtils def test_register(tmp_path): @@ -36,7 +33,7 @@ def test_register(tmp_path): assert model_saver.policy is None trainer_params = TrainerSettings() - policy = create_policy_mock(trainer_params) + policy = create_policy_mock(trainer_params.network_settings) opt.get_modules = mock.Mock(return_value={}) model_saver.register(policy) assert model_saver.policy is not None @@ -46,7 +43,7 @@ def test_load_save_policy(tmp_path): path1 = os.path.join(tmp_path, "runid1") path2 = os.path.join(tmp_path, "runid2") trainer_params = TrainerSettings() - policy = create_policy_mock(trainer_params) + policy = create_policy_mock(trainer_params.network_settings) model_saver = TorchModelSaver(trainer_params, path1) model_saver.register(policy) model_saver.initialize_or_load(policy) @@ -58,7 +55,7 @@ def test_load_save_policy(tmp_path): # Try load from this path model_saver2 = TorchModelSaver(trainer_params, path1, load=True) - policy2 = create_policy_mock(trainer_params) + policy2 = create_policy_mock(trainer_params.network_settings) model_saver2.register(policy2) model_saver2.initialize_or_load(policy2) _compare_two_policies(policy, policy2) @@ -67,7 +64,7 @@ def test_load_save_policy(tmp_path): # Try initialize from path 1 trainer_params.init_path = os.path.join(path1, DEFAULT_CHECKPOINT_NAME) model_saver3 = TorchModelSaver(trainer_params, path2) - policy3 = create_policy_mock(trainer_params) + policy3 = create_policy_mock(trainer_params.network_settings) model_saver3.register(policy3) model_saver3.initialize_or_load(policy3) _compare_two_policies(policy2, policy3) @@ -82,7 +79,7 @@ def test_load_policy_different_hidden_units(tmp_path, vis_encode_type): trainer_params.network_settings = NetworkSettings( hidden_units=12, vis_encode_type=EncoderType(vis_encode_type) ) - policy = create_policy_mock(trainer_params, use_visual=True) + policy = create_policy_mock(trainer_params.network_settings, use_visual=True) conv_params = [mod for mod in policy.actor.parameters() if len(mod.shape) > 2] model_saver = TorchModelSaver(trainer_params, path1) @@ -99,7 +96,7 @@ def test_load_policy_different_hidden_units(tmp_path, vis_encode_type): hidden_units=10, vis_encode_type=EncoderType(vis_encode_type) ) model_saver2 = TorchModelSaver(trainer_params2, path1, load=True) - policy2 = create_policy_mock(trainer_params2, use_visual=True) + policy2 = create_policy_mock(trainer_params2.network_settings, use_visual=True) conv_params2 = [mod for mod in policy2.actor.parameters() if len(mod.shape) > 2] # asserts convolutions have different parameters before load for conv1, conv2 in zip(conv_params, conv_params2): @@ -133,7 +130,7 @@ def test_load_save_optimizer(tmp_path, optimizer): trainer_settings = TrainerSettings() trainer_settings.hyperparameters = HyperparametersClass() - policy = create_policy_mock(trainer_settings, use_discrete=False) + policy = create_policy_mock(trainer_settings.network_settings, use_discrete=False) optimizer = OptimizerClass(policy, trainer_settings) # save at path 1 @@ -146,7 +143,7 @@ def test_load_save_optimizer(tmp_path, optimizer): model_saver.save_checkpoint("MockBrain", 2000) # create a new optimizer and policy - policy2 = create_policy_mock(trainer_settings, use_discrete=False) + policy2 = create_policy_mock(trainer_settings.network_settings, use_discrete=False) optimizer2 = OptimizerClass(policy2, trainer_settings) # load weights @@ -180,12 +177,14 @@ def _compare_two_policies(policy1: TorchPolicy, policy2: TorchPolicy) -> None: tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs] with torch.no_grad(): - _, log_probs1, _, _ = policy1.sample_actions( + _, stat_dict1, _ = policy1.actor.get_action_and_stats( tensor_obs, masks=masks, memories=memories ) - _, log_probs2, _, _ = policy2.sample_actions( + _, stat_dict2, _ = policy2.actor.get_action_and_stats( tensor_obs, masks=masks, memories=memories ) + log_probs1 = stat_dict1["log_probs"] + log_probs2 = stat_dict2["log_probs"] np.testing.assert_array_equal( ModelUtils.to_numpy(log_probs1.all_discrete_tensor), ModelUtils.to_numpy(log_probs2.all_discrete_tensor), @@ -218,7 +217,10 @@ def test_checkpoint_conversion(tmpdir, rnn, visual, discrete): dummy_config = TrainerSettings() model_path = os.path.join(tmpdir, "Mock_Brain") policy = create_policy_mock( - dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual + dummy_config.network_settings, + use_rnn=rnn, + use_discrete=discrete, + use_visual=visual, ) trainer_params = TrainerSettings() model_saver = TorchModelSaver(trainer_params, model_path) diff --git a/ml-agents/mlagents/trainers/tests/torch/saver/test_saver_reward_providers.py b/ml-agents/mlagents/trainers/tests/torch_entities/saver/test_saver_reward_providers.py similarity index 89% rename from ml-agents/mlagents/trainers/tests/torch/saver/test_saver_reward_providers.py rename to ml-agents/mlagents/trainers/tests/torch_entities/saver/test_saver_reward_providers.py index 908d0158aa..3ffe61d153 100644 --- a/ml-agents/mlagents/trainers/tests/torch/saver/test_saver_reward_providers.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/saver/test_saver_reward_providers.py @@ -4,9 +4,9 @@ import numpy as np from mlagents_envs.logging_util import WARNING -from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer -from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer -from mlagents.trainers.poca.optimizer_torch import TorchPOCAOptimizer +from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer, PPOSettings +from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer, SACSettings +from mlagents.trainers.poca.optimizer_torch import TorchPOCAOptimizer, POCASettings from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver from mlagents.trainers.settings import ( TrainerSettings, @@ -14,12 +14,9 @@ CuriositySettings, GAILSettings, RNDSettings, - PPOSettings, - SACSettings, - POCASettings, ) -from mlagents.trainers.tests.torch.test_policy import create_policy_mock -from mlagents.trainers.tests.torch.test_reward_providers.utils import ( +from mlagents.trainers.tests.torch_entities.test_policy import create_policy_mock +from mlagents.trainers.tests.torch_entities.test_reward_providers.utils import ( create_agent_buffer, ) @@ -49,7 +46,7 @@ def test_reward_provider_save(tmp_path, optimizer): RewardSignalType.GAIL: GAILSettings(demo_path=DEMO_PATH), RewardSignalType.RND: RNDSettings(), } - policy = create_policy_mock(trainer_settings, use_discrete=False) + policy = create_policy_mock(trainer_settings.network_settings, use_discrete=False) optimizer = OptimizerClass(policy, trainer_settings) # save at path 1 @@ -63,7 +60,7 @@ def test_reward_provider_save(tmp_path, optimizer): # create a new optimizer and policy optimizer2 = OptimizerClass(policy, trainer_settings) - policy2 = create_policy_mock(trainer_settings, use_discrete=False) + policy2 = create_policy_mock(trainer_settings.network_settings, use_discrete=False) # load weights model_saver2 = TorchModelSaver(trainer_settings, path1, load=True) @@ -116,7 +113,7 @@ def test_load_different_reward_provider(caplog, tmp_path, optimizer): RewardSignalType.RND: RNDSettings(), } - policy = create_policy_mock(trainer_settings, use_discrete=False) + policy = create_policy_mock(trainer_settings.network_settings, use_discrete=False) optimizer = OptimizerClass(policy, trainer_settings) # save at path 1 @@ -136,7 +133,7 @@ def test_load_different_reward_provider(caplog, tmp_path, optimizer): } # create a new optimizer and policy - policy2 = create_policy_mock(trainer_settings2, use_discrete=False) + policy2 = create_policy_mock(trainer_settings2.network_settings, use_discrete=False) optimizer2 = OptimizerClass(policy2, trainer_settings2) # load weights diff --git a/ml-agents/mlagents/trainers/tests/torch/test.demo b/ml-agents/mlagents/trainers/tests/torch_entities/test.demo similarity index 100% rename from ml-agents/mlagents/trainers/tests/torch/test.demo rename to ml-agents/mlagents/trainers/tests/torch_entities/test.demo diff --git a/ml-agents/mlagents/trainers/tests/torch/test_action_model.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_action_model.py similarity index 96% rename from ml-agents/mlagents/trainers/tests/torch/test_action_model.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_action_model.py index 7c28a0b7d0..cfcaf6a067 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_action_model.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_action_model.py @@ -1,9 +1,9 @@ import pytest from mlagents.torch_utils import torch -from mlagents.trainers.torch.action_model import ActionModel, DistInstances -from mlagents.trainers.torch.agent_action import AgentAction -from mlagents.trainers.torch.distributions import ( +from mlagents.trainers.torch_entities.action_model import ActionModel, DistInstances +from mlagents.trainers.torch_entities.agent_action import AgentAction +from mlagents.trainers.torch_entities.distributions import ( GaussianDistInstance, CategoricalDistInstance, ) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_agent_action.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_agent_action.py similarity index 97% rename from ml-agents/mlagents/trainers/tests/torch/test_agent_action.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_agent_action.py index d8d4cee76a..771872d7b1 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_agent_action.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_agent_action.py @@ -2,7 +2,7 @@ from mlagents.torch_utils import torch from mlagents.trainers.buffer import AgentBuffer, BufferKey -from mlagents.trainers.torch.agent_action import AgentAction +from mlagents.trainers.torch_entities.agent_action import AgentAction def test_agent_action_group_from_buffer(): diff --git a/ml-agents/mlagents/trainers/tests/torch/test_attention.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_attention.py similarity index 97% rename from ml-agents/mlagents/trainers/tests/torch/test_attention.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_attention.py index d2db1773cc..f7344a647b 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_attention.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_attention.py @@ -2,9 +2,9 @@ from mlagents.torch_utils import torch import numpy as np -from mlagents.trainers.torch.utils import ModelUtils -from mlagents.trainers.torch.layers import linear_layer, LinearEncoder -from mlagents.trainers.torch.attention import ( +from mlagents.trainers.torch_entities.utils import ModelUtils +from mlagents.trainers.torch_entities.layers import linear_layer, LinearEncoder +from mlagents.trainers.torch_entities.attention import ( MultiHeadAttention, EntityEmbedding, ResidualSelfAttention, diff --git a/ml-agents/mlagents/trainers/tests/torch/test_bcmodule.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_bcmodule.py similarity index 92% rename from ml-agents/mlagents/trainers/tests/torch/test_bcmodule.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_bcmodule.py index e1d7dc6f11..0364ddd478 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_bcmodule.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_bcmodule.py @@ -1,11 +1,11 @@ +import os +from typing import Dict, Any from unittest.mock import MagicMock import pytest import mlagents.trainers.tests.mock_brain as mb - -import os - from mlagents.trainers.policy.torch_policy import TorchPolicy -from mlagents.trainers.torch.components.bc.module import BCModule +from mlagents.trainers.torch_entities.components.bc.module import BCModule +from mlagents.trainers.torch_entities.networks import SimpleActor from mlagents.trainers.settings import ( TrainerSettings, BehavioralCloningSettings, @@ -19,8 +19,16 @@ def create_bc_module(mock_behavior_specs, bc_settings, use_rnn, tanhresample): trainer_config.network_settings.memory = ( NetworkSettings.MemorySettings() if use_rnn else None ) + actor_kwargs: Dict[str, Any] = { + "conditional_sigma": False, + "tanh_squash": tanhresample, + } policy = TorchPolicy( - 0, mock_behavior_specs, trainer_config, tanhresample, tanhresample + 0, + mock_behavior_specs, + trainer_config.network_settings, + SimpleActor, + actor_kwargs, ) bc_module = BCModule( policy, @@ -128,6 +136,8 @@ def test_bcmodule_dc_visual_update(is_sac): # Test with discrete control, visual observations and RNN + + @pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"]) def test_bcmodule_rnn_dc_update(is_sac): mock_specs = mb.create_mock_banana_behavior_specs() diff --git a/ml-agents/mlagents/trainers/tests/torch/test_conditioning.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_conditioning.py similarity index 92% rename from ml-agents/mlagents/trainers/tests/torch/test_conditioning.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_conditioning.py index 4e7d2f0656..7cfab0251e 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_conditioning.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_conditioning.py @@ -2,8 +2,8 @@ from mlagents.torch_utils import torch import numpy as np -from mlagents.trainers.torch.layers import linear_layer -from mlagents.trainers.torch.conditioning import ConditionalEncoder +from mlagents.trainers.torch_entities.layers import linear_layer +from mlagents.trainers.torch_entities.conditioning import ConditionalEncoder def test_conditional_layer_initialization(): diff --git a/ml-agents/mlagents/trainers/tests/torch/test_decoders.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_decoders.py similarity index 94% rename from ml-agents/mlagents/trainers/tests/torch/test_decoders.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_decoders.py index 1b40745854..8ff54c561f 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_decoders.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_decoders.py @@ -1,7 +1,7 @@ import pytest from mlagents.torch_utils import torch -from mlagents.trainers.torch.decoders import ValueHeads +from mlagents.trainers.torch_entities.decoders import ValueHeads def test_valueheads(): diff --git a/ml-agents/mlagents/trainers/tests/torch/test_distributions.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_distributions.py similarity index 98% rename from ml-agents/mlagents/trainers/tests/torch/test_distributions.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_distributions.py index f004403260..cfe1f0e888 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_distributions.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_distributions.py @@ -1,7 +1,7 @@ import pytest from mlagents.torch_utils import torch -from mlagents.trainers.torch.distributions import ( +from mlagents.trainers.torch_entities.distributions import ( GaussianDistribution, MultiCategoricalDistribution, GaussianDistInstance, diff --git a/ml-agents/mlagents/trainers/tests/torch/test_encoders.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_encoders.py similarity index 96% rename from ml-agents/mlagents/trainers/tests/torch/test_encoders.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_encoders.py index a6cab7b037..d3c1ebb33e 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_encoders.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_encoders.py @@ -2,7 +2,7 @@ from unittest import mock import pytest -from mlagents.trainers.torch.encoders import ( +from mlagents.trainers.torch_entities.encoders import ( VectorInput, Normalizer, SmallVisualEncoder, @@ -50,7 +50,7 @@ def test_normalizer(): assert val == pytest.approx(0.707, abs=0.001) -@mock.patch("mlagents.trainers.torch.encoders.Normalizer") +@mock.patch("mlagents.trainers.torch_entities.encoders.Normalizer") def test_vector_encoder(mock_normalizer): mock_normalizer_inst = mock.Mock() mock_normalizer.return_value = mock_normalizer_inst diff --git a/ml-agents/mlagents/trainers/tests/torch/test_ghost.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_ghost.py similarity index 100% rename from ml-agents/mlagents/trainers/tests/torch/test_ghost.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_ghost.py diff --git a/ml-agents/mlagents/trainers/tests/torch/test_hybrid.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_hybrid.py similarity index 100% rename from ml-agents/mlagents/trainers/tests/torch/test_hybrid.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_hybrid.py diff --git a/ml-agents/mlagents/trainers/tests/torch/test_layers.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_layers.py similarity index 97% rename from ml-agents/mlagents/trainers/tests/torch/test_layers.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_layers.py index d0fc30d989..b0dafd4881 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_layers.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_layers.py @@ -1,6 +1,6 @@ from mlagents.torch_utils import torch -from mlagents.trainers.torch.layers import ( +from mlagents.trainers.torch_entities.layers import ( Swish, linear_layer, lstm_layer, diff --git a/ml-agents/mlagents/trainers/tests/torch/test_networks.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_networks.py similarity index 88% rename from ml-agents/mlagents/trainers/tests/torch/test_networks.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_networks.py index 6932cc42eb..7cf8120806 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_networks.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_networks.py @@ -1,8 +1,8 @@ import pytest from mlagents.torch_utils import torch -from mlagents.trainers.torch.agent_action import AgentAction -from mlagents.trainers.torch.networks import ( +from mlagents.trainers.torch_entities.agent_action import AgentAction +from mlagents.trainers.torch_entities.networks import ( NetworkBody, MultiAgentNetworkBody, ValueNetwork, @@ -273,10 +273,11 @@ def test_valuenetwork(): @pytest.mark.parametrize("lstm", [True, False]) def test_actor_critic(lstm, shared): obs_size = 4 + vis_obs_size = (84, 84, 3) network_settings = NetworkSettings( memory=NetworkSettings.MemorySettings() if lstm else None, normalize=True ) - obs_spec = create_observation_specs_with_shapes([(obs_size,)]) + obs_spec = create_observation_specs_with_shapes([(obs_size,), vis_obs_size]) act_size = 2 mask = torch.ones([1, act_size * 2]) stream_names = [f"stream_name{n}" for n in range(4)] @@ -289,17 +290,23 @@ def test_actor_critic(lstm, shared): actor = SimpleActor(obs_spec, network_settings, action_spec) critic = ValueNetwork(stream_names, obs_spec, network_settings) if lstm: - sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size)) + sample_vis_obs = torch.ones( + (network_settings.memory.sequence_length, 84, 84, 3), dtype=torch.float32 + ) + sample_obs = torch.ones((network_settings.memory.sequence_length, obs_size)) memories = torch.ones( (1, network_settings.memory.sequence_length, actor.memory_size) ) else: + sample_vis_obs = 0.1 * torch.ones((1, 84, 84, 3), dtype=torch.float32) sample_obs = torch.ones((1, obs_size)) memories = torch.tensor([]) # memories isn't always set to None, the network should be able to # deal with that. # Test critic pass - value_out, memories_out = critic.critic_pass([sample_obs], memories=memories) + value_out, memories_out = critic.critic_pass( + [sample_obs] + [sample_vis_obs], memories=memories + ) for stream in stream_names: if lstm: assert value_out[stream].shape == (network_settings.memory.sequence_length,) @@ -308,20 +315,43 @@ def test_actor_critic(lstm, shared): assert value_out[stream].shape == (1,) # Test get action stats and_value - action, log_probs, entropies, mem_out = actor.get_action_and_stats( - [sample_obs], memories=memories, masks=mask + action, run_out, mem_out = actor.get_action_and_stats( + [sample_obs] + [sample_vis_obs], memories=memories, masks=mask ) + log_probs = run_out["log_probs"] + entropy = run_out["entropy"] + + eval_run_out = actor.get_stats( + [sample_obs] + [sample_vis_obs], action, memories=memories, masks=mask + ) + eval_log_probs = eval_run_out["log_probs"] + eval_entropy = eval_run_out["entropy"] + if lstm: assert action.continuous_tensor.shape == (64, 2) + assert log_probs.continuous_tensor.shape == (64, 2) + assert entropy.shape == (64,) + assert eval_log_probs.continuous_tensor.shape == (64, 2) + assert eval_entropy.shape == (64,) + else: assert action.continuous_tensor.shape == (1, 2) + assert log_probs.continuous_tensor.shape == (1, 2) + assert entropy.shape == (1,) + assert eval_log_probs.continuous_tensor.shape == (1, 2) + assert eval_entropy.shape == (1,) assert len(action.discrete_list) == 2 - for _disc in action.discrete_list: + for _disc, _disc_prob, _eval_disc_prob in zip( + action.discrete_list, log_probs.discrete_list, eval_log_probs.discrete_list + ): if lstm: assert _disc.shape == (64, 1) + assert _eval_disc_prob.shape == (64,) else: assert _disc.shape == (1, 1) + assert _disc_prob.shape == (1,) + assert _eval_disc_prob.shape == (1,) if mem_out is not None: assert mem_out.shape == memories.shape diff --git a/ml-agents/mlagents/trainers/tests/torch/test_poca.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_poca.py similarity index 95% rename from ml-agents/mlagents/trainers/tests/torch/test_poca.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_poca.py index 01080bf27b..4ce0e59a3e 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_poca.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_poca.py @@ -1,6 +1,6 @@ from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers import pytest - +from typing import Dict, Any import numpy as np import attr @@ -22,6 +22,7 @@ curiosity_dummy_config, gail_dummy_config, ) +from mlagents.trainers.torch_entities.networks import SimpleActor from mlagents.trainers.agent_processor import AgentManagerQueue from mlagents.trainers.settings import TrainerSettings @@ -64,7 +65,13 @@ def create_test_poca_optimizer(dummy_config, use_rnn, use_discrete, use_visual): if use_rnn else None ) - policy = TorchPolicy(0, mock_specs, trainer_settings, "test", False) + actor_kwargs: Dict[str, Any] = { + "conditional_sigma": False, + "tanh_squash": False, + } + policy = TorchPolicy( + 0, mock_specs, trainer_settings.network_settings, SimpleActor, actor_kwargs + ) optimizer = TorchPOCAOptimizer(policy, trainer_settings) return optimizer @@ -294,7 +301,7 @@ def test_poca_optimizer_update_gail(gail_dummy_config, dummy_config): # noqa: F def test_poca_end_episode(): - name_behavior_id = "test_trainer" + name_behavior_id = "test_brain?team=0" trainer = POCATrainer( name_behavior_id, 10, @@ -310,8 +317,8 @@ def test_poca_end_episode(): parsed_behavior_id = BehaviorIdentifiers.from_name_behavior_id(name_behavior_id) mock_policy = trainer.create_policy(parsed_behavior_id, behavior_spec) trainer.add_policy(parsed_behavior_id, mock_policy) - trajectory_queue = AgentManagerQueue("testbrain") - policy_queue = AgentManagerQueue("testbrain") + trajectory_queue = AgentManagerQueue("test_brain?team=0") + policy_queue = AgentManagerQueue("test_brain?team=0") trainer.subscribe_trajectory_queue(trajectory_queue) trainer.publish_policy_queue(policy_queue) time_horizon = 10 diff --git a/ml-agents/mlagents/trainers/tests/torch_entities/test_policy.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_policy.py new file mode 100644 index 0000000000..4358c8ed75 --- /dev/null +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_policy.py @@ -0,0 +1,66 @@ +import pytest + +from mlagents.trainers.policy.torch_policy import TorchPolicy +from mlagents.trainers.tests import mock_brain as mb +from mlagents.trainers.settings import NetworkSettings +from mlagents.trainers.torch_entities.networks import SimpleActor + +VECTOR_ACTION_SPACE = 2 +VECTOR_OBS_SPACE = 8 +DISCRETE_ACTION_SPACE = [3, 3, 3, 2] +BUFFER_INIT_SAMPLES = 32 +NUM_AGENTS = 12 +EPSILON = 1e-7 + + +def create_policy_mock( + dummy_config: NetworkSettings, + use_rnn: bool = False, + use_discrete: bool = True, + use_visual: bool = False, + seed: int = 0, +) -> TorchPolicy: + mock_spec = mb.setup_test_behavior_specs( + use_discrete, + use_visual, + vector_action_space=DISCRETE_ACTION_SPACE + if use_discrete + else VECTOR_ACTION_SPACE, + vector_obs_space=VECTOR_OBS_SPACE, + ) + + network_settings = dummy_config + network_settings.memory = NetworkSettings.MemorySettings() if use_rnn else None + actor_kwargs = { + "conditional_sigma": False, + "tanh_squash": False, + } + policy = TorchPolicy(seed, mock_spec, network_settings, SimpleActor, actor_kwargs) + return policy + + +@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) +@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) +@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) +def test_policy_evaluate(rnn, visual, discrete): + # Test evaluate + policy = create_policy_mock( + NetworkSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual + ) + decision_step, terminal_step = mb.create_steps_from_behavior_spec( + policy.behavior_spec, num_agents=NUM_AGENTS + ) + + run_out = policy.evaluate(decision_step, list(decision_step.agent_id)) + if discrete: + run_out["action"].discrete.shape == (NUM_AGENTS, len(DISCRETE_ACTION_SPACE)) + else: + assert run_out["action"].continuous.shape == (NUM_AGENTS, VECTOR_ACTION_SPACE) + + +def test_step_overflow(): + policy = create_policy_mock(NetworkSettings()) + policy.set_step(2**31 - 1) + assert policy.get_current_step() == 2**31 - 1 # step = 2147483647 + policy.increment_step(3) + assert policy.get_current_step() == 2**31 + 2 # step = 2147483650 diff --git a/ml-agents/mlagents/trainers/tests/torch/test_ppo.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_ppo.py similarity index 96% rename from ml-agents/mlagents/trainers/tests/torch/test_ppo.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_ppo.py index 743bd824b7..d4e71d4e15 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_ppo.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_ppo.py @@ -5,6 +5,7 @@ from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer from mlagents.trainers.policy.torch_policy import TorchPolicy +from mlagents.trainers.torch_entities.networks import SimpleActor from mlagents.trainers.tests import mock_brain as mb from mlagents.trainers.tests.mock_brain import copy_buffer_fields from mlagents.trainers.tests.test_trajectory import make_fake_trajectory @@ -50,7 +51,13 @@ def create_test_ppo_optimizer(dummy_config, use_rnn, use_discrete, use_visual): if use_rnn else None ) - policy = TorchPolicy(0, mock_specs, trainer_settings, "test", False) + actor_kwargs = { + "conditional_sigma": False, + "tanh_squash": False, + } + policy = TorchPolicy( + 0, mock_specs, trainer_settings.network_settings, SimpleActor, actor_kwargs + ) optimizer = TorchPPOOptimizer(policy, trainer_settings) return optimizer diff --git a/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_reward_providers/test_curiosity.py similarity index 95% rename from ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_reward_providers/test_curiosity.py index 9409e221ea..cc219d0320 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_reward_providers/test_curiosity.py @@ -2,16 +2,16 @@ import pytest from mlagents.torch_utils import torch from mlagents.trainers.buffer import BufferKey -from mlagents.trainers.torch.components.reward_providers import ( +from mlagents.trainers.torch_entities.components.reward_providers import ( CuriosityRewardProvider, create_reward_provider, ) from mlagents_envs.base_env import BehaviorSpec, ActionSpec from mlagents.trainers.settings import CuriositySettings, RewardSignalType -from mlagents.trainers.tests.torch.test_reward_providers.utils import ( +from mlagents.trainers.tests.torch_entities.test_reward_providers.utils import ( create_agent_buffer, ) -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.torch_entities.utils import ModelUtils from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes SEED = [42] diff --git a/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_reward_providers/test_extrinsic.py similarity index 95% rename from ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_reward_providers/test_extrinsic.py index 385891b709..77e77d806f 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_reward_providers/test_extrinsic.py @@ -1,13 +1,13 @@ from mlagents.trainers.buffer import BufferKey import pytest import numpy as np -from mlagents.trainers.torch.components.reward_providers import ( +from mlagents.trainers.torch_entities.components.reward_providers import ( ExtrinsicRewardProvider, create_reward_provider, ) from mlagents_envs.base_env import BehaviorSpec, ActionSpec from mlagents.trainers.settings import RewardSignalSettings, RewardSignalType -from mlagents.trainers.tests.torch.test_reward_providers.utils import ( +from mlagents.trainers.tests.torch_entities.test_reward_providers.utils import ( create_agent_buffer, ) from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes diff --git a/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_reward_providers/test_gail.py similarity index 92% rename from ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_reward_providers/test_gail.py index a26d6a6a8e..d14c8a8439 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_reward_providers/test_gail.py @@ -4,16 +4,16 @@ from unittest.mock import patch from mlagents.torch_utils import torch import os -from mlagents.trainers.torch.components.reward_providers import ( +from mlagents.trainers.torch_entities.components.reward_providers import ( GAILRewardProvider, create_reward_provider, ) from mlagents_envs.base_env import BehaviorSpec, ActionSpec from mlagents.trainers.settings import GAILSettings, RewardSignalType -from mlagents.trainers.tests.torch.test_reward_providers.utils import ( +from mlagents.trainers.tests.torch_entities.test_reward_providers.utils import ( create_agent_buffer, ) -from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import ( +from mlagents.trainers.torch_entities.components.reward_providers.gail_reward_provider import ( DiscriminatorNetwork, ) from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes @@ -74,7 +74,7 @@ def test_factory(behavior_spec: BehaviorSpec) -> None: ) @pytest.mark.parametrize("use_actions", [False, True]) @patch( - "mlagents.trainers.torch.components.reward_providers.gail_reward_provider.demo_to_buffer" + "mlagents.trainers.torch_entities.components.reward_providers.gail_reward_provider.demo_to_buffer" ) def test_reward_decreases( demo_to_buffer: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int @@ -129,7 +129,7 @@ def test_reward_decreases( ) @pytest.mark.parametrize("use_actions", [False, True]) @patch( - "mlagents.trainers.torch.components.reward_providers.gail_reward_provider.demo_to_buffer" + "mlagents.trainers.torch_entities.components.reward_providers.gail_reward_provider.demo_to_buffer" ) def test_reward_decreases_vail( demo_to_buffer: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int diff --git a/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_rnd.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_reward_providers/test_rnd.py similarity index 94% rename from ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_rnd.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_reward_providers/test_rnd.py index bde126aee5..555f567af5 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_rnd.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_reward_providers/test_rnd.py @@ -1,13 +1,13 @@ import numpy as np import pytest from mlagents.torch_utils import torch -from mlagents.trainers.torch.components.reward_providers import ( +from mlagents.trainers.torch_entities.components.reward_providers import ( RNDRewardProvider, create_reward_provider, ) from mlagents_envs.base_env import BehaviorSpec, ActionSpec from mlagents.trainers.settings import RNDSettings, RewardSignalType -from mlagents.trainers.tests.torch.test_reward_providers.utils import ( +from mlagents.trainers.tests.torch_entities.test_reward_providers.utils import ( create_agent_buffer, ) from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes diff --git a/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_reward_providers/utils.py similarity index 100% rename from ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_reward_providers/utils.py diff --git a/ml-agents/mlagents/trainers/tests/torch/test_sac.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_sac.py similarity index 91% rename from ml-agents/mlagents/trainers/tests/torch/test_sac.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_sac.py index d87c076859..ff39caeb71 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_sac.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_sac.py @@ -4,6 +4,7 @@ from mlagents.trainers.buffer import BufferKey, RewardSignalUtil from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer from mlagents.trainers.policy.torch_policy import TorchPolicy +from mlagents.trainers.torch_entities.networks import SimpleActor from mlagents.trainers.tests import mock_brain as mb from mlagents.trainers.settings import NetworkSettings from mlagents.trainers.tests.dummy_config import ( # noqa: F401 @@ -39,7 +40,13 @@ def create_sac_optimizer_mock(dummy_config, use_rnn, use_discrete, use_visual): if use_rnn else None ) - policy = TorchPolicy(0, mock_brain, trainer_settings) + actor_kwargs = { + "conditional_sigma": False, + "tanh_squash": False, + } + policy = TorchPolicy( + 0, mock_brain, trainer_settings.network_settings, SimpleActor, actor_kwargs + ) optimizer = TorchSACOptimizer(policy, trainer_settings) return optimizer @@ -103,9 +110,7 @@ def test_sac_update_reward_signals( update_buffer[RewardSignalUtil.rewards_key("curiosity")] = update_buffer[ BufferKey.ENVIRONMENT_REWARDS ] - return_stats = optimizer.update_reward_signals( - {"curiosity": update_buffer}, num_sequences=update_buffer.num_experiences - ) + return_stats = optimizer.update_reward_signals(update_buffer) required_stats = ["Losses/Curiosity Forward Loss", "Losses/Curiosity Inverse Loss"] for stat in required_stats: assert stat in return_stats.keys() diff --git a/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_simple_rl.py similarity index 98% rename from ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_simple_rl.py index 923491c36b..0c5b509682 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_simple_rl.py @@ -151,7 +151,8 @@ def test_2d_ppo(action_sizes): @pytest.mark.parametrize("action_sizes", [(0, 1), (1, 0)]) @pytest.mark.parametrize("num_visual", [1, 2]) -def test_visual_ppo(num_visual, action_sizes): +@pytest.mark.parametrize("shared_critic", [True, False]) +def test_visual_ppo(shared_critic, num_visual, action_sizes): env = SimpleEnvironment( [BRAIN_NAME], action_sizes=action_sizes, @@ -160,7 +161,9 @@ def test_visual_ppo(num_visual, action_sizes): step_size=0.2, ) new_hyperparams = attr.evolve( - PPO_TORCH_CONFIG.hyperparameters, learning_rate=3.0e-4 + PPO_TORCH_CONFIG.hyperparameters, + learning_rate=3.0e-4, + shared_critic=shared_critic, ) config = attr.evolve(PPO_TORCH_CONFIG, hyperparameters=new_hyperparams) check_environment_trains(env, {BRAIN_NAME: config}) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_utils.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_utils.py similarity index 98% rename from ml-agents/mlagents/trainers/tests/torch/test_utils.py rename to ml-agents/mlagents/trainers/tests/torch_entities/test_utils.py index ee00553756..943e3a6668 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_utils.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_utils.py @@ -3,9 +3,9 @@ import numpy as np from mlagents.trainers.settings import EncoderType, ScheduleType -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.torch_entities.utils import ModelUtils from mlagents.trainers.exception import UnityTrainerException -from mlagents.trainers.torch.encoders import VectorInput +from mlagents.trainers.torch_entities.encoders import VectorInput from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes diff --git a/ml-agents/mlagents/trainers/tests/torch/testdcvis.demo b/ml-agents/mlagents/trainers/tests/torch_entities/testdcvis.demo similarity index 100% rename from ml-agents/mlagents/trainers/tests/torch/testdcvis.demo rename to ml-agents/mlagents/trainers/tests/torch_entities/testdcvis.demo diff --git a/ml-agents/mlagents/trainers/torch/components/reward_providers/__init__.py b/ml-agents/mlagents/trainers/torch/components/reward_providers/__init__.py deleted file mode 100644 index 21f73cb24b..0000000000 --- a/ml-agents/mlagents/trainers/torch/components/reward_providers/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( # noqa F401 - BaseRewardProvider, -) -from mlagents.trainers.torch.components.reward_providers.extrinsic_reward_provider import ( # noqa F401 - ExtrinsicRewardProvider, -) -from mlagents.trainers.torch.components.reward_providers.curiosity_reward_provider import ( # noqa F401 - CuriosityRewardProvider, -) -from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import ( # noqa F401 - GAILRewardProvider, -) -from mlagents.trainers.torch.components.reward_providers.rnd_reward_provider import ( # noqa F401 - RNDRewardProvider, -) -from mlagents.trainers.torch.components.reward_providers.reward_provider_factory import ( # noqa F401 - create_reward_provider, -) diff --git a/ml-agents/mlagents/trainers/torch_entities/__init__.py b/ml-agents/mlagents/trainers/torch_entities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ml-agents/mlagents/trainers/torch/action_flattener.py b/ml-agents/mlagents/trainers/torch_entities/action_flattener.py similarity index 92% rename from ml-agents/mlagents/trainers/torch/action_flattener.py rename to ml-agents/mlagents/trainers/torch_entities/action_flattener.py index 556e844ffb..beb529c963 100644 --- a/ml-agents/mlagents/trainers/torch/action_flattener.py +++ b/ml-agents/mlagents/trainers/torch_entities/action_flattener.py @@ -2,8 +2,8 @@ from mlagents.torch_utils import torch from mlagents_envs.base_env import ActionSpec -from mlagents.trainers.torch.agent_action import AgentAction -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.torch_entities.agent_action import AgentAction +from mlagents.trainers.torch_entities.utils import ModelUtils class ActionFlattener: diff --git a/ml-agents/mlagents/trainers/torch/action_log_probs.py b/ml-agents/mlagents/trainers/torch_entities/action_log_probs.py similarity index 95% rename from ml-agents/mlagents/trainers/torch/action_log_probs.py rename to ml-agents/mlagents/trainers/torch_entities/action_log_probs.py index 71d6598ad2..b72e7bb223 100644 --- a/ml-agents/mlagents/trainers/torch/action_log_probs.py +++ b/ml-agents/mlagents/trainers/torch_entities/action_log_probs.py @@ -2,7 +2,7 @@ from mlagents.torch_utils import torch import numpy as np -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.torch_entities.utils import ModelUtils from mlagents.trainers.buffer import AgentBuffer, BufferKey from mlagents_envs.base_env import _ActionTupleBase @@ -23,6 +23,13 @@ def discrete_dtype(self) -> np.dtype: """ return np.float32 + @staticmethod + def empty_log_probs() -> "LogProbsTuple": + """ + Generates a dummy LogProbsTuple + """ + return LogProbsTuple() + class ActionLogProbs(NamedTuple): """ diff --git a/ml-agents/mlagents/trainers/torch/action_model.py b/ml-agents/mlagents/trainers/torch_entities/action_model.py similarity index 97% rename from ml-agents/mlagents/trainers/torch/action_model.py rename to ml-agents/mlagents/trainers/torch_entities/action_model.py index b4c3282de5..7b88c0262d 100644 --- a/ml-agents/mlagents/trainers/torch/action_model.py +++ b/ml-agents/mlagents/trainers/torch_entities/action_model.py @@ -1,13 +1,13 @@ from typing import List, Tuple, NamedTuple, Optional from mlagents.torch_utils import torch, nn -from mlagents.trainers.torch.distributions import ( +from mlagents.trainers.torch_entities.distributions import ( DistInstance, DiscreteDistInstance, GaussianDistribution, MultiCategoricalDistribution, ) -from mlagents.trainers.torch.agent_action import AgentAction -from mlagents.trainers.torch.action_log_probs import ActionLogProbs +from mlagents.trainers.torch_entities.agent_action import AgentAction +from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs from mlagents_envs.base_env import ActionSpec @@ -68,7 +68,7 @@ def __init__( # During training, clipping is done in TorchPolicy, but we need to clip before ONNX # export as well. - self._clip_action_on_export = not tanh_squash + self.clip_action = not tanh_squash self._deterministic = deterministic def _sample_action(self, dists: DistInstances) -> AgentAction: @@ -181,7 +181,7 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten continuous_out = dists.continuous.exported_model_output() action_out_deprecated = continuous_out deterministic_continuous_out = dists.continuous.deterministic_sample() - if self._clip_action_on_export: + if self.clip_action: continuous_out = torch.clamp(continuous_out, -3, 3) / 3 action_out_deprecated = continuous_out deterministic_continuous_out = ( diff --git a/ml-agents/mlagents/trainers/torch/agent_action.py b/ml-agents/mlagents/trainers/torch_entities/agent_action.py similarity index 99% rename from ml-agents/mlagents/trainers/torch/agent_action.py rename to ml-agents/mlagents/trainers/torch_entities/agent_action.py index dfa7c356e0..1ecc995a55 100644 --- a/ml-agents/mlagents/trainers/torch/agent_action.py +++ b/ml-agents/mlagents/trainers/torch_entities/agent_action.py @@ -4,7 +4,7 @@ from mlagents.torch_utils import torch from mlagents.trainers.buffer import AgentBuffer, BufferKey -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.torch_entities.utils import ModelUtils from mlagents_envs.base_env import ActionTuple diff --git a/ml-agents/mlagents/trainers/torch/attention.py b/ml-agents/mlagents/trainers/torch_entities/attention.py similarity index 98% rename from ml-agents/mlagents/trainers/torch/attention.py rename to ml-agents/mlagents/trainers/torch_entities/attention.py index 253e232dea..ba34e01995 100644 --- a/ml-agents/mlagents/trainers/torch/attention.py +++ b/ml-agents/mlagents/trainers/torch_entities/attention.py @@ -1,13 +1,13 @@ from mlagents.torch_utils import torch import warnings from typing import Tuple, Optional, List -from mlagents.trainers.torch.layers import ( +from mlagents.trainers.torch_entities.layers import ( LinearEncoder, Initialization, linear_layer, LayerNorm, ) -from mlagents.trainers.torch.model_serialization import exporting_to_onnx +from mlagents.trainers.torch_entities.model_serialization import exporting_to_onnx from mlagents.trainers.exception import UnityTrainerException diff --git a/ml-agents/mlagents/trainers/torch_entities/components/__init__.py b/ml-agents/mlagents/trainers/torch_entities/components/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ml-agents/mlagents/trainers/torch_entities/components/bc/__init__.py b/ml-agents/mlagents/trainers/torch_entities/components/bc/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ml-agents/mlagents/trainers/torch/components/bc/module.py b/ml-agents/mlagents/trainers/torch_entities/components/bc/module.py similarity index 94% rename from ml-agents/mlagents/trainers/torch/components/bc/module.py rename to ml-agents/mlagents/trainers/torch_entities/components/bc/module.py index 508c9505cd..ab454409c6 100644 --- a/ml-agents/mlagents/trainers/torch/components/bc/module.py +++ b/ml-agents/mlagents/trainers/torch_entities/components/bc/module.py @@ -5,9 +5,9 @@ from mlagents.trainers.policy.torch_policy import TorchPolicy from mlagents.trainers.demo_loader import demo_to_buffer from mlagents.trainers.settings import BehavioralCloningSettings, ScheduleType -from mlagents.trainers.torch.agent_action import AgentAction -from mlagents.trainers.torch.action_log_probs import ActionLogProbs -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.torch_entities.agent_action import AgentAction +from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs +from mlagents.trainers.torch_entities.utils import ModelUtils from mlagents.trainers.trajectory import ObsUtil from mlagents.trainers.buffer import AgentBuffer @@ -168,12 +168,13 @@ def _update_batch( if self.policy.use_recurrent: memories = torch.zeros(1, self.n_sequences, self.policy.m_size) - selected_actions, log_probs, _, _ = self.policy.sample_actions( + selected_actions, run_out, _ = self.policy.actor.get_action_and_stats( tensor_obs, masks=act_masks, memories=memories, - seq_len=self.policy.sequence_length, + sequence_length=self.policy.sequence_length, ) + log_probs = run_out["log_probs"] bc_loss = self._behavioral_cloning_loss( selected_actions, log_probs, expert_actions ) diff --git a/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/__init__.py b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/__init__.py new file mode 100644 index 0000000000..696d978b8a --- /dev/null +++ b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/__init__.py @@ -0,0 +1,18 @@ +from mlagents.trainers.torch_entities.components.reward_providers.base_reward_provider import ( # noqa F401 + BaseRewardProvider, +) +from mlagents.trainers.torch_entities.components.reward_providers.extrinsic_reward_provider import ( # noqa F401 + ExtrinsicRewardProvider, +) +from mlagents.trainers.torch_entities.components.reward_providers.curiosity_reward_provider import ( # noqa F401 + CuriosityRewardProvider, +) +from mlagents.trainers.torch_entities.components.reward_providers.gail_reward_provider import ( # noqa F401 + GAILRewardProvider, +) +from mlagents.trainers.torch_entities.components.reward_providers.rnd_reward_provider import ( # noqa F401 + RNDRewardProvider, +) +from mlagents.trainers.torch_entities.components.reward_providers.reward_provider_factory import ( # noqa F401 + create_reward_provider, +) diff --git a/ml-agents/mlagents/trainers/torch/components/reward_providers/base_reward_provider.py b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/base_reward_provider.py similarity index 100% rename from ml-agents/mlagents/trainers/torch/components/reward_providers/base_reward_provider.py rename to ml-agents/mlagents/trainers/torch_entities/components/reward_providers/base_reward_provider.py diff --git a/ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/curiosity_reward_provider.py similarity index 95% rename from ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py rename to ml-agents/mlagents/trainers/torch_entities/components/reward_providers/curiosity_reward_provider.py index 1814b22ca5..b4cbf34dd9 100644 --- a/ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py +++ b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/curiosity_reward_provider.py @@ -3,18 +3,18 @@ from mlagents.torch_utils import torch, default_device from mlagents.trainers.buffer import AgentBuffer, BufferKey -from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( +from mlagents.trainers.torch_entities.components.reward_providers.base_reward_provider import ( BaseRewardProvider, ) from mlagents.trainers.settings import CuriositySettings from mlagents_envs.base_env import BehaviorSpec from mlagents_envs import logging_util -from mlagents.trainers.torch.agent_action import AgentAction -from mlagents.trainers.torch.action_flattener import ActionFlattener -from mlagents.trainers.torch.utils import ModelUtils -from mlagents.trainers.torch.networks import NetworkBody -from mlagents.trainers.torch.layers import LinearEncoder, linear_layer +from mlagents.trainers.torch_entities.agent_action import AgentAction +from mlagents.trainers.torch_entities.action_flattener import ActionFlattener +from mlagents.trainers.torch_entities.utils import ModelUtils +from mlagents.trainers.torch_entities.networks import NetworkBody +from mlagents.trainers.torch_entities.layers import LinearEncoder, linear_layer from mlagents.trainers.trajectory import ObsUtil logger = logging_util.get_logger(__name__) diff --git a/ml-agents/mlagents/trainers/torch/components/reward_providers/extrinsic_reward_provider.py b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/extrinsic_reward_provider.py similarity index 94% rename from ml-agents/mlagents/trainers/torch/components/reward_providers/extrinsic_reward_provider.py rename to ml-agents/mlagents/trainers/torch_entities/components/reward_providers/extrinsic_reward_provider.py index 4d58e7c3b1..b0b847463c 100644 --- a/ml-agents/mlagents/trainers/torch/components/reward_providers/extrinsic_reward_provider.py +++ b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/extrinsic_reward_provider.py @@ -2,7 +2,7 @@ from typing import Dict from mlagents.trainers.buffer import AgentBuffer, BufferKey -from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( +from mlagents.trainers.torch_entities.components.reward_providers.base_reward_provider import ( BaseRewardProvider, ) from mlagents_envs.base_env import BehaviorSpec diff --git a/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py similarity index 95% rename from ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py rename to ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py index b7aa8d272c..0ae77ba143 100644 --- a/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py +++ b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py @@ -3,17 +3,17 @@ from mlagents.torch_utils import torch, default_device from mlagents.trainers.buffer import AgentBuffer, BufferKey -from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( +from mlagents.trainers.torch_entities.components.reward_providers.base_reward_provider import ( BaseRewardProvider, ) from mlagents.trainers.settings import GAILSettings from mlagents_envs.base_env import BehaviorSpec from mlagents_envs import logging_util -from mlagents.trainers.torch.utils import ModelUtils -from mlagents.trainers.torch.agent_action import AgentAction -from mlagents.trainers.torch.action_flattener import ActionFlattener -from mlagents.trainers.torch.networks import NetworkBody -from mlagents.trainers.torch.layers import linear_layer, Initialization +from mlagents.trainers.torch_entities.utils import ModelUtils +from mlagents.trainers.torch_entities.agent_action import AgentAction +from mlagents.trainers.torch_entities.action_flattener import ActionFlattener +from mlagents.trainers.torch_entities.networks import NetworkBody +from mlagents.trainers.torch_entities.layers import linear_layer, Initialization from mlagents.trainers.demo_loader import demo_to_buffer from mlagents.trainers.trajectory import ObsUtil diff --git a/ml-agents/mlagents/trainers/torch/components/reward_providers/reward_provider_factory.py b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/reward_provider_factory.py similarity index 72% rename from ml-agents/mlagents/trainers/torch/components/reward_providers/reward_provider_factory.py rename to ml-agents/mlagents/trainers/torch_entities/components/reward_providers/reward_provider_factory.py index db0e6ffe19..825fc49006 100644 --- a/ml-agents/mlagents/trainers/torch/components/reward_providers/reward_provider_factory.py +++ b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/reward_provider_factory.py @@ -3,19 +3,19 @@ from mlagents.trainers.settings import RewardSignalSettings, RewardSignalType -from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( +from mlagents.trainers.torch_entities.components.reward_providers.base_reward_provider import ( BaseRewardProvider, ) -from mlagents.trainers.torch.components.reward_providers.extrinsic_reward_provider import ( +from mlagents.trainers.torch_entities.components.reward_providers.extrinsic_reward_provider import ( ExtrinsicRewardProvider, ) -from mlagents.trainers.torch.components.reward_providers.curiosity_reward_provider import ( +from mlagents.trainers.torch_entities.components.reward_providers.curiosity_reward_provider import ( CuriosityRewardProvider, ) -from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import ( +from mlagents.trainers.torch_entities.components.reward_providers.gail_reward_provider import ( GAILRewardProvider, ) -from mlagents.trainers.torch.components.reward_providers.rnd_reward_provider import ( +from mlagents.trainers.torch_entities.components.reward_providers.rnd_reward_provider import ( RNDRewardProvider, ) diff --git a/ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/rnd_reward_provider.py similarity index 92% rename from ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py rename to ml-agents/mlagents/trainers/torch_entities/components/reward_providers/rnd_reward_provider.py index 8408b08b8d..bda1424ab5 100644 --- a/ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py +++ b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/rnd_reward_provider.py @@ -3,15 +3,15 @@ from mlagents.torch_utils import torch from mlagents.trainers.buffer import AgentBuffer -from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( +from mlagents.trainers.torch_entities.components.reward_providers.base_reward_provider import ( BaseRewardProvider, ) from mlagents.trainers.settings import RNDSettings from mlagents_envs.base_env import BehaviorSpec from mlagents_envs import logging_util -from mlagents.trainers.torch.utils import ModelUtils -from mlagents.trainers.torch.networks import NetworkBody +from mlagents.trainers.torch_entities.utils import ModelUtils +from mlagents.trainers.torch_entities.networks import NetworkBody from mlagents.trainers.trajectory import ObsUtil logger = logging_util.get_logger(__name__) diff --git a/ml-agents/mlagents/trainers/torch/conditioning.py b/ml-agents/mlagents/trainers/torch_entities/conditioning.py similarity index 98% rename from ml-agents/mlagents/trainers/torch/conditioning.py rename to ml-agents/mlagents/trainers/torch_entities/conditioning.py index 8255748f6d..65f622eba3 100644 --- a/ml-agents/mlagents/trainers/torch/conditioning.py +++ b/ml-agents/mlagents/trainers/torch_entities/conditioning.py @@ -2,7 +2,7 @@ from typing import List import math -from mlagents.trainers.torch.layers import ( +from mlagents.trainers.torch_entities.layers import ( linear_layer, Swish, Initialization, diff --git a/ml-agents/mlagents/trainers/torch/decoders.py b/ml-agents/mlagents/trainers/torch_entities/decoders.py similarity index 91% rename from ml-agents/mlagents/trainers/torch/decoders.py rename to ml-agents/mlagents/trainers/torch_entities/decoders.py index 44b69e60e5..30f196a455 100644 --- a/ml-agents/mlagents/trainers/torch/decoders.py +++ b/ml-agents/mlagents/trainers/torch_entities/decoders.py @@ -1,7 +1,7 @@ from typing import List, Dict from mlagents.torch_utils import torch, nn -from mlagents.trainers.torch.layers import linear_layer +from mlagents.trainers.torch_entities.layers import linear_layer class ValueHeads(nn.Module): diff --git a/ml-agents/mlagents/trainers/torch/distributions.py b/ml-agents/mlagents/trainers/torch_entities/distributions.py similarity index 99% rename from ml-agents/mlagents/trainers/torch/distributions.py rename to ml-agents/mlagents/trainers/torch_entities/distributions.py index 91fd1002a5..47fd0d0847 100644 --- a/ml-agents/mlagents/trainers/torch/distributions.py +++ b/ml-agents/mlagents/trainers/torch_entities/distributions.py @@ -3,7 +3,7 @@ from mlagents.torch_utils import torch, nn import numpy as np import math -from mlagents.trainers.torch.layers import linear_layer, Initialization +from mlagents.trainers.torch_entities.layers import linear_layer, Initialization EPSILON = 1e-7 # Small value to avoid divide by zero diff --git a/ml-agents/mlagents/trainers/torch/encoders.py b/ml-agents/mlagents/trainers/torch_entities/encoders.py similarity index 98% rename from ml-agents/mlagents/trainers/torch/encoders.py rename to ml-agents/mlagents/trainers/torch_entities/encoders.py index ef8117f29c..32b944ddfc 100644 --- a/ml-agents/mlagents/trainers/torch/encoders.py +++ b/ml-agents/mlagents/trainers/torch_entities/encoders.py @@ -1,9 +1,9 @@ from typing import Tuple, Optional, Union -from mlagents.trainers.torch.layers import linear_layer, Initialization, Swish +from mlagents.trainers.torch_entities.layers import linear_layer, Initialization, Swish from mlagents.torch_utils import torch, nn -from mlagents.trainers.torch.model_serialization import exporting_to_onnx +from mlagents.trainers.torch_entities.model_serialization import exporting_to_onnx class Normalizer(nn.Module): diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch_entities/layers.py similarity index 98% rename from ml-agents/mlagents/trainers/torch/layers.py rename to ml-agents/mlagents/trainers/torch_entities/layers.py index 5edf3acbdf..e5a598edf6 100644 --- a/ml-agents/mlagents/trainers/torch/layers.py +++ b/ml-agents/mlagents/trainers/torch_entities/layers.py @@ -2,7 +2,7 @@ import abc from typing import Tuple from enum import Enum -from mlagents.trainers.torch.model_serialization import exporting_to_onnx +from mlagents.trainers.torch_entities.model_serialization import exporting_to_onnx class Swish(torch.nn.Module): diff --git a/ml-agents/mlagents/trainers/torch/model_serialization.py b/ml-agents/mlagents/trainers/torch_entities/model_serialization.py similarity index 100% rename from ml-agents/mlagents/trainers/torch/model_serialization.py rename to ml-agents/mlagents/trainers/torch_entities/model_serialization.py diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch_entities/networks.py similarity index 95% rename from ml-agents/mlagents/trainers/torch/networks.py rename to ml-agents/mlagents/trainers/torch_entities/networks.py index be8fb4b732..555268075c 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch_entities/networks.py @@ -1,21 +1,20 @@ -from typing import Callable, List, Dict, Tuple, Optional, Union +from typing import Callable, List, Dict, Tuple, Optional, Union, Any import abc from mlagents.torch_utils import torch, nn from mlagents_envs.base_env import ActionSpec, ObservationSpec, ObservationType -from mlagents.trainers.torch.action_model import ActionModel -from mlagents.trainers.torch.agent_action import AgentAction -from mlagents.trainers.torch.action_log_probs import ActionLogProbs +from mlagents.trainers.torch_entities.action_model import ActionModel +from mlagents.trainers.torch_entities.agent_action import AgentAction from mlagents.trainers.settings import NetworkSettings, EncoderType, ConditioningType -from mlagents.trainers.torch.utils import ModelUtils -from mlagents.trainers.torch.decoders import ValueHeads -from mlagents.trainers.torch.layers import LSTM, LinearEncoder -from mlagents.trainers.torch.encoders import VectorInput +from mlagents.trainers.torch_entities.utils import ModelUtils +from mlagents.trainers.torch_entities.decoders import ValueHeads +from mlagents.trainers.torch_entities.layers import LSTM, LinearEncoder +from mlagents.trainers.torch_entities.encoders import VectorInput from mlagents.trainers.buffer import AgentBuffer from mlagents.trainers.trajectory import ObsUtil -from mlagents.trainers.torch.conditioning import ConditionalEncoder -from mlagents.trainers.torch.attention import ( +from mlagents.trainers.torch_entities.conditioning import ConditionalEncoder +from mlagents.trainers.torch_entities.attention import ( EntityEmbedding, ResidualSelfAttention, get_zero_entities_mask, @@ -88,7 +87,7 @@ def update_normalization(self, buffer: AgentBuffer) -> None: obs = ObsUtil.from_buffer(buffer, len(self.processors)) for vec_input, enc in zip(obs, self.processors): if isinstance(enc, VectorInput): - enc.update_normalization(torch.as_tensor(vec_input)) + enc.update_normalization(torch.as_tensor(vec_input.to_ndarray())) def copy_normalization(self, other_encoder: "ObservationEncoder") -> None: if self.normalize: @@ -519,7 +518,7 @@ def get_action_and_stats( masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, - ) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]: + ) -> Tuple[AgentAction, Dict[str, Any], torch.Tensor]: """ Returns sampled actions. If memory is enabled, return the memories as well. @@ -539,7 +538,7 @@ def get_stats( masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, - ) -> Tuple[ActionLogProbs, torch.Tensor]: + ) -> Dict[str, Any]: """ Returns log_probs for actions and entropies. If memory is enabled, return the memories as well. @@ -633,13 +632,22 @@ def get_action_and_stats( masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, - ) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]: + ) -> Tuple[AgentAction, Dict[str, Any], torch.Tensor]: encoding, memories = self.network_body( inputs, memories=memories, sequence_length=sequence_length ) action, log_probs, entropies = self.action_model(encoding, masks) - return action, log_probs, entropies, memories + run_out = {} + # This is the clipped action which is not saved to the buffer + # but is exclusively sent to the environment. + run_out["env_action"] = action.to_action_tuple( + clip=self.action_model.clip_action + ) + run_out["log_probs"] = log_probs + run_out["entropy"] = entropies + + return action, run_out, memories def get_stats( self, @@ -648,13 +656,16 @@ def get_stats( masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, - ) -> Tuple[ActionLogProbs, torch.Tensor]: + ) -> Dict[str, Any]: encoding, actor_mem_outs = self.network_body( inputs, memories=memories, sequence_length=sequence_length ) - log_probs, entropies = self.action_model.evaluate(encoding, masks, actions) - return log_probs, entropies + log_probs, entropies = self.action_model.evaluate(encoding, masks, actions) + run_out = {} + run_out["log_probs"] = log_probs + run_out["entropy"] = entropies + return run_out def forward( self, diff --git a/ml-agents/mlagents/trainers/torch/utils.py b/ml-agents/mlagents/trainers/torch_entities/utils.py similarity index 95% rename from ml-agents/mlagents/trainers/torch/utils.py rename to ml-agents/mlagents/trainers/torch_entities/utils.py index 4feb70018f..048ce8b591 100644 --- a/ml-agents/mlagents/trainers/torch/utils.py +++ b/ml-agents/mlagents/trainers/torch_entities/utils.py @@ -1,9 +1,9 @@ from typing import List, Optional, Tuple, Dict from mlagents.torch_utils import torch, nn -from mlagents.trainers.torch.layers import LinearEncoder, Initialization +from mlagents.trainers.torch_entities.layers import LinearEncoder, Initialization import numpy as np -from mlagents.trainers.torch.encoders import ( +from mlagents.trainers.torch_entities.encoders import ( SimpleVisualEncoder, ResNetVisualEncoder, NatureVisualEncoder, @@ -12,7 +12,10 @@ VectorInput, ) from mlagents.trainers.settings import EncoderType, ScheduleType -from mlagents.trainers.torch.attention import EntityEmbedding, ResidualSelfAttention +from mlagents.trainers.torch_entities.attention import ( + EntityEmbedding, + ResidualSelfAttention, +) from mlagents.trainers.exception import UnityTrainerException from mlagents_envs.base_env import ObservationSpec, DimensionProperty @@ -317,9 +320,24 @@ def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: :param tensor: Tensor which needs mean computation. :param masks: Boolean tensor of masks with same dimension as tensor. """ - return (tensor.T * masks).sum() / torch.clamp( - (torch.ones_like(tensor.T) * masks).float().sum(), min=1.0 - ) + if tensor.ndim == 0: + return (tensor * masks).sum() / torch.clamp( + (torch.ones_like(tensor) * masks).float().sum(), min=1.0 + ) + else: + return ( + tensor.permute(*torch.arange(tensor.ndim - 1, -1, -1)) * masks + ).sum() / torch.clamp( + ( + torch.ones_like( + tensor.permute(*torch.arange(tensor.ndim - 1, -1, -1)) + ) + * masks + ) + .float() + .sum(), + min=1.0, + ) @staticmethod def soft_update(source: nn.Module, target: nn.Module, tau: float) -> None: diff --git a/ml-agents/mlagents/trainers/trainer/off_policy_trainer.py b/ml-agents/mlagents/trainers/trainer/off_policy_trainer.py new file mode 100644 index 0000000000..92ae496692 --- /dev/null +++ b/ml-agents/mlagents/trainers/trainer/off_policy_trainer.py @@ -0,0 +1,263 @@ +# ## ML-Agent Learning (SAC) +# Contains an implementation of SAC as described in https://arxiv.org/abs/1801.01290 +# and implemented in https://github.com/hill-a/stable-baselines + +from collections import defaultdict +from typing import Dict, cast +import os + +import numpy as np +from mlagents.trainers.policy.checkpoint_manager import ModelCheckpoint + +from mlagents_envs.logging_util import get_logger +from mlagents_envs.timers import timed +from mlagents.trainers.buffer import RewardSignalUtil +from mlagents.trainers.policy import Policy +from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer +from mlagents.trainers.trainer.rl_trainer import RLTrainer +from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers +from mlagents.trainers.settings import TrainerSettings, OffPolicyHyperparamSettings + +logger = get_logger(__name__) + +BUFFER_TRUNCATE_PERCENT = 0.8 + + +class OffPolicyTrainer(RLTrainer): + """ + The SACTrainer is an implementation of the SAC algorithm, with support + for discrete actions and recurrent networks. + """ + + def __init__( + self, + behavior_name: str, + reward_buff_cap: int, + trainer_settings: TrainerSettings, + training: bool, + load: bool, + seed: int, + artifact_path: str, + ): + """ + Responsible for collecting experiences and training an off-policy model. + :param behavior_name: The name of the behavior associated with trainer config + :param reward_buff_cap: Max reward history to track in the reward buffer + :param trainer_settings: The parameters for the trainer. + :param training: Whether the trainer is set for training. + :param load: Whether the model should be loaded. + :param seed: The seed the model will be initialized with + :param artifact_path: The directory within which to store artifacts from this trainer. + """ + super().__init__( + behavior_name, + trainer_settings, + training, + load, + artifact_path, + reward_buff_cap, + ) + + self.seed = seed + self.policy: Policy = None # type: ignore + self.optimizer: TorchOptimizer = None # type: ignore + self.hyperparameters: OffPolicyHyperparamSettings = cast( + OffPolicyHyperparamSettings, trainer_settings.hyperparameters + ) + + self._step = 0 + + # Don't divide by zero + self.update_steps = 1 + self.reward_signal_update_steps = 1 + + self.steps_per_update = self.hyperparameters.steps_per_update + self.reward_signal_steps_per_update = ( + self.hyperparameters.reward_signal_steps_per_update + ) + + self.checkpoint_replay_buffer = self.hyperparameters.save_replay_buffer + + def _checkpoint(self) -> ModelCheckpoint: + """ + Writes a checkpoint model to memory + Overrides the default to save the replay buffer. + """ + ckpt = super()._checkpoint() + if self.checkpoint_replay_buffer: + self.save_replay_buffer() + return ckpt + + def save_model(self) -> None: + """ + Saves the final training model to memory + Overrides the default to save the replay buffer. + """ + super().save_model() + if self.checkpoint_replay_buffer: + self.save_replay_buffer() + + def save_replay_buffer(self) -> None: + """ + Save the training buffer's update buffer to a pickle file. + """ + filename = os.path.join(self.artifact_path, "last_replay_buffer.hdf5") + logger.info(f"Saving Experience Replay Buffer to {filename}...") + with open(filename, "wb") as file_object: + self.update_buffer.save_to_file(file_object) + logger.info( + f"Saved Experience Replay Buffer ({os.path.getsize(filename)} bytes)." + ) + + def load_replay_buffer(self) -> None: + """ + Loads the last saved replay buffer from a file. + """ + filename = os.path.join(self.artifact_path, "last_replay_buffer.hdf5") + logger.info(f"Loading Experience Replay Buffer from {filename}...") + with open(filename, "rb+") as file_object: + self.update_buffer.load_from_file(file_object) + logger.debug( + "Experience replay buffer has {} experiences.".format( + self.update_buffer.num_experiences + ) + ) + + def _is_ready_update(self) -> bool: + """ + Returns whether or not the trainer has enough elements to run update model + :return: A boolean corresponding to whether or not _update_policy() can be run + """ + return ( + self.update_buffer.num_experiences >= self.hyperparameters.batch_size + and self._step >= self.hyperparameters.buffer_init_steps + ) + + def maybe_load_replay_buffer(self): + # Load the replay buffer if load + if self.load and self.checkpoint_replay_buffer: + try: + self.load_replay_buffer() + except (AttributeError, FileNotFoundError): + logger.warning( + "Replay buffer was unable to load, starting from scratch." + ) + logger.debug( + "Loaded update buffer with {} sequences".format( + self.update_buffer.num_experiences + ) + ) + + def add_policy( + self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy + ) -> None: + """ + Adds policy to trainer. + """ + if self.policy: + logger.warning( + "Your environment contains multiple teams, but {} doesn't support adversarial games. Enable self-play to \ + train adversarial games.".format( + self.__class__.__name__ + ) + ) + self.policy = policy + self.policies[parsed_behavior_id.behavior_id] = policy + self.optimizer = self.create_optimizer() + for _reward_signal in self.optimizer.reward_signals.keys(): + self.collected_rewards[_reward_signal] = defaultdict(lambda: 0) + + self.model_saver.register(self.policy) + self.model_saver.register(self.optimizer) + self.model_saver.initialize_or_load() + + # Needed to resume loads properly + self._step = policy.get_current_step() + # Assume steps were updated at the correct ratio before + self.update_steps = int(max(1, self._step / self.steps_per_update)) + self.reward_signal_update_steps = int( + max(1, self._step / self.reward_signal_steps_per_update) + ) + + @timed + def _update_policy(self) -> bool: + """ + Uses update_buffer to update the policy. We sample the update_buffer and update + until the steps_per_update ratio is met. + """ + has_updated = False + self.cumulative_returns_since_policy_update.clear() + n_sequences = max( + int(self.hyperparameters.batch_size / self.policy.sequence_length), 1 + ) + + batch_update_stats: Dict[str, list] = defaultdict(list) + while ( + self._step - self.hyperparameters.buffer_init_steps + ) / self.update_steps > self.steps_per_update: + logger.debug(f"Updating SAC policy at step {self._step}") + buffer = self.update_buffer + if self.update_buffer.num_experiences >= self.hyperparameters.batch_size: + sampled_minibatch = buffer.sample_mini_batch( + self.hyperparameters.batch_size, + sequence_length=self.policy.sequence_length, + ) + # Get rewards for each reward + for name, signal in self.optimizer.reward_signals.items(): + sampled_minibatch[RewardSignalUtil.rewards_key(name)] = ( + signal.evaluate(sampled_minibatch) * signal.strength + ) + + update_stats = self.optimizer.update(sampled_minibatch, n_sequences) + for stat_name, value in update_stats.items(): + batch_update_stats[stat_name].append(value) + + self.update_steps += 1 + + for stat, stat_list in batch_update_stats.items(): + self._stats_reporter.add_stat(stat, np.mean(stat_list)) + has_updated = True + + if self.optimizer.bc_module: + update_stats = self.optimizer.bc_module.update() + for stat, val in update_stats.items(): + self._stats_reporter.add_stat(stat, val) + + # Truncate update buffer if neccessary. Truncate more than we need to to avoid truncating + # a large buffer at each update. + if self.update_buffer.num_experiences > self.hyperparameters.buffer_size: + self.update_buffer.truncate( + int(self.hyperparameters.buffer_size * BUFFER_TRUNCATE_PERCENT) + ) + # TODO: revisit this update + self._update_reward_signals() + return has_updated + + def _update_reward_signals(self) -> None: + """ + Iterate through the reward signals and update them. Unlike in PPO, + do it separate from the policy so that it can be done at a different + interval. + This function should only be used to simulate + http://arxiv.org/abs/1809.02925 and similar papers, where the policy is updated + N times, then the reward signals are updated N times. Normally, the reward signal + and policy are updated in parallel. + """ + buffer = self.update_buffer + batch_update_stats: Dict[str, list] = defaultdict(list) + while ( + self._step - self.hyperparameters.buffer_init_steps + ) / self.reward_signal_update_steps > self.reward_signal_steps_per_update: + # Get minibatches for reward signal update if needed + minibatch = buffer.sample_mini_batch( + self.hyperparameters.batch_size, + sequence_length=self.policy.sequence_length, + ) + update_stats = self.optimizer.update_reward_signals(minibatch) + + for stat_name, value in update_stats.items(): + batch_update_stats[stat_name].append(value) + self.reward_signal_update_steps += 1 + + for stat, stat_list in batch_update_stats.items(): + self._stats_reporter.add_stat(stat, np.mean(stat_list)) diff --git a/ml-agents/mlagents/trainers/trainer/on_policy_trainer.py b/ml-agents/mlagents/trainers/trainer/on_policy_trainer.py new file mode 100644 index 0000000000..879640a0e5 --- /dev/null +++ b/ml-agents/mlagents/trainers/trainer/on_policy_trainer.py @@ -0,0 +1,144 @@ +# # Unity ML-Agents Toolkit +# ## ML-Agent Learning (PPO) +# Contains an implementation of PPO as described in: https://arxiv.org/abs/1707.06347 + +from collections import defaultdict +from typing import cast + +import numpy as np + +from mlagents_envs.logging_util import get_logger +from mlagents.trainers.buffer import BufferKey +from mlagents.trainers.trainer.rl_trainer import RLTrainer +from mlagents.trainers.policy import Policy +from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer +from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers +from mlagents.trainers.settings import TrainerSettings, OnPolicyHyperparamSettings + +logger = get_logger(__name__) + + +class OnPolicyTrainer(RLTrainer): + """The PPOTrainer is an implementation of the PPO algorithm.""" + + def __init__( + self, + behavior_name: str, + reward_buff_cap: int, + trainer_settings: TrainerSettings, + training: bool, + load: bool, + seed: int, + artifact_path: str, + ): + """ + Responsible for collecting experiences and training an on-policy model. + :param behavior_name: The name of the behavior associated with trainer config + :param reward_buff_cap: Max reward history to track in the reward buffer + :param trainer_settings: The parameters for the trainer. + :param training: Whether the trainer is set for training. + :param load: Whether the model should be loaded. + :param seed: The seed the model will be initialized with + :param artifact_path: The directory within which to store artifacts from this trainer. + """ + super().__init__( + behavior_name, + trainer_settings, + training, + load, + artifact_path, + reward_buff_cap, + ) + self.hyperparameters = cast( + OnPolicyHyperparamSettings, self.trainer_settings.hyperparameters + ) + self.seed = seed + self.policy: Policy = None # type: ignore + self.optimizer: TorchOptimizer = None # type: ignore + + def _is_ready_update(self): + """ + Returns whether or not the trainer has enough elements to run update model + :return: A boolean corresponding to whether or not update_model() can be run + """ + size_of_buffer = self.update_buffer.num_experiences + return size_of_buffer > self.hyperparameters.buffer_size + + def _update_policy(self): + """ + Uses demonstration_buffer to update the policy. + The reward signal generators must be updated in this method at their own pace. + """ + buffer_length = self.update_buffer.num_experiences + self.cumulative_returns_since_policy_update.clear() + + # Make sure batch_size is a multiple of sequence length. During training, we + # will need to reshape the data into a batch_size x sequence_length tensor. + batch_size = ( + self.hyperparameters.batch_size + - self.hyperparameters.batch_size % self.policy.sequence_length + ) + # Make sure there is at least one sequence + batch_size = max(batch_size, self.policy.sequence_length) + + n_sequences = max( + int(self.hyperparameters.batch_size / self.policy.sequence_length), 1 + ) + + advantages = np.array( + self.update_buffer[BufferKey.ADVANTAGES].get_batch(), dtype=np.float32 + ) + self.update_buffer[BufferKey.ADVANTAGES].set( + (advantages - advantages.mean()) / (advantages.std() + 1e-10) + ) + num_epoch = self.hyperparameters.num_epoch + batch_update_stats = defaultdict(list) + for _ in range(num_epoch): + self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) + buffer = self.update_buffer + max_num_batch = buffer_length // batch_size + for i in range(0, max_num_batch * batch_size, batch_size): + minibatch = buffer.make_mini_batch(i, i + batch_size) + update_stats = self.optimizer.update(minibatch, n_sequences) + update_stats.update(self.optimizer.update_reward_signals(minibatch)) + for stat_name, value in update_stats.items(): + batch_update_stats[stat_name].append(value) + + for stat, stat_list in batch_update_stats.items(): + self._stats_reporter.add_stat(stat, np.mean(stat_list)) + + if self.optimizer.bc_module: + update_stats = self.optimizer.bc_module.update() + for stat, val in update_stats.items(): + self._stats_reporter.add_stat(stat, val) + self._clear_update_buffer() + return True + + def add_policy( + self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy + ) -> None: + """ + Adds policy to trainer. + :param parsed_behavior_id: Behavior identifiers that the policy should belong to. + :param policy: Policy to associate with name_behavior_id. + """ + if self.policy: + logger.warning( + "Your environment contains multiple teams, but {} doesn't support adversarial games. Enable self-play to \ + train adversarial games.".format( + self.__class__.__name__ + ) + ) + self.policy = policy + self.policies[parsed_behavior_id.behavior_id] = policy + + self.optimizer = self.create_optimizer() + for _reward_signal in self.optimizer.reward_signals.keys(): + self.collected_rewards[_reward_signal] = defaultdict(lambda: 0) + + self.model_saver.register(self.policy) + self.model_saver.register(self.optimizer) + self.model_saver.initialize_or_load() + + # Needed to resume loads properly + self._step = policy.get_current_step() diff --git a/ml-agents/mlagents/trainers/trainer/rl_trainer.py b/ml-agents/mlagents/trainers/trainer/rl_trainer.py index cc6cc5773c..57454900a0 100644 --- a/ml-agents/mlagents/trainers/trainer/rl_trainer.py +++ b/ml-agents/mlagents/trainers/trainer/rl_trainer.py @@ -14,17 +14,14 @@ from mlagents_envs.logging_util import get_logger from mlagents_envs.timers import timed from mlagents.trainers.optimizer import Optimizer +from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer from mlagents.trainers.buffer import AgentBuffer, BufferKey from mlagents.trainers.trainer import Trainer -from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( +from mlagents.trainers.torch_entities.components.reward_providers.base_reward_provider import ( BaseRewardProvider, ) from mlagents_envs.timers import hierarchical_timer -from mlagents_envs.base_env import BehaviorSpec -from mlagents.trainers.policy.policy import Policy -from mlagents.trainers.policy.torch_policy import TorchPolicy from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver -from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers from mlagents.trainers.agent_processor import AgentManagerQueue from mlagents.trainers.trajectory import Trajectory from mlagents.trainers.settings import TrainerSettings @@ -110,20 +107,10 @@ def _is_ready_update(self): """ return False - def create_policy( - self, - parsed_behavior_id: BehaviorIdentifiers, - behavior_spec: BehaviorSpec, - create_graph: bool = False, - ) -> Policy: - return self.create_torch_policy(parsed_behavior_id, behavior_spec) - @abc.abstractmethod - def create_torch_policy( - self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec - ) -> TorchPolicy: + def create_optimizer(self) -> TorchOptimizer: """ - Create a Policy object that uses the PyTorch backend. + Creates an Optimizer object """ pass diff --git a/ml-agents/mlagents/trainers/trainer/trainer.py b/ml-agents/mlagents/trainers/trainer/trainer.py index f51be84169..58a339efd2 100644 --- a/ml-agents/mlagents/trainers/trainer/trainer.py +++ b/ml-agents/mlagents/trainers/trainer/trainer.py @@ -128,10 +128,9 @@ def create_policy( self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec, - create_graph: bool = False, ) -> Policy: """ - Creates policy + Creates a Policy object """ pass @@ -144,12 +143,13 @@ def add_policy( """ pass - @abc.abstractmethod def get_policy(self, name_behavior_id: str) -> Policy: """ - Gets policy from trainer. + Gets policy associated with name_behavior_id + :param name_behavior_id: Fully qualified behavior name + :return: Policy associated with name_behavior_id """ - pass + return self.policies[name_behavior_id] @abc.abstractmethod def advance(self) -> None: @@ -177,3 +177,7 @@ def subscribe_trajectory_queue( :param trajectory_queue: Trajectory queue to read from. """ self.trajectory_queues.append(trajectory_queue) + + @staticmethod + def get_trainer_name() -> str: + raise NotImplementedError diff --git a/ml-agents/mlagents/trainers/trainer/trainer_factory.py b/ml-agents/mlagents/trainers/trainer/trainer_factory.py index 90f1aabef0..ffb7741513 100644 --- a/ml-agents/mlagents/trainers/trainer/trainer_factory.py +++ b/ml-agents/mlagents/trainers/trainer/trainer_factory.py @@ -5,12 +5,10 @@ from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager from mlagents.trainers.exception import TrainerConfigError from mlagents.trainers.trainer import Trainer -from mlagents.trainers.ppo.trainer import PPOTrainer -from mlagents.trainers.sac.trainer import SACTrainer -from mlagents.trainers.poca.trainer import POCATrainer from mlagents.trainers.ghost.trainer import GhostTrainer from mlagents.trainers.ghost.controller import GhostController -from mlagents.trainers.settings import TrainerSettings, TrainerType +from mlagents.trainers.settings import TrainerSettings +from mlagents.plugins import all_trainer_types logger = get_logger(__name__) @@ -101,10 +99,10 @@ def _initialize_trainer( min_lesson_length = param_manager.get_minimum_reward_buffer_size(brain_name) trainer: Trainer = None # type: ignore # will be set to one of these, or raise - trainer_type = trainer_settings.trainer_type - if trainer_type == TrainerType.PPO: - trainer = PPOTrainer( + try: + trainer_type = all_trainer_types[trainer_settings.trainer_type] + trainer = trainer_type( brain_name, min_lesson_length, trainer_settings, @@ -113,29 +111,11 @@ def _initialize_trainer( seed, trainer_artifact_path, ) - elif trainer_type == TrainerType.POCA: - trainer = POCATrainer( - brain_name, - min_lesson_length, - trainer_settings, - train_model, - load_model, - seed, - trainer_artifact_path, - ) - elif trainer_type == TrainerType.SAC: - trainer = SACTrainer( - brain_name, - min_lesson_length, - trainer_settings, - train_model, - load_model, - seed, - trainer_artifact_path, - ) - else: + + except KeyError: raise TrainerConfigError( - f'The trainer config contains an unknown trainer type "{trainer_type}" for brain {brain_name}' + f"The trainer config contains an unknown trainer type " + f"{trainer_settings.trainer_type} for brain {brain_name}" ) if trainer_settings.self_play is not None: diff --git a/ml-agents/mlagents/trainers/trainer/trainer_utils.py b/ml-agents/mlagents/trainers/trainer/trainer_utils.py new file mode 100644 index 0000000000..ad94bd35f0 --- /dev/null +++ b/ml-agents/mlagents/trainers/trainer/trainer_utils.py @@ -0,0 +1,45 @@ +import numpy as np + + +def discount_rewards(r, gamma=0.99, value_next=0.0): + """ + Computes discounted sum of future rewards for use in updating value estimate. + :param r: List of rewards. + :param gamma: Discount factor. + :param value_next: T+1 value estimate for returns calculation. + :return: discounted sum of future rewards as list. + """ + discounted_r = np.zeros_like(r) + running_add = value_next + for t in reversed(range(0, r.size)): + running_add = running_add * gamma + r[t] + discounted_r[t] = running_add + return discounted_r + + +def get_gae(rewards, value_estimates, value_next=0.0, gamma=0.99, lambd=0.95): + """ + Computes generalized advantage estimate for use in updating policy. + :param rewards: list of rewards for time-steps t to T. + :param value_next: Value estimate for time-step T+1. + :param value_estimates: list of value estimates for time-steps t to T. + :param gamma: Discount factor. + :param lambd: GAE weighing factor. + :return: list of advantage estimates for time-steps t to T. + """ + value_estimates = np.append(value_estimates, value_next) + delta_t = rewards + gamma * value_estimates[1:] - value_estimates[:-1] + advantage = discount_rewards(r=delta_t, gamma=gamma * lambd) + return advantage + + +def lambda_return(r, value_estimates, gamma=0.99, lambd=0.8, value_next=0.0): + returns = np.zeros_like(r) + returns[-1] = r[-1] + gamma * value_next + for t in reversed(range(0, r.size - 1)): + returns[t] = ( + gamma * lambd * returns[t + 1] + + r[t] + + (1 - lambd) * gamma * value_estimates[t + 1] + ) + return returns diff --git a/ml-agents/mlagents/trainers/trainer_controller.py b/ml-agents/mlagents/trainers/trainer_controller.py index 8700da3cbe..69da1e5694 100644 --- a/ml-agents/mlagents/trainers/trainer_controller.py +++ b/ml-agents/mlagents/trainers/trainer_controller.py @@ -137,7 +137,6 @@ def _create_trainer_and_manager( policy = trainer.create_policy( parsed_behavior_id, env_manager.training_behaviors[name_behavior_id], - create_graph=True, ) trainer.add_policy(parsed_behavior_id, policy) diff --git a/ml-agents/mlagents/trainers/training_analytics_side_channel.py b/ml-agents/mlagents/trainers/training_analytics_side_channel.py index e0e379325e..84edad7175 100644 --- a/ml-agents/mlagents/trainers/training_analytics_side_channel.py +++ b/ml-agents/mlagents/trainers/training_analytics_side_channel.py @@ -147,7 +147,7 @@ def training_started(self, behavior_name: str, config: TrainerSettings) -> None: raw_config = self._sanitize_trainer_settings(config) msg = TrainingBehaviorInitialized( behavior_name=self._hash(behavior_name), - trainer_type=config.trainer_type.value, + trainer_type=config.trainer_type, extrinsic_reward_enabled=( RewardSignalType.EXTRINSIC in config.reward_signals ), diff --git a/ml-agents/mlagents/trainers/trajectory.py b/ml-agents/mlagents/trainers/trajectory.py index 8b5a246518..0a08bc24b4 100644 --- a/ml-agents/mlagents/trainers/trajectory.py +++ b/ml-agents/mlagents/trainers/trajectory.py @@ -8,7 +8,7 @@ BufferKey, ) from mlagents_envs.base_env import ActionTuple -from mlagents.trainers.torch.action_log_probs import LogProbsTuple +from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple class AgentStatus(NamedTuple): diff --git a/ml-agents/mlagents/trainers/upgrade_config.py b/ml-agents/mlagents/trainers/upgrade_config.py index e1c8a05ad7..d07ce0016d 100644 --- a/ml-agents/mlagents/trainers/upgrade_config.py +++ b/ml-agents/mlagents/trainers/upgrade_config.py @@ -7,9 +7,10 @@ import yaml from typing import Dict, Any, Optional import argparse -from mlagents.trainers.settings import TrainerSettings, NetworkSettings, TrainerType +from mlagents.trainers.settings import TrainerSettings, NetworkSettings from mlagents.trainers.cli_utils import load_config from mlagents.trainers.exception import TrainerConfigError +from mlagents.plugins import all_trainer_settings # Take an existing trainer config (e.g. trainer_config.yaml) and turn it into the new format. @@ -32,7 +33,7 @@ def convert_behaviors(old_trainer_config: Dict[str, Any]) -> Dict[str, Any]: ) new_config = {} new_config["trainer_type"] = trainer_type - hyperparam_cls = TrainerType(trainer_type).to_settings() + hyperparam_cls = all_trainer_settings[trainer_type] # Try to absorb as much as possible into the hyperparam_cls new_config["hyperparameters"] = cattr.structure(config, hyperparam_cls) diff --git a/ml-agents/pydoc-config.yaml b/ml-agents/pydoc-config.yaml new file mode 100644 index 0000000000..8fbfe199dc --- /dev/null +++ b/ml-agents/pydoc-config.yaml @@ -0,0 +1,16 @@ +# config to specify which modules will be used to render api docs from ml-agents package +folder: docs +modules: + - name: mlagents + file_name: Python-On-Off-Policy-Trainer-Documentation.md + submodules: + - trainers.trainer.on_policy_trainer + - trainers.trainer.off_policy_trainer + - trainers.trainer.rl_trainer + - trainers.trainer.trainer + - trainers.settings + - name: mlagents + file_name: Python-Optimizer-Documentation.md + submodules: + - trainers.optimizer.torch_optimizer + - trainers.optimizer.optimizer \ No newline at end of file diff --git a/ml-agents/setup.py b/ml-agents/setup.py index ce7dcae32c..d1b8c6fe64 100644 --- a/ml-agents/setup.py +++ b/ml-agents/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages from setuptools.command.install import install -from mlagents.plugins import ML_AGENTS_STATS_WRITER +from mlagents.plugins import ML_AGENTS_STATS_WRITER, ML_AGENTS_TRAINER_TYPE import mlagents.trainers VERSION = mlagents.trainers.__version__ @@ -48,8 +48,9 @@ def run(self): "Intended Audience :: Developers", "Topic :: Scientific/Engineering :: Artificial Intelligence", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", ], # find_namespace_packages will recurse through the directories and find all the packages packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), @@ -67,7 +68,7 @@ def run(self): # https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/Installation.md#windows-installing-pytorch # Torch only working on python 3.9 for 1.8.0 and above. Details see: # https://github.com/pytorch/pytorch/issues/50014 - "torch>=1.8.0,<1.9.0;(platform_system!='Windows' and python_version>='3.9')", + "torch>=1.8.0,<=1.11.0;(platform_system!='Windows' and python_version>='3.9')", "torch>=1.6.0,<1.9.0;(platform_system!='Windows' and python_version<'3.9')", "tensorboard>=1.15", # cattrs 1.1.0 dropped support for python 3.6, but 1.0.0 doesn't work for python 3.9 @@ -76,9 +77,9 @@ def run(self): "cattrs>=1.1.0,<1.7; python_version>='3.8'", "attrs>=19.3.0", 'pypiwin32==223;platform_system=="Windows"', - "importlib_metadata; python_version<'3.8'", + "importlib_metadata==4.4; python_version<'3.8'", ], - python_requires=">=3.7.2,<3.10.0", + python_requires=">=3.8.13,<=3.10.8", entry_points={ "console_scripts": [ "mlagents-learn=mlagents.trainers.learn:main", @@ -88,6 +89,9 @@ def run(self): ML_AGENTS_STATS_WRITER: [ "default=mlagents.plugins.stats_writer:get_default_stats_writers" ], + ML_AGENTS_TRAINER_TYPE: [ + "default=mlagents.plugins.trainer_type:get_default_trainer_types" + ], }, # TODO: Remove this once mypy stops having spurious setuptools issues. cmdclass={"verify": VerifyVersionCommand}, # type: ignore diff --git a/ml-agents/tests/yamato/yamato_utils.py b/ml-agents/tests/yamato/yamato_utils.py index 1b793ff321..9ebc91043f 100644 --- a/ml-agents/tests/yamato/yamato_utils.py +++ b/ml-agents/tests/yamato/yamato_utils.py @@ -45,7 +45,7 @@ def run_standalone_build( unity_exe = get_unity_executable_path() print(f"Running BuildStandalonePlayer via {unity_exe}") - # enum values from https://docs.unity3d.com/2020.3/Documentation/ScriptReference/BuildTarget.html + # enum values from https://docs.unity3d.com/2021.3/Documentation/ScriptReference/BuildTarget.html build_target_to_enum: Mapping[Optional[str], str] = { "mac": "StandaloneOSX", "osx": "StandaloneOSX", diff --git a/test_constraints_max_version.txt b/test_constraints_max_version.txt index e985203838..18d655ee7e 100644 --- a/test_constraints_max_version.txt +++ b/test_constraints_max_version.txt @@ -1,3 +1,3 @@ # pip constraints to use the *highest* versions allowed in ml-agents/setup.py # For projects with upper bounds, we should periodically update this list to the latest -torch==1.8.0 \ No newline at end of file +torch==1.11.0 diff --git a/test_constraints_mid_version.txt b/test_constraints_mid_version.txt index 0accf06143..ea9fa2e27f 100644 --- a/test_constraints_mid_version.txt +++ b/test_constraints_mid_version.txt @@ -1,2 +1,2 @@ # pip constraints to use a version in the middle of allowed ranges in ml-agents/setup.py -torch==1.7.0 \ No newline at end of file +torch==1.8.0 diff --git a/test_constraints_min_version.txt b/test_constraints_min_version.txt index 9a9de347cd..a6a3046846 100644 --- a/test_constraints_min_version.txt +++ b/test_constraints_min_version.txt @@ -1,2 +1,2 @@ # pip constraints to use the *lowest* versions allowed in ml-agents/setup.py -torch==1.6.0 \ No newline at end of file +torch==1.7.0 diff --git a/utils/validate_inits.py b/utils/validate_inits.py index 43f3250d62..39bd762c99 100755 --- a/utils/validate_inits.py +++ b/utils/validate_inits.py @@ -13,7 +13,7 @@ class NonTrivialPEP420PackageFinder(PEP420PackageFinder): """ @staticmethod - def _looks_like_package(path): + def _looks_like_package(path, package_name=None): glob_path = os.path.join(path, "*.py") return any(glob.iglob(glob_path)) diff --git a/utils/validate_meta_files.py b/utils/validate_meta_files.py index f0616ced85..38a3837540 100644 --- a/utils/validate_meta_files.py +++ b/utils/validate_meta_files.py @@ -28,6 +28,7 @@ def main(): "Documentation~", ".github", ".yamato", + "Samples", } num_matched = 0