From 8157118bd8538c7224daad4339d768ba2381eec1 Mon Sep 17 00:00:00 2001 From: Devendra Tewari Date: Tue, 24 Nov 2020 08:18:36 -0300 Subject: [PATCH] specify model and device id via cli --- README.md | 2 +- src/assistant/run_assistant_audio.cc | 30 ++++++++++++++++++++++------ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 2cfb3be..7331443 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ aplay ./response.wav --rate=16000 --format=S16_LE On a Linux workstation, you can alternatively use ALSA for audio input: ```bash -./run_assistant_audio --credentials ./credentials.json +./run_assistant_audio --credentials ./credentials.json --model_id default --device_id default ``` You can use a text-based query instead of audio. This allows you to continually enter text queries to the Assistant. diff --git a/src/assistant/run_assistant_audio.cc b/src/assistant/run_assistant_audio.cc index e1bd2d4..7eb010e 100644 --- a/src/assistant/run_assistant_audio.cc +++ b/src/assistant/run_assistant_audio.cc @@ -93,19 +93,25 @@ void PrintUsage() { << "--credentials " << "[--api_endpoint ] " << "[--locale ]" - << "[--html_out ]" << std::endl; + << "[--html_out ]" + << "[--model_id ]" + << "[--device_id ]" << std::endl; } bool GetCommandLineFlags(int argc, char** argv, std::string* credentials_file_path, std::string* api_endpoint, std::string* locale, - std::string* html_out_command) { + std::string* html_out_command, + std::string* model_id, + std::string* device_id) { const struct option long_options[] = { {"credentials", required_argument, nullptr, 'c'}, {"api_endpoint", required_argument, nullptr, 'e'}, {"locale", required_argument, nullptr, 'l'}, {"verbose", no_argument, nullptr, 'v'}, {"html_out", required_argument, nullptr, 'h'}, + {"model_id", required_argument, nullptr, 'm'}, + {"device_id", required_argument, nullptr, 'd'}, {nullptr, 0, nullptr, 0}}; *api_endpoint = ASSISTANT_ENDPOINT; while (true) { @@ -131,6 +137,12 @@ bool GetCommandLineFlags(int argc, char** argv, case 'h': *html_out_command = optarg; break; + case 'm': + *model_id = optarg; + break; + case 'd': + *device_id = optarg; + break; default: PrintUsage(); return false; @@ -140,7 +152,7 @@ bool GetCommandLineFlags(int argc, char** argv, } int main(int argc, char** argv) { - std::string credentials_file_path, api_endpoint, locale, html_out_command; + std::string credentials_file_path, api_endpoint, locale, html_out_command, model_id, device_id; #ifndef ENABLE_ALSA std::cerr << "ALSA audio input is not supported on this platform." << std::endl; @@ -151,9 +163,15 @@ int main(int argc, char** argv) { // https://github.com/grpc/grpc/issues/11366#issuecomment-328595941 grpc_init(); if (!GetCommandLineFlags(argc, argv, &credentials_file_path, &api_endpoint, - &locale, &html_out_command)) { + &locale, &html_out_command, &model_id, &device_id)) { return -1; } + if (device_id.empty()) { + device_id.assign(kDeviceInstanceId); + } + if (model_id.empty()) { + model_id.assign(kDeviceModelId); + } while (true) { // Create an AssistRequest @@ -170,8 +188,8 @@ int main(int argc, char** argv) { assist_config->mutable_dialog_state_in()->set_language_code(locale); // Set the DeviceConfig of the AssistRequest - assist_config->mutable_device_config()->set_device_id(kDeviceInstanceId); - assist_config->mutable_device_config()->set_device_model_id(kDeviceModelId); + assist_config->mutable_device_config()->set_device_id(device_id); + assist_config->mutable_device_config()->set_device_model_id(model_id); // Set parameters for audio output assist_config->mutable_audio_out_config()->set_encoding(