From 7b3517bb0fcf23aaf1d432508f7afa8fb389f7fd Mon Sep 17 00:00:00 2001 From: Ryan Strauss Date: Fri, 29 Dec 2023 16:06:52 -0500 Subject: [PATCH] Add Go and Chess envs --- dopamax/environments/__init__.py | 2 +- dopamax/environments/pgx/__init__.py | 1 + dopamax/environments/pgx/chess.py | 23 ++++++++++++++++ dopamax/environments/pgx/go.py | 40 ++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 dopamax/environments/pgx/chess.py create mode 100644 dopamax/environments/pgx/go.py diff --git a/dopamax/environments/__init__.py b/dopamax/environments/__init__.py index e96a60f..8ccb931 100644 --- a/dopamax/environments/__init__.py +++ b/dopamax/environments/__init__.py @@ -4,5 +4,5 @@ from .cartpole import CartPole from .mountain_car import MountainCar from .mountain_car_continuous import MountainCarContinuous -from .pgx import ConnectFour, TicTacToe +from .pgx import ConnectFour, TicTacToe, Go9x9, Go19x19 from .utils import make_env diff --git a/dopamax/environments/pgx/__init__.py b/dopamax/environments/pgx/__init__.py index 7f461ed..4e8fbbb 100644 --- a/dopamax/environments/pgx/__init__.py +++ b/dopamax/environments/pgx/__init__.py @@ -1,2 +1,3 @@ from .connect_four import ConnectFour +from .go import Go9x9, Go19x19 from .tic_tac_toe import TicTacToe diff --git a/dopamax/environments/pgx/chess.py b/dopamax/environments/pgx/chess.py new file mode 100644 index 0000000..54a9ade --- /dev/null +++ b/dopamax/environments/pgx/chess.py @@ -0,0 +1,23 @@ +import pgx +from chex import dataclass + +from dopamax.environments.pgx.base import PGXEnvironment +from dopamax.environments.utils import register + +_NAME = "Chess" + + +@register(_NAME) +@dataclass(frozen=True) +class Chess(PGXEnvironment): + def __init__(self): + pgx_env = pgx.make("chess") + super(Chess, self).__init__(_pgx_env=pgx_env) + + @property + def max_episode_length(self) -> int: + return 512 # From AlphaZero paper + + @property + def name(self) -> str: + return _NAME diff --git a/dopamax/environments/pgx/go.py b/dopamax/environments/pgx/go.py new file mode 100644 index 0000000..e49aa1e --- /dev/null +++ b/dopamax/environments/pgx/go.py @@ -0,0 +1,40 @@ +import pgx +from chex import dataclass + +from dopamax.environments.pgx.base import PGXEnvironment +from dopamax.environments.utils import register + +_NAME_9x9 = "Go9x9" +_NAME_19x19 = "Go19x19" + + +@register(_NAME_9x9) +@dataclass(frozen=True) +class Go9x9(PGXEnvironment): + def __init__(self): + pgx_env = pgx.make("go_9x9") + super(Go9x9, self).__init__(_pgx_env=pgx_env) + + @property + def max_episode_length(self) -> int: + return 9 * 9 * 2 + + @property + def name(self) -> str: + return _NAME_9x9 + + +@register(_NAME_19x19) +@dataclass(frozen=True) +class Go19x19(PGXEnvironment): + def __init__(self): + pgx_env = pgx.make("go_19x19") + super(Go19x19, self).__init__(_pgx_env=pgx_env) + + @property + def max_episode_length(self) -> int: + return 19 * 19 * 2 + + @property + def name(self) -> str: + return _NAME_19x19