Skip to content

Commit 62f3a83

Browse files
committed
added NeuralNetwork and genetic algorithm functions
1 parent 1ccc2b5 commit 62f3a83

File tree

5 files changed

+158
-13
lines changed

5 files changed

+158
-13
lines changed

bird.js

+9-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ class Bird {
33
this.actualHeight = height - groundImg.height;
44
this.x = 50;
55
this.y = this.actualHeight / 2;
6-
this.radius = birdImg.width / 2;
6+
this.width = birdImg.width;
7+
this.height = birdImg.height;
78
this.gravity = 0.8;
89
this.upLift = -12;
910
this.velocity = 0;
@@ -16,7 +17,7 @@ class Bird {
1617

1718
if (brain instanceof NeuralNetwork) {
1819
this.brain = brain.copy();
19-
this.brain.mutate(mutate);
20+
this.brain.mutate(0.1);
2021
} else {
2122
// Parameters are number of inputs, number of units in hidden Layer, number of outputs
2223
this.brain = new NeuralNetwork(5, 8, 1);
@@ -27,6 +28,10 @@ class Bird {
2728
return new Bird(this.brain);
2829
}
2930

31+
// mutate(rate) {
32+
// this.brain.mutate(rate);
33+
// }
34+
3035
show() {
3136
image(birdImg, this.x, this.y);
3237
}
@@ -61,8 +66,7 @@ class Bird {
6166
// 5. bird's velocity
6267
inputs[4] = map(this.velocity, -12, 12, 0, 1);
6368

64-
// const action = this.brain.predict(inputs);
65-
const action = [0.3];
69+
const action = this.brain.predict(inputs);
6670
if (action[0] > 0.5) {
6771
this.jump();
6872
}
@@ -75,7 +79,7 @@ class Bird {
7579
}
7680

7781
bottomTopCollision() {
78-
return this.y + this.radius > this.actualHeight || this.y - this.radius < 0;
82+
return this.y + this.height / 2 > this.actualHeight || this.y - this.hieght / 2 < 0;
7983
}
8084

8185
update() {

geneticAlgorithm.js

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
function resetGame() {
2+
frameCounter = 0;
3+
pipes = [];
4+
}
5+
6+
function createNextGeneration() {
7+
resetGame();
8+
normalizeFitness(allBirds);
9+
aliveBirds = generate(allBirds);
10+
allBirds = aliveBirds.slice();
11+
}
12+
13+
function generate(oldBirds) {
14+
let newBirds = [];
15+
for (let i = 0; i < oldBirds.length; i++) {
16+
// Select a bird based on fitness
17+
let bird = poolSelection(oldBirds);
18+
newBirds[i] = bird;
19+
}
20+
return newBirds;
21+
}
22+
23+
function normalizeFitness(birds) {
24+
for (let i = 0; i < birds.length; i++) {
25+
birds[i].score = pow(birds[i].score, 2);
26+
}
27+
28+
let sum = 0;
29+
for (let i = 0; i < birds.length; i++) {
30+
sum += birds[i].score;
31+
}
32+
33+
for (let i = 0; i < birds.length; i++) {
34+
birds[i].fitness = birds[i].score / sum;
35+
}
36+
}
37+
38+
// An algorithm for picking one bird from an array
39+
// based on fitness
40+
function poolSelection(birds) {
41+
// Start at 0
42+
let index = 0;
43+
44+
// Pick a random number between 0 and 1
45+
let r = random(1);
46+
47+
// Keep subtracting probabilities until you get less than zero
48+
// Higher probabilities will be more likely to be fixed since they will
49+
// subtract a larger number towards zero
50+
while (r > 0) {
51+
r -= birds[index].fitness;
52+
// And move on to the next
53+
index += 1;
54+
}
55+
56+
// Go back one
57+
index -= 1;
58+
59+
// Make sure it's a copy!
60+
// (this includes mutation)
61+
return birds[index].copy();
62+
}

neuralNetwork.js

+76-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,78 @@
11
class NeuralNetwork {
2-
constructor() {
3-
2+
constructor(inputs, hiddenUnits, outputs, model = {}) {
3+
this.input_nodes = inputs;
4+
this.hidden_nodes = hiddenUnits;
5+
this.output_nodes = outputs;
6+
7+
if (model instanceof tf.Sequential) {
8+
this.model = model;
9+
10+
} else {
11+
this.model = this.createModel();
12+
}
413
}
5-
};
14+
15+
// Copy a model
16+
copy() {
17+
return tf.tidy(() => {
18+
const modelCopy = this.createModel();
19+
const weights = this.model.getWeights();
20+
const weightCopies = [];
21+
for (let i = 0; i < weights.length; i++) {
22+
weightCopies[i] = weights[i].clone();
23+
}
24+
modelCopy.setWeights(weightCopies);
25+
return new NeuralNetwork(this.input_nodes, this.hidden_nodes, this.output_nodes, modelCopy);
26+
});
27+
}
28+
29+
mutate(rate) {
30+
tf.tidy(() => {
31+
const weights = this.model.getWeights();
32+
const mutatedWeights = [];
33+
for (let i = 0; i < weights.length; i++) {
34+
let tensor = weights[i];
35+
let shape = weights[i].shape;
36+
let values = tensor.dataSync().slice();
37+
for (let j = 0; j < values.length; j++) {
38+
if (random(1) < rate) {
39+
let w = values[j];
40+
values[j] = w + randomGaussian();
41+
}
42+
}
43+
let newTensor = tf.tensor(values, shape);
44+
mutatedWeights[i] = newTensor;
45+
}
46+
this.model.setWeights(mutatedWeights);
47+
});
48+
}
49+
50+
dispose() {
51+
this.model.dispose();
52+
}
53+
54+
predict(inputs) {
55+
return tf.tidy(() => {
56+
const xs = tf.tensor2d([inputs]);
57+
const ys = this.model.predict(xs);
58+
const output = ys.dataSync();
59+
return output;
60+
});
61+
}
62+
63+
createModel() {
64+
const model = tf.sequential();
65+
const hiddenLayer = tf.layers.dense({
66+
units: this.hidden_nodes,
67+
inputShape: [this.input_nodes],
68+
activation: "relu"
69+
});
70+
model.add(hiddenLayer);
71+
const outputLayer = tf.layers.dense({
72+
units: this.output_nodes,
73+
activation: "sigmoid"
74+
});
75+
model.add(outputLayer);
76+
return model;
77+
}
78+
}

pipe.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class Pipe {
2323
}
2424

2525
checkCollision(bird) {
26-
if (bird.x + bird.radius >= this.x && bird.x - bird.radius <= this.x + this.width) {
27-
if (bird.y - bird.radius <= this.top || bird.y + bird.radius >= this.bottom) {
26+
if (bird.x + bird.width / 2 >= this.x && bird.x - bird.width / 2 <= this.x + this.width) {
27+
if (bird.y - bird.height / 2 <= this.top || bird.y + bird.height / 2 >= this.bottom) {
2828
return true;
2929
}
3030
}

sketch.js

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// The number of birds in each population
22
const totalPopulation = 300;
3-
3+
let generation = 0;
44
// Birds currently alived
55
let aliveBirds = [];
66

@@ -20,6 +20,7 @@ function preload() {
2020
}
2121

2222
function setup() {
23+
tf.setBackend("cpu");
2324
let canvas = createCanvas(bg.width, bg.height);
2425
canvas.parent("sketch");
2526
for (let i = 0; i < totalPopulation; i++) {
@@ -49,12 +50,11 @@ function draw() {
4950
}
5051
}
5152
if (bird.bottomTopCollision()) {
52-
console.log("dead");
5353
aliveBirds.splice(i, 1);
5454
}
5555
}
5656

57-
if (frameCounter % 75 === 0) {
57+
if (frameCounter % 50 === 0) {
5858
pipes.push(new Pipe());
5959
}
6060

@@ -66,5 +66,11 @@ function draw() {
6666
for (let i = 0; i < aliveBirds.length; i++) {
6767
aliveBirds[i].show();
6868
}
69+
if (aliveBirds.length == 0) {
70+
generation++;
71+
console.log("generation ", generation);
72+
createNextGeneration();
73+
}
74+
6975
image(groundImg, 0, height - groundImg.height);
7076
}

0 commit comments

Comments
 (0)