Skip to content

Commit

Permalink
board state as multi one-hot
Browse files Browse the repository at this point in the history
  • Loading branch information
denizetkar committed Aug 21, 2024
1 parent 32651a1 commit 54b570c
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 26 deletions.
2 changes: 0 additions & 2 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@

- Parallelize traj collection: `ParallelEnv` or `MultiSyncDataCollector` ???

---

- Maybe turn categorical vars to one-hot vectors (?)

```python
Expand Down
14 changes: 10 additions & 4 deletions src/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

def create_action_nets(base_env: ChessEnv, final_env: EnvBase, default_device: torch.device):
action_dims = [a.n for a in final_env.full_action_spec[base_env.action_key].values()]
obs_total_dims = sum([final_env.full_observation_spec[key].shape[-1] for key in base_env.observation_keys])
batch_shape = final_env.full_observation_spec[ChessEnv.OBSERVATION_KEY].shape
obs_total_dims = sum(
[final_env.full_observation_spec[key].shape[len(batch_shape) :].numel() for key in base_env.observation_keys]
)
# Only thing to persist to disk is `action_nets`
action_nets = nn.ModuleDict(
{
Expand Down Expand Up @@ -76,7 +79,7 @@ def create_actor(
actor = ProbabilisticActor(
module=identity_module,
spec=final_env.full_action_spec,
in_keys={"observations": base_env.OBSERVATION_KEY, "mask": "action_mask"},
in_keys={"observations": ChessEnv.OBSERVATION_KEY, "mask": "action_mask"},
# We shouldn't have to specify the `out_keys` below but otherwise `ProbabilisticActor.__init__` complains about
# not having `distribution_map` in `distribution_kwargs`. It doesn't allow to customize `CompositeDistribution`
# that doesn't take `distribution_map` arg. WTF?
Expand All @@ -101,8 +104,11 @@ def create_actor(


def create_critic(base_env: ChessEnv, final_env: EnvBase, default_device: torch.device):
obs_without_turn_keys = [key for key in base_env.observation_keys if key != (base_env.OBSERVATION_KEY, "turn")]
obs_without_turn_total_dims = sum([final_env.full_observation_spec[key].shape[-1] for key in obs_without_turn_keys])
obs_without_turn_keys = [key for key in base_env.observation_keys if key != (ChessEnv.OBSERVATION_KEY, "turn")]
batch_shape = final_env.full_observation_spec[ChessEnv.OBSERVATION_KEY].shape
obs_without_turn_total_dims = sum(
[final_env.full_observation_spec[key].shape[len(batch_shape) :].numel() for key in obs_without_turn_keys]
)
critic_modules = (
TensorDictModule(
MLP(
Expand Down
23 changes: 12 additions & 11 deletions src/chess_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tensordict import TensorDictBase, TensorDict
from torchrl.envs import EnvBase, ParallelEnv
from torchrl.envs.transforms import TransformedEnv, ObservationNorm, Compose, Transform
from torchrl.data import CompositeSpec, DiscreteTensorSpec, BoundedTensorSpec
from torchrl.data import CompositeSpec, DiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, BoundedTensorSpec

from custom_tensor_specs import DependentDiscreteTensorsSpec
from custom_transforms import DiscreteToContinuousTransform, DoubleToFloat
Expand All @@ -31,12 +31,13 @@ def __init__(
allow_done_after_reset=allow_done_after_reset,
)
self.n_agents = 2
self.total_piece_cnt = 2 * len(chess.PIECE_TYPES)
self.observation_spec = CompositeSpec(
{
ChessEnv.OBSERVATION_KEY: CompositeSpec(
# 0 for no piece at that position, 1-6 for black, 7-12 for white pieces
piece_at_pos=DiscreteTensorSpec(2 * len(chess.PIECE_TYPES) + 1, shape=(64,)),
turn=DiscreteTensorSpec(self.n_agents, shape=(1,)),
piece_at_pos=MultiOneHotDiscreteTensorSpec([self.total_piece_cnt] * 64, dtype=torch.bool),
turn=DiscreteTensorSpec(self.n_agents, shape=(1,), dtype=torch.bool),
)
},
action_mask=DiscreteTensorSpec(2, shape=(64, 64), dtype=torch.bool),
Expand Down Expand Up @@ -84,11 +85,11 @@ def observation_td(self):
piece_positions, piece_types = [], []
for piece_pos, piece in self.board.piece_map().items():
piece_positions.append(piece_pos)
piece_types.append(int(piece.color) * len(chess.PIECE_TYPES) + piece.piece_type)
piece_types.append(int(piece.color) * len(chess.PIECE_TYPES) + piece.piece_type - 1)
piece_positions = torch.tensor(piece_positions, device=obs_td.device)
piece_types = torch.tensor(piece_types, device=obs_td.device)
obs_td[ChessEnv.OBSERVATION_KEY, "piece_at_pos"][piece_positions] = piece_types
obs_td[ChessEnv.OBSERVATION_KEY, "turn"].fill_(int(self.board.turn))
obs_td[ChessEnv.OBSERVATION_KEY, "piece_at_pos"][piece_positions * self.total_piece_cnt + piece_types] = True
obs_td[ChessEnv.OBSERVATION_KEY, "turn"][...] = self.board.turn

legal_moves = []
for move in self.board.legal_moves:
Expand All @@ -103,8 +104,8 @@ def observation_td(self):
def done_td(self):
d_td = self.full_done_spec.zero()
is_done = self.board.is_game_over()
d_td["done"].fill_(is_done)
d_td["agents", "done"].fill_(is_done)
d_td["done"][...] = is_done
d_td["agents", "done"][...] = is_done
return d_td

@property
Expand All @@ -115,10 +116,9 @@ def reward_td(self):
return r_td

if outcome.winner is None:
r_td[self.reward_key].fill_(-1.0)
r_td[self.reward_key][...] = -1.0
else:
winner, loser = int(outcome.winner), int(not outcome.winner)
indexes = torch.tensor([winner, loser], device=r_td.device)
indexes = [int(outcome.winner), int(not outcome.winner)]
r_td[self.reward_key][indexes] = torch.tensor([[1.0], [-1.0]], device=r_td.device)
return r_td

Expand Down Expand Up @@ -202,4 +202,5 @@ def create_tenv():
return tenv

penv = ParallelEnv(2, create_tenv)
# penv = create_tenv()
pass
4 changes: 2 additions & 2 deletions src/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def get_move(td: TensorDict):


if __name__ == "__main__":
obs_transforms_save_path = "./lightning_logs/version_0/checkpoints/epoch=74-step=37575-obs_transforms.pt"
action_nets_save_path = "./lightning_logs/version_0/checkpoints/epoch=74-step=37575-action_nets-PPO.pt"
obs_transforms_save_path = "./lightning_logs/version_1/checkpoints/epoch=34-step=17535-obs_transforms.pt"
action_nets_save_path = "./lightning_logs/version_1/checkpoints/epoch=34-step=17535-action_nets-PPO.pt"
default_device = torch.device("cpu")

env = ChessEnv()
Expand Down
7 changes: 5 additions & 2 deletions src/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from tensordict import TensorDict

import lightning as L
import torch.nn.functional as F
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from torch.utils.data import Dataset, random_split, DataLoader
import numpy as np
Expand Down Expand Up @@ -52,7 +53,7 @@ def map_game_to_winner_moves(row: pd.Series):
piece_at_pos = [0] * 64
for square, piece in game.piece_map().items():
index = square # Use the square index directly
piece_at_pos[index] = int(piece.color) * len(chess.PIECE_TYPES) + piece.piece_type
piece_at_pos[index] = int(piece.color) * len(chess.PIECE_TYPES) + piece.piece_type - 1

game_data.append(
{"piece_at_pos": piece_at_pos, "turn": [int(winner)], "move": [move.from_square, move.to_square]}
Expand Down Expand Up @@ -152,6 +153,8 @@ def __init__(self, *args: Any, max_epochs: int, lr: float = 1e-4, min_lr: float

def get_ce_loss(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int):
piece_at_pos, turn, move = batch
# Turn piece_at_pos from categorical to one-hot
piece_at_pos = F.one_hot(piece_at_pos, num_classes=2 * len(chess.PIECE_TYPES)).flatten(-2, -1)
observations = TensorDict(
{ChessEnv.OBSERVATION_KEY: {"piece_at_pos": piece_at_pos, "turn": turn}}
).auto_batch_size_()
Expand Down Expand Up @@ -194,7 +197,7 @@ def configure_optimizers(self):
# Hyperparameters
max_epochs = 200
batch_size = 1024
lr, min_lr = 5e-4, 1e-6
lr, min_lr = 2e-4, 1e-6
gradient_clip_val = 10.0

piece_at_pos, turn, move = winner_moves_df_to_np_arrays(load_winner_moves_df())
Expand Down
10 changes: 5 additions & 5 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@
lmbda = 0.95
entropy_eps = 1e-4

obs_transforms_save_path = "./lightning_logs/version_0/checkpoints/epoch=74-step=37575-obs_transforms.pt"
action_nets_save_path = "./lightning_logs/version_0/checkpoints/epoch=74-step=37575-action_nets.pt"
critic_save_path = "./lightning_logs/version_0/checkpoints/epoch=74-step=37575-critic.pt"
obs_transforms_save_path = "./lightning_logs/version_1/checkpoints/epoch=34-step=17535-obs_transforms.pt"
action_nets_save_path = "./lightning_logs/version_1/checkpoints/epoch=34-step=17535-action_nets.pt"
critic_save_path = "./lightning_logs/version_1/checkpoints/epoch=34-step=17535-critic.pt"

env = ChessEnv()
transforms = Compose(
Expand Down Expand Up @@ -136,7 +136,7 @@ def create_tenv():
subdata: TensorDict = replay_buffer.sample()
loss_vals: TensorDict = loss_module(subdata)

turn_data = transforms.inv(subdata)[env.OBSERVATION_KEY, "turn"]
turn_data = transforms.inv(subdata)[ChessEnv.OBSERVATION_KEY, "turn"].type(torch.long)
losses: TensorDict = (
loss_vals.select("loss_objective").auto_batch_size_().gather(index=turn_data, dim=-1).mean()
)
Expand All @@ -148,7 +148,7 @@ def create_tenv():
optim.step()
optim.zero_grad()

turn_data = transforms.inv(td_data)[env.OBSERVATION_KEY, "turn"].unsqueeze(-2)
turn_data = transforms.inv(td_data)[ChessEnv.OBSERVATION_KEY, "turn"].type(torch.long).unsqueeze(-2)
selected_rewards = td_data.get(("next", "agents", "episode_reward")).gather(index=turn_data, dim=-2)
episode_reward_mean = selected_rewards.mean().item()
episode_reward_neg_sum = selected_rewards.type(torch.long).clip(max=0.0).sum().item()
Expand Down

0 comments on commit 54b570c

Please sign in to comment.