-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbeam.py
138 lines (101 loc) · 3.54 KB
/
beam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import state
from mlm_training.model import Model
import numpy as np
import copy
class Beam:
"""
The beam for performing beam search.
This beam stores the current states and has the ability to cut it down
to a given size, which will save only the most-probable N states.
"""
def __init__(self):
"""
Create a new beam, initially empty.
"""
self.beam = []
def __iter__(self):
"""
Get an iterator for this beam.
Returns
=======
iter : iterator
An iterator for this beam, iterating over the states within it.
"""
return self.beam.__iter__()
def __len__(self):
"""
Get the number of states in this beam.
Returns
=======
length : int
The number of states in this beam.
"""
return len(self.beam)
def add(self, state):
"""
Add a State to this beam.
Parameters
==========
state : State
The state to add to this beam.
"""
self.beam.append(state)
def get_top_state(self):
"""
Get the most probable state from this beam.
Returns
=======
The most probable state from the beam.
"""
best = None
best_prob = float("-infinity")
for state in self.beam:
if state.log_prob > best_prob:
best = state
best_prob = state.log_prob
return best
def add_initial_state(self, model, sess, P, iterative_pw=False):
"""
Add an empty initial state to the beam.
This is used once before the initial beam search begins.
Parameters
==========
model : Model
The language model to use for the transduction process
sess : tf.session
The tensorflow session of the loaded model.
pitch_wise : boolean
True to use iterative pitchwise processing (and save only a single hidden_state
per State). False (default) otherwise.
"""
if model.pitchwise and not iterative_pw:
single_state = model.get_initial_state(sess, 1)[0]
# We have to get 88 initial states, one for each pitch
initial_state = [copy.copy(single_state) for i in range(P)]
else:
initial_state = model.get_initial_state(sess, 1)[0]
prior = np.ones(P) / 2 if not iterative_pw else np.array([0.5])
new_state = state.State(P, model.with_onsets)
new_state.update_from_lstm(initial_state, prior)
self.beam.append(new_state)
def cut_to_size(self, beam_size, hash_length):
"""
Removes all but the beam_size most probable states from this beam.
Parameters
==========
beam_size : int
The maximum number of states to save.
hash_length : int
The hash length to save. If two states do not differ in the past hash_length
frames, only the most probable one is saved in the beam.
"""
beam = sorted(self.beam, key=lambda s: s.log_prob, reverse=True)
self.beam = []
piano_rolls = []
for state in beam:
if len(self.beam) == beam_size:
break
pr = state.get_piano_roll(max_length=hash_length)
if not any((pr == x).all() for x in piano_rolls):
self.beam.append(state)
piano_rolls.append(pr)