Skip to content

Commit cb76ea4

Browse files
committed
update batch decoder configs and rnnlm opts
1 parent 39feaed commit cb76ea4

File tree

2 files changed

+28
-29
lines changed

2 files changed

+28
-29
lines changed

src/decoder/decoder-batch.cpp

+25-13
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,31 @@ BatchDecoder::BatchDecoder(ChainModel *const model) : model_(model) {
1515
if (model_->wb_info != nullptr) options.enable_word_level = true;
1616
if (model_->rnnlm_info != nullptr) options.enable_rnnlm = true;
1717

18-
const char *usage = "Usage: decoder-batch.cpp [options]";
19-
kaldi::ParseOptions po(usage);
20-
kaldi::CuDevice::RegisterDeviceOptions(&po);
21-
kaldi::RegisterCuAllocatorOptions(&po);
22-
batched_decoder_config_.Register(&po);
23-
24-
const char *argv[] = {
25-
"decoder-batch.cpp",
26-
// std::string("").c_str(),
27-
NULL
28-
};
29-
30-
po.Read((sizeof(argv)/sizeof(argv[0])) - 1, argv);
18+
// kaldi::CuDevice::RegisterDeviceOptions(&po); // only need if using fp16 (can't access device_options_ directly)
19+
// kaldi::g_allocator_options // only need if need to customize cuda memory usage
20+
21+
batched_decoder_config_.cuda_online_pipeline_opts.use_gpu_feature_extraction = false;
22+
batched_decoder_config_.cuda_online_pipeline_opts.determinize_lattice = false;
23+
24+
// decoder options
25+
batched_decoder_config_.cuda_online_pipeline_opts.decoder_opts.default_beam = model_->model_spec.beam;
26+
batched_decoder_config_.cuda_online_pipeline_opts.decoder_opts.lattice_beam = model_->model_spec.lattice_beam;
27+
batched_decoder_config_.cuda_online_pipeline_opts.decoder_opts.max_active = model_->model_spec.max_active;
28+
29+
// feature pipeline options
30+
batched_decoder_config_.cuda_online_pipeline_opts.feature_opts.feature_type = "mfcc";
31+
std::string model_dir = model_->model_spec.path;
32+
std::string conf_dir = join_path(model_dir, "conf");
33+
std::string mfcc_conf_filepath = join_path(conf_dir, "mfcc.conf");
34+
std::string ivector_conf_filepath = join_path(conf_dir, "ivector_extractor.conf");
35+
36+
batched_decoder_config_.cuda_online_pipeline_opts.feature_opts.mfcc_config = mfcc_conf_filepath;
37+
batched_decoder_config_.cuda_online_pipeline_opts.feature_opts.ivector_extraction_config = ivector_conf_filepath;
38+
batched_decoder_config_.cuda_online_pipeline_opts.feature_opts.silence_weighting_config.silence_weight = model_->model_spec.silence_weight;
39+
40+
// compute options
41+
batched_decoder_config_.cuda_online_pipeline_opts.compute_opts.acoustic_scale = model_->model_spec.acoustic_scale;
42+
batched_decoder_config_.cuda_online_pipeline_opts.compute_opts.frame_subsampling_factor = model_->model_spec.frame_subsampling_factor;
3143

3244
cuda_pipeline_ = NULL;
3345
}

src/model/model-chain.cpp

+3-16
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ ChainModel::ChainModel(const ModelSpec &model_spec) : model_spec(model_spec) {
5757
exists(join_path(rnnlm_dir, "final.raw")) &&
5858
exists(join_path(rnnlm_dir, "word_embedding.mat")) &&
5959
exists(join_path(rnnlm_dir, "G.fst"))) {
60+
61+
rnnlm_opts.bos_index = std::stoi(model_spec.bos_index);
62+
rnnlm_opts.eos_index = std::stoi(model_spec.eos_index);
6063

6164
lm_to_subtract_fst =
6265
std::unique_ptr<const fst::VectorFst<fst::StdArc>>(fst::ReadAndPrepareLmFst(join_path(rnnlm_dir, "G.fst")));
@@ -68,22 +71,6 @@ ChainModel::ChainModel(const ModelSpec &model_spec) : model_spec(model_spec) {
6871

6972
std::cout << "# Word Embeddings (RNNLM): " << word_embedding_mat.NumRows() << ENDL;
7073

71-
// hack: RNNLM compute opts only takes values from parsed options like in cmd-line
72-
const char *usage = "Usage: model.hpp [options]";
73-
kaldi::ParseOptions po(usage);
74-
rnnlm_opts.Register(&po);
75-
76-
std::string bos_opt = "--bos-symbol=" + model_spec.bos_index;
77-
std::string eos_opt = "--eos-symbol=" + model_spec.eos_index;
78-
79-
const char *argv[] = {
80-
"model.hpp",
81-
bos_opt.c_str(),
82-
eos_opt.c_str(),
83-
NULL
84-
};
85-
86-
po.Read((sizeof(argv)/sizeof(argv[0])) - 1, argv);
8774
rnnlm_info =
8875
make_uniq<const kaldi::rnnlm::RnnlmComputeStateInfo>(rnnlm_opts, rnnlm, word_embedding_mat);
8976
} else {

0 commit comments

Comments
 (0)