diff --git a/dlrm_s_pytorch.py b/dlrm_s_pytorch.py index eb352664..92404560 100644 --- a/dlrm_s_pytorch.py +++ b/dlrm_s_pytorch.py @@ -474,8 +474,7 @@ def interact_features(self, x, ly): if self.arch_interaction_op == "dot": # concatenate dense and sparse features - (batch_size, d) = x.shape - T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d)) + T = torch.stack([x] + ly, dim=1) # perform a dot product Z = torch.bmm(T, torch.transpose(T, 1, 2)) # append dense feature with the interactions (into a row vector)