@@ -42,33 +42,35 @@ static void conv1_cond_init(float *mem, int len, int dilation, int *init)
4242 * init = 1 ;
4343}
4444
45- void DRED_rdovae_decode_all (const RDOVAEDec * model , float * features , const float * state , const float * latents , int nb_latents )
45+ void DRED_rdovae_decode_all (const RDOVAEDec * model , float * features , const float * state , const float * latents , int nb_latents , int arch )
4646{
4747 int i ;
4848 RDOVAEDecState dec ;
4949 memset (& dec , 0 , sizeof (dec ));
50- dred_rdovae_dec_init_states (& dec , model , state );
50+ dred_rdovae_dec_init_states (& dec , model , state , arch );
5151 for (i = 0 ; i < 2 * nb_latents ; i += 2 )
5252 {
5353 dred_rdovae_decode_qframe (
5454 & dec ,
5555 model ,
5656 & features [2 * i * DRED_NUM_FEATURES ],
57- & latents [(i /2 )* DRED_LATENT_DIM ]);
57+ & latents [(i /2 )* DRED_LATENT_DIM ],
58+ arch );
5859 }
5960}
6061
6162void dred_rdovae_dec_init_states (
6263 RDOVAEDecState * h , /* io: state buffer handle */
6364 const RDOVAEDec * model ,
64- const float * initial_state /* i: initial state */
65+ const float * initial_state , /* i: initial state */
66+ int arch
6567 )
6668{
6769 float hidden [DEC_HIDDEN_INIT_OUT_SIZE ];
6870 float state_init [DEC_GRU1_STATE_SIZE + DEC_GRU2_STATE_SIZE + DEC_GRU3_STATE_SIZE + DEC_GRU4_STATE_SIZE + DEC_GRU5_STATE_SIZE ];
6971 int counter = 0 ;
70- compute_generic_dense (& model -> dec_hidden_init , hidden , initial_state , ACTIVATION_TANH );
71- compute_generic_dense (& model -> dec_gru_init , state_init , hidden , ACTIVATION_TANH );
72+ compute_generic_dense (& model -> dec_hidden_init , hidden , initial_state , ACTIVATION_TANH , arch );
73+ compute_generic_dense (& model -> dec_gru_init , state_init , hidden , ACTIVATION_TANH , arch );
7274 OPUS_COPY (h -> gru1_state , state_init , DEC_GRU1_STATE_SIZE );
7375 counter += DEC_GRU1_STATE_SIZE ;
7476 OPUS_COPY (h -> gru2_state , & state_init [counter ], DEC_GRU2_STATE_SIZE );
@@ -86,51 +88,52 @@ void dred_rdovae_decode_qframe(
8688 RDOVAEDecState * dec_state , /* io: state buffer handle */
8789 const RDOVAEDec * model ,
8890 float * qframe , /* o: quadruple feature frame (four concatenated frames in reverse order) */
89- const float * input /* i: latent vector */
91+ const float * input , /* i: latent vector */
92+ int arch
9093 )
9194{
9295 float buffer [DEC_DENSE1_OUT_SIZE + DEC_GRU1_OUT_SIZE + DEC_GRU2_OUT_SIZE + DEC_GRU3_OUT_SIZE + DEC_GRU4_OUT_SIZE + DEC_GRU5_OUT_SIZE
9396 + DEC_CONV1_OUT_SIZE + DEC_CONV2_OUT_SIZE + DEC_CONV3_OUT_SIZE + DEC_CONV4_OUT_SIZE + DEC_CONV5_OUT_SIZE ];
9497 int output_index = 0 ;
9598
9699 /* run encoder stack and concatenate output in buffer*/
97- compute_generic_dense (& model -> dec_dense1 , & buffer [output_index ], input , ACTIVATION_TANH );
100+ compute_generic_dense (& model -> dec_dense1 , & buffer [output_index ], input , ACTIVATION_TANH , arch );
98101 output_index += DEC_DENSE1_OUT_SIZE ;
99102
100- compute_generic_gru (& model -> dec_gru1_input , & model -> dec_gru1_recurrent , dec_state -> gru1_state , buffer );
103+ compute_generic_gru (& model -> dec_gru1_input , & model -> dec_gru1_recurrent , dec_state -> gru1_state , buffer , arch );
101104 OPUS_COPY (& buffer [output_index ], dec_state -> gru1_state , DEC_GRU1_OUT_SIZE );
102105 output_index += DEC_GRU1_OUT_SIZE ;
103106 conv1_cond_init (dec_state -> conv1_state , output_index , 1 , & dec_state -> initialized );
104- compute_generic_conv1d (& model -> dec_conv1 , & buffer [output_index ], dec_state -> conv1_state , buffer , output_index , ACTIVATION_TANH );
107+ compute_generic_conv1d (& model -> dec_conv1 , & buffer [output_index ], dec_state -> conv1_state , buffer , output_index , ACTIVATION_TANH , arch );
105108 output_index += DEC_CONV1_OUT_SIZE ;
106109
107- compute_generic_gru (& model -> dec_gru2_input , & model -> dec_gru2_recurrent , dec_state -> gru2_state , buffer );
110+ compute_generic_gru (& model -> dec_gru2_input , & model -> dec_gru2_recurrent , dec_state -> gru2_state , buffer , arch );
108111 OPUS_COPY (& buffer [output_index ], dec_state -> gru2_state , DEC_GRU2_OUT_SIZE );
109112 output_index += DEC_GRU2_OUT_SIZE ;
110113 conv1_cond_init (dec_state -> conv2_state , output_index , 1 , & dec_state -> initialized );
111- compute_generic_conv1d (& model -> dec_conv2 , & buffer [output_index ], dec_state -> conv2_state , buffer , output_index , ACTIVATION_TANH );
114+ compute_generic_conv1d (& model -> dec_conv2 , & buffer [output_index ], dec_state -> conv2_state , buffer , output_index , ACTIVATION_TANH , arch );
112115 output_index += DEC_CONV2_OUT_SIZE ;
113116
114- compute_generic_gru (& model -> dec_gru3_input , & model -> dec_gru3_recurrent , dec_state -> gru3_state , buffer );
117+ compute_generic_gru (& model -> dec_gru3_input , & model -> dec_gru3_recurrent , dec_state -> gru3_state , buffer , arch );
115118 OPUS_COPY (& buffer [output_index ], dec_state -> gru3_state , DEC_GRU3_OUT_SIZE );
116119 output_index += DEC_GRU3_OUT_SIZE ;
117120 conv1_cond_init (dec_state -> conv3_state , output_index , 1 , & dec_state -> initialized );
118- compute_generic_conv1d (& model -> dec_conv3 , & buffer [output_index ], dec_state -> conv3_state , buffer , output_index , ACTIVATION_TANH );
121+ compute_generic_conv1d (& model -> dec_conv3 , & buffer [output_index ], dec_state -> conv3_state , buffer , output_index , ACTIVATION_TANH , arch );
119122 output_index += DEC_CONV3_OUT_SIZE ;
120123
121- compute_generic_gru (& model -> dec_gru4_input , & model -> dec_gru4_recurrent , dec_state -> gru4_state , buffer );
124+ compute_generic_gru (& model -> dec_gru4_input , & model -> dec_gru4_recurrent , dec_state -> gru4_state , buffer , arch );
122125 OPUS_COPY (& buffer [output_index ], dec_state -> gru4_state , DEC_GRU4_OUT_SIZE );
123126 output_index += DEC_GRU4_OUT_SIZE ;
124127 conv1_cond_init (dec_state -> conv4_state , output_index , 1 , & dec_state -> initialized );
125- compute_generic_conv1d (& model -> dec_conv4 , & buffer [output_index ], dec_state -> conv4_state , buffer , output_index , ACTIVATION_TANH );
128+ compute_generic_conv1d (& model -> dec_conv4 , & buffer [output_index ], dec_state -> conv4_state , buffer , output_index , ACTIVATION_TANH , arch );
126129 output_index += DEC_CONV4_OUT_SIZE ;
127130
128- compute_generic_gru (& model -> dec_gru5_input , & model -> dec_gru5_recurrent , dec_state -> gru5_state , buffer );
131+ compute_generic_gru (& model -> dec_gru5_input , & model -> dec_gru5_recurrent , dec_state -> gru5_state , buffer , arch );
129132 OPUS_COPY (& buffer [output_index ], dec_state -> gru5_state , DEC_GRU5_OUT_SIZE );
130133 output_index += DEC_GRU5_OUT_SIZE ;
131134 conv1_cond_init (dec_state -> conv5_state , output_index , 1 , & dec_state -> initialized );
132- compute_generic_conv1d (& model -> dec_conv5 , & buffer [output_index ], dec_state -> conv5_state , buffer , output_index , ACTIVATION_TANH );
135+ compute_generic_conv1d (& model -> dec_conv5 , & buffer [output_index ], dec_state -> conv5_state , buffer , output_index , ACTIVATION_TANH , arch );
133136 output_index += DEC_CONV5_OUT_SIZE ;
134137
135- compute_generic_dense (& model -> dec_output , qframe , buffer , ACTIVATION_LINEAR );
138+ compute_generic_dense (& model -> dec_output , qframe , buffer , ACTIVATION_LINEAR , arch );
136139}
0 commit comments