1- import functools
2- import torch .nn as nn
1+ """
2+ LossDecoders are meant to be mappings between the representation being learned,
3+ and the representation or tensor that is fed directly into the loss. In many
4+ cases, these are the same, and this will just be a NoOp.
5+
6+ Some cases where it is different:
7+ - When you are using a Projection Head in your contrastive loss, and comparing
8+ similarities of vectors that are k >=1 nonlinear layers downstream from the
9+ actual representation you'll use in later tasks
10+ - When you're learning a VAE, and the loss is determined by how effectively you
11+ can reconstruct the image from a representation vector, the LossDecoder will
12+ handle that representation -> image mapping
13+ - When you're predicting actions given current and next state, you'll want to
14+ predict those actions given both the representation of the current state, and
15+ also information about the next state. This occasional need for extra
16+ information beyond the central context state is why we have `extra_context`
17+ as an optional bit of data that pair constructors can return, to be passed
18+ forward for use here
19+ """
320import copy
4- import torch
5- import torch .nn .functional as F
6- from torch .distributions import MultivariateNormal
7- from il_representations .algos .utils import independent_multivariate_normal
21+ import functools
22+
823import gym .spaces as spaces
924import numpy as np
25+ import torch
26+ from torch .distributions import MultivariateNormal
27+ import torch .nn as nn
28+ import torch .nn .functional as F
1029
30+ from il_representations .algos .utils import independent_multivariate_normal
1131
12- """
13- LossDecoders are meant to be mappings between the representation being learned,
14- and the representation or tensor that is fed directly into the loss. In many cases, these are the
15- same, and this will just be a NoOp.
16-
17- Some cases where it is different:
18- - When you are using a Projection Head in your contrastive loss, and comparing similarities of vectors that are
19- k >=1 nonlinear layers downstream from the actual representation you'll use in later tasks
20- - When you're learning a VAE, and the loss is determined by how effectively you can reconstruct the image
21- from a representation vector, the LossDecoder will handle that representation -> image mapping
22- - When you're predicting actions given current and next state, you'll want to predict those actions given
23- both the representation of the current state, and also information about the next state. This occasional
24- need for extra information beyond the central context state is why we have `extra_context` as an optional
25- bit of data that pair constructors can return, to be passed forward for use here
26- """
32+ # TODO change shape to dim throughout this file and the code
2733
28- #TODO change shape to dim throughout this file and the code
2934
3035class LossDecoder (nn .Module ):
3136 def __init__ (self , representation_dim , projection_shape , sample = False ):
@@ -51,20 +56,23 @@ def get_vector(self, z_dist):
5156 return z_dist .loc
5257
5358 def ones_like_projection_dim (self , x ):
54- return torch .ones (size = (x .shape [0 ], self .projection_dim ,), device = x .device )
59+ return torch .ones (size = (
60+ x .shape [0 ],
61+ self .projection_dim ,
62+ ), device = x .device )
63+
5564
5665class NoOp (LossDecoder ):
5766 def forward (self , z , traj_info , extra_context = None ):
5867 return z
5968
69+
6070class ProjectionHead (LossDecoder ):
6171 def __init__ (self , representation_dim , projection_shape , sample = False , learn_scale = False ):
6272 super (ProjectionHead , self ).__init__ (representation_dim , projection_shape , sample )
6373
64- self .shared_mlp = nn .Sequential (nn .Linear (self .representation_dim , 256 ),
65- nn .ReLU (),
66- nn .Linear (256 , 256 ),
67- nn .ReLU ())
74+ self .shared_mlp = nn .Sequential (nn .Linear (self .representation_dim , 256 ), nn .ReLU (),
75+ nn .Linear (256 , 256 ), nn .ReLU ())
6876 self .mean_layer = nn .Linear (256 , self .projection_dim )
6977
7078 if learn_scale :
@@ -75,14 +83,24 @@ def __init__(self, representation_dim, projection_shape, sample=False, learn_sca
7583 def forward (self , z_dist , traj_info , extra_context = None ):
7684 z = self .get_vector (z_dist )
7785 shared_repr = self .shared_mlp (z )
78- return independent_multivariate_normal (loc = self .mean_layer (shared_repr ), scale = torch .exp (self .scale_layer (shared_repr )))
86+ return independent_multivariate_normal (loc = self .mean_layer (shared_repr ),
87+ scale = torch .exp (self .scale_layer (shared_repr )))
7988
8089
8190class MomentumProjectionHead (LossDecoder ):
82- def __init__ (self , representation_dim , projection_shape , sample = False , momentum_weight = 0.99 , learn_scale = False ):
83- super (MomentumProjectionHead , self ).__init__ (representation_dim , projection_shape , sample = sample )
84- self .context_decoder = ProjectionHead (representation_dim , projection_shape ,
85- sample = sample , learn_scale = learn_scale )
91+ def __init__ (self ,
92+ representation_dim ,
93+ projection_shape ,
94+ sample = False ,
95+ momentum_weight = 0.99 ,
96+ learn_scale = False ):
97+ super (MomentumProjectionHead , self ).__init__ (representation_dim ,
98+ projection_shape ,
99+ sample = sample )
100+ self .context_decoder = ProjectionHead (representation_dim ,
101+ projection_shape ,
102+ sample = sample ,
103+ learn_scale = learn_scale )
86104 self .target_decoder = copy .deepcopy (self .context_decoder )
87105 for param in self .target_decoder .parameters ():
88106 param .requires_grad = False
@@ -93,32 +111,39 @@ def decode_context(self, z_dist, traj_info, extra_context=None):
93111
94112 def decode_target (self , z_dist , traj_info , extra_context = None ):
95113 """
96- Encoder target/keys using momentum-updated key encoder. Had some thought of making _momentum_update_key_encoder
97- a backwards hook, but seemed overly complex for an initial POC
114+ Encoder target/keys using momentum-updated key encoder. Had some
115+ thought of making _momentum_update_key_encoder a backwards hook, but
116+ seemed overly complex for an initial POC
98117 :param x:
99118 :return:
100119 """
101120 with torch .no_grad ():
102121 self ._momentum_update_key_encoder ()
103122 decoded_z_dist = self .target_decoder (z_dist , traj_info , extra_context = extra_context )
104- return MultivariateNormal (loc = decoded_z_dist .loc .detach (), covariance_matrix = decoded_z_dist .covariance_matrix .detach ())
123+ return MultivariateNormal (loc = decoded_z_dist .loc .detach (),
124+ covariance_matrix = decoded_z_dist .covariance_matrix .detach ())
105125
106126 @torch .no_grad ()
107127 def _momentum_update_key_encoder (self ):
108- for param_q , param_k in zip (self .context_decoder .parameters (), self .target_decoder .parameters ()):
109- param_k .data = param_k .data * self .momentum_weight + param_q .data * (1. - self .momentum_weight )
128+ for param_q , param_k in zip (self .context_decoder .parameters (),
129+ self .target_decoder .parameters ()):
130+ param_k .data = param_k .data * self .momentum_weight + param_q .data * (
131+ 1. - self .momentum_weight )
110132
111133
112134class BYOLProjectionHead (MomentumProjectionHead ):
113135 def __init__ (self , representation_dim , projection_shape , momentum_weight = 0.99 , sample = False ):
114- super (BYOLProjectionHead , self ).__init__ (representation_dim , projection_shape ,
115- sample = sample , momentum_weight = momentum_weight )
136+ super (BYOLProjectionHead , self ).__init__ (representation_dim ,
137+ projection_shape ,
138+ sample = sample ,
139+ momentum_weight = momentum_weight )
116140 self .context_predictor = ProjectionHead (projection_shape , projection_shape )
117141
118142 def forward (self , z_dist , traj_info , extra_context = None ):
119143 internal_dist = super ().forward (z_dist , traj_info , extra_context = extra_context )
120144 prediction_dist = self .context_predictor (internal_dist , traj_info , extra_context = None )
121- return independent_multivariate_normal (loc = F .normalize (prediction_dist .loc , dim = 1 ), scale = prediction_dist .scale )
145+ return independent_multivariate_normal (loc = F .normalize (prediction_dist .loc , dim = 1 ),
146+ scale = prediction_dist .scale )
122147
123148 def decode_target (self , z_dist , traj_info , extra_context = None ):
124149 with torch .no_grad ():
@@ -128,39 +153,60 @@ def decode_target(self, z_dist, traj_info, extra_context=None):
128153
129154
130155class ActionConditionedVectorDecoder (LossDecoder ):
131- def __init__ (self , representation_dim , projection_shape , action_space , sample = False , action_encoding_dim = 128 ,
132- action_encoder_layers = 1 , learn_scale = False , action_embedding_dim = 5 , use_lstm = False ):
133- super (ActionConditionedVectorDecoder , self ).__init__ (representation_dim , projection_shape , sample = sample )
156+ def __init__ (self ,
157+ representation_dim ,
158+ projection_shape ,
159+ action_space ,
160+ sample = False ,
161+ action_encoding_dim = 128 ,
162+ action_encoder_layers = 1 ,
163+ learn_scale = False ,
164+ action_embedding_dim = 5 ,
165+ use_lstm = False ):
166+ super (ActionConditionedVectorDecoder , self ).__init__ (representation_dim ,
167+ projection_shape ,
168+ sample = sample )
134169 self .learn_scale = learn_scale
135170
136- # Machinery for turning raw actions into vectors. If actions are discrete, this is done via an Embedding.
171+ # Machinery for turning raw actions into vectors. If actions are
172+ # discrete, this is done via an Embedding.
137173 # If actions are continuous/box, this is done via a simple flattening.
138174 if isinstance (action_space , spaces .Discrete ):
139- self .action_processor = nn .Embedding (num_embeddings = action_space .n , embedding_dim = action_embedding_dim )
175+ self .action_processor = nn .Embedding (num_embeddings = action_space .n ,
176+ embedding_dim = action_embedding_dim )
140177 processed_action_dim = action_embedding_dim
141178 self .action_shape = () # discrete actions are just numbers
142179 elif isinstance (action_space , spaces .Box ):
143- self .action_processor = functools .partial (torch .flatten ,
144- start_dim = 2 )
180+ self .action_processor = functools .partial (torch .flatten , start_dim = 2 )
145181 processed_action_dim = np .prod (action_space .shape )
146182 self .action_shape = action_space .shape
147183 else :
148- raise NotImplementedError ("Action conditioning is only currently implemented for Discrete and Box action spaces" )
184+ raise NotImplementedError (
185+ "Action conditioning is only currently implemented for Discrete and Box action "
186+ "spaces" )
149187
150- # Machinery for aggregating information from an arbitrary number of actions into a single vector,
151- # either through a LSTM, or by simply averaging the vector representations of the k states together
188+ # Machinery for aggregating information from an arbitrary number of
189+ # actions into a single vector, either through a LSTM, or by simply
190+ # averaging the vector representations of the k states together
152191 if use_lstm :
153- self .action_encoder = nn .LSTM (processed_action_dim , action_encoding_dim , action_encoder_layers , batch_first = True )
192+ self .action_encoder = nn .LSTM (processed_action_dim ,
193+ action_encoding_dim ,
194+ action_encoder_layers ,
195+ batch_first = True )
154196 else :
155197 self .action_encoder = None
156198 action_encoding_dim = processed_action_dim
157199
158- # Machinery for mapping a concatenated (context representation, action representation) into a projection
159- self .action_conditioned_projection = nn .Linear (representation_dim + action_encoding_dim , projection_shape )
200+ # Machinery for mapping a concatenated (context representation, action
201+ # representation) into a projection
202+ self .action_conditioned_projection = nn .Linear (representation_dim + action_encoding_dim ,
203+ projection_shape )
160204
161- # If learning scale/std deviation parameter, declare a layer for that, otherwise, return a unit-constant vector
205+ # If learning scale/std deviation parameter, declare a layer for that,
206+ # otherwise, return a unit-constant vector
162207 if self .learn_scale :
163- self .scale_projection = nn .Linear (representation_dim + action_encoding_dim , projection_shape )
208+ self .scale_projection = nn .Linear (representation_dim + action_encoding_dim ,
209+ projection_shape )
164210 else :
165211 self .scale_projection = self .ones_like_projection_dim
166212
@@ -192,9 +238,9 @@ def decode_context(self, z_dist, traj_info, extra_context=None):
192238 assert action_encoding_vector .shape [0 ] == batch_dim , \
193239 action_encoding_vector .shape
194240
195- # Concatenate context representation and action representation and map to a merged representation
241+ # Concatenate context representation and action representation and map
242+ # to a merged representation
196243 merged_vector = torch .cat ([z , action_encoding_vector ], dim = 1 )
197244 mean_projection = self .action_conditioned_projection (merged_vector )
198245 scale = self .scale_projection (merged_vector )
199246 return independent_multivariate_normal (loc = mean_projection , scale = scale )
200-
0 commit comments