@@ -119,11 +119,12 @@ public Layer(LayerArgs args)
119119 /// Wraps `call`, applying pre- and post-processing steps.
120120 /// </summary>
121121 /// <param name="input"></param>
122+ /// <param name="state"></param>
122123 /// <param name="is_training"></param>
123124 /// <returns></returns>
124- public Tensor Apply ( Tensor inputs , bool is_training = false )
125+ public Tensors Apply ( Tensors inputs , Tensor state = null , bool is_training = false )
125126 {
126- Tensor outputs = null ;
127+ Tensors outputs = null ;
127128
128129 callContext = callContext ?? new ThreadLocal < CallContext > ( )
129130 {
@@ -148,7 +149,7 @@ public Tensor Apply(Tensor inputs, bool is_training = false)
148149 if ( ! built )
149150 MaybeBuild ( inputs ) ;
150151
151- outputs = call ( inputs , is_training : is_training ) ;
152+ outputs = call ( inputs , state : state , is_training : is_training ) ;
152153
153154 outputs = _set_connectivity_metadata_ ( inputs , outputs ) ;
154155 _handle_activity_regularization ( inputs , outputs ) ;
@@ -161,36 +162,7 @@ public Tensor Apply(Tensor inputs, bool is_training = false)
161162 return outputs ;
162163 }
163164
164- public Tensor [ ] Apply ( Tensor [ ] inputs , Tensor state , bool is_training = false )
165- {
166- Tensor [ ] outputs = null ;
167-
168- callContext = callContext ?? new ThreadLocal < CallContext > ( )
169- {
170- Value = new CallContext ( )
171- } ;
172-
173- var eager = tf . executing_eagerly ( ) ;
174- using var ctxManager = CallContext . enter ( ) ;
175-
176- string nameScope = "" ;
177- if ( eager )
178- nameScope = name ;
179- else
180- nameScope = _name_scope ( ) ;
181-
182- tf_with ( ops . name_scope ( nameScope ) , scope =>
183- {
184- if ( ! built )
185- MaybeBuild ( inputs [ 0 ] ) ;
186-
187- outputs = call ( inputs , is_training : is_training , state : state ) ;
188- } ) ;
189-
190- return outputs ;
191- }
192-
193- private Tensor _set_connectivity_metadata_ ( Tensor inputs , Tensor outputs )
165+ private Tensors _set_connectivity_metadata_ ( Tensors inputs , Tensors outputs )
194166 {
195167 /*var returnOutputs = new List<Tensor>();
196168 foreach(var x in outputs)
@@ -211,15 +183,15 @@ private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs)
211183 return outputs ;
212184 }
213185
214- private void _handle_activity_regularization ( Tensor inputs , Tensor outputs )
186+ private void _handle_activity_regularization ( Tensors inputs , Tensors outputs )
215187 {
216188 //if(_activity_regularizer != null)
217189 {
218190
219191 }
220192 }
221193
222- private void _set_mask_metadata ( Tensor inputs , Tensor outputs , Tensor previous_mask )
194+ private void _set_mask_metadata ( Tensors inputs , Tensors outputs , Tensors previous_mask )
223195 {
224196
225197 }
@@ -229,12 +201,7 @@ private Tensor compute_mask(Tensor inputs, Tensor mask = null)
229201 return null ;
230202 }
231203
232- protected virtual Tensor call ( Tensor inputs , bool is_training = false )
233- {
234- throw new NotImplementedException ( "" ) ;
235- }
236-
237- protected virtual Tensor [ ] call ( Tensor [ ] inputs , Tensor state , bool is_training = false )
204+ protected virtual Tensors call ( Tensors inputs , Tensor state = null , bool is_training = false )
238205 {
239206 throw new NotImplementedException ( "" ) ;
240207 }
@@ -244,15 +211,15 @@ protected virtual string _name_scope()
244211 return Name ;
245212 }
246213
247- protected void MaybeBuild ( Tensor inputs )
214+ protected void MaybeBuild ( Tensors inputs )
248215 {
249216 // Check input assumptions set before layer building, e.g. input rank.
250217 if ( built )
251218 return ;
252219 if ( DType == TF_DataType . DtInvalid )
253220 args . DType = inputs . dtype ;
254221
255- var input_shapes = inputs . TensorShape ;
222+ var input_shapes = inputs . shape ;
256223 build ( input_shapes ) ;
257224 built = true ;
258225 }
0 commit comments