Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ This repository contains unofficial code reproducing Agent57, which outperformed
# Directory File
1. **agent.py**

define agent to play a supecific environment.
define agent to play a specific environment.

2. **buffer.py**

define buffer to store experiences with priorites.
define buffer to store experiences with priorities.

3. **learner.py**

Expand Down
10 changes: 5 additions & 5 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Agent:
betas (list): list of beta which decide weights between intrinsic qvalues and extrinsic qvalues
gammas (list): list of gamma which is discount rate
epsilon (float): coefficient for epsilon greedy
eta (float): coefficient for priority caluclation
eta (float): coefficient for priority calculation
lamda (float): coefficient for retrace operation
burnin_length (int): length of burnin to calculate qvalues
unroll_length (int): length of unroll to calculate qvalues
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(self,
env_name (str): name of environment
n_frames (int): number of images to be stacked
epsilon (float): coefficient for epsilon soft-max
eta (float): coefficient for priority caluclation
eta (float): coefficient for priority calculation
lamda (float): coefficient for retrace operation
burnin_length (int): length of burnin to calculate qvalues
unroll_length (int): length of unroll to calculate qvalues
Expand Down Expand Up @@ -131,7 +131,7 @@ def sync_weights_and_rollout(self, in_q_weight, ex_q_weight, embed_weight, lifel
lifelong_weight : weight of lifelong network
Returns:
priority (list): priority of segments when pulling segments from sum tree
segments : parts of expecimences
segments : parts of experiences
self.pid : process id
"""

Expand Down Expand Up @@ -190,7 +190,7 @@ def _rollout(self):

segments = episode_buffer.pull_segments()

self.states, self.actions, self.in_rewards, self.ex_rewards, self.dones, self.j, self.next_states, \
self.states, self.actions, self.ex_rewards, self.in_rewards, self.dones, self.j, self.next_states, \
in_h0, in_c0, ex_h0, ex_c0, self.prev_in_rewards, self.prev_ex_rewards, self.prev_actions = segments2contents(segments, self.burnin_len)

# (unroll_len+1, batch_size, action_space)
Expand All @@ -215,7 +215,7 @@ def _rollout(self):

def get_qvalues(self, q_network, h, c):
"""
get qvalues from expeiences using q network
get qvalues from experiences using q network
Args:
q_network : network to get Q values
h (torch.tensor): LSTM hidden state
Expand Down
4 changes: 2 additions & 2 deletions buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class SegmentReplayBuffer:
buffer_size [int]: size of buffer
priorities : SumTree object to determine priority
segment_buffer [list]: buffer of segment which size is buffer_size
weight_expo [float]: exponetial value to smooth weights
weight_expo [float]: exponential value to smooth weights
eta [float] : coefficient for reduce priority
count [int] : index of priorities list
full [bool] : flag whether segment buffer is full or not
Expand All @@ -100,7 +100,7 @@ def __init__(self, buffer_size, weight_expo, eta=0.9):
"""
Args
buffer_size [int]: size of buffer
weight_expo [float]: exponetial value to smooth weights
weight_expo [float]: exponential value to smooth weights
eta [float, Optical]: coefficient for reduce priority
"""

Expand Down
12 changes: 6 additions & 6 deletions learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Learner:
criterion : loss function of embedding classifier
betas (list): list of beta which decide weights between intrinsic qvalues and extrinsic qvalues
gammas (list): list of gamma which is discount rate
eta (float): coefficient for priority caluclation
eta (float): coefficient for priority calculation
lamda (float): coefficient for retrace operation
burnin_length (int): length of burnin to calculate qvalues
unroll_length (int): length of unroll to calculate qvalues
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(self,
Args:
env_name (str): name of environment
n_frames (int): number of images to be stacked
eta (float): coefficient for priority caluclation
eta (float): coefficient for priority calculation
lamda (float): coefficient for retrace operation
num_arms (int): number of multi arms
burnin_length (int): length of burnin to calculate qvalues
Expand Down Expand Up @@ -141,13 +141,13 @@ def set_device(self):

def define_network(self):
"""
define network and get initial parameter to copy to angents
define network and get initial parameter to copy to agents
"""

frame = self.frame_process_func(self.env.reset())
frames = [frame] * self.n_frames

# (1, n_frams, 32, 32)
# (1, n_frames, 32, 32)
state = torch.tensor(np.stack(frames, axis=0)[None, ...]).float()
h = torch.zeros(1, 1, self.in_online_q_network.lstm.hidden_size).float()
c = torch.zeros(1, 1, self.ex_online_q_network.lstm.hidden_size).float()
Expand Down Expand Up @@ -287,7 +287,7 @@ def qnet_update(self, weights, segments):
segments: a coherent body of experience of some length
"""

self.states, self.actions, self.in_rewards, self.ex_rewards, self.dones, self.j, self.next_states, in_h0, in_c0, ex_h0, ex_c0, \
self.states, self.actions, self.ex_rewards, self.in_rewards, self.dones, self.j, self.next_states, in_h0, in_c0, ex_h0, ex_c0, \
self.prev_in_rewards, self.prev_ex_rewards, self.prev_actions = segments2contents(segments, burnin_len=self.burnin_len, is_grad=True, device=self.device)

self.in_online_q_network.train()
Expand Down Expand Up @@ -357,7 +357,7 @@ def qnet_update(self, weights, segments):

def get_qvalues(self, q_network, h, c):
"""
get qvalues from expeiences using specific q network
get qvalues from experiences using specific q network
Args:
q_network : network to get Q values
h (torch.tensor): LSTM hidden state
Expand Down
42 changes: 23 additions & 19 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,24 @@
from tester import Tester
from buffer import SegmentReplayBuffer
from learner import Learner
from utils import seed_evrything
from utils import seed_everything


weight_dir = "result/weights"
os.makedirs(weight_dir, exist_ok=True)


def main(args):
if os.path.exists("log"):
shutil.rmtree("log")
os.makedirs("log")

seed_evrything(args.seed)
seed_everything(args.seed)
ray.init(ignore_reinit_error=True, local_mode=False)

total_s = time.time()
in_q_loss_history, ex_q_loss_history, embed_loss_history, lifelong_loss_history, score_history = [], [], [], [], []

learner = Learner.remote(env_name=args.env_name,
target_update_period=args.target_update_period,
eta=args.eta,
Expand All @@ -43,9 +44,10 @@ def main(args):
ex_q_clip_grad=args.ex_q_clip_grad,
embed_clip_grad=args.embed_clip_grad,
lifelong_clip_grad=args.lifelong_clip_grad)

in_q_weight, ex_q_weight, embed_weight, trained_lifelong_weight, original_lifelong_weight = ray.get(learner.define_network.remote())


in_q_weight, ex_q_weight, embed_weight, trained_lifelong_weight, original_lifelong_weight = ray.get(
learner.define_network.remote())

# put weights for agents to refer them
in_q_weight = ray.put(in_q_weight)
ex_q_weight = ray.put(ex_q_weight)
Expand All @@ -72,7 +74,7 @@ def main(args):
original_lifelong_weight=original_lifelong_weight)
for i in range(args.num_agents)]

replay_buffer = SegmentReplayBuffer(buffer_size=args.buffer_size, weight_expo=args.weight_expo)
replay_buffer = SegmentReplayBuffer(buffer_size=args.buffer_size, weight_expo=args.weight_expo)

tester = Tester.remote(env_name=args.env_name,
n_frames=args.n_frames,
Expand All @@ -93,11 +95,11 @@ def main(args):

for i in range(args.n_agent_burnin):
s = time.time()
# finised agent, working agents

# finished agent, working agents
finished, wip_agents = ray.wait(wip_agents, num_returns=1)
priorities, segments, pid = ray.get(finished[0])

replay_buffer.add(priorities, segments)
wip_agents.extend([agents[pid].sync_weights_and_rollout.remote(in_q_weight=in_q_weight,
ex_q_weight=ex_q_weight,
Expand All @@ -107,7 +109,7 @@ def main(args):
f.write(f"{i}th Agent's time[sec]: {time.time() - s:.5f}\n")

print("="*100)

minibatchs = [replay_buffer.sample_minibatch(batch_size=args.batch_size) for _ in range(args.update_iter)]
wip_learner = learner.update_network.remote(minibatchs)
wip_tester = tester.test_play.remote(in_q_weight=in_q_weight,
Expand All @@ -123,7 +125,7 @@ def main(args):
while learner_cycles <= args.n_learner_cycle:
agent_cycles += 1
s = time.time()

# get agent's experience
finished, wip_agents = ray.wait(wip_agents, num_returns=1)
priorities, segments, pid = ray.get(finished[0])
Expand All @@ -132,17 +134,18 @@ def main(args):
ex_q_weight=ex_q_weight,
embed_weight=embed_weight,
lifelong_weight=trained_lifelong_weight)])

n_segment_added += len(segments)

finished_learner, _ = ray.wait([wip_learner], timeout=0)

if finished_learner:
in_q_weight, ex_q_weight, embed_weight, trained_lifelong_weight, indices, priorities, in_q_loss, ex_q_loss, embed_loss, lifelong_loss = ray.get(finished_learner[0])

in_q_weight, ex_q_weight, embed_weight, trained_lifelong_weight, indices, priorities, in_q_loss, ex_q_loss, embed_loss, lifelong_loss = ray.get(
finished_learner[0])

replay_buffer.update_priority(indices, priorities)
minibatchs = [replay_buffer.sample_minibatch(batch_size=args.batch_size) for _ in range(args.update_iter)]

wip_learner = learner.update_network.remote(minibatchs)

in_q_weight = ray.put(in_q_weight)
Expand All @@ -151,7 +154,8 @@ def main(args):
trained_lifelong_weight = ray.put(trained_lifelong_weight)

with open(f"log/loss_history.txt", mode="a") as f:
f.write(f"{learner_cycles}th results => Agent cycle: {agent_cycles}, Added: {n_segment_added}, InQLoss: {in_q_loss:.4f}, ExQLoss: {ex_q_loss:.4f}, EmbeddingLoss: {embed_loss:.4f}, LifeLongLoss: {lifelong_loss:.8f} \n")
f.write(f"{learner_cycles}th results => Agent cycle: {agent_cycles}, Added: {n_segment_added}, InQLoss: {
in_q_loss:.4f}, ExQLoss: {ex_q_loss:.4f}, EmbeddingLoss: {embed_loss:.4f}, LifeLongLoss: {lifelong_loss:.8f} \n")

in_q_loss_history.append((learner_cycles-1, in_q_loss))
ex_q_loss_history.append((learner_cycles-1, ex_q_loss))
Expand All @@ -163,7 +167,7 @@ def main(args):
score_history.append((learner_cycles-args.switch_test_cycle, test_score))
with open(f"log/score_history.txt", mode="a") as f:
f.write(f"Cycle: {learner_cycles}, Score: {test_score}\n")

wip_tester = tester.test_play.remote(in_q_weight=in_q_weight,
ex_q_weight=ex_q_weight,
embed_weight=embed_weight,
Expand Down Expand Up @@ -248,7 +252,7 @@ def main(args):
parser.add_argument('--L', default=5, type=int)
parser.add_argument('--num_arms', default=32, type=int)
# Base
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--gamma', default=0.997, type=float)
parser.add_argument('--eta', default=0.9, type=float)
parser.add_argument('--lamda', default=0.95, type=float)
Expand Down
8 changes: 4 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn.functional as F


# Convolutional Encording
# Convolutional Encoding
class ConvEncoder(nn.Module):
"""
Encoder with convolution
Expand Down Expand Up @@ -32,7 +32,7 @@ def forward(self, x):
Args:
x (torch.tensor): input [b, n_frames, 84, 84]
Returns
x (torch.tensor): ouput [b, units]
x (torch.tensor): output [b, units]
"""

x = F.relu(self.conv1(x)) # (b, 32, 20, 20)
Expand Down Expand Up @@ -129,7 +129,7 @@ def forward(self, inputs):
Args:
input (torch.tensor): state [b, n_frames, 84, 84]
Returns:
embeded state [b, emebed_units]
embeded state [b, embed_units]
"""

return F.relu(self.conv_encoder(inputs))
Expand All @@ -154,7 +154,7 @@ def __init__(self, action_space, hidden=128):
def forward(self, input1, input2):
"""
Args:
embeded state (torch.tensor): state [b, emebed_units]
embeded state (torch.tensor): state [b, embed_units]
Returns:
action probability [b, action_space]
"""
Expand Down
2 changes: 1 addition & 1 deletion tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@ray.remote(num_cpus=1)
class Tester:
"""
calculate score to evaluate peformance
calculate score to evaluate performance
Attributes:
env_name (str): name of environment
n_frames (int): number of images to be stacked
Expand Down
Loading