diff --git a/README.md b/README.md index f54d6d6..ebf592d 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,11 @@ This repository contains unofficial code reproducing Agent57, which outperformed # Directory File 1. **agent.py** - define agent to play a supecific environment. + define agent to play a specific environment. 2. **buffer.py** - define buffer to store experiences with priorites. + define buffer to store experiences with priorities. 3. **learner.py** diff --git a/agent.py b/agent.py index aec6eb3..06ebbb6 100644 --- a/agent.py +++ b/agent.py @@ -35,7 +35,7 @@ class Agent: betas (list): list of beta which decide weights between intrinsic qvalues and extrinsic qvalues gammas (list): list of gamma which is discount rate epsilon (float): coefficient for epsilon greedy - eta (float): coefficient for priority caluclation + eta (float): coefficient for priority calculation lamda (float): coefficient for retrace operation burnin_length (int): length of burnin to calculate qvalues unroll_length (int): length of unroll to calculate qvalues @@ -71,7 +71,7 @@ def __init__(self, env_name (str): name of environment n_frames (int): number of images to be stacked epsilon (float): coefficient for epsilon soft-max - eta (float): coefficient for priority caluclation + eta (float): coefficient for priority calculation lamda (float): coefficient for retrace operation burnin_length (int): length of burnin to calculate qvalues unroll_length (int): length of unroll to calculate qvalues @@ -131,7 +131,7 @@ def sync_weights_and_rollout(self, in_q_weight, ex_q_weight, embed_weight, lifel lifelong_weight : weight of lifelong network Returns: priority (list): priority of segments when pulling segments from sum tree - segments : parts of expecimences + segments : parts of experiences self.pid : process id """ @@ -190,7 +190,7 @@ def _rollout(self): segments = episode_buffer.pull_segments() - self.states, self.actions, self.in_rewards, self.ex_rewards, self.dones, self.j, self.next_states, \ + self.states, self.actions, self.ex_rewards, self.in_rewards, self.dones, self.j, self.next_states, \ in_h0, in_c0, ex_h0, ex_c0, self.prev_in_rewards, self.prev_ex_rewards, self.prev_actions = segments2contents(segments, self.burnin_len) # (unroll_len+1, batch_size, action_space) @@ -215,7 +215,7 @@ def _rollout(self): def get_qvalues(self, q_network, h, c): """ - get qvalues from expeiences using q network + get qvalues from experiences using q network Args: q_network : network to get Q values h (torch.tensor): LSTM hidden state diff --git a/buffer.py b/buffer.py index 4d7596a..213b897 100644 --- a/buffer.py +++ b/buffer.py @@ -90,7 +90,7 @@ class SegmentReplayBuffer: buffer_size [int]: size of buffer priorities : SumTree object to determine priority segment_buffer [list]: buffer of segment which size is buffer_size - weight_expo [float]: exponetial value to smooth weights + weight_expo [float]: exponential value to smooth weights eta [float] : coefficient for reduce priority count [int] : index of priorities list full [bool] : flag whether segment buffer is full or not @@ -100,7 +100,7 @@ def __init__(self, buffer_size, weight_expo, eta=0.9): """ Args buffer_size [int]: size of buffer - weight_expo [float]: exponetial value to smooth weights + weight_expo [float]: exponential value to smooth weights eta [float, Optical]: coefficient for reduce priority """ diff --git a/learner.py b/learner.py index 6f01c4b..44b79fd 100644 --- a/learner.py +++ b/learner.py @@ -43,7 +43,7 @@ class Learner: criterion : loss function of embedding classifier betas (list): list of beta which decide weights between intrinsic qvalues and extrinsic qvalues gammas (list): list of gamma which is discount rate - eta (float): coefficient for priority caluclation + eta (float): coefficient for priority calculation lamda (float): coefficient for retrace operation burnin_length (int): length of burnin to calculate qvalues unroll_length (int): length of unroll to calculate qvalues @@ -72,7 +72,7 @@ def __init__(self, Args: env_name (str): name of environment n_frames (int): number of images to be stacked - eta (float): coefficient for priority caluclation + eta (float): coefficient for priority calculation lamda (float): coefficient for retrace operation num_arms (int): number of multi arms burnin_length (int): length of burnin to calculate qvalues @@ -141,13 +141,13 @@ def set_device(self): def define_network(self): """ - define network and get initial parameter to copy to angents + define network and get initial parameter to copy to agents """ frame = self.frame_process_func(self.env.reset()) frames = [frame] * self.n_frames - # (1, n_frams, 32, 32) + # (1, n_frames, 32, 32) state = torch.tensor(np.stack(frames, axis=0)[None, ...]).float() h = torch.zeros(1, 1, self.in_online_q_network.lstm.hidden_size).float() c = torch.zeros(1, 1, self.ex_online_q_network.lstm.hidden_size).float() @@ -287,7 +287,7 @@ def qnet_update(self, weights, segments): segments: a coherent body of experience of some length """ - self.states, self.actions, self.in_rewards, self.ex_rewards, self.dones, self.j, self.next_states, in_h0, in_c0, ex_h0, ex_c0, \ + self.states, self.actions, self.ex_rewards, self.in_rewards, self.dones, self.j, self.next_states, in_h0, in_c0, ex_h0, ex_c0, \ self.prev_in_rewards, self.prev_ex_rewards, self.prev_actions = segments2contents(segments, burnin_len=self.burnin_len, is_grad=True, device=self.device) self.in_online_q_network.train() @@ -357,7 +357,7 @@ def qnet_update(self, weights, segments): def get_qvalues(self, q_network, h, c): """ - get qvalues from expeiences using specific q network + get qvalues from experiences using specific q network Args: q_network : network to get Q values h (torch.tensor): LSTM hidden state diff --git a/main.py b/main.py index d5b414e..daca078 100644 --- a/main.py +++ b/main.py @@ -10,23 +10,24 @@ from tester import Tester from buffer import SegmentReplayBuffer from learner import Learner -from utils import seed_evrything +from utils import seed_everything weight_dir = "result/weights" os.makedirs(weight_dir, exist_ok=True) + def main(args): if os.path.exists("log"): shutil.rmtree("log") os.makedirs("log") - seed_evrything(args.seed) + seed_everything(args.seed) ray.init(ignore_reinit_error=True, local_mode=False) total_s = time.time() in_q_loss_history, ex_q_loss_history, embed_loss_history, lifelong_loss_history, score_history = [], [], [], [], [] - + learner = Learner.remote(env_name=args.env_name, target_update_period=args.target_update_period, eta=args.eta, @@ -43,9 +44,10 @@ def main(args): ex_q_clip_grad=args.ex_q_clip_grad, embed_clip_grad=args.embed_clip_grad, lifelong_clip_grad=args.lifelong_clip_grad) - - in_q_weight, ex_q_weight, embed_weight, trained_lifelong_weight, original_lifelong_weight = ray.get(learner.define_network.remote()) - + + in_q_weight, ex_q_weight, embed_weight, trained_lifelong_weight, original_lifelong_weight = ray.get( + learner.define_network.remote()) + # put weights for agents to refer them in_q_weight = ray.put(in_q_weight) ex_q_weight = ray.put(ex_q_weight) @@ -72,7 +74,7 @@ def main(args): original_lifelong_weight=original_lifelong_weight) for i in range(args.num_agents)] - replay_buffer = SegmentReplayBuffer(buffer_size=args.buffer_size, weight_expo=args.weight_expo) + replay_buffer = SegmentReplayBuffer(buffer_size=args.buffer_size, weight_expo=args.weight_expo) tester = Tester.remote(env_name=args.env_name, n_frames=args.n_frames, @@ -93,11 +95,11 @@ def main(args): for i in range(args.n_agent_burnin): s = time.time() - - # finised agent, working agents + + # finished agent, working agents finished, wip_agents = ray.wait(wip_agents, num_returns=1) priorities, segments, pid = ray.get(finished[0]) - + replay_buffer.add(priorities, segments) wip_agents.extend([agents[pid].sync_weights_and_rollout.remote(in_q_weight=in_q_weight, ex_q_weight=ex_q_weight, @@ -107,7 +109,7 @@ def main(args): f.write(f"{i}th Agent's time[sec]: {time.time() - s:.5f}\n") print("="*100) - + minibatchs = [replay_buffer.sample_minibatch(batch_size=args.batch_size) for _ in range(args.update_iter)] wip_learner = learner.update_network.remote(minibatchs) wip_tester = tester.test_play.remote(in_q_weight=in_q_weight, @@ -123,7 +125,7 @@ def main(args): while learner_cycles <= args.n_learner_cycle: agent_cycles += 1 s = time.time() - + # get agent's experience finished, wip_agents = ray.wait(wip_agents, num_returns=1) priorities, segments, pid = ray.get(finished[0]) @@ -132,17 +134,18 @@ def main(args): ex_q_weight=ex_q_weight, embed_weight=embed_weight, lifelong_weight=trained_lifelong_weight)]) - + n_segment_added += len(segments) finished_learner, _ = ray.wait([wip_learner], timeout=0) if finished_learner: - in_q_weight, ex_q_weight, embed_weight, trained_lifelong_weight, indices, priorities, in_q_loss, ex_q_loss, embed_loss, lifelong_loss = ray.get(finished_learner[0]) - + in_q_weight, ex_q_weight, embed_weight, trained_lifelong_weight, indices, priorities, in_q_loss, ex_q_loss, embed_loss, lifelong_loss = ray.get( + finished_learner[0]) + replay_buffer.update_priority(indices, priorities) minibatchs = [replay_buffer.sample_minibatch(batch_size=args.batch_size) for _ in range(args.update_iter)] - + wip_learner = learner.update_network.remote(minibatchs) in_q_weight = ray.put(in_q_weight) @@ -151,7 +154,8 @@ def main(args): trained_lifelong_weight = ray.put(trained_lifelong_weight) with open(f"log/loss_history.txt", mode="a") as f: - f.write(f"{learner_cycles}th results => Agent cycle: {agent_cycles}, Added: {n_segment_added}, InQLoss: {in_q_loss:.4f}, ExQLoss: {ex_q_loss:.4f}, EmbeddingLoss: {embed_loss:.4f}, LifeLongLoss: {lifelong_loss:.8f} \n") + f.write(f"{learner_cycles}th results => Agent cycle: {agent_cycles}, Added: {n_segment_added}, InQLoss: { + in_q_loss:.4f}, ExQLoss: {ex_q_loss:.4f}, EmbeddingLoss: {embed_loss:.4f}, LifeLongLoss: {lifelong_loss:.8f} \n") in_q_loss_history.append((learner_cycles-1, in_q_loss)) ex_q_loss_history.append((learner_cycles-1, ex_q_loss)) @@ -163,7 +167,7 @@ def main(args): score_history.append((learner_cycles-args.switch_test_cycle, test_score)) with open(f"log/score_history.txt", mode="a") as f: f.write(f"Cycle: {learner_cycles}, Score: {test_score}\n") - + wip_tester = tester.test_play.remote(in_q_weight=in_q_weight, ex_q_weight=ex_q_weight, embed_weight=embed_weight, @@ -248,7 +252,7 @@ def main(args): parser.add_argument('--L', default=5, type=int) parser.add_argument('--num_arms', default=32, type=int) # Base - parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--seed', default=0, type=int) parser.add_argument('--gamma', default=0.997, type=float) parser.add_argument('--eta', default=0.9, type=float) parser.add_argument('--lamda', default=0.95, type=float) diff --git a/model.py b/model.py index 98a923e..ae302e9 100644 --- a/model.py +++ b/model.py @@ -3,7 +3,7 @@ import torch.nn.functional as F -# Convolutional Encording +# Convolutional Encoding class ConvEncoder(nn.Module): """ Encoder with convolution @@ -32,7 +32,7 @@ def forward(self, x): Args: x (torch.tensor): input [b, n_frames, 84, 84] Returns - x (torch.tensor): ouput [b, units] + x (torch.tensor): output [b, units] """ x = F.relu(self.conv1(x)) # (b, 32, 20, 20) @@ -129,7 +129,7 @@ def forward(self, inputs): Args: input (torch.tensor): state [b, n_frames, 84, 84] Returns: - embeded state [b, emebed_units] + embeded state [b, embed_units] """ return F.relu(self.conv_encoder(inputs)) @@ -154,7 +154,7 @@ def __init__(self, action_space, hidden=128): def forward(self, input1, input2): """ Args: - embeded state (torch.tensor): state [b, emebed_units] + embeded state (torch.tensor): state [b, embed_units] Returns: action probability [b, action_space] """ diff --git a/tester.py b/tester.py index 00c594e..ca7edbc 100644 --- a/tester.py +++ b/tester.py @@ -11,7 +11,7 @@ @ray.remote(num_cpus=1) class Tester: """ - calculate score to evaluate peformance + calculate score to evaluate performance Attributes: env_name (str): name of environment n_frames (int): number of images to be stacked diff --git a/utils.py b/utils.py index 5d4500f..91a32cb 100644 --- a/utils.py +++ b/utils.py @@ -16,7 +16,7 @@ def rescaling(x): Returns: rescaled value """ - + eps = 0.001 return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1.) - 1.) + eps * x @@ -30,7 +30,7 @@ def inverse_rescaling(x): Returns: inverse rescaled value """ - + eps = 0.001 return torch.sign(x) * (torch.square(((torch.sqrt(1. + 4. * eps * (torch.abs(x) + 1. + eps))) - 1.) / (2. * eps)) - 1.) @@ -44,13 +44,13 @@ def get_preprocess_func(env_name): Returns: preprocess function corresponding to env_name """ - + if "Breakout" in env_name: return _preprocess_breakout elif "Pacman" in env_name: return _preprocess_mspackman else: - raise NotImplementedError(f"Frame processor not implemeted for {env_name}") + raise NotImplementedError(f"Frame processor not implemented for {env_name}") def _preprocess_breakout(frame, resize=84): @@ -61,7 +61,7 @@ def _preprocess_breakout(frame, resize=84): Returns: preprocessed image """ - + image = Image.fromarray(frame) image = image.convert("L").crop((0, 34, 160, 200)).resize((resize, resize)) image_scaled = np.array(image) / 255.0 @@ -76,7 +76,7 @@ def _preprocess_mspackman(frame, resize=84): Returns: preprocessed image """ - + image = Image.fromarray(frame) image = image.convert("L").crop((0, 0, 160, 170)).resize((resize, resize)) image_scaled = np.array(image) / 255.0 @@ -96,10 +96,10 @@ def get_initial_lives(env_name): elif "Pacman" in env_name: return 3 else: - raise NotImplementedError(f"Frame processor not implemeted for {env_name}") + raise NotImplementedError(f"Frame processor not implemented for {env_name}") -def seed_evrything(seed): +def seed_everything(seed): """ set seed Args: @@ -109,8 +109,8 @@ def seed_evrything(seed): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) - - + + def create_beta_list(num_arms, beta=0.3): """ create beta list for each arm @@ -120,7 +120,7 @@ def create_beta_list(num_arms, beta=0.3): betas (list): list of beta which decide weights between intrinsic qvalues and extrinsic qvalues NOTE: Values differ from those in the paper of Agent57. """ - + betas = [torch.tensor(0)] for i in range(1, num_arms-1): betas.append(beta * torch.sigmoid(torch.tensor(10 * (2*i / (num_arms-2) - 1)))) @@ -137,16 +137,16 @@ def create_gamma_list(num_arms, gamma0=0.9999, gamma1=0.997, gamma2=0.99): gammas (list): list of gamma which is discount rate NOTE: Values differ from those in the paper of Agent57. """ - + gammas = [torch.tensor(gamma0)] for i in range(1, 7): gammas.append(gamma0 + (gamma1 - gamma0) * torch.sigmoid(torch.tensor(10 * (i - 3) / 3))) gammas.append(torch.tensor(gamma1)) - + for i in range(8, num_arms): t = (num_arms-i-1) * torch.log(torch.tensor(1-gamma1)) + (i-8) * torch.log(torch.tensor(1-gamma2)) gammas.append(1 - torch.exp(t / (num_arms-9))) - + return gammas @@ -168,7 +168,7 @@ def __init__(self, num_arms, window_size, epsilon, beta): epsilon (float): probability to select the index of the arms used in multi-armed bandit problem beta (float): weight between frequency and mean reward """ - + self.data = collections.deque(maxlen=window_size) self.num_arms = num_arms self.epsilon = epsilon @@ -181,22 +181,22 @@ def pull_index(self): Returns: index (float): index of arms """ - + if self.count < self.num_arms: index = self.count self.count += 1 - + else: if random.random() > self.epsilon: N = np.zeros(self.num_arms) mu = np.zeros(self.num_arms) - + for j, reward in self.data: N[j] += 1 mu[j] += reward mu = mu / (N + 1e-10) index = np.argmax(mu + self.beta * np.sqrt(1 / (N + 1e-6))) - + else: index = np.random.choice(self.num_arms) return index @@ -207,7 +207,7 @@ def push_data(self, datas): Args: datas :store index of arms and resulting reward """ - + self.data += [(j, reward) for j, reward in datas] @@ -221,18 +221,18 @@ def get_episodic_reward(x, M, k, c=0.001, epsilon=0.0001, cluster_distance=0.008 Returns: episodic reward (np.ndarray): reward based on how different from neighbors """ - - dist_list = [np.linalg.norm((m -x), ord=2) for m in M] + + dist_list = [np.linalg.norm((m - x), ord=2) for m in M] topk_dist_list = np.sort(dist_list)[:k] - dm = np.mean(topk_dist_list) - + dm = np.mean(topk_dist_list) + if dm == 0: return 1e-10 else: topk_dist_list = topk_dist_list / dm - topk_dist_list = np.where(topk_dist_list-cluster_distance<0, 0, topk_dist_list-cluster_distance) - + topk_dist_list = np.where(topk_dist_list-cluster_distance < 0, 0, topk_dist_list-cluster_distance) + K = epsilon / (epsilon + topk_dist_list) s = np.sqrt(np.sum(K)) + c @@ -251,16 +251,16 @@ def transformed_retrace_operator(delta, pi, actions, gamma, unroll_len, lamda, d C_{i, b} = \lambda * min(1, \frac{\pi(a_{i}|x_{i}^{b})}{\mu_{i}}) \delta_{j, b} = r_{j}^{b} + \gamma * \Sigma_{a \in A} {\pi(a|x_{j+1}^{b})}*h^{-1}(Q(x_{j+1}^{b}, a))-h^{-1}(Q(x_{j}^{b}, a_{j}^{b})) """ - + # (unroll_len, batch_size) - P_list = delta - + P_list = delta + # (unroll_len, batch_size) - C = torch.where(pi == actions, torch.tensor(lamda).to(device), torch.tensor(0.).to(device)) - + C = torch.where(pi == actions, torch.tensor(lamda).to(device), torch.tensor(0.).to(device)) + for t in range(unroll_len-2, -1, -1): P_list[t, :] += gamma * C[t+1, :] * P_list[t+1, :] - + return P_list @@ -301,10 +301,10 @@ def play_episode(frame_process_func, beta (float): coefficient to decide weights between intrinsic qvalues and extrinsic qvalues is_test (bool): flag indicating whether it is a test or not """ - + env = gym.make(env_name) frame = frame_process_func(env.reset()) - + # (n_frames, 84, 84) frames = collections.deque([frame] * n_frames, maxlen=n_frames) @@ -317,16 +317,16 @@ def play_episode(frame_process_func, done = False lives = get_initial_lives(env_name) - + M = collections.deque(maxlen=int(1e3)) ucb_datas = [] transitions = [] while not done: - + # batching (1, n_frames, 84, 84) state = torch.tensor(np.stack(frames, axis=0)[None, ...]).float() - + # intrinsic Qvalues (1, action_space) in_qvalue, (next_in_h, next_in_c) = in_q_network(state, states=(in_h, in_c), @@ -350,17 +350,18 @@ def play_episode(frame_process_func, qvalue = rescaling(inverse_rescaling(ex_qvalue) + beta * inverse_rescaling(in_qvalue)) action = np.argmax(qvalue.detach().numpy()) - # step enviroment + # step environment next_frame, ex_reward, done, info = env.step(action) frames.append(frame_process_func(next_frame)) - + # batching (1, n_frames, 84, 84) next_state = np.stack(frames, axis=0)[None, ...] control_state = embedding_net(state).squeeze(0).detach().numpy() - error = np.square(original_lifelong_net(state).detach().numpy(), trained_lifelong_net(state).detach().numpy()).mean() - - if len(M) < k: + error = np.square(original_lifelong_net(state).detach().numpy(), + trained_lifelong_net(state).detach().numpy()).mean() + + if len(M) < k: episodic_reward = 0 std = 1 avg = 1 @@ -368,28 +369,28 @@ def play_episode(frame_process_func, episodic_reward = get_episodic_reward(control_state, M, k) std = np.std(error_list) avg = np.mean(error_list) - + curiosity = 1 + (error - avg) / (std + 1e-10) - + # push data to Memory - M.append(control_state) + M.append(control_state) error_list.append(error) - + in_reward = episodic_reward * np.clip(curiosity, 1, L) if is_test: episode_reward += ex_reward - + else: if lives != info["ale.lives"] or done: # done==True when lose life lives = info["ale.lives"] - transition = (prev_ex_reward, prev_in_reward, prev_action, + transition = (prev_in_reward, prev_ex_reward, prev_action, state, action, in_h, in_c, ex_h, ex_c, j, - True, ex_reward, in_reward, next_state) + True, in_reward, ex_reward, next_state) else: - transition = (prev_ex_reward, prev_in_reward, prev_action, + transition = (prev_in_reward, prev_ex_reward, prev_action, state, action, in_h, in_c, ex_h, ex_c, j, - done, ex_reward, in_reward, next_state) + done, in_reward, ex_reward, next_state) transitions.append(transition) ucb_datas.append((j, ex_reward)) @@ -412,58 +413,62 @@ def segments2contents(segments, burnin_len, is_grad=False, device=torch.device(" Returns: each content """ - + # (burnin_len+unroll_len, batch_size, n_frames, 84, 84) - states = torch.stack([torch.tensor(np.vstack(seg.states), requires_grad=is_grad) for seg in segments], dim=1).float().to(device) - + states = torch.stack([torch.tensor(np.vstack(seg.states), requires_grad=is_grad) + for seg in segments], dim=1).float().to(device) + # (burnin_len+unroll_len, batch_size) actions = torch.stack([torch.tensor(seg.actions) for seg in segments], dim=1).to(device) - + # (burnin_len+unroll_len, batch_size) - ex_rewards = torch.stack([torch.tensor(seg.ex_rewards, requires_grad=is_grad) for seg in segments], dim=1).float().to(device) - + ex_rewards = torch.stack([torch.tensor(seg.ex_rewards, requires_grad=is_grad) + for seg in segments], dim=1).float().to(device) + # (burnin_len+unroll_len, batch_size) - in_rewards = torch.stack([torch.tensor(seg.in_rewards, requires_grad=is_grad) for seg in segments], dim=1).float().to(device) - + in_rewards = torch.stack([torch.tensor(seg.in_rewards, requires_grad=is_grad) + for seg in segments], dim=1).float().to(device) + # (unroll_len, batch_size) dones = torch.stack([torch.tensor(seg.dones[burnin_len:]) for seg in segments], dim=1).float().to(device) # (batch_size,) j = torch.stack([torch.tensor(seg.j) for seg in segments], dim=0).to(device) - + # (batch_size, n_frames, 84, 84) - last_state = torch.stack([torch.tensor(np.vstack(seg.last_state), requires_grad=is_grad) for seg in segments], dim=0).float().to(device) - + last_state = torch.stack([torch.tensor(np.vstack(seg.last_state), requires_grad=is_grad) + for seg in segments], dim=0).float().to(device) + # (burnin_len+unroll_len, batch_size, n_frames, 84, 84) next_states = torch.cat([states, last_state[None, :]], dim=0)[1:].to(device) # (1, batch_size, hidden_size) in_h0 = torch.cat([seg.in_h_init for seg in segments], dim=1).float().to(device) - + # (1, batch_size, hidden_size) in_c0 = torch.cat([seg.in_c_init for seg in segments], dim=1).float().to(device) - + # (1, batch_size, hidden_size) ex_h0 = torch.cat([seg.ex_h_init for seg in segments], dim=1).float().to(device) - + # (1, batch_size, hidden_size) ex_c0 = torch.cat([seg.ex_c_init for seg in segments], dim=1).float().to(device) # (batch_size) in_reward0 = torch.tensor([seg.prev_in_reward_init for seg in segments]).float().to(device) - + # (batch_size) ex_reward0 = torch.tensor([seg.prev_ex_reward_init for seg in segments]).float().to(device) - + # (burnin+unroll_len, batch_size) prev_in_rewards = torch.cat([in_reward0[None, :], in_rewards], dim=0)[:-1] - + # (burnin+unroll_len, batch_size) prev_ex_rewards = torch.cat([ex_reward0[None, :], ex_rewards], dim=0)[:-1] # (batch_size) a0 = torch.tensor([seg.prev_a_init for seg in segments]).to(device) - + # (burnin+unroll_len, batch_size) prev_actions = torch.cat([a0[None, :], actions], dim=0)[:-1]