@@ -42,33 +42,35 @@ static void conv1_cond_init(float *mem, int len, int dilation, int *init)
42
42
* init = 1 ;
43
43
}
44
44
45
- void DRED_rdovae_decode_all (const RDOVAEDec * model , float * features , const float * state , const float * latents , int nb_latents )
45
+ void DRED_rdovae_decode_all (const RDOVAEDec * model , float * features , const float * state , const float * latents , int nb_latents , int arch )
46
46
{
47
47
int i ;
48
48
RDOVAEDecState dec ;
49
49
memset (& dec , 0 , sizeof (dec ));
50
- dred_rdovae_dec_init_states (& dec , model , state );
50
+ dred_rdovae_dec_init_states (& dec , model , state , arch );
51
51
for (i = 0 ; i < 2 * nb_latents ; i += 2 )
52
52
{
53
53
dred_rdovae_decode_qframe (
54
54
& dec ,
55
55
model ,
56
56
& features [2 * i * DRED_NUM_FEATURES ],
57
- & latents [(i /2 )* DRED_LATENT_DIM ]);
57
+ & latents [(i /2 )* DRED_LATENT_DIM ],
58
+ arch );
58
59
}
59
60
}
60
61
61
62
void dred_rdovae_dec_init_states (
62
63
RDOVAEDecState * h , /* io: state buffer handle */
63
64
const RDOVAEDec * model ,
64
- const float * initial_state /* i: initial state */
65
+ const float * initial_state , /* i: initial state */
66
+ int arch
65
67
)
66
68
{
67
69
float hidden [DEC_HIDDEN_INIT_OUT_SIZE ];
68
70
float state_init [DEC_GRU1_STATE_SIZE + DEC_GRU2_STATE_SIZE + DEC_GRU3_STATE_SIZE + DEC_GRU4_STATE_SIZE + DEC_GRU5_STATE_SIZE ];
69
71
int counter = 0 ;
70
- compute_generic_dense (& model -> dec_hidden_init , hidden , initial_state , ACTIVATION_TANH );
71
- compute_generic_dense (& model -> dec_gru_init , state_init , hidden , ACTIVATION_TANH );
72
+ compute_generic_dense (& model -> dec_hidden_init , hidden , initial_state , ACTIVATION_TANH , arch );
73
+ compute_generic_dense (& model -> dec_gru_init , state_init , hidden , ACTIVATION_TANH , arch );
72
74
OPUS_COPY (h -> gru1_state , state_init , DEC_GRU1_STATE_SIZE );
73
75
counter += DEC_GRU1_STATE_SIZE ;
74
76
OPUS_COPY (h -> gru2_state , & state_init [counter ], DEC_GRU2_STATE_SIZE );
@@ -86,51 +88,52 @@ void dred_rdovae_decode_qframe(
86
88
RDOVAEDecState * dec_state , /* io: state buffer handle */
87
89
const RDOVAEDec * model ,
88
90
float * qframe , /* o: quadruple feature frame (four concatenated frames in reverse order) */
89
- const float * input /* i: latent vector */
91
+ const float * input , /* i: latent vector */
92
+ int arch
90
93
)
91
94
{
92
95
float buffer [DEC_DENSE1_OUT_SIZE + DEC_GRU1_OUT_SIZE + DEC_GRU2_OUT_SIZE + DEC_GRU3_OUT_SIZE + DEC_GRU4_OUT_SIZE + DEC_GRU5_OUT_SIZE
93
96
+ DEC_CONV1_OUT_SIZE + DEC_CONV2_OUT_SIZE + DEC_CONV3_OUT_SIZE + DEC_CONV4_OUT_SIZE + DEC_CONV5_OUT_SIZE ];
94
97
int output_index = 0 ;
95
98
96
99
/* run encoder stack and concatenate output in buffer*/
97
- compute_generic_dense (& model -> dec_dense1 , & buffer [output_index ], input , ACTIVATION_TANH );
100
+ compute_generic_dense (& model -> dec_dense1 , & buffer [output_index ], input , ACTIVATION_TANH , arch );
98
101
output_index += DEC_DENSE1_OUT_SIZE ;
99
102
100
- compute_generic_gru (& model -> dec_gru1_input , & model -> dec_gru1_recurrent , dec_state -> gru1_state , buffer );
103
+ compute_generic_gru (& model -> dec_gru1_input , & model -> dec_gru1_recurrent , dec_state -> gru1_state , buffer , arch );
101
104
OPUS_COPY (& buffer [output_index ], dec_state -> gru1_state , DEC_GRU1_OUT_SIZE );
102
105
output_index += DEC_GRU1_OUT_SIZE ;
103
106
conv1_cond_init (dec_state -> conv1_state , output_index , 1 , & dec_state -> initialized );
104
- compute_generic_conv1d (& model -> dec_conv1 , & buffer [output_index ], dec_state -> conv1_state , buffer , output_index , ACTIVATION_TANH );
107
+ compute_generic_conv1d (& model -> dec_conv1 , & buffer [output_index ], dec_state -> conv1_state , buffer , output_index , ACTIVATION_TANH , arch );
105
108
output_index += DEC_CONV1_OUT_SIZE ;
106
109
107
- compute_generic_gru (& model -> dec_gru2_input , & model -> dec_gru2_recurrent , dec_state -> gru2_state , buffer );
110
+ compute_generic_gru (& model -> dec_gru2_input , & model -> dec_gru2_recurrent , dec_state -> gru2_state , buffer , arch );
108
111
OPUS_COPY (& buffer [output_index ], dec_state -> gru2_state , DEC_GRU2_OUT_SIZE );
109
112
output_index += DEC_GRU2_OUT_SIZE ;
110
113
conv1_cond_init (dec_state -> conv2_state , output_index , 1 , & dec_state -> initialized );
111
- compute_generic_conv1d (& model -> dec_conv2 , & buffer [output_index ], dec_state -> conv2_state , buffer , output_index , ACTIVATION_TANH );
114
+ compute_generic_conv1d (& model -> dec_conv2 , & buffer [output_index ], dec_state -> conv2_state , buffer , output_index , ACTIVATION_TANH , arch );
112
115
output_index += DEC_CONV2_OUT_SIZE ;
113
116
114
- compute_generic_gru (& model -> dec_gru3_input , & model -> dec_gru3_recurrent , dec_state -> gru3_state , buffer );
117
+ compute_generic_gru (& model -> dec_gru3_input , & model -> dec_gru3_recurrent , dec_state -> gru3_state , buffer , arch );
115
118
OPUS_COPY (& buffer [output_index ], dec_state -> gru3_state , DEC_GRU3_OUT_SIZE );
116
119
output_index += DEC_GRU3_OUT_SIZE ;
117
120
conv1_cond_init (dec_state -> conv3_state , output_index , 1 , & dec_state -> initialized );
118
- compute_generic_conv1d (& model -> dec_conv3 , & buffer [output_index ], dec_state -> conv3_state , buffer , output_index , ACTIVATION_TANH );
121
+ compute_generic_conv1d (& model -> dec_conv3 , & buffer [output_index ], dec_state -> conv3_state , buffer , output_index , ACTIVATION_TANH , arch );
119
122
output_index += DEC_CONV3_OUT_SIZE ;
120
123
121
- compute_generic_gru (& model -> dec_gru4_input , & model -> dec_gru4_recurrent , dec_state -> gru4_state , buffer );
124
+ compute_generic_gru (& model -> dec_gru4_input , & model -> dec_gru4_recurrent , dec_state -> gru4_state , buffer , arch );
122
125
OPUS_COPY (& buffer [output_index ], dec_state -> gru4_state , DEC_GRU4_OUT_SIZE );
123
126
output_index += DEC_GRU4_OUT_SIZE ;
124
127
conv1_cond_init (dec_state -> conv4_state , output_index , 1 , & dec_state -> initialized );
125
- compute_generic_conv1d (& model -> dec_conv4 , & buffer [output_index ], dec_state -> conv4_state , buffer , output_index , ACTIVATION_TANH );
128
+ compute_generic_conv1d (& model -> dec_conv4 , & buffer [output_index ], dec_state -> conv4_state , buffer , output_index , ACTIVATION_TANH , arch );
126
129
output_index += DEC_CONV4_OUT_SIZE ;
127
130
128
- compute_generic_gru (& model -> dec_gru5_input , & model -> dec_gru5_recurrent , dec_state -> gru5_state , buffer );
131
+ compute_generic_gru (& model -> dec_gru5_input , & model -> dec_gru5_recurrent , dec_state -> gru5_state , buffer , arch );
129
132
OPUS_COPY (& buffer [output_index ], dec_state -> gru5_state , DEC_GRU5_OUT_SIZE );
130
133
output_index += DEC_GRU5_OUT_SIZE ;
131
134
conv1_cond_init (dec_state -> conv5_state , output_index , 1 , & dec_state -> initialized );
132
- compute_generic_conv1d (& model -> dec_conv5 , & buffer [output_index ], dec_state -> conv5_state , buffer , output_index , ACTIVATION_TANH );
135
+ compute_generic_conv1d (& model -> dec_conv5 , & buffer [output_index ], dec_state -> conv5_state , buffer , output_index , ACTIVATION_TANH , arch );
133
136
output_index += DEC_CONV5_OUT_SIZE ;
134
137
135
- compute_generic_dense (& model -> dec_output , qframe , buffer , ACTIVATION_LINEAR );
138
+ compute_generic_dense (& model -> dec_output , qframe , buffer , ACTIVATION_LINEAR , arch );
136
139
}
0 commit comments