forked from elastic/ml-cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathModelTestHelpers.h
175 lines (146 loc) · 6.43 KB
/
ModelTestHelpers.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the following additional limitation. Functionality enabled by the
* files subject to the Elastic License 2.0 may only be used in production when
* invoked by an Elasticsearch process with a license key installed that permits
* use of machine learning features. You may not use this file except in
* compliance with the Elastic License 2.0 and the foregoing additional
* limitation.
*/
#ifndef INCLUDED_ml_model_ModelTestHelpers_h
#define INCLUDED_ml_model_ModelTestHelpers_h
#include <core/CJsonStatePersistInserter.h>
#include <core/CJsonStateRestoreTraverser.h>
#include <model/CDataGatherer.h>
#include <model/CSearchKey.h>
#include <model/ModelTypes.h>
#include <boost/test/unit_test.hpp>
namespace ml {
namespace model {
const CSearchKey KEY;
const std::string EMPTY_STRING;
static void testPersistence(const SModelParams& params,
const CDataGatherer& origGatherer,
model_t::EAnalysisCategory category) {
// Test persistence. (We check for idempotency.)
std::ostringstream origJson;
core::CJsonStatePersistInserter::persist(
origJson, [&origGatherer](core::CJsonStatePersistInserter& inserter) {
origGatherer.acceptPersistInserter(inserter);
});
LOG_DEBUG(<< "gatherer JSON size " << origJson.str().size());
LOG_TRACE(<< "gatherer JSON representation:\n" << origJson.str());
// Restore the JSON into a new filter
// The traverser expects the state json in a embedded document
std::istringstream origJsonStrm{"{\"topLevel\" : " + origJson.str() + "}"};
core::CJsonStateRestoreTraverser traverser(origJsonStrm);
CDataGatherer restoredGatherer(category, model_t::E_None, params, EMPTY_STRING,
EMPTY_STRING, EMPTY_STRING, EMPTY_STRING,
EMPTY_STRING, {}, KEY, traverser);
BOOST_REQUIRE_EQUAL(origGatherer.checksum(), restoredGatherer.checksum());
// The JSON representation of the new filter should be the
// same as the original
std::ostringstream newJson;
core::CJsonStatePersistInserter::persist(
newJson, [&restoredGatherer](core::CJsonStatePersistInserter& inserter) {
restoredGatherer.acceptPersistInserter(inserter);
});
BOOST_REQUIRE_EQUAL(origJson.str(), newJson.str());
}
static void testGathererAttributes(const CDataGatherer& gatherer,
core_t::TTime startTime,
core_t::TTime bucketLength) {
BOOST_REQUIRE_EQUAL(1, gatherer.numberActivePeople());
BOOST_REQUIRE_EQUAL(1, gatherer.numberByFieldValues());
BOOST_REQUIRE_EQUAL(std::string("p"), gatherer.personName(0));
BOOST_REQUIRE_EQUAL(std::string("-"), gatherer.personName(1));
std::size_t pid;
BOOST_TEST_REQUIRE(gatherer.personId("p", pid));
BOOST_REQUIRE_EQUAL(0, pid);
BOOST_TEST_REQUIRE(!gatherer.personId("a.n.other p", pid));
BOOST_REQUIRE_EQUAL(0, gatherer.numberActiveAttributes());
BOOST_REQUIRE_EQUAL(0, gatherer.numberOverFieldValues());
BOOST_REQUIRE_EQUAL(startTime, gatherer.currentBucketStartTime());
BOOST_REQUIRE_EQUAL(bucketLength, gatherer.bucketLength());
}
class CDataGathererBuilder {
public:
using TFeatureVec = CDataGatherer::TFeatureVec;
using TStrVec = CDataGatherer::TStrVec;
public:
CDataGathererBuilder(model_t::EAnalysisCategory gathererType,
const TFeatureVec& features,
const SModelParams& params,
const CSearchKey& searchKey,
const core_t::TTime startTime)
: m_Features(features), m_Params(params), m_StartTime(startTime),
m_SearchKey(searchKey), m_GathererType(gathererType) {}
CDataGatherer build() const {
return {m_GathererType,
m_SummaryMode,
m_Params,
m_SummaryCountFieldName,
m_PartitionFieldValue,
m_PersonFieldName,
m_AttributeFieldName,
m_ValueFieldName,
m_InfluenceFieldNames,
m_SearchKey,
m_Features,
m_StartTime,
m_SampleCountOverride};
}
std::shared_ptr<CDataGatherer> buildSharedPtr() const {
return std::make_shared<CDataGatherer>(
m_GathererType, m_SummaryMode, m_Params, m_SummaryCountFieldName,
m_PartitionFieldValue, m_PersonFieldName, m_AttributeFieldName,
m_ValueFieldName, m_InfluenceFieldNames, m_SearchKey, m_Features,
m_StartTime, m_SampleCountOverride);
}
CDataGathererBuilder& partitionFieldValue(std::string_view partitionFieldValue) {
m_PartitionFieldValue = partitionFieldValue;
return *this;
}
CDataGathererBuilder& personFieldName(std::string_view personFieldName) {
m_PersonFieldName = personFieldName;
return *this;
}
CDataGathererBuilder& valueFieldName(std::string_view valueFieldName) {
m_ValueFieldName = valueFieldName;
return *this;
}
CDataGathererBuilder& influenceFieldNames(const TStrVec& influenceFieldName) {
m_InfluenceFieldNames = influenceFieldName;
return *this;
}
CDataGathererBuilder& attributeFieldName(std::string_view attributeFieldName) {
m_AttributeFieldName = attributeFieldName;
return *this;
}
CDataGathererBuilder& gathererType(model_t::EAnalysisCategory gathererType) {
m_GathererType = gathererType;
return *this;
}
CDataGathererBuilder& sampleCountOverride(std::size_t sampleCount) {
m_SampleCountOverride = static_cast<int>(sampleCount);
return *this;
}
private:
const TFeatureVec& m_Features;
const SModelParams& m_Params;
core_t::TTime m_StartTime;
const CSearchKey& m_SearchKey;
model_t::EAnalysisCategory m_GathererType;
model_t::ESummaryMode m_SummaryMode{model_t::E_None};
std::string m_SummaryCountFieldName{EMPTY_STRING};
std::string m_PartitionFieldValue{EMPTY_STRING};
std::string m_PersonFieldName{EMPTY_STRING};
std::string m_AttributeFieldName{EMPTY_STRING};
std::string m_ValueFieldName{EMPTY_STRING};
TStrVec m_InfluenceFieldNames;
int m_SampleCountOverride{0};
};
}
}
#endif // INCLUDED_ml_model_ModelTestHelpers_h