Skip to content

Commit

Permalink
Add Go and Chess envs
Browse files Browse the repository at this point in the history
  • Loading branch information
rystrauss committed Dec 29, 2023
1 parent 7ace59e commit 7b3517b
Showing 4 changed files with 65 additions and 1 deletion.
2 changes: 1 addition & 1 deletion dopamax/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -4,5 +4,5 @@
from .cartpole import CartPole
from .mountain_car import MountainCar
from .mountain_car_continuous import MountainCarContinuous
from .pgx import ConnectFour, TicTacToe
from .pgx import ConnectFour, TicTacToe, Go9x9, Go19x19
from .utils import make_env
1 change: 1 addition & 0 deletions dopamax/environments/pgx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .connect_four import ConnectFour
from .go import Go9x9, Go19x19
from .tic_tac_toe import TicTacToe
23 changes: 23 additions & 0 deletions dopamax/environments/pgx/chess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pgx
from chex import dataclass

from dopamax.environments.pgx.base import PGXEnvironment
from dopamax.environments.utils import register

_NAME = "Chess"


@register(_NAME)
@dataclass(frozen=True)
class Chess(PGXEnvironment):
def __init__(self):
pgx_env = pgx.make("chess")
super(Chess, self).__init__(_pgx_env=pgx_env)

@property
def max_episode_length(self) -> int:
return 512 # From AlphaZero paper

@property
def name(self) -> str:
return _NAME
40 changes: 40 additions & 0 deletions dopamax/environments/pgx/go.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pgx
from chex import dataclass

from dopamax.environments.pgx.base import PGXEnvironment
from dopamax.environments.utils import register

_NAME_9x9 = "Go9x9"
_NAME_19x19 = "Go19x19"


@register(_NAME_9x9)
@dataclass(frozen=True)
class Go9x9(PGXEnvironment):
def __init__(self):
pgx_env = pgx.make("go_9x9")
super(Go9x9, self).__init__(_pgx_env=pgx_env)

@property
def max_episode_length(self) -> int:
return 9 * 9 * 2

@property
def name(self) -> str:
return _NAME_9x9


@register(_NAME_19x19)
@dataclass(frozen=True)
class Go19x19(PGXEnvironment):
def __init__(self):
pgx_env = pgx.make("go_19x19")
super(Go19x19, self).__init__(_pgx_env=pgx_env)

@property
def max_episode_length(self) -> int:
return 19 * 19 * 2

@property
def name(self) -> str:
return _NAME_19x19

0 comments on commit 7b3517b

Please sign in to comment.