Skip to content

Commit d2a5779

Browse files
committed
feat: main and train CLI args
1 parent 4f05837 commit d2a5779

18 files changed

+236
-121
lines changed

FlashAttentionTransformer_weights.pth

2.46 MB
Binary file not shown.

RNNModel_weights.pth

-196 KB
Binary file not shown.

TransformerModel_weights.pth

2.32 MB
Binary file not shown.

exact_engine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def generate_pong_states(num_steps=None):
2727
@inject.params(game_state=State)
2828
def _generate_pong_states(game_state: State = None):
2929
dt = 1 # Time step
30-
ball_random_velocity = random_velocity_generator()
30+
ball_random_velocity = random_velocity_generator(min=game_state.engineConfig.min_ball_velocity, max=game_state.engineConfig.max_ball_velocity)
3131

3232
left_paddle = game_state.left_paddle
3333
right_paddle = game_state.right_paddle

fuzzy_engine.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import mlflow
77

88
from game.state import State
9+
from main_arguments import MainArguments
910
from models import ModelConfiguration
10-
from runtime_configuration import mlflow_model_path, model_path, classification_threshold, temperature_variance, mlflow_server_url
11+
from runtime_configuration import mlflow_model_path, classification_threshold, temperature_variance, mlflow_server_url
1112
import numpy as np
1213
from model_loaders import load_mlflow_model, load_pytorch_model
1314

@@ -26,7 +27,7 @@
2627
print(f"Failed to connect to MLflow server at {mlflow_server_url}. Error: {e}")
2728
print("Will load models from local mlruns directory")
2829

29-
config = ModelConfiguration()
30+
# config = ModelConfiguration()
3031

3132
def generate_fuzzy_states(num_steps=None):
3233
state_generator = _generate_fuzzy_states()
@@ -37,33 +38,36 @@ def generate_fuzzy_states(num_steps=None):
3738
for step in range(num_steps):
3839
yield next(state_generator)
3940

40-
@inject.params(game_state=State)
41-
def _generate_fuzzy_states(game_state=State):
41+
@inject.params(game_state=State, main_arguments=MainArguments)
42+
def _generate_fuzzy_states(game_state=State, main_arguments=MainArguments):
4243
dt = 1 # Time step
43-
4444
# Either load model from mlflow run
4545
# model = load_mlflow_model(mlflow_model_path)
4646

4747
# Or load the model from pth file containing weights
48-
model = load_pytorch_model(model_path)
48+
model = load_pytorch_model(f"{main_arguments.model_type}_weights.pth")
49+
4950

5051
model.eval()
51-
window_size = config.input_sequence_length
52+
window_size = main_arguments.input_sequence_length
5253
window = deque(maxlen=window_size)
53-
window.extend(np.zeros((config.input_sequence_length, config.input_size), dtype=float))
54+
window.extend(np.zeros((main_arguments.input_sequence_length, main_arguments.input_size), dtype=float))
5455
while True:
5556
game_state.left_paddle.update(dt)
5657
game_state.right_paddle.update(dt)
5758
paddle_data = game_state.left_paddle.vectorize_state() + game_state.right_paddle.vectorize_state()
58-
temperature = torch.from_numpy(
59-
np.random.uniform(1.0 - temperature_variance, 1.0 + temperature_variance, config.discrete_output_size)).to(
60-
device=config.device)
61-
ball_data, discrete_data = model(torch.tensor(np.array([window])).to(device=config.device, dtype=torch.float), temperature)
59+
# temperature = torch.from_numpy(
60+
# np.random.uniform(1.0 - temperature_variance, 1.0 + temperature_variance, main_arguments.discrete_output_size) * 100).to(
61+
# device=main_arguments.device)
62+
temperature = 1 # larger temperature is more creativw
63+
ball_data, discrete_data = model(torch.tensor(np.array([window])).to(device=main_arguments.device, dtype=torch.float), temperature)
64+
6265
ball_data = ball_data.tolist()[0]
6366
discrete_probabilities = torch.sigmoid(discrete_data)
6467

6568
classes = (discrete_probabilities > classification_threshold).int()
6669
classes = classes.tolist()[0]
70+
6771
window.append(ball_data + paddle_data + classes)
6872
yield ball_data, paddle_data, classes[:4], classes[4:]
6973

game/configuration.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
class EngineConfig:
22
def __init__(self, ball_radius_percent=.01, paddle_width_percent=.01,
3-
paddle_height_percent=.2):
3+
paddle_height_percent=.2, min_ball_velocity=.005, max_ball_velocity=.025):
44
self.ball_radius_percent = ball_radius_percent
55
self.paddle_width_percent = paddle_width_percent
66
self.paddle_height_percent = paddle_height_percent
7+
self.max_ball_velocity = max_ball_velocity
8+
self.min_ball_velocity = min_ball_velocity
9+

main.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from fuzzy_engine import generate_fuzzy_states
1212
import pygame
1313
import inject
14+
from models import ModelConfiguration
15+
from main_arguments import MainArguments
1416

1517
# Initialize Pygame
1618
pygame.init()
@@ -100,12 +102,11 @@ def render_state(state, count, engine_config: EngineConfig = None, field: Field
100102

101103
# Main loop to render the state
102104

103-
@inject.params(generator="generator")
104-
def main(generator):
105+
@inject.params(main_arguments=MainArguments)
106+
def main(main_arguments: MainArguments):
105107
global screen, screen_width, screen_height
106108
running = True
107-
# for index, state in enumerate(generate_fuzzy_states()):
108-
for index, state in enumerate(generator()):
109+
for index, state in enumerate(main_arguments.generator()):
109110
if not running:
110111
break
111112
for event in pygame.event.get():
@@ -131,6 +132,9 @@ def main(generator):
131132
pygame.quit()
132133

133134
def configure_main(binder: inject.Binder):
135+
main_arguments = MainArguments.get_arguments()
136+
binder.bind(MainArguments, main_arguments)
137+
binder.bind(ModelConfiguration, main_arguments)
134138
# immediatly construct and bind an instance to the given key
135139
binder.bind(Field, Field(1.0, 1.0))
136140
binder.bind(EngineConfig, EngineConfig())
@@ -147,10 +151,6 @@ def configure_main(binder: inject.Binder):
147151
binder.bind_to_constructor(State, State)
148152

149153

150-
# Choose the kind of generator desired
151-
# binder.bind("generator", generate_pong_states)
152-
binder.bind("generator", generate_fuzzy_states)
153-
154154
if __name__ == "__main__":
155155
inject.configure(configure_main)
156156
main()

main_arguments.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from dataclasses import dataclass
2+
import argparse
3+
import os
4+
import subprocess
5+
from models.rnn import RNNModel
6+
from models.transformer import TransformerModel
7+
from models.transformer_flashattn import FlashAttentionTransformer
8+
from models import ModelConfiguration
9+
from exact_engine import generate_pong_states
10+
from fuzzy_engine import generate_fuzzy_states
11+
12+
generators = {
13+
"exact": generate_pong_states,
14+
"fuzzy": generate_fuzzy_states,
15+
}
16+
17+
model_dictionary = {RNNModel.__name__: RNNModel,
18+
TransformerModel.__name__: TransformerModel,
19+
FlashAttentionTransformer.__name__: FlashAttentionTransformer}
20+
model_names = list(model_dictionary.keys())
21+
22+
@dataclass
23+
class MainArguments(ModelConfiguration):
24+
mlflow_server_url: str = "https://localhost:8080"
25+
26+
model_type: str = RNNModel.__name__
27+
28+
generator_type: str = list(generators.keys())[1]
29+
generator = list(generators.values())[1]
30+
31+
# keep this parameter last
32+
command: str = ""
33+
34+
@staticmethod
35+
def get_arguments():
36+
parser = argparse.ArgumentParser(description="Main configuration")
37+
38+
parser.add_argument("--mlflow_server_url", type=str, default="http://localhost:8080", help="mlflow server url")
39+
parser.add_argument("--model_type", type=str, default=model_names[0],help="The model type to train", choices=model_names)
40+
parser.add_argument("--generator_type", type=str, default=list(generators.keys())[1],help="The generator type to train", choices=list(generators.keys()))
41+
parser.add_argument("--input_size", type=int, default=16, help="The input size of the model")
42+
parser.add_argument("--hidden_size", type=int, default=128, help="The hidden size of the model")
43+
parser.add_argument("--num_layers", type=int, default=2, help="The number of layers of the model")
44+
parser.add_argument("--number_heads", type=int, default=16, help="The number of heads of the model (transformer model only)")
45+
parser.add_argument("--input_sequence_length", type=int, default=10,help="The length of the input sequence")
46+
47+
args = parser.parse_args()
48+
main_arguments = MainArguments()
49+
for key, value in vars(args).items():
50+
setattr(main_arguments, key, value)
51+
52+
main_arguments.generator = generators.get(args.__dict__["generator_type"])
53+
54+
main_arguments.command = str(subprocess.run(["ps", "-p", f"{os.getpid()}", "-o", "args", "--no-headers"], capture_output=True,
55+
text=True).stdout)
56+
57+
return main_arguments

model_loaders.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import torch
22
import mlflow.pytorch
3-
from runtime_configuration import Model
43
from models import ModelConfiguration
5-
6-
config = ModelConfiguration()
4+
import inject
75

86

97
def save_mlflow_model(model, path):
@@ -13,12 +11,14 @@ def save_pytorch_model(model, path):
1311
torch.save(model.state_dict(), path)
1412

1513

16-
def load_mlflow_model(path):
14+
@inject.params(config=ModelConfiguration)
15+
def load_mlflow_model(path, config: ModelConfiguration):
1716
model = mlflow.pytorch.load_model(path, map_location=torch.device(config.device))
1817
return model
1918

20-
def load_pytorch_model(path):
19+
@inject.params(config=ModelConfiguration)
20+
def load_pytorch_model(path, config: ModelConfiguration):
2121
"""Load weights into a pytorch model from the specified path to .pth file"""
22-
model = Model().to(device=config.device)
22+
model = config.model().to(device=config.device)
2323
model.load_state_dict(torch.load(path, weights_only=True, map_location=torch.device(config.device)))
2424
return model

models/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# only export public facing stuff from the package
22
from .model_configuration import ModelConfiguration
33
from .pong_dataset import PongDataset # should probably move to separate package
4-
from .rnn import RNNModel
5-
from .transformer import TransformerModel
4+
# from .rnn import RNNModel
5+
# from .transformer import TransformerModel
66
# from .transformer_flashattn import FlashAttentionTransformer

models/base_pong_model.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import torch
44
from torch import nn as nn
5-
from . import ModelConfiguration
5+
import inject
6+
from models import ModelConfiguration
67

7-
config = ModelConfiguration()
88

99

1010
class BasePongModel(nn.Module, ABC):
11-
def __init__(self):
11+
def __init__(self, config: ModelConfiguration):
1212
super(BasePongModel, self).__init__()
13+
self.config = config
1314
# Linear layer to expand input from 10 to 64 dimensions
1415
self.fc_feature_expansion = nn.Linear(config.input_size, config.hidden_size)
1516

models/model_configuration.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
# general parameters
22
import torch
3+
import dataclasses
34

45

6+
@dataclasses.dataclass
57
class ModelConfiguration:
6-
def __init__(self):
7-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
8-
# model parameters
9-
self.input_size = 16
10-
self.hidden_size = 128
11-
self.output_size = 4
12-
self.discrete_output_size = 6
13-
self.num_layers = 2
14-
self.number_heads = 8
8+
device = "cuda" if torch.cuda.is_available() else "cpu"
9+
# model parameters
10+
input_size = 16
11+
hidden_size = 128
12+
output_size = 4
13+
discrete_output_size = 6
14+
num_layers = 2
15+
number_heads = 16
1516

16-
# training parameters
17-
self.input_sequence_length=20
17+
# training parameters
18+
input_sequence_length=20
19+
20+
# def get_model_path(self):
21+
# return f"{self.model.__name__}_weights.pth"

models/rnn.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
from torch import nn as nn
22

33
from .base_pong_model import BasePongModel
4-
from . import ModelConfiguration
5-
6-
config = ModelConfiguration()
74

85

96
class RNNModel(BasePongModel):
10-
def __init__(self):
11-
super(RNNModel, self).__init__()
12-
self.lstm = nn.LSTM(config.hidden_size, config.hidden_size, config.num_layers, batch_first=True, dropout=0.2)
7+
def __init__(self, model_config):
8+
super(RNNModel, self).__init__(model_config)
9+
self.lstm = nn.LSTM(self.config.hidden_size, self.config.hidden_size, self.config.num_layers, batch_first=True, dropout=0.2)
1310

1411
def _forward(self, x):
1512
out, _ = self.lstm(x)

models/transformer.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,19 @@
44
from .base_pong_model import BasePongModel
55
from . import ModelConfiguration
66

7-
config = ModelConfiguration()
8-
9-
107
class TransformerModel(BasePongModel):
11-
def __init__(self):
12-
super(TransformerModel, self).__init__()
8+
def __init__(self, model_config: ModelConfiguration):
9+
super(TransformerModel, self).__init__(model_config)
1310
# Consider using decoder only with flash attention
14-
self.positional_encoding = nn.Parameter(torch.zeros(1, 100, config.hidden_size))
11+
self.positional_encoding = nn.Parameter(torch.zeros(1, 100, model_config.hidden_size))
1512

1613
# self.transformer = nn.TransformerEncoder(
1714
self.transformer_list = nn.ModuleList([nn.TransformerEncoderLayer(
18-
d_model=config.hidden_size,
19-
nhead=config.number_heads,
20-
dim_feedforward=config.hidden_size,
15+
d_model=self.config.hidden_size,
16+
nhead=self.config.number_heads,
17+
dim_feedforward=self.config.hidden_size,
2118
batch_first=True,
22-
) for _ in range(config.num_layers)])
19+
) for _ in range(self.config.num_layers)])
2320
# num_layers=num_layers,
2421
# )
2522
# self.transformer = nn.TransformerDecoderLayer(

models/transformer_flashattn.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
from flash_attn.modules.mha import MHA
66
from flash_attn import flash_attn_func
77
from .base_pong_model import BasePongModel
8-
from . import ModelConfiguration
98

10-
config = ModelConfiguration()
119

1210

1311
class Transformer(nn.Module):
@@ -47,7 +45,7 @@ def _generate_positional_encoding(max_seq_len, embed_dim):
4745
class TransformerLayer(nn.Module):
4846
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
4947
super(TransformerLayer, self).__init__()
50-
self.mha = MHA(embed_dim, num_heads, causal=True, use_flash_attn=True, return_residual=False)
48+
self.mha = MHA(embed_dim, num_heads, causal=True, use_flash_attn=False, return_residual=False)
5149
# self.mha = MultiHeadAttention(embed_dim, num_heads)
5250
self.ffn = FeedForwardNetwork(embed_dim, ff_dim, dropout)
5351
self.norm1 = nn.LayerNorm(embed_dim)
@@ -112,7 +110,7 @@ def _forward(self, x: torch.Tensor):
112110
x = self.transformer(x)
113111
return x.mean(dim=1)
114112

115-
def __init__(self):
116-
super(FlashAttentionTransformer, self).__init__()
117-
self.transformer = Transformer(config.hidden_size, config.number_heads, config.num_layers, config.hidden_size, config.input_sequence_length, 0.2)
113+
def __init__(self, model_config):
114+
super(FlashAttentionTransformer, self).__init__(model_config)
115+
self.transformer = Transformer(self.config.hidden_size, self.config.number_heads, self.config.num_layers, self.config.hidden_size, self.config.input_sequence_length, 0.2)
118116

runtime_configuration.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
1-
from models import TransformerModel, RNNModel
2-
# from models import FlashAttentionTransformer
1+
2+
33

44
mlflow_server_url = "http://localhost:8080"
55

6-
# model to use during training and inference
7-
Model = RNNModel
86

97
# path to mlflow model
108
# use a run that corresponds with the desired model type (rnn/transformer/flashtransformer)
11-
mlflow_model_path = f"runs:/62a7d1ead3564c379cbacbff4ef7ac55/model_e99"
12-
13-
# path to pytorch model weights
14-
model_path = f"{Model.__name__}_weights.pth"
9+
mlflow_model_path = f"runs:/6d4a6cb5a09c420ca834fe16795b16a3/model_e99"
1510

1611

1712
classification_threshold = 0.5
18-
temperature_variance = 0.0
13+
temperature_variance = 0.7
14+
15+
# class RuntimeArguments:

0 commit comments

Comments
 (0)