Skip to content

Commit 541a8b0

Browse files
lapp0rlouf
authored andcommitted
update logits in place for GuideLogitsProcessor
1 parent ef4e819 commit 541a8b0

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

outlines/processors/structured.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,13 @@ def process_logits(
102102

103103
sequence_states.append(self._guide_states[curr_state_key])
104104

105-
mask = torch.full_like(logits, -math.inf)
105+
mask = torch.ones_like(logits, dtype=torch.bool)
106106
for i, guide_state in enumerate(sequence_states):
107107
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens
108-
mask[i, allowed_tokens] = logits[i, allowed_tokens]
108+
mask[i, allowed_tokens] = False
109+
logits.masked_fill_(mask, float("-inf"))
109110

110-
return mask
111+
return logits
111112

112113
def copy(self) -> "GuideLogitsProcessor":
113114
"""Return a copy of the logits processor."""

0 commit comments

Comments
 (0)