Skip to content

Commit dc08541

Browse files
committed
Huge batch of formatting fixes
1 parent 13b13ac commit dc08541

27 files changed

+792
-605
lines changed

src/il_representations/algos/__init__.py

Lines changed: 103 additions & 78 deletions
Large diffs are not rendered by default.

src/il_representations/algos/augmenters.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,16 @@
1-
import enum
2-
from torchvision import transforms
3-
from imitation.augment.color import ColorSpace # noqa: F401
4-
from imitation.augment.convenience import StandardAugmentations
5-
from il_representations.algos.utils import gaussian_blur
6-
import torch
7-
from abc import ABC, abstractmethod
8-
import PIL
91
"""
102
These are pretty basic: when constructed, they take in a list of augmentations, and
113
either augment just the context, or both the context and the target, depending on the algorithm.
124
"""
5+
from abc import ABC, abstractmethod
6+
7+
from imitation.augment.color import ColorSpace # noqa: F401
8+
from imitation.augment.convenience import StandardAugmentations
139

1410

1511
class Augmenter(ABC):
1612
def __init__(self, augmenter_spec, color_space):
17-
augment_op = StandardAugmentations.from_string_spec(
18-
augmenter_spec, color_space)
13+
augment_op = StandardAugmentations.from_string_spec(augmenter_spec, color_space)
1914
self.augment_op = augment_op
2015

2116
@abstractmethod

src/il_representations/algos/base_learner.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import gym
2+
23
from il_representations.algos.utils import set_global_seeds
34

45

@@ -13,12 +14,10 @@ def __init__(self, env):
1314
# if EncoderSimplePolicyHead is refactored.
1415
if isinstance(self.action_space, gym.spaces.Discrete):
1516
self.action_size = env.action_space.n
16-
elif (isinstance(self.action_space, gym.spaces.Box)
17-
and len(self.action_space.shape) == 1):
17+
elif (isinstance(self.action_space, gym.spaces.Box) and len(self.action_space.shape) == 1):
1818
self.action_size, = self.action_space.shape
1919
else:
20-
raise NotImplementedError(
21-
f"can't handle action space {self.action_space}")
20+
raise NotImplementedError(f"can't handle action space {self.action_space}")
2221

2322
def set_random_seed(self, seed):
2423
if seed is None:

src/il_representations/algos/batch_extenders.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
"""
2+
BatchExtenders are used in situations where you want to pass a batch forward
3+
for loss that is different than the batch seen by your encoder. The currently
4+
implemented situation where this is the case is Momentum, where you want to
5+
pass forward a bunch of negatives from prior encoding runs to increase the
6+
difficulty of your prediction task. One might also imagine this being useful
7+
for doing trajectory-mixing in a RNN case where batches naturally need to be
8+
all from a small number of trajectories, but this isn't yet implemented.
9+
"""
110
from abc import ABC, abstractmethod
11+
212
import torch
3-
from torch.distributions import Normal
13+
414
from il_representations.algos.utils import independent_multivariate_normal
5-
"""
6-
BatchExtenders are used in situations where you want to pass a batch forward for loss that is different than the
7-
batch seen by your encoder. The currently implemented situation where this is the case is Momentum, where you want
8-
to pass forward a bunch of negatives from prior encoding runs to increase the difficulty of your prediction task.
9-
One might also imagine this being useful for doing trajectory-mixing in a RNN case where batches naturally need
10-
to be all from a small number of trajectories, but this isn't yet implemented.
11-
"""
1215

1316

1417
class BatchExtender(ABC):
@@ -33,20 +36,24 @@ def __init__(self, queue_dim, queue_size, sample=False):
3336
self.queue_ptr = 0
3437

3538
def __call__(self, context_dist, target_dist):
36-
# Call up current contents of the queue, duplicate. Add targets to the queue,
37-
# potentially overriding old information in the process. Return targets concatenated to contents of queue
39+
# Call up current contents of the queue, duplicate. Add targets to the
40+
# queue, potentially overriding old information in the process. Return
41+
# targets concatenated to contents of queue
3842
targets_loc = target_dist.loc
3943
targets_covariance = target_dist.covariance_matrix
4044
device = targets_loc.device
4145
assert targets_loc.device == targets_covariance.device
4246

43-
# Pull out the diagonals of our MultivariateNormal covariance matrices, so we don't store all the extra 0s
44-
targets_scale = torch.stack([batch_element_matrix.diag() for batch_element_matrix in targets_covariance])
47+
# Pull out the diagonals of our MultivariateNormal covariance matrices,
48+
# so we don't store all the extra 0s
49+
targets_scale = torch.stack(
50+
[batch_element_matrix.diag() for batch_element_matrix in targets_covariance])
4551

4652
batch_size = targets_loc.shape[0]
4753
queue_targets_scale = self.queue_scale.clone().detach().to(device)
4854
queue_targets_loc = self.queue_loc.clone().detach().to(device)
49-
# TODO: Currently requires the queue size to be a multiple of the batch size. Don't require that.
55+
# TODO: Currently requires the queue size to be a multiple of the batch
56+
# size. Don't require that.
5057
self.queue_loc[self.queue_ptr:self.queue_ptr + batch_size] = targets_loc
5158
self.queue_scale[self.queue_ptr:self.queue_ptr + batch_size] = targets_scale
5259
self.queue_ptr = (self.queue_ptr + batch_size) % self.queue_size

src/il_representations/algos/decoders.py

Lines changed: 103 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,36 @@
1-
import functools
2-
import torch.nn as nn
1+
"""
2+
LossDecoders are meant to be mappings between the representation being learned,
3+
and the representation or tensor that is fed directly into the loss. In many
4+
cases, these are the same, and this will just be a NoOp.
5+
6+
Some cases where it is different:
7+
- When you are using a Projection Head in your contrastive loss, and comparing
8+
similarities of vectors that are k >=1 nonlinear layers downstream from the
9+
actual representation you'll use in later tasks
10+
- When you're learning a VAE, and the loss is determined by how effectively you
11+
can reconstruct the image from a representation vector, the LossDecoder will
12+
handle that representation -> image mapping
13+
- When you're predicting actions given current and next state, you'll want to
14+
predict those actions given both the representation of the current state, and
15+
also information about the next state. This occasional need for extra
16+
information beyond the central context state is why we have `extra_context`
17+
as an optional bit of data that pair constructors can return, to be passed
18+
forward for use here
19+
"""
320
import copy
4-
import torch
5-
import torch.nn.functional as F
6-
from torch.distributions import MultivariateNormal
7-
from il_representations.algos.utils import independent_multivariate_normal
21+
import functools
22+
823
import gym.spaces as spaces
924
import numpy as np
25+
import torch
26+
from torch.distributions import MultivariateNormal
27+
import torch.nn as nn
28+
import torch.nn.functional as F
1029

30+
from il_representations.algos.utils import independent_multivariate_normal
1131

12-
"""
13-
LossDecoders are meant to be mappings between the representation being learned,
14-
and the representation or tensor that is fed directly into the loss. In many cases, these are the
15-
same, and this will just be a NoOp.
16-
17-
Some cases where it is different:
18-
- When you are using a Projection Head in your contrastive loss, and comparing similarities of vectors that are
19-
k >=1 nonlinear layers downstream from the actual representation you'll use in later tasks
20-
- When you're learning a VAE, and the loss is determined by how effectively you can reconstruct the image
21-
from a representation vector, the LossDecoder will handle that representation -> image mapping
22-
- When you're predicting actions given current and next state, you'll want to predict those actions given
23-
both the representation of the current state, and also information about the next state. This occasional
24-
need for extra information beyond the central context state is why we have `extra_context` as an optional
25-
bit of data that pair constructors can return, to be passed forward for use here
26-
"""
32+
# TODO change shape to dim throughout this file and the code
2733

28-
#TODO change shape to dim throughout this file and the code
2934

3035
class LossDecoder(nn.Module):
3136
def __init__(self, representation_dim, projection_shape, sample=False):
@@ -51,20 +56,23 @@ def get_vector(self, z_dist):
5156
return z_dist.loc
5257

5358
def ones_like_projection_dim(self, x):
54-
return torch.ones(size=(x.shape[0], self.projection_dim,), device=x.device)
59+
return torch.ones(size=(
60+
x.shape[0],
61+
self.projection_dim,
62+
), device=x.device)
63+
5564

5665
class NoOp(LossDecoder):
5766
def forward(self, z, traj_info, extra_context=None):
5867
return z
5968

69+
6070
class ProjectionHead(LossDecoder):
6171
def __init__(self, representation_dim, projection_shape, sample=False, learn_scale=False):
6272
super(ProjectionHead, self).__init__(representation_dim, projection_shape, sample)
6373

64-
self.shared_mlp = nn.Sequential(nn.Linear(self.representation_dim, 256),
65-
nn.ReLU(),
66-
nn.Linear(256, 256),
67-
nn.ReLU())
74+
self.shared_mlp = nn.Sequential(nn.Linear(self.representation_dim, 256), nn.ReLU(),
75+
nn.Linear(256, 256), nn.ReLU())
6876
self.mean_layer = nn.Linear(256, self.projection_dim)
6977

7078
if learn_scale:
@@ -75,14 +83,24 @@ def __init__(self, representation_dim, projection_shape, sample=False, learn_sca
7583
def forward(self, z_dist, traj_info, extra_context=None):
7684
z = self.get_vector(z_dist)
7785
shared_repr = self.shared_mlp(z)
78-
return independent_multivariate_normal(loc=self.mean_layer(shared_repr), scale=torch.exp(self.scale_layer(shared_repr)))
86+
return independent_multivariate_normal(loc=self.mean_layer(shared_repr),
87+
scale=torch.exp(self.scale_layer(shared_repr)))
7988

8089

8190
class MomentumProjectionHead(LossDecoder):
82-
def __init__(self, representation_dim, projection_shape, sample=False, momentum_weight=0.99, learn_scale=False):
83-
super(MomentumProjectionHead, self).__init__(representation_dim, projection_shape, sample=sample)
84-
self.context_decoder = ProjectionHead(representation_dim, projection_shape,
85-
sample=sample, learn_scale=learn_scale)
91+
def __init__(self,
92+
representation_dim,
93+
projection_shape,
94+
sample=False,
95+
momentum_weight=0.99,
96+
learn_scale=False):
97+
super(MomentumProjectionHead, self).__init__(representation_dim,
98+
projection_shape,
99+
sample=sample)
100+
self.context_decoder = ProjectionHead(representation_dim,
101+
projection_shape,
102+
sample=sample,
103+
learn_scale=learn_scale)
86104
self.target_decoder = copy.deepcopy(self.context_decoder)
87105
for param in self.target_decoder.parameters():
88106
param.requires_grad = False
@@ -93,32 +111,39 @@ def decode_context(self, z_dist, traj_info, extra_context=None):
93111

94112
def decode_target(self, z_dist, traj_info, extra_context=None):
95113
"""
96-
Encoder target/keys using momentum-updated key encoder. Had some thought of making _momentum_update_key_encoder
97-
a backwards hook, but seemed overly complex for an initial POC
114+
Encoder target/keys using momentum-updated key encoder. Had some
115+
thought of making _momentum_update_key_encoder a backwards hook, but
116+
seemed overly complex for an initial POC
98117
:param x:
99118
:return:
100119
"""
101120
with torch.no_grad():
102121
self._momentum_update_key_encoder()
103122
decoded_z_dist = self.target_decoder(z_dist, traj_info, extra_context=extra_context)
104-
return MultivariateNormal(loc=decoded_z_dist.loc.detach(), covariance_matrix=decoded_z_dist.covariance_matrix.detach())
123+
return MultivariateNormal(loc=decoded_z_dist.loc.detach(),
124+
covariance_matrix=decoded_z_dist.covariance_matrix.detach())
105125

106126
@torch.no_grad()
107127
def _momentum_update_key_encoder(self):
108-
for param_q, param_k in zip(self.context_decoder.parameters(), self.target_decoder.parameters()):
109-
param_k.data = param_k.data * self.momentum_weight + param_q.data * (1. - self.momentum_weight)
128+
for param_q, param_k in zip(self.context_decoder.parameters(),
129+
self.target_decoder.parameters()):
130+
param_k.data = param_k.data * self.momentum_weight + param_q.data * (
131+
1. - self.momentum_weight)
110132

111133

112134
class BYOLProjectionHead(MomentumProjectionHead):
113135
def __init__(self, representation_dim, projection_shape, momentum_weight=0.99, sample=False):
114-
super(BYOLProjectionHead, self).__init__(representation_dim, projection_shape,
115-
sample=sample, momentum_weight=momentum_weight)
136+
super(BYOLProjectionHead, self).__init__(representation_dim,
137+
projection_shape,
138+
sample=sample,
139+
momentum_weight=momentum_weight)
116140
self.context_predictor = ProjectionHead(projection_shape, projection_shape)
117141

118142
def forward(self, z_dist, traj_info, extra_context=None):
119143
internal_dist = super().forward(z_dist, traj_info, extra_context=extra_context)
120144
prediction_dist = self.context_predictor(internal_dist, traj_info, extra_context=None)
121-
return independent_multivariate_normal(loc=F.normalize(prediction_dist.loc, dim=1), scale=prediction_dist.scale)
145+
return independent_multivariate_normal(loc=F.normalize(prediction_dist.loc, dim=1),
146+
scale=prediction_dist.scale)
122147

123148
def decode_target(self, z_dist, traj_info, extra_context=None):
124149
with torch.no_grad():
@@ -128,39 +153,60 @@ def decode_target(self, z_dist, traj_info, extra_context=None):
128153

129154

130155
class ActionConditionedVectorDecoder(LossDecoder):
131-
def __init__(self, representation_dim, projection_shape, action_space, sample=False, action_encoding_dim=128,
132-
action_encoder_layers=1, learn_scale=False, action_embedding_dim=5, use_lstm=False):
133-
super(ActionConditionedVectorDecoder, self).__init__(representation_dim, projection_shape, sample=sample)
156+
def __init__(self,
157+
representation_dim,
158+
projection_shape,
159+
action_space,
160+
sample=False,
161+
action_encoding_dim=128,
162+
action_encoder_layers=1,
163+
learn_scale=False,
164+
action_embedding_dim=5,
165+
use_lstm=False):
166+
super(ActionConditionedVectorDecoder, self).__init__(representation_dim,
167+
projection_shape,
168+
sample=sample)
134169
self.learn_scale = learn_scale
135170

136-
# Machinery for turning raw actions into vectors. If actions are discrete, this is done via an Embedding.
171+
# Machinery for turning raw actions into vectors. If actions are
172+
# discrete, this is done via an Embedding.
137173
# If actions are continuous/box, this is done via a simple flattening.
138174
if isinstance(action_space, spaces.Discrete):
139-
self.action_processor = nn.Embedding(num_embeddings=action_space.n, embedding_dim=action_embedding_dim)
175+
self.action_processor = nn.Embedding(num_embeddings=action_space.n,
176+
embedding_dim=action_embedding_dim)
140177
processed_action_dim = action_embedding_dim
141178
self.action_shape = () # discrete actions are just numbers
142179
elif isinstance(action_space, spaces.Box):
143-
self.action_processor = functools.partial(torch.flatten,
144-
start_dim=2)
180+
self.action_processor = functools.partial(torch.flatten, start_dim=2)
145181
processed_action_dim = np.prod(action_space.shape)
146182
self.action_shape = action_space.shape
147183
else:
148-
raise NotImplementedError("Action conditioning is only currently implemented for Discrete and Box action spaces")
184+
raise NotImplementedError(
185+
"Action conditioning is only currently implemented for Discrete and Box action "
186+
"spaces")
149187

150-
# Machinery for aggregating information from an arbitrary number of actions into a single vector,
151-
# either through a LSTM, or by simply averaging the vector representations of the k states together
188+
# Machinery for aggregating information from an arbitrary number of
189+
# actions into a single vector, either through a LSTM, or by simply
190+
# averaging the vector representations of the k states together
152191
if use_lstm:
153-
self.action_encoder = nn.LSTM(processed_action_dim, action_encoding_dim, action_encoder_layers, batch_first=True)
192+
self.action_encoder = nn.LSTM(processed_action_dim,
193+
action_encoding_dim,
194+
action_encoder_layers,
195+
batch_first=True)
154196
else:
155197
self.action_encoder = None
156198
action_encoding_dim = processed_action_dim
157199

158-
# Machinery for mapping a concatenated (context representation, action representation) into a projection
159-
self.action_conditioned_projection = nn.Linear(representation_dim + action_encoding_dim, projection_shape)
200+
# Machinery for mapping a concatenated (context representation, action
201+
# representation) into a projection
202+
self.action_conditioned_projection = nn.Linear(representation_dim + action_encoding_dim,
203+
projection_shape)
160204

161-
# If learning scale/std deviation parameter, declare a layer for that, otherwise, return a unit-constant vector
205+
# If learning scale/std deviation parameter, declare a layer for that,
206+
# otherwise, return a unit-constant vector
162207
if self.learn_scale:
163-
self.scale_projection = nn.Linear(representation_dim + action_encoding_dim, projection_shape)
208+
self.scale_projection = nn.Linear(representation_dim + action_encoding_dim,
209+
projection_shape)
164210
else:
165211
self.scale_projection = self.ones_like_projection_dim
166212

@@ -192,9 +238,9 @@ def decode_context(self, z_dist, traj_info, extra_context=None):
192238
assert action_encoding_vector.shape[0] == batch_dim, \
193239
action_encoding_vector.shape
194240

195-
# Concatenate context representation and action representation and map to a merged representation
241+
# Concatenate context representation and action representation and map
242+
# to a merged representation
196243
merged_vector = torch.cat([z, action_encoding_vector], dim=1)
197244
mean_projection = self.action_conditioned_projection(merged_vector)
198245
scale = self.scale_projection(merged_vector)
199246
return independent_multivariate_normal(loc=mean_projection, scale=scale)
200-

0 commit comments

Comments
 (0)