Skip to content

Commit 81b8704

Browse files
authored
Avoid Copying Fit Models (#219)
* Modify Prediction to avoid copying the model and fit when used in a chained call such as: model.fit(dataset).predict(features); * use _t helpers
1 parent 8605e60 commit 81b8704

File tree

4 files changed

+40
-24
lines changed

4 files changed

+40
-24
lines changed

include/albatross/src/core/declarations.hpp

+17
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@ using mapbox::util::variant;
2323

2424
namespace albatross {
2525

26+
/*
27+
* We frequently inspect for definitions of functions which
28+
* must be defined for const references to objects
29+
* (so that repeated evaluations return the same thing
30+
* and so the computations are not repeatedly copying.)
31+
* This type conversion utility will turn a type `T` into `const T&`
32+
*/
33+
template <class T> struct const_ref {
34+
typedef std::add_lvalue_reference_t<std::add_const_t<T>> type;
35+
};
36+
37+
template <typename T> using const_ref_t = typename const_ref<T>::type;
38+
2639
/*
2740
* Model
2841
*/
@@ -35,6 +48,10 @@ template <typename T> struct PredictTypeIdentity;
3548
template <typename ModelType, typename FeatureType, typename FitType>
3649
class Prediction;
3750

51+
template <typename ModelType, typename FeatureType, typename FitType>
52+
using PredictionReference =
53+
Prediction<const_ref_t<ModelType>, FeatureType, const_ref_t<FitType>>;
54+
3855
template <typename ModelType, typename FitType> class FitModel;
3956

4057
template <typename Derived> class Fit {};

include/albatross/src/core/fit_model.hpp

+16-7
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,32 @@ template <typename ModelType, typename Fit> class FitModel {
3030
FitModel(const ModelType &model, Fit &&fit)
3131
: model_(model), fit_(std::move(fit)) {}
3232

33+
// When FitModel is an lvalue we store a reference to the fit
34+
// inside the resulting Prediction class.
35+
template <typename PredictFeatureType>
36+
const PredictionReference<ModelType, PredictFeatureType, Fit>
37+
predict(const std::vector<PredictFeatureType> &features) const & {
38+
return PredictionReference<ModelType, PredictFeatureType, Fit>(model_, fit_,
39+
features);
40+
}
41+
42+
// When FitModel is an rvalue the Fit will be a temporary so
43+
// we move it into the Prediction class to be stored there.
3344
template <typename PredictFeatureType>
3445
Prediction<ModelType, PredictFeatureType, Fit>
35-
predict(const std::vector<PredictFeatureType> &features) const {
36-
return Prediction<ModelType, PredictFeatureType, Fit>(model_, fit_,
37-
features);
46+
predict(const std::vector<PredictFeatureType> &features) && {
47+
return Prediction<ModelType, PredictFeatureType, Fit>(
48+
std::move(model_), std::move(fit_), features);
3849
}
3950

4051
template <typename PredictFeatureType>
41-
Prediction<ModelType, Measurement<PredictFeatureType>, Fit>
42-
predict_with_measurement_noise(
52+
auto predict_with_measurement_noise(
4353
const std::vector<PredictFeatureType> &features) const {
4454
std::vector<Measurement<PredictFeatureType>> measurements;
4555
for (const auto &f : features) {
4656
measurements.emplace_back(Measurement<PredictFeatureType>(f));
4757
}
48-
return Prediction<ModelType, Measurement<PredictFeatureType>, Fit>(
49-
model_, fit_, measurements);
58+
return predict(measurements);
5059
}
5160

5261
Fit get_fit() const { return fit_; }

include/albatross/src/core/prediction.hpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,17 @@ class JointPredictor {
113113
template <typename ModelType, typename FeatureType, typename FitType>
114114
class Prediction {
115115

116+
using PlainModelType = typename std::decay<ModelType>::type;
117+
using PlainFitType = typename std::decay<FitType>::type;
118+
116119
public:
117-
Prediction(const ModelType &model, const FitType &fit,
120+
Prediction(const PlainModelType &model, const PlainFitType &fit,
118121
const std::vector<FeatureType> &features)
119122
: model_(model), fit_(fit), features_(features) {}
120123

121-
Prediction(const ModelType &model, const FitType &fit,
122-
std::vector<FeatureType> &&features)
123-
: model_(model), fit_(fit), features_(std::move(features)) {}
124+
Prediction(PlainModelType &&model, PlainFitType &&fit,
125+
const std::vector<FeatureType> &features)
126+
: model_(std::move(model)), fit_(std::move(fit)), features_(features) {}
124127

125128
// Mean
126129
template <typename DummyType = FeatureType,

include/albatross/src/details/traits.hpp

-13
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,6 @@
1515

1616
namespace albatross {
1717

18-
/*
19-
* We frequently inspect for definitions of functions which
20-
* must be defined for const references to objects
21-
* (so that repeated evaluations return the same thing
22-
* and so the computations are not repeatedly copying.)
23-
* This type conversion utility will turn a type `T` into `const T&`
24-
*/
25-
template <class T> struct const_ref {
26-
typedef
27-
typename std::add_lvalue_reference<typename std::add_const<T>::type>::type
28-
type;
29-
};
30-
3118
/*
3219
* This little trick was borrowed from cereal, you can think of it as
3320
* a function that will always return false ... but that doesn't

0 commit comments

Comments
 (0)