Skip to content

Commit b886a99

Browse files
authored
Snake (#5)
* Installs pygame * Adds gamescomre_plotter * Adds snake game * Adds QLearning * Lintfix * Lintfix hard * Solves lint errors * Adds help target * Adds pylint configuration
1 parent a94003b commit b886a99

21 files changed

+1964
-147
lines changed

.pylintrc

+569
Large diffs are not rendered by default.

Makefile

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
# Usage:
2-
# make # setup the project for the first time
3-
# make install # install pipenv dependencies
4-
# make packages # adds src/ files to be available as python modules
5-
61
# default target
72
.DEFAULT_GOAL := init
83

94
# targets that do not create a file
105
.PHONY: init activate test lint lint-src lintfix lintfixhard install lock clean
116

7+
# help target taken from https://gist.github.com/prwhite/8168133
8+
help: ## Shows help message
9+
@awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[$$()% 0-9a-zA-Z_-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST)
10+
1211
PY_FILES := src/ cli/
1312

1413
init: activate packages install test

Pipfile

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ name = "pypi"
77
tensorflow = "*"
88
matplotlib = "*"
99
autopep8 = "*"
10+
pygame = "*"
1011

1112
[requires]
1213
python_version = "3.9"

Pipfile.lock

+233-142
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cli/.DS_Store

6 KB
Binary file not shown.

cli/play_snake.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""CLI to play the snake game manually"""
2+
import pygame
3+
4+
from game.snake import SnakeGame, player_to_snake_perspective
5+
6+
7+
def play_snake():
8+
"""Initialize and run the game loop"""
9+
pygame.init()
10+
11+
game = SnakeGame()
12+
13+
speed = 20
14+
clock = pygame.time.Clock()
15+
stop = False
16+
# game loop
17+
while True:
18+
19+
# 1. collect user input
20+
action = "forward"
21+
for event in pygame.event.get():
22+
if event.type == pygame.QUIT:
23+
game.quit()
24+
stop = True
25+
26+
elif event.type == pygame.KEYDOWN:
27+
if event.key == pygame.K_q:
28+
game.quit()
29+
stop = True
30+
31+
if event.key in [
32+
pygame.K_LEFT,
33+
pygame.K_RIGHT,
34+
pygame.K_UP,
35+
pygame.K_DOWN]: # any other key keeps forward
36+
37+
player_direction = {
38+
pygame.K_LEFT: "left",
39+
pygame.K_RIGHT: "right",
40+
pygame.K_UP: "up",
41+
pygame.K_DOWN: "down"
42+
}[event.key]
43+
44+
action = player_to_snake_perspective(game.direction,
45+
player_direction)
46+
if stop:
47+
break
48+
49+
_, score, game_over = game.play_step(action)
50+
game.pygame_draw()
51+
clock.tick(speed)
52+
53+
if game_over:
54+
break
55+
56+
print('Final Score', score)
57+
58+
pygame.quit()
59+
60+
61+
if __name__ == '__main__':
62+
play_snake()

cli/qlearning_snake.py

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
"""CLI to train/see QLearning in action solving the snake game"""
2+
import argparse
3+
import numpy as np
4+
import pygame
5+
6+
from reinforcement_learning.q_learning import (
7+
SnakeAgent, ai_direction_to_snake,
8+
QTrainer, linear_qnet, snake_state_11, snake_reward
9+
)
10+
from game.snake import SnakeGame
11+
from plotter import gamescore_plotter
12+
13+
14+
def buil_arg_parser():
15+
"""Parses the user's arguments"""
16+
parser = argparse.ArgumentParser(
17+
description="Use Deep-QLearning in the snake game",
18+
epilog="Built with <3 by Emmanuel Byrd at 8th Light Ltd.")
19+
parser.add_argument(
20+
"--best-models-dir",
21+
metavar="./model", default="./model", type=str,
22+
help="Folder to store the increasingly best models"
23+
)
24+
parser.add_argument(
25+
"--score-history",
26+
metavar="./score_history", default="./score_history", type=str,
27+
help="Where to store the score history"
28+
)
29+
parser.add_argument(
30+
"--checkpoint-path", metavar="./model/snake_5.pth", type=str,
31+
help="Path of pre-trained model to start from"
32+
)
33+
parser.add_argument(
34+
"--fps", metavar="100", type=int, default=100,
35+
help="Frames per second"
36+
)
37+
parser.add_argument(
38+
"--learning-rate", metavar="1e-3", type=float, default=1e-3,
39+
help="QTrainer learning rate"
40+
)
41+
parser.add_argument(
42+
"--gamma", metavar="0.9", type=float, default=0.9,
43+
help="QTrainer gamma value"
44+
)
45+
parser.add_argument(
46+
"--hidden-layer-size", metavar="256", type=int, default=256,
47+
help="Size of the hidden layer"
48+
)
49+
parser.add_argument(
50+
"--max-width", metavar="400", type=int, default=400,
51+
help="Maximum board width"
52+
)
53+
parser.add_argument(
54+
"--max-height", metavar="300", type=int, default=300,
55+
help="Maximum board height"
56+
)
57+
return parser
58+
59+
60+
def train(args):
61+
"""Execute AI training/game loop"""
62+
pygame.init()
63+
64+
score_tracker = ScoreTracker()
65+
66+
high_score = 0
67+
68+
agent = SnakeAgent(
69+
QTrainer(generate_model(args),
70+
learning_rate=args.learning_rate,
71+
gamma=args.gamma)
72+
)
73+
game = SnakeGame(width=200, height=160)
74+
75+
clock = pygame.time.Clock()
76+
77+
game_frames = 0
78+
79+
while True:
80+
# get old state
81+
state = snake_state_11(game)
82+
83+
# get move
84+
action = agent.get_action(state)
85+
# [0, 0, 0] -> left, right, forward
86+
87+
# perform move and get new state
88+
eaten, score, done = game.play_step(ai_direction_to_snake(action))
89+
90+
# show AI training in real-time
91+
for event in pygame.event.get():
92+
if event.type == pygame.QUIT or (
93+
event.type == pygame.KEYDOWN and event.key == pygame.K_q):
94+
pygame.quit()
95+
return
96+
97+
# drawing requires to consume events e.g. pygame.event.get()
98+
game.pygame_draw() # draw the game
99+
clock.tick(args.fps)
100+
101+
reward = snake_reward(eaten, done)
102+
103+
game_frames += 1
104+
if game_frames > 30 * len(game.snake):
105+
eaten = False
106+
done = True
107+
reward = -10
108+
print("Stopping due to infinite loop strategy")
109+
110+
state_next = snake_state_11(game)
111+
# train short memory
112+
agent.train_short_memory(state, action, reward, state_next, done)
113+
114+
# remember
115+
agent.remember(state, action, reward, state_next, done)
116+
117+
if done:
118+
if score > high_score:
119+
high_score = score
120+
agent.save_model(args.best_models_dir,
121+
f'snake_{high_score}.pth')
122+
123+
game = scaling_board(high_score,
124+
args.max_width, args.max_height)
125+
126+
game_frames = 0
127+
agent.n_games += 1
128+
129+
# train long memory (replay memory, or experience replay)
130+
agent.train_long_memory()
131+
132+
print('Game', agent.n_games, 'Score', score, 'Record:', high_score)
133+
134+
# show the results
135+
score_tracker.add_new_score(score)
136+
score_tracker.show_hist()
137+
np.save(args.score_history, np.array(score_tracker.get_hist()))
138+
139+
140+
def generate_model(args):
141+
"""Generate a linear neural network of input 11 and output 3"""
142+
model = linear_qnet(11, args.hidden_layer_size, 3)
143+
if args.checkpoint_path:
144+
model.load_weights(args.checkpoint_path)
145+
146+
return model
147+
148+
149+
def scaling_board(high_score, max_width, max_height):
150+
"""Choose the appropriate size for the next game depending on the score"""
151+
if high_score > 5:
152+
return SnakeGame()
153+
154+
if high_score > 3:
155+
return SnakeGame(width=max_width, height=max_height)
156+
157+
if high_score > 1:
158+
return SnakeGame(width=320, height=240)
159+
160+
return SnakeGame(width=200, height=160)
161+
162+
163+
class ScoreTracker:
164+
"""State class that keeps updated information on the score"""
165+
166+
def __init__(self):
167+
"""Initialize analysis variables"""
168+
self.plot_scores = []
169+
self.plot_mean_scores = []
170+
self.total_score = 0
171+
172+
def add_new_score(self, score):
173+
"""Adds the given score and calculates the average so far"""
174+
self.plot_scores.append(score)
175+
self.total_score += score
176+
self.plot_mean_scores.append(self.total_score / len(self.plot_scores))
177+
178+
def show_hist(self):
179+
"""Plot all the stored information"""
180+
gamescore_plotter(self.plot_scores, self.plot_mean_scores)
181+
182+
def get_hist(self):
183+
"""Returns a list with the scores and mean scores"""
184+
return [self.plot_scores, self.plot_mean_scores]
185+
186+
187+
def main():
188+
"""Main function"""
189+
arg_parser = buil_arg_parser()
190+
args = arg_parser.parse_args()
191+
192+
train(args)
193+
194+
print("Finished.")
195+
196+
197+
if __name__ == "__main__":
198+
main()

src/game/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
"""All games should be playable by both humans and AI"""
2+
from .snake import *

src/game/snake/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
"""Simple game of snake"""
2+
from .snake import *

0 commit comments

Comments
 (0)