-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
177 lines (138 loc) · 6.32 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""
Given a balls initial position, direction and
"""
from game.configuration import EngineConfig
from game.field import Field
from game.paddle import UserPaddleFactory, RandomPaddleFactory, PaddleFactory
from game.score import Score
from game.state import State
from game.ball import Ball
import pygame
import inject
from models import ModelConfiguration
from main_arguments import MainArguments
import torch
from exact_engine import generate_pong_states
from fuzzy_engine import generate_fuzzy_states
from models.base_pong_model import BasePongModel
from models.rnn import RNNModel
from models.transformer import TransformerModel
from models.transformer_flashattn import FlashAttentionTransformer
models = {
RNNModel.__name__: RNNModel,
TransformerModel.__name__: TransformerModel,
FlashAttentionTransformer.__name__: FlashAttentionTransformer
}
generators = {
"exact": generate_pong_states,
"fuzzy": generate_fuzzy_states,
}
# Initialize Pygame
pygame.init()
# Game parameters
screen_width = 800
screen_height = 400
screen_quarter = int(screen_width/4)
background_color = (0, 0, 0) # Black
ball_color = (255, 255, 255) # White
paddle_color = (255, 255, 255) # White
paddle_color_collision = (255, 0, 0) # White
half_screen_width = screen_width / 2.0
half_screen_height = screen_height / 2.0
# Initialize Pygame screen
screen = pygame.display.set_mode((screen_width, screen_height), pygame.RESIZABLE)
pygame.display.set_caption("Pong State Renderer")
def translate(point):
x,y = point
return x + half_screen_width, y + half_screen_height
@inject.params(field=Field)
def scale_to_screen(point, field: Field = None):
x, y = point
x = (x + field.width / 2.0) / field.width
y = (y + field.height / 2.0) / field.height
return int(x * screen_width), int(y * screen_height)
font_size = 24
font = pygame.font.Font(None, font_size) # None = default font
small_font = pygame.font.Font(None, 18)
text_color = (255, 255, 255) # White
@inject.params(scores=Score)
def update_scores(state, scores: Score = None):
ball_data, paddle_data, collision_data, score_data = state
scores.update(*(score_data+collision_data[:2]))
@inject.params(scores=Score)
def render_scores(scores: Score = None):
score1_surface = font.render(f"{scores.left_score} | {scores.left_blocked}", True, text_color)
score2_surface = font.render(f"{scores.right_score} | {scores.right_blocked}", True, text_color)
screen.blit(score1_surface, (screen_quarter - score1_surface.get_width()/2, 10))
screen.blit(score2_surface, (screen_width - screen_quarter - score2_surface.get_width()/2, 10))
def render_field():
pygame.draw.line(screen, text_color, (screen_width/2, 0), (screen_width/2, screen_height), 1)
# Function to render the state
@inject.params(engine_config=EngineConfig, field=Field)
def render_state(state, count, engine_config: EngineConfig = None, field: Field = None):
ball_data, paddle_data, collision_data, score_data = state
ball_x, ball_y, _, _ = ball_data
paddle1_x, paddle1_y, paddle1_vy, paddle2_x, paddle2_y, paddle2_vy = paddle_data
collision_1, collision_2, collision_3, collision_4 = collision_data
# Clear the screen
screen.fill(background_color)
# Draw the ball
pygame.draw.circle(screen, ball_color, scale_to_screen((ball_x, ball_y)), engine_config.ball_radius_percent*(screen_width / field.width), 0)
# Draw the paddles... might be able to use data from paddles directly?
paddle_width = engine_config.paddle_width_percent / field.width * screen_width
paddle_height = engine_config.paddle_height_percent / field.height * screen_height
pygame.draw.rect(screen, paddle_color_collision if collision_1 else paddle_color, scale_to_screen((paddle1_x, paddle1_y)) + (paddle_width, paddle_height)) # Left paddle
pygame.draw.rect(screen, paddle_color_collision if collision_2 else paddle_color, scale_to_screen((paddle2_x, paddle2_y)) + (paddle_width, paddle_height)) # Right paddle
render_scores()
render_field()
debug_surface = small_font.render(f"{count}", True, (0, 255, 0))
screen.blit(debug_surface, (0, 10))
# Update the display
pygame.display.flip()
# Main loop to render the state
@inject.params(main_arguments=MainArguments)
def main(main_arguments: MainArguments):
global screen, screen_width, screen_height
running = True
for index, state in enumerate(generators.get(main_arguments.generator_type)()):
if not running:
break
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
if event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE:
running = False
if event.type == pygame.VIDEORESIZE: # Handle window resizing
screen = pygame.display.set_mode((event.w, event.h), pygame.RESIZABLE)
screen_width = event.w
screen_height = event.h
update_scores(state)
# Render the state
render_state(state, index)
# Add a delay to control the frame rate
pygame.time.delay(30)
# Quit Pygame
pygame.quit()
def configure_main(binder: inject.Binder):
main_arguments = MainArguments.get_arguments()
Model = models[main_arguments.model_type]
binder.bind(MainArguments, main_arguments)
binder.bind(ModelConfiguration, main_arguments)
binder.bind("device", torch.device(main_arguments.device))
binder.bind(BasePongModel, Model)
# immediatly construct and bind an instance to the given key
binder.bind(Field, Field(1.0, 1.0))
binder.bind(EngineConfig, EngineConfig())
binder.bind(PaddleFactory, UserPaddleFactory())
binder.bind(Score, Score())
# defer constructions for objects with more complex dependencies
# what are needed during initialization
# will create singleton instance upon retrieval of the object bound to the key
# necessary as trying to access instances during bind configuration will crash with injector not configured error
binder.bind_to_constructor("left_paddle", lambda: inject.instance(PaddleFactory).create_left_paddle())
binder.bind_to_constructor("right_paddle", lambda: inject.instance(PaddleFactory).create_right_paddle())
binder.bind_to_constructor(Ball, Ball)
binder.bind_to_constructor(State, State)
if __name__ == "__main__":
inject.configure(configure_main)
main()