-
Notifications
You must be signed in to change notification settings - Fork 538
/
Copy pathcompression_clif_aux.cc
124 lines (99 loc) · 4.08 KB
/
compression_clif_aux.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include "compression/python/compression_clif_aux.h"
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"compression/python/compression_clif_aux.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Must come after foreach_target.h to avoid redefinition errors.
#include "compression/compress-inl.h"
#include "hwy/highway.h"
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
// compile pass, whereas we want this defined in the first.
#ifndef GEMMA_ONCE
#define GEMMA_ONCE
#include "compression/io.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
class WriterInterface {
public:
virtual ~WriterInterface() = default;
virtual void Insert(std::string name, absl::Span<const float> weights) = 0;
virtual void InsertNUQ(std::string name, absl::Span<const float> weights) = 0;
virtual void InsertBfloat16(std::string name,
absl::Span<const float> weights) = 0;
virtual void AddScales(const std::vector<float>& scales) = 0;
virtual void Write(std::string path) = 0;
};
} // namespace gcpp
#endif // GEMMA_ONCE
// SIMD code, compiled once per target.
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
class SbsWriterImpl : public WriterInterface {
public:
SbsWriterImpl() : pool_(0), compressor_(pool_) {}
void Insert(std::string name, absl::Span<const float> weights) override {
const size_t out_size = CompressedArraySize<SfpStream>(weights.size());
sfp_streams_.push_back(std::vector<SfpStream>(out_size));
compressor_.Insert<SfpStream>(name.data(), weights.data(), weights.size(),
working_set_, out_size,
sfp_streams_.back().data(), 0, pool_);
}
void InsertNUQ(std::string name, absl::Span<const float> weights) override {
const size_t out_size = CompressedArraySize<NuqStream>(weights.size());
nuq_streams_.push_back(std::vector<NuqStream>(out_size));
compressor_.Insert<NuqStream>(name.data(), weights.data(), weights.size(),
working_set_, out_size,
nuq_streams_.back().data(), 0, pool_);
}
void InsertBfloat16(std::string name,
absl::Span<const float> weights) override {
const size_t out_size =
CompressedArraySize<hwy::bfloat16_t>(weights.size());
bf16_streams_.push_back(std::vector<hwy::bfloat16_t>(out_size));
compressor_.Insert<hwy::bfloat16_t>(name.data(), weights.data(),
weights.size(), working_set_, out_size,
bf16_streams_.back().data(), 0, pool_);
}
void AddScales(const std::vector<float>& scales) override {
HWY_ASSERT(scales_.empty());
scales_ = scales;
compressor_.AddScales(scales_.data(), scales_.size());
}
void Write(std::string path) override {
compressor_.WriteAll(pool_, gcpp::Path(path));
}
hwy::ThreadPool pool_;
Compressor compressor_;
CompressWorkingSet working_set_;
std::vector<std::vector<SfpStream>> sfp_streams_;
std::vector<std::vector<NuqStream>> nuq_streams_;
std::vector<std::vector<hwy::bfloat16_t>> bf16_streams_;
std::vector<float> scales_;
};
WriterInterface* NewSbsWriter() { return new SbsWriterImpl(); }
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(NewSbsWriter);
SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {}
SbsWriter::~SbsWriter() = default;
void SbsWriter::Insert(std::string name, absl::Span<const float> weights) {
impl_->Insert(name, weights);
}
void SbsWriter::InsertNUQ(std::string name, absl::Span<const float> weights) {
impl_->InsertNUQ(name, weights);
}
void SbsWriter::InsertBfloat16(std::string name,
absl::Span<const float> weights) {
impl_->InsertBfloat16(name, weights);
}
void SbsWriter::AddScales(const std::vector<float>& scales) {
impl_->AddScales(scales);
}
void SbsWriter::Write(std::string path) { impl_->Write(path); }
} // namespace gcpp
#endif // HWY_ONCE