We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent ef4e819 commit 541a8b0Copy full SHA for 541a8b0
outlines/processors/structured.py
@@ -102,12 +102,13 @@ def process_logits(
102
103
sequence_states.append(self._guide_states[curr_state_key])
104
105
- mask = torch.full_like(logits, -math.inf)
+ mask = torch.ones_like(logits, dtype=torch.bool)
106
for i, guide_state in enumerate(sequence_states):
107
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens
108
- mask[i, allowed_tokens] = logits[i, allowed_tokens]
+ mask[i, allowed_tokens] = False
109
+ logits.masked_fill_(mask, float("-inf"))
110
- return mask
111
+ return logits
112
113
def copy(self) -> "GuideLogitsProcessor":
114
"""Return a copy of the logits processor."""
0 commit comments