-
Notifications
You must be signed in to change notification settings - Fork 17
Open
Description
Hello there, as the code below:
def get_pos_enc(self, seq_len, dtype, device):
...
for bidx in range(0, self.net_config.n_block):
# For each block with bidx > 0, we need two types pos_encs:
# - Attn(pooled-q, unpooled-kv)
# - Attn(pooled-q, pooled-kv)
#### First type: Attn(pooled-q, unpooled-kv)
if bidx > 0:
# HERE, the pos_id has been changed in the `Second type` below
pooled_pos_id = self.stride_pool_pos(pos_id, bidx)
# construct rel_pos_id
q_stride = self.net_config.pooling_size ** bidx
k_stride = self.net_config.pooling_size ** (bidx - 1)
rel_pos_id = self.construct_rel_pos_seq(
q_pos=pooled_pos_id, q_stride=q_stride,
k_pos=pos_id, k_stride=k_stride)
# gather relative positional encoding
rel_pos_id = rel_pos_id[:, None] + zero_offset
rel_pos_id = rel_pos_id.expand(rel_pos_id.size(0), d_model)
pos_enc_2 = torch.gather(pos_enc, 0, rel_pos_id)
else:
pos_enc_2 = None
#### Second type: Attn(pooled-q, pooled-kv)
# construct rel_pos_id
pos_id = pooled_pos_id
stride = self.net_config.pooling_size ** bidx
rel_pos_id = self.construct_rel_pos_seq(
q_pos=pos_id, q_stride=stride,
k_pos=pos_id, k_stride=stride)
# gather relative positional encoding
rel_pos_id = rel_pos_id[:, None] + zero_offset
rel_pos_id = rel_pos_id.expand(rel_pos_id.size(0), d_model)
pos_enc_1 = torch.gather(pos_enc, 0, rel_pos_id)
pos_enc_list.append([pos_enc_1, pos_enc_2])
return pos_enc_listHere, the pos_id used in the first type has been changed (pooled) after the 1 block. I'm a little confused about that.
Is anyone can explain that , thanks.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels