Skip to content

Commit

Permalink
train against a random agent (sanity check)
Browse files Browse the repository at this point in the history
  • Loading branch information
denizetkar committed Aug 22, 2024
1 parent 8875954 commit 39c5205
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 14 deletions.
2 changes: 0 additions & 2 deletions src/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from chess_env import ChessEnv
from torchrl.envs import EnvBase

# from torchrl.envs.utils import check_env_specs

from torchrl.modules import MLP, ProbabilisticActor
from custom_modules import ExpandNewDimension, NegConcat
from tensordict.nn import InteractionType
Expand Down
5 changes: 1 addition & 4 deletions src/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from chess_env import ChessEnv
from torchrl.envs.transforms import TransformedEnv, Compose

# import chess


if __name__ == "__main__":
obs_transforms_save_path = "./lightning_logs/version_1/checkpoints/epoch=34-step=17535-obs_transforms.pt"
Expand All @@ -32,8 +30,7 @@
td_step = tenv.step(td_actor)
td = td_step["next"]

# move = chess.Move(td_step[env.action_key, "0"].item(), td_step[env.action_key, "1"].item())
# print(f"Move: {move}")
# print(f"Move: {env.board.move_stack[-1]}")
# print(env.board)

outcome = env.board.outcome()
Expand Down
17 changes: 10 additions & 7 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from chess_env import ChessEnv
from torchrl.envs.transforms import TransformedEnv, Compose
from torchrl.envs import ParallelEnv, RewardSum

# from torchrl.envs.utils import check_env_specs
from torchrl.envs.utils import check_env_specs

from actor_critic import load_action_nets, save_action_nets, create_actor, load_critic, save_critic, create_logits_fn

Expand Down Expand Up @@ -64,14 +63,18 @@
RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]),
)

def create_tenv():
env = ChessEnv()
def create_tenv(index: int | None = None):
# If you want self-play, pass `rand_player_idx=None`
env = ChessEnv(rand_player_idx=index)
tenv = TransformedEnv(env, transforms.clone(), cache_specs=False, device=default_device)
# check_env_specs(tenv)
check_env_specs(tenv)
return tenv

penv = ParallelEnv(n_envs, create_tenv)
# penv = create_tenv()
penv = ParallelEnv(
num_workers=n_envs,
create_env_fn=[create_tenv] * n_envs,
create_env_kwargs=[{"index": i} for i in range(n_envs)],
)

action_nets = load_action_nets(env, penv, default_device, action_nets_save_path)
actor = create_actor(env, penv, default_device, create_logits_fn(env, action_nets))
Expand Down
1 change: 0 additions & 1 deletion src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def _get_next_mask(
action_indices = tuple(actions)
indices = batch_indices + action_indices

# self.mask cannot be None here
filtered_mask = mask[indices]

legal_action_mask = filtered_mask.any(dim=list(range(-num_actions_not_taken + 1, 0)))
Expand Down

0 comments on commit 39c5205

Please sign in to comment.