Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
*~
*~
/bazel-*
/lyra/model_coeffs/_models.h
363 changes: 26 additions & 337 deletions README.md

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,16 @@ maven_install(
)


load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

http_archive(
name = "gflags",
urls = ["https://github.com/gflags/gflags/archive/refs/tags/v2.2.2.tar.gz"],
strip_prefix = "gflags-2.2.2",
sha256 = "34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf",
)


# Begin Tensorflow WORKSPACE subset required for TFLite

git_repository(
Expand Down
33 changes: 33 additions & 0 deletions dll/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
cc_binary(
name = "lyra_dll",
srcs = [
"dllmain.cc",
],
#data = [":tflite_testdata"],
linkopts = select({
"//lyra:android_config": ["-landroid"],
"//conditions:default": [],
}),
deps = [
"//lyra:lyra_config",
"//lyra:lyra_encoder",
"//lyra:lyra_decoder",
"//lyra:wav_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/flags:usage",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@com_google_glog//:glog",
"@gulrak_filesystem//:filesystem",
],
linkshared = 1,
copts = ["/DCOMPILING_DLL"],
target_compatible_with = [
"@platforms//cpu:x86_64",
"@platforms//os:windows",
],
)
96 changes: 96 additions & 0 deletions dll/dllmain.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN
#endif
#include "Windows.h"

#include "lyra/lyra_config.h"
#include "lyra/lyra_encoder.h"
#include "lyra/lyra_decoder.h"
#include "lyra/model_coeffs/_models.h"

#define BYTES_PER_SAMPLE 2

BOOL APIENTRY DllMain(HMODULE hModule, DWORD ul_reason_for_call, LPVOID lpReserved)
{
switch (ul_reason_for_call)
{
case DLL_PROCESS_ATTACH:
case DLL_THREAD_ATTACH:
case DLL_THREAD_DETACH:
case DLL_PROCESS_DETACH:
break;
}
return TRUE;
}

std::unique_ptr<chromemedia::codec::LyraEncoder> m_Encoder = nullptr;
std::unique_ptr<chromemedia::codec::LyraDecoder> m_Decoder = nullptr;

extern "C" __declspec(dllexport) bool Initialize()
{
const int samplerate = 16000;
const int bitrate = 3200;

const chromemedia::codec::LyraModels models = GetEmbeddedLyraModels();

if (!m_Encoder)
{
m_Encoder = chromemedia::codec::LyraEncoder::Create(samplerate, 1, bitrate, false, models);
}
if (!m_Decoder)
{
m_Decoder = chromemedia::codec::LyraDecoder::Create(samplerate, 1, models);
}
return m_Encoder != nullptr && m_Decoder != nullptr;
}

extern "C" __declspec(dllexport) void Shutdown()
{
m_Encoder.reset();
m_Decoder.reset();
}

extern "C" __declspec(dllexport) void Encode(const int16_t* uncompressed, size_t uncompressed_size, uint8_t* compressed, size_t compressed_size)
{
const int num_samples_per_packet = m_Encoder->sample_rate_hz() / m_Encoder->frame_rate();
const int raw_frame_size = num_samples_per_packet * BYTES_PER_SAMPLE;

assert(uncompressed_size >= num_samples_per_packet);

std::vector<int16_t> uncompressed_vector(uncompressed, uncompressed + num_samples_per_packet);
std::optional<std::vector<uint8_t>> encoded = m_Encoder->Encode(uncompressed_vector);

if (!encoded.has_value())
{
return;
}

assert(encoded->size() == chromemedia::codec::BitrateToPacketSize(m_Encoder->bitrate()));
assert(compressed_size >= encoded->size());

memcpy_s(compressed, compressed_size, encoded->data(), encoded->size());
}

extern "C" __declspec(dllexport) void Decode(const int8_t* compressed, size_t compressed_size, uint16_t* uncompressed, size_t uncompressed_size)
{
const int num_samples_per_packet = m_Encoder->sample_rate_hz() / m_Encoder->frame_rate();
const int packet_size = chromemedia::codec::BitrateToPacketSize(m_Encoder->bitrate());

assert(compressed_size == packet_size);

bool valid = m_Decoder->SetEncodedPacket(absl::MakeSpan(reinterpret_cast<const uint8_t*>(compressed), compressed_size));
assert(valid == true);
if (!valid) return;

std::optional<std::vector<int16_t>> decoded = m_Decoder->DecodeSamples(num_samples_per_packet);
if (!decoded.has_value())
{
assert(decoded.has_value());
return;
}

assert(decoded->size() == num_samples_per_packet);
assert(uncompressed_size >= decoded->size());

memcpy_s(uncompressed, uncompressed_size * sizeof(uint16_t), decoded->data(), decoded->size() * sizeof(uint16_t));
}
13 changes: 5 additions & 8 deletions lyra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,11 @@ config_setting(
values = {"crosstool_top": "//external:android/crosstool"},
)

cc_library(
name = "architecture_utils",
hdrs = ["architecture_utils.h"],
deps = ["@gulrak_filesystem//:filesystem"],
)

cc_library(
name = "lyra_benchmark_lib",
srcs = ["lyra_benchmark_lib.cc"],
hdrs = ["lyra_benchmark_lib.h"],
deps = [
":architecture_utils",
":dsp_utils",
":feature_extractor_interface",
":generative_model_interface",
Expand Down Expand Up @@ -771,6 +764,7 @@ cc_library(
],
hdrs = [
"tflite_model_wrapper.h",
"lyra_embedded_models.h",
],
deps = [
"@com_google_absl//absl/memory",
Expand Down Expand Up @@ -814,7 +808,10 @@ cc_test(

cc_test(
name = "tflite_model_wrapper_test",
srcs = ["tflite_model_wrapper_test.cc"],
srcs = [
"tflite_model_wrapper_test.cc",
"model_coeffs/_models.h",
],
data = ["model_coeffs/lyragan.tflite"],
deps = [
":tflite_model_wrapper",
Expand Down
34 changes: 0 additions & 34 deletions lyra/architecture_utils.h

This file was deleted.

2 changes: 0 additions & 2 deletions lyra/cli_example/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ cc_binary(
}),
deps = [
":encoder_main_lib",
"//lyra:architecture_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/flags:usage",
Expand All @@ -139,7 +138,6 @@ cc_binary(
}),
deps = [
":decoder_main_lib",
"//lyra:architecture_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/flags:usage",
Expand Down
14 changes: 4 additions & 10 deletions lyra/cli_example/decoder_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
#include "absl/strings/string_view.h"
#include "glog/logging.h" // IWYU pragma: keep
#include "include/ghc/filesystem.hpp"
#include "lyra/architecture_utils.h"
#include "lyra/cli_example/decoder_main_lib.h"

#include "lyra/model_coeffs/_models.h"

ABSL_FLAG(std::string, encoded_path, "",
"Complete path to the file containing the encoded features.");
ABSL_FLAG(std::string, output_dir, "",
Expand All @@ -50,11 +51,6 @@ ABSL_FLAG(chromemedia::codec::PacketLossPattern, fixed_packet_loss_pattern,
"bursts will be rounded up to the nearest packet duration boundary. "
"If this flag contains a nonzero number of values we ignore "
"|packet_loss_rate| and |average_burst_length|.");
ABSL_FLAG(std::string, model_path, "lyra/model_coeffs",
"Path to directory containing TFLite files. For mobile this is the "
"absolute path, like "
"'/data/local/tmp/lyra/model_coeffs/'."
" For desktop this is the path relative to the binary.");

int main(int argc, char** argv) {
absl::SetProgramUsageMessage(argv[0]);
Expand All @@ -71,9 +67,7 @@ int main(int argc, char** argv) {
const float average_burst_length = absl::GetFlag(FLAGS_average_burst_length);
const chromemedia::codec::PacketLossPattern fixed_packet_loss_pattern =
absl::GetFlag(FLAGS_fixed_packet_loss_pattern);
const ghc::filesystem::path model_path =
chromemedia::codec::GetCompleteArchitecturePath(
absl::GetFlag(FLAGS_model_path));
const chromemedia::codec::LyraModels models = GetEmbeddedLyraModels();
if (!fixed_packet_loss_pattern.starts_.empty()) {
LOG(INFO) << "Using fixed packet loss pattern instead of gilbert model.";
}
Expand Down Expand Up @@ -102,7 +96,7 @@ int main(int argc, char** argv) {
if (!chromemedia::codec::DecodeFile(encoded_path, output_path, sample_rate_hz,
bitrate, randomize_num_samples_requested,
packet_loss_rate, average_burst_length,
fixed_packet_loss_pattern, model_path)) {
fixed_packet_loss_pattern, models)) {
LOG(ERROR) << "Could not decode " << encoded_path;
return -1;
}
Expand Down
4 changes: 2 additions & 2 deletions lyra/cli_example/decoder_main_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ bool DecodeFile(const ghc::filesystem::path& encoded_path,
int bitrate, bool randomize_num_samples_requested,
float packet_loss_rate, float average_burst_length,
const PacketLossPattern& fixed_packet_loss_pattern,
const ghc::filesystem::path& model_path) {
auto decoder = LyraDecoder::Create(sample_rate_hz, kNumChannels, model_path);
const LyraModels& models) {
auto decoder = LyraDecoder::Create(sample_rate_hz, kNumChannels, models);
if (decoder == nullptr) {
LOG(ERROR) << "Could not create lyra decoder.";
return false;
Expand Down
3 changes: 1 addition & 2 deletions lyra/cli_example/decoder_main_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ bool DecodeFeatures(const std::vector<uint8_t>& packet_stream, int packet_size,
std::vector<int16_t>* decoded_audio);

// Decodes an encoded features file into a wav file.
// Uses the model and quant files located under |model_path|.
// Given the file /tmp/lyra/file1.lyra exists and is a valid encoded file. For:
// |encoded_path| = "/tmp/lyra/file1.lyra"
// |output_path| = "/tmp/lyra/file1_decoded.lyra"
Expand All @@ -66,7 +65,7 @@ bool DecodeFile(const ghc::filesystem::path& encoded_path,
int bitrate, bool randomize_num_samples_requested,
float packet_loss_rate, float average_burst_length,
const PacketLossPattern& fixed_packet_loss_pattern,
const ghc::filesystem::path& model_path);
const chromemedia::codec::LyraModels& models);

} // namespace codec
} // namespace chromemedia
Expand Down
Loading