Skip to content

Commit 788dbc9

Browse files
committed
substrait yaml file parser and function lookup support
1 parent 44bb6f7 commit 788dbc9

22 files changed

+2677
-2
lines changed

core/CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
14+
add_subdirectory(common)
15+
add_subdirectory(type)
16+
add_subdirectory(function)

core/common/CMakeLists.txt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
add_library(
14+
substrait_common
15+
Exceptions.cpp)
16+
17+
target_link_libraries(
18+
substrait_common
19+
fmt)
20+

core/common/Exceptions.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
15+
#include "common/Exceptions.h"
16+
#include "fmt/format.h"
17+
18+
namespace io::substrait::common {
19+
20+
SubstraitException::SubstraitException(
21+
std::string exceptionCode,
22+
std::string& exceptionMessage,
23+
Type exceptionType,
24+
std::string exceptionName)
25+
: msg_(fmt::format(
26+
"Exception: {}\nError Code: {}\nType: {}\nReason: {}\n"
27+
"Function: {}\nFile: {}\n:Line: {}\n",
28+
exceptionName,
29+
exceptionCode,
30+
exceptionType == Type::kSystem ? "system" : "user",
31+
exceptionMessage,
32+
__FUNCTION__,
33+
__FILE__,
34+
std::to_string(__LINE__))) {}
35+
36+
} // namespace io::substrait::common

core/function/CMakeLists.txt

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
set(FUNCTION_SRCS
14+
Function.cpp
15+
Extension.cpp
16+
../../include/function/FunctionMapping.h
17+
../../include/function/FunctionSignature.h
18+
FunctionLookup.cpp)
19+
20+
add_library(substrait_function ${FUNCTION_SRCS})
21+
22+
target_link_libraries(
23+
substrait_function
24+
substrait_type
25+
yaml-cpp)
26+
27+
if (${BUILD_TESTING})
28+
add_subdirectory(tests)
29+
endif ()

core/function/Extension.cpp

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
15+
#include "function/Extension.h"
16+
#include "yaml-cpp/yaml.h"
17+
18+
bool decodeFunctionVariant(
19+
const YAML::Node& node,
20+
io::substrait::FunctionVariant& function) {
21+
const auto& returnType = node["return"];
22+
if (returnType && returnType.IsScalar()) {
23+
/// Return type can be an expression.
24+
const auto& returnExpr = returnType.as<std::string>();
25+
std::stringstream ss(returnExpr);
26+
27+
// TODO: currently we only parse the last sentence of type definition, use
28+
// ANTLR in future.
29+
std::string lastReturnType;
30+
while (std::getline(ss, lastReturnType, '\n')) {
31+
}
32+
function.returnType = io::substrait::Type::decode(lastReturnType);
33+
}
34+
const auto& args = node["args"];
35+
if (args && args.IsSequence()) {
36+
for (auto& arg : args) {
37+
if (arg["options"]) { // enum argument
38+
auto enumArgument = std::make_shared<io::substrait::EnumArgument>(
39+
arg.as<io::substrait::EnumArgument>());
40+
function.arguments.emplace_back(enumArgument);
41+
} else if (arg["value"]) { // value argument
42+
auto valueArgument = std::make_shared<io::substrait::ValueArgument>(
43+
arg.as<io::substrait::ValueArgument>());
44+
function.arguments.emplace_back(valueArgument);
45+
} else { // type argument
46+
auto typeArgument = std::make_shared<io::substrait::TypeArgument>(
47+
arg.as<io::substrait::TypeArgument>());
48+
function.arguments.emplace_back(typeArgument);
49+
}
50+
}
51+
}
52+
53+
const auto& variadic = node["variadic"];
54+
if (variadic) {
55+
auto& min = variadic["min"];
56+
auto& max = variadic["max"];
57+
if (min) {
58+
function.variadic = std::make_optional<io::substrait::FunctionVariadic>(
59+
{min.as<int>(),
60+
max ? std::make_optional<int>(max.as<int>()) : std::nullopt});
61+
} else {
62+
function.variadic = std::nullopt;
63+
}
64+
} else {
65+
function.variadic = std::nullopt;
66+
}
67+
68+
return true;
69+
}
70+
71+
template <>
72+
struct YAML::convert<io::substrait::EnumArgument> {
73+
static bool decode(const Node& node, io::substrait::EnumArgument& argument) {
74+
// 'options' is required property
75+
const auto& options = node["options"];
76+
if (options && options.IsSequence()) {
77+
auto& required = node["required"];
78+
argument.required = required && required.as<bool>();
79+
return true;
80+
} else {
81+
return false;
82+
}
83+
}
84+
};
85+
86+
template <>
87+
struct YAML::convert<io::substrait::ValueArgument> {
88+
static bool decode(const Node& node, io::substrait::ValueArgument& argument) {
89+
const auto& value = node["value"];
90+
if (value && value.IsScalar()) {
91+
auto valueType = value.as<std::string>();
92+
argument.type = io::substrait::Type::decode(valueType);
93+
return true;
94+
}
95+
return false;
96+
}
97+
};
98+
99+
template <>
100+
struct YAML::convert<io::substrait::TypeArgument> {
101+
static bool decode(
102+
const YAML::Node& node,
103+
io::substrait::TypeArgument& argument) {
104+
// no properties need to populate for type argument, just return true if
105+
// 'type' element exists.
106+
if (node["type"]) {
107+
return true;
108+
}
109+
return false;
110+
}
111+
};
112+
113+
template <>
114+
struct YAML::convert<io::substrait::ScalarFunctionVariant> {
115+
static bool decode(
116+
const Node& node,
117+
io::substrait::ScalarFunctionVariant& function) {
118+
return decodeFunctionVariant(node, function);
119+
};
120+
};
121+
122+
template <>
123+
struct YAML::convert<io::substrait::AggregateFunctionVariant> {
124+
static bool decode(
125+
const Node& node,
126+
io::substrait::AggregateFunctionVariant& function) {
127+
const auto& res = decodeFunctionVariant(node, function);
128+
if (res) {
129+
const auto& intermediate = node["intermediate"];
130+
if (intermediate) {
131+
function.intermediate =
132+
io::substrait::ParameterizedType::decode(intermediate.as<std::string>());
133+
}
134+
}
135+
return res;
136+
}
137+
};
138+
139+
template <>
140+
struct YAML::convert<io::substrait::TypeVariant> {
141+
static bool decode(const Node& node, io::substrait::TypeVariant& typeAnchor) {
142+
const auto& name = node["name"];
143+
if (name && name.IsScalar()) {
144+
typeAnchor.name = name.as<std::string>();
145+
return true;
146+
}
147+
return false;
148+
}
149+
};
150+
151+
namespace io::substrait {
152+
153+
std::shared_ptr<Extension> Extension::load(const std::string& basePath) {
154+
static const std::vector<std::string> extensionFiles{
155+
"functions_aggregate_approx.yaml",
156+
"functions_aggregate_generic.yaml",
157+
"functions_arithmetic.yaml",
158+
"functions_arithmetic_decimal.yaml",
159+
"functions_boolean.yaml",
160+
"functions_comparison.yaml",
161+
"functions_datetime.yaml",
162+
"functions_logarithmic.yaml",
163+
"functions_rounding.yaml",
164+
"functions_string.yaml",
165+
"functions_set.yaml",
166+
};
167+
return load(basePath, extensionFiles);
168+
}
169+
170+
std::shared_ptr<Extension> Extension::load(
171+
const std::string& basePath,
172+
const std::vector<std::string>& extensionFiles) {
173+
std::vector<std::string> yamlExtensionFiles;
174+
yamlExtensionFiles.reserve(extensionFiles.size());
175+
for (auto& extensionFile : extensionFiles) {
176+
auto const pos = basePath.find_last_of('/');
177+
const auto& extensionUri = basePath.substr(0, pos) + "/" + extensionFile;
178+
yamlExtensionFiles.emplace_back(extensionUri);
179+
}
180+
return load(yamlExtensionFiles);
181+
}
182+
183+
std::shared_ptr<Extension> Extension::load(
184+
const std::vector<std::string>& extensionFiles) {
185+
auto extension = std::make_shared<Extension>();
186+
for (const auto& extensionUri : extensionFiles) {
187+
const auto& node = YAML::LoadFile(extensionUri);
188+
189+
const auto& scalarFunctions = node["scalar_functions"];
190+
if (scalarFunctions && scalarFunctions.IsSequence()) {
191+
for (auto& scalarFunctionNode : scalarFunctions) {
192+
const auto functionName = scalarFunctionNode["name"].as<std::string>();
193+
for (auto& scalaFunctionVariantNode : scalarFunctionNode["impls"]) {
194+
auto scalarFunctionVariant =
195+
scalaFunctionVariantNode.as<ScalarFunctionVariant>();
196+
scalarFunctionVariant.name = functionName;
197+
scalarFunctionVariant.uri = extensionUri;
198+
extension->addScalarFunctionVariant(
199+
std::make_shared<ScalarFunctionVariant>(scalarFunctionVariant));
200+
}
201+
}
202+
}
203+
204+
const auto& aggregateFunctions = node["aggregate_functions"];
205+
if (aggregateFunctions && aggregateFunctions.IsSequence()) {
206+
for (auto& aggregateFunctionNode : aggregateFunctions) {
207+
const auto functionName =
208+
aggregateFunctionNode["name"].as<std::string>();
209+
for (auto& aggregateFunctionVariantNode :
210+
aggregateFunctionNode["impls"]) {
211+
auto aggregateFunctionVariant =
212+
aggregateFunctionVariantNode.as<AggregateFunctionVariant>();
213+
aggregateFunctionVariant.name = functionName;
214+
aggregateFunctionVariant.uri = extensionUri;
215+
extension->addAggregateFunctionVariant(
216+
std::make_shared<AggregateFunctionVariant>(
217+
aggregateFunctionVariant));
218+
}
219+
}
220+
}
221+
222+
const auto& types = node["types"];
223+
if (types && types.IsSequence()) {
224+
for (auto& type : types) {
225+
auto typeAnchor = type.as<TypeVariant>();
226+
typeAnchor.uri = extensionUri;
227+
extension->addTypeVariant(std::make_shared<TypeVariant>(typeAnchor));
228+
}
229+
}
230+
}
231+
return extension;
232+
}
233+
234+
void Extension::addWindowFunctionVariant(
235+
const FunctionVariantPtr& functionVariant) {
236+
const auto& functionVariants =
237+
windowFunctionVariantMap_.find(functionVariant->name);
238+
if (functionVariants != windowFunctionVariantMap_.end()) {
239+
auto& variants = functionVariants->second;
240+
variants.emplace_back(functionVariant);
241+
} else {
242+
std::vector<FunctionVariantPtr> variants;
243+
variants.emplace_back(functionVariant);
244+
windowFunctionVariantMap_.insert(
245+
{functionVariant->name, std::move(variants)});
246+
}
247+
}
248+
249+
void Extension::addTypeVariant(const TypeVariantPtr& functionVariant) {
250+
typeVariantMap_.insert({functionVariant->name, functionVariant});
251+
}
252+
253+
TypeVariantPtr Extension::lookupType(const std::string& typeName) const {
254+
auto typeVariantIter = typeVariantMap_.find(typeName);
255+
if (typeVariantIter != typeVariantMap_.end()) {
256+
return typeVariantIter->second;
257+
}
258+
return nullptr;
259+
}
260+
261+
void Extension::addScalarFunctionVariant(
262+
const FunctionVariantPtr& functionVariant) {
263+
const auto& functionVariants =
264+
scalarFunctionVariantMap_.find(functionVariant->name);
265+
if (functionVariants != scalarFunctionVariantMap_.end()) {
266+
auto& variants = functionVariants->second;
267+
variants.emplace_back(functionVariant);
268+
} else {
269+
std::vector<FunctionVariantPtr> variants;
270+
variants.emplace_back(functionVariant);
271+
scalarFunctionVariantMap_.insert(
272+
{functionVariant->name, std::move(variants)});
273+
}
274+
}
275+
276+
void Extension::addAggregateFunctionVariant(
277+
const FunctionVariantPtr& functionVariant) {
278+
const auto& functionVariants =
279+
aggregateFunctionVariantMap_.find(functionVariant->name);
280+
if (functionVariants != aggregateFunctionVariantMap_.end()) {
281+
auto& variants = functionVariants->second;
282+
variants.emplace_back(functionVariant);
283+
} else {
284+
std::vector<FunctionVariantPtr> variants;
285+
variants.emplace_back(functionVariant);
286+
aggregateFunctionVariantMap_.insert(
287+
{functionVariant->name, std::move(variants)});
288+
}
289+
}
290+
291+
} // namespace io::substrait

0 commit comments

Comments
 (0)