Skip to content

Commit 5c3b548

Browse files
Added a trampoline class PlaintextImpl_helper to override PlaintextImpl's virtual functions
1 parent e3894cb commit 5c3b548

File tree

1 file changed

+80
-1
lines changed

1 file changed

+80
-1
lines changed

src/lib/bindings.cpp

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1057,9 +1057,88 @@ void bind_keys(py::module &m)
10571057
.def(py::init<>());
10581058
}
10591059

1060+
// PlaintextImpl is an abstract class, so we should use a helper (trampoline) class
1061+
class PlaintextImpl_helper : public PlaintextImpl
1062+
{
1063+
public:
1064+
using PlaintextImpl::PlaintextImpl; // inherited constructors
1065+
1066+
// the PlaintextImpl virtual functions' overrides
1067+
bool Encode() override {
1068+
PYBIND11_OVERRIDE_PURE(
1069+
bool, // return type
1070+
PlaintextImpl, // parent class
1071+
Encode // function name
1072+
// no arguments
1073+
);
1074+
}
1075+
bool Decode() override {
1076+
PYBIND11_OVERRIDE_PURE(
1077+
bool, // return type
1078+
PlaintextImpl, // parent class
1079+
Decode // function name
1080+
// no arguments
1081+
);
1082+
}
1083+
bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode) override {
1084+
PYBIND11_OVERRIDE(
1085+
bool, // return type
1086+
PlaintextImpl, // parent class
1087+
Decode, // function name
1088+
depth, scalingFactor, scalTech, executionMode // arguments
1089+
);
1090+
}
1091+
size_t GetLength() const override {
1092+
PYBIND11_OVERRIDE_PURE(
1093+
size_t, // return type
1094+
PlaintextImpl, // parent class
1095+
GetLength // function name
1096+
// no arguments
1097+
);
1098+
}
1099+
void SetLength(size_t newSize) override {
1100+
PYBIND11_OVERRIDE(
1101+
void, // return type
1102+
PlaintextImpl, // parent class
1103+
SetLength, // function name
1104+
newSize // arguments
1105+
);
1106+
}
1107+
double GetLogError() const override {
1108+
PYBIND11_OVERRIDE(double, PlaintextImpl, GetLogError);
1109+
}
1110+
double GetLogPrecision() const override {
1111+
PYBIND11_OVERRIDE(double, PlaintextImpl, GetLogPrecision);
1112+
}
1113+
const std::string& GetStringValue() const override {
1114+
PYBIND11_OVERRIDE(const std::string&, PlaintextImpl, GetStringValue);
1115+
}
1116+
const std::vector<int64_t>& GetCoefPackedValue() const override {
1117+
PYBIND11_OVERRIDE(const std::vector<int64_t>&, PlaintextImpl, GetCoefPackedValue);
1118+
}
1119+
const std::vector<int64_t>& GetPackedValue() const override {
1120+
PYBIND11_OVERRIDE(const std::vector<int64_t>&, PlaintextImpl, GetPackedValue);
1121+
}
1122+
const std::vector<std::complex<double>>& GetCKKSPackedValue() const override {
1123+
PYBIND11_OVERRIDE(const std::vector<std::complex<double>>&, PlaintextImpl, GetCKKSPackedValue);
1124+
}
1125+
std::vector<double> GetRealPackedValue() const override {
1126+
PYBIND11_OVERRIDE(std::vector<double>, PlaintextImpl, GetRealPackedValue);
1127+
}
1128+
void SetStringValue(const std::string& str) override {
1129+
PYBIND11_OVERRIDE(void, PlaintextImpl, SetStringValue, str);
1130+
}
1131+
void SetIntVectorValue(const std::vector<int64_t>& vec) override {
1132+
PYBIND11_OVERRIDE(void, PlaintextImpl, SetIntVectorValue, vec);
1133+
}
1134+
std::string GetFormattedValues(int64_t precision) const override {
1135+
PYBIND11_OVERRIDE(std::string, PlaintextImpl, GetFormattedValues, precision);
1136+
}
1137+
};
1138+
10601139
void bind_encodings(py::module &m)
10611140
{
1062-
py::class_<PlaintextImpl, std::shared_ptr<PlaintextImpl>>(m, "Plaintext")
1141+
py::class_<PlaintextImpl, std::shared_ptr<PlaintextImpl>, PlaintextImpl_helper>(m, "Plaintext")
10631142
.def("GetScalingFactor", &PlaintextImpl::GetScalingFactor,
10641143
ptx_GetScalingFactor_docs)
10651144
.def("SetScalingFactor", &PlaintextImpl::SetScalingFactor,

0 commit comments

Comments
 (0)