-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmemory.py
More file actions
55 lines (42 loc) · 1.74 KB
/
memory.py
File metadata and controls
55 lines (42 loc) · 1.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from collections import deque, namedtuple
import random
import numpy as np
class Memory:
def __init__(self, batch_size=32):
self.buffer = deque(maxlen=100000)
self.batch_size = batch_size
self.Sample = namedtuple('Sample', ['current_obs', 'current_action', 'next_obs', 'reward', 'done'])
def add(self, curr_obs, curr_action, next_obs, reward, done=False):
if len(self.buffer) == 10000:
self.buffer.popleft()
self.buffer.append(self.Sample(curr_obs, curr_action, next_obs, reward, done))
def sample(self):
rand_samp = random.sample(list(self.buffer), self.batch_size)
current_obs = []
current_action = []
next_obs = []
reward = []
done = []
for i in range(32):
current_obs.append(rand_samp[i].current_obs)
current_action.append(rand_samp[i].current_action)
next_obs.append(rand_samp[i].next_obs)
reward.append(rand_samp[i].reward)
done.append(rand_samp[i].done)
current_obs = np.asarray(current_obs, dtype=np.float32)
current_action = np.asarray(current_action, dtype=np.int32)
next_obs = np.asarray(next_obs, dtype=np.float32)
reward = np.asarray(reward, dtype=np.float32)
done = np.asarray(done)
return current_obs, current_action, next_obs, reward, done
if __name__ == "__main__":
test = Memory()
for i in range(100):
curr_obs = np.random.rand(6, 7)
curr_state = random.randint(-1, 2)
next_obs = np.random.rand(6, 7)
reward = random.randint(0, 4)
done = True
test.add(curr_obs, curr_state, next_obs, reward, done)
# print(test.sample())
# co, cs, no, re, do = test.sample()