Skip to content

Commit 3e31dbe

Browse files
committed
cleanup search beam logic
#18
1 parent 5c12985 commit 3e31dbe

File tree

2 files changed

+0
-20
lines changed

2 files changed

+0
-20
lines changed

nn/search.py

-12
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,6 @@ class SearchFuncInterface:
1111
Interface for search.
1212
"""
1313

14-
def get_beam(self) -> nn.SearchBeam:
15-
"""
16-
Return a new beam instance. dependency and name is still unset and will be set outside.
17-
This overwrites whatever is returned by :func:`choice`,
18-
and copy_as_prev_frame() is used to set the initial (prev) state beam.
19-
"""
20-
raise NotImplementedError
21-
2214
def choice(self, *, probs: nn.Tensor, probs_type: str) -> nn.Tensor:
2315
"""
2416
Given an output tensor (logits or log prop), returns a beam of chosen indices.
@@ -43,10 +35,6 @@ def __init__(self, beam_size: int, max_seq_len: nn.Tensor):
4335
self.beam_size = beam_size
4436
self.max_seq_len = max_seq_len
4537

46-
def get_beam(self):
47-
"""beam"""
48-
return nn.SearchBeam(beam_size=self.beam_size)
49-
5038
def choice(self, *, probs: nn.Tensor, probs_type: str) -> nn.Tensor:
5139
"""nn.choice"""
5240
return nn.choice(

nn/transformer.py

-8
Original file line numberDiff line numberDiff line change
@@ -380,13 +380,6 @@ def __call__(self, source: nn.Tensor, *,
380380
assert target_spatial_axis, f"{self}: Target spatial axis must be specified when target is given"
381381
loop = nn.Loop(axis=target_spatial_axis)
382382
loop.state = state if state else self.default_initial_state()
383-
beam = None
384-
if search:
385-
beam = search.get_beam()
386-
beam.name = f"{nn.NameCtx.current_ctx().get_abs_name()}/target"
387-
beam.dependency = beam.copy_as_prev_frame()
388-
for x in loop.state.deep_tensors():
389-
x.data.beam = beam.dependency
390383
with loop:
391384
prev_target_embed = self.target_embedding(loop.state.target)
392385
output, loop.state.decoder = self.decoder(
@@ -397,7 +390,6 @@ def __call__(self, source: nn.Tensor, *,
397390
if search:
398391
search.apply_loop(loop)
399392
choice = search.choice(probs=logits, probs_type="logits")
400-
choice.data.beam = beam
401393
loop.state.target = choice
402394
loop.end(loop.state.target == self.target_eos_symbol, include_eos=False)
403395
else:

0 commit comments

Comments
 (0)