diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc index e74d76d7..672f6530 100644 --- a/src/trainer_interface.cc +++ b/src/trainer_interface.cc @@ -303,8 +303,8 @@ bool TrainerInterface::IsValidSentencePiece( } template -void AddDPNoise(const TrainerSpec &trainer_spec, absl::SharedBitGen &generator, - T *to_update) { +void AddDPNoise(const TrainerSpec &trainer_spec, + random::SharedBitGen &generator, T *to_update) { if (trainer_spec.differential_privacy_noise_level() > 0) { float random_num = absl::Gaussian( generator, 0, trainer_spec.differential_privacy_noise_level()); @@ -480,7 +480,7 @@ util::Status TrainerInterface::LoadSentences() { for (int n = 0; n < num_workers; ++n) { pool->Schedule([&, n]() { // One per thread generator. - absl::SharedBitGen generator; + random::SharedBitGen generator; for (size_t i = n; i < sentences_.size(); i += num_workers) { AddDPNoise(trainer_spec_, generator, &(sentences_[i].second)); diff --git a/src/util.h b/src/util.h index cd843275..b305aa83 100644 --- a/src/util.h +++ b/src/util.h @@ -288,6 +288,11 @@ namespace random { std::mt19937 *GetRandomGenerator(); +class SharedBitGen { + public: + std::mt19937 *engine() { return GetRandomGenerator(); } +}; + template class ReservoirSampler { public: diff --git a/third_party/absl/random/distributions.h b/third_party/absl/random/distributions.h index 246ecb27..b559db9d 100644 --- a/third_party/absl/random/distributions.h +++ b/third_party/absl/random/distributions.h @@ -21,8 +21,8 @@ namespace absl { -template -T Gaussian(SharedBitGen &generator, T mean, T stddev) { +template +T Gaussian(G &generator, T mean, T stddev) { std::normal_distribution<> dist(mean, stddev); return dist(*generator.engine()); } diff --git a/third_party/absl/random/random.h b/third_party/absl/random/random.h index 3c3a21ed..d131d801 100644 --- a/third_party/absl/random/random.h +++ b/third_party/absl/random/random.h @@ -15,19 +15,4 @@ #ifndef ABSL_CONTAINER_RANDOM_H_ #define ABSL_CONTAINER_RANDOM_H_ -#include - -#include "../../../src/util.h" - -using sentencepiece::random::GetRandomGenerator; - -namespace absl { - -class SharedBitGen { - public: - std::mt19937 *engine() { return GetRandomGenerator(); } -}; - -} // namespace absl - #endif // ABSL_CONTAINER_RANDOM_H_