Skip to content

Commit bdf731f

Browse files
jckingcopybara-github
authored andcommitted
Implement hand-rolled variant for cel::StructValue
PiperOrigin-RevId: 740908079
1 parent 32a793d commit bdf731f

13 files changed

+603
-445
lines changed

common/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,7 @@ cc_library(
603603
"//extensions/protobuf/internal:map_reflection",
604604
"//extensions/protobuf/internal:qualify",
605605
"//internal:casts",
606+
"//internal:empty_descriptors",
606607
"//internal:json",
607608
"//internal:manual",
608609
"//internal:message_equality",

common/values/message_value.cc

+4-2
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,13 @@ common_internal::ValueVariant MessageValue::ToValueVariant() && {
294294

295295
common_internal::StructValueVariant MessageValue::ToStructValueVariant()
296296
const& {
297-
return absl::get<ParsedMessageValue>(variant_);
297+
return common_internal::StructValueVariant(
298+
absl::get<ParsedMessageValue>(variant_));
298299
}
299300

300301
common_internal::StructValueVariant MessageValue::ToStructValueVariant() && {
301-
return absl::get<ParsedMessageValue>(std::move(variant_));
302+
return common_internal::StructValueVariant(
303+
absl::get<ParsedMessageValue>(std::move(variant_)));
302304
}
303305

304306
} // namespace cel

common/values/parsed_message_value.cc

+64-32
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
#include <cstdint>
1818
#include <limits>
1919
#include <string>
20+
#include <type_traits>
2021
#include <utility>
2122
#include <vector>
2223

24+
#include "google/protobuf/empty.pb.h"
2325
#include "absl/base/nullability.h"
2426
#include "absl/base/optimization.h"
2527
#include "absl/log/absl_check.h"
@@ -33,6 +35,7 @@
3335
#include "common/memory.h"
3436
#include "common/value.h"
3537
#include "extensions/protobuf/internal/qualify.h"
38+
#include "internal/empty_descriptors.h"
3639
#include "internal/json.h"
3740
#include "internal/message_equality.h"
3841
#include "internal/status_macros.h"
@@ -42,16 +45,37 @@
4245
#include "google/protobuf/descriptor.h"
4346
#include "google/protobuf/io/zero_copy_stream.h"
4447
#include "google/protobuf/message.h"
48+
#include "google/protobuf/message_lite.h"
4549

4650
namespace cel {
4751

52+
namespace {
53+
4854
using ::cel::well_known_types::ValueReflection;
4955

56+
template <typename T>
57+
std::enable_if_t<std::is_base_of_v<google::protobuf::Message, T>,
58+
absl::Nonnull<const google::protobuf::Message*>>
59+
EmptyParsedMessageValue() {
60+
return &T::default_instance();
61+
}
62+
63+
template <typename T>
64+
std::enable_if_t<
65+
std::conjunction_v<std::is_base_of<google::protobuf::MessageLite, T>,
66+
std::negation<std::is_base_of<google::protobuf::Message, T>>>,
67+
absl::Nonnull<const google::protobuf::Message*>>
68+
EmptyParsedMessageValue() {
69+
return internal::GetEmptyDefaultInstance();
70+
}
71+
72+
} // namespace
73+
74+
ParsedMessageValue::ParsedMessageValue()
75+
: value_(EmptyParsedMessageValue<google::protobuf::Empty>()),
76+
arena_(nullptr) {}
77+
5078
bool ParsedMessageValue::IsZeroValue() const {
51-
ABSL_DCHECK(*this);
52-
if (ABSL_PREDICT_FALSE(value_ == nullptr)) {
53-
return true;
54-
}
5579
const auto* reflection = GetReflection();
5680
if (!reflection->GetUnknownFields(*value_).empty()) {
5781
return false;
@@ -62,9 +86,6 @@ bool ParsedMessageValue::IsZeroValue() const {
6286
}
6387

6488
std::string ParsedMessageValue::DebugString() const {
65-
if (ABSL_PREDICT_FALSE(value_ == nullptr)) {
66-
return "INVALID";
67-
}
6889
return absl::StrCat(*value_);
6990
}
7091

@@ -75,11 +96,6 @@ absl::Status ParsedMessageValue::SerializeTo(
7596
ABSL_DCHECK(descriptor_pool != nullptr);
7697
ABSL_DCHECK(message_factory != nullptr);
7798
ABSL_DCHECK(output != nullptr);
78-
ABSL_DCHECK(*this);
79-
80-
if (ABSL_PREDICT_FALSE(value_ == nullptr)) {
81-
return absl::OkStatus();
82-
}
8399

84100
if (!value_->SerializePartialToZeroCopyStream(output)) {
85101
return absl::UnknownError(
@@ -97,16 +113,11 @@ absl::Status ParsedMessageValue::ConvertToJson(
97113
ABSL_DCHECK(json != nullptr);
98114
ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(),
99115
google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE);
100-
ABSL_DCHECK(*this);
101116

102117
ValueReflection value_reflection;
103118
CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor()));
104119
google::protobuf::Message* json_object = value_reflection.MutableStructValue(json);
105120

106-
if (ABSL_PREDICT_FALSE(value_ == nullptr)) {
107-
json_object->Clear();
108-
return absl::OkStatus();
109-
}
110121
return internal::MessageToJson(*value_, descriptor_pool, message_factory,
111122
json_object);
112123
}
@@ -120,12 +131,7 @@ absl::Status ParsedMessageValue::ConvertToJsonObject(
120131
ABSL_DCHECK(json != nullptr);
121132
ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(),
122133
google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT);
123-
ABSL_DCHECK(*this);
124134

125-
if (ABSL_PREDICT_FALSE(value_ == nullptr)) {
126-
json->Clear();
127-
return absl::OkStatus();
128-
}
129135
return internal::MessageToJson(*value_, descriptor_pool, message_factory,
130136
json);
131137
}
@@ -135,7 +141,11 @@ absl::Status ParsedMessageValue::Equal(
135141
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
136142
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
137143
absl::Nonnull<google::protobuf::Arena*> arena, absl::Nonnull<Value*> result) const {
138-
ABSL_DCHECK(*this);
144+
ABSL_DCHECK(descriptor_pool != nullptr);
145+
ABSL_DCHECK(message_factory != nullptr);
146+
ABSL_DCHECK(arena != nullptr);
147+
ABSL_DCHECK(result != nullptr);
148+
139149
if (auto other_message = other.AsParsedMessage(); other_message) {
140150
CEL_ASSIGN_OR_RETURN(
141151
auto equal, internal::MessageEquals(*value_, **other_message,
@@ -154,10 +164,8 @@ absl::Status ParsedMessageValue::Equal(
154164

155165
ParsedMessageValue ParsedMessageValue::Clone(
156166
absl::Nonnull<google::protobuf::Arena*> arena) const {
157-
ABSL_DCHECK(*this);
158-
if (ABSL_PREDICT_FALSE(value_ == nullptr)) {
159-
return ParsedMessageValue();
160-
}
167+
ABSL_DCHECK(arena != nullptr);
168+
161169
if (arena_ == arena) {
162170
return *this;
163171
}
@@ -171,6 +179,11 @@ absl::Status ParsedMessageValue::GetFieldByName(
171179
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
172180
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
173181
absl::Nonnull<google::protobuf::Arena*> arena, absl::Nonnull<Value*> result) const {
182+
ABSL_DCHECK(descriptor_pool != nullptr);
183+
ABSL_DCHECK(message_factory != nullptr);
184+
ABSL_DCHECK(arena != nullptr);
185+
ABSL_DCHECK(result != nullptr);
186+
174187
const auto* descriptor = GetDescriptor();
175188
const auto* field = descriptor->FindFieldByName(name);
176189
if (field == nullptr) {
@@ -190,6 +203,11 @@ absl::Status ParsedMessageValue::GetFieldByNumber(
190203
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
191204
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
192205
absl::Nonnull<google::protobuf::Arena*> arena, absl::Nonnull<Value*> result) const {
206+
ABSL_DCHECK(descriptor_pool != nullptr);
207+
ABSL_DCHECK(message_factory != nullptr);
208+
ABSL_DCHECK(arena != nullptr);
209+
ABSL_DCHECK(result != nullptr);
210+
193211
const auto* descriptor = GetDescriptor();
194212
if (number < std::numeric_limits<int32_t>::min() ||
195213
number > std::numeric_limits<int32_t>::max()) {
@@ -238,10 +256,10 @@ absl::Status ParsedMessageValue::ForEachField(
238256
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
239257
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
240258
absl::Nonnull<google::protobuf::Arena*> arena) const {
241-
ABSL_DCHECK(*this);
242-
if (ABSL_PREDICT_FALSE(value_ == nullptr)) {
243-
return absl::OkStatus();
244-
}
259+
ABSL_DCHECK(descriptor_pool != nullptr);
260+
ABSL_DCHECK(message_factory != nullptr);
261+
ABSL_DCHECK(arena != nullptr);
262+
245263
std::vector<const google::protobuf::FieldDescriptor*> fields;
246264
const auto* reflection = GetReflection();
247265
reflection->ListFields(*value_, &fields);
@@ -322,7 +340,13 @@ absl::Status ParsedMessageValue::Qualify(
322340
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
323341
absl::Nonnull<google::protobuf::Arena*> arena, absl::Nonnull<Value*> result,
324342
absl::Nonnull<int*> count) const {
325-
ABSL_DCHECK(*this);
343+
ABSL_DCHECK(!qualifiers.empty());
344+
ABSL_DCHECK(descriptor_pool != nullptr);
345+
ABSL_DCHECK(message_factory != nullptr);
346+
ABSL_DCHECK(arena != nullptr);
347+
ABSL_DCHECK(result != nullptr);
348+
ABSL_DCHECK(count != nullptr);
349+
326350
if (ABSL_PREDICT_FALSE(qualifiers.empty())) {
327351
return absl::InvalidArgumentError("invalid select qualifier path.");
328352
}
@@ -357,13 +381,21 @@ absl::Status ParsedMessageValue::GetField(
357381
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
358382
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
359383
absl::Nonnull<google::protobuf::Arena*> arena, absl::Nonnull<Value*> result) const {
384+
ABSL_DCHECK(field != nullptr);
385+
ABSL_DCHECK(descriptor_pool != nullptr);
386+
ABSL_DCHECK(message_factory != nullptr);
387+
ABSL_DCHECK(arena != nullptr);
388+
ABSL_DCHECK(result != nullptr);
389+
360390
*result = Value::WrapField(unboxing_options, value_, field, descriptor_pool,
361391
message_factory, arena);
362392
return absl::OkStatus();
363393
}
364394

365395
bool ParsedMessageValue::HasField(
366396
absl::Nonnull<const google::protobuf::FieldDescriptor*> field) const {
397+
ABSL_DCHECK(field != nullptr);
398+
367399
const auto* reflection = GetReflection();
368400
if (field->is_map() || field->is_repeated()) {
369401
return reflection->FieldSize(*value_, field) > 0;

common/values/parsed_message_value.h

+6-11
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ class ParsedMessageValue final
7272
ABSL_DCHECK_OK(CheckArena(value_, arena_));
7373
}
7474

75-
// Places the `ParsedMessageValue` into an invalid state. Anything except
76-
// assigning to `MessageValue` is undefined behavior.
77-
ParsedMessageValue() = default;
78-
75+
// Places the `ParsedMessageValue` into a special state where it is logically
76+
// equivalent to the default instance of `google.protobuf.Empty`, however
77+
// dereferencing via `operator*` or `operator->` is not allowed.
78+
ParsedMessageValue();
7979
ParsedMessageValue(const ParsedMessageValue&) = default;
8080
ParsedMessageValue(ParsedMessageValue&&) = default;
8181
ParsedMessageValue& operator=(const ParsedMessageValue&) = default;
@@ -96,13 +96,11 @@ class ParsedMessageValue final
9696
}
9797

9898
const google::protobuf::Message& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND {
99-
ABSL_DCHECK(*this);
10099
return *value_;
101100
}
102101

103102
absl::Nonnull<const google::protobuf::Message*> operator->() const
104103
ABSL_ATTRIBUTE_LIFETIME_BOUND {
105-
ABSL_DCHECK(*this);
106104
return value_;
107105
}
108106

@@ -171,9 +169,6 @@ class ParsedMessageValue final
171169
absl::Nonnull<int*> count) const;
172170
using StructValueMixin::Qualify;
173171

174-
// Returns `true` if `ParsedMessageValue` is in a valid state.
175-
explicit operator bool() const { return value_ != nullptr; }
176-
177172
friend void swap(ParsedMessageValue& lhs, ParsedMessageValue& rhs) noexcept {
178173
using std::swap;
179174
swap(lhs.value_, rhs.value_);
@@ -205,8 +200,8 @@ class ParsedMessageValue final
205200

206201
bool HasField(absl::Nonnull<const google::protobuf::FieldDescriptor*> field) const;
207202

208-
absl::Nullable<const google::protobuf::Message*> value_ = nullptr;
209-
absl::Nullable<google::protobuf::Arena*> arena_ = nullptr;
203+
absl::Nonnull<const google::protobuf::Message*> value_;
204+
absl::Nullable<google::protobuf::Arena*> arena_;
210205
};
211206

212207
inline std::ostream& operator<<(std::ostream& out,

common/values/parsed_message_value_test.cc

-10
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,6 @@ using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes;
4040

4141
using ParsedMessageValueTest = common_internal::ValueTest<>;
4242

43-
TEST_F(ParsedMessageValueTest, Default) {
44-
ParsedMessageValue value;
45-
EXPECT_FALSE(value);
46-
}
47-
48-
TEST_F(ParsedMessageValueTest, Field) {
49-
ParsedMessageValue value = MakeParsedMessage<TestAllTypesProto3>();
50-
EXPECT_TRUE(value);
51-
}
52-
5343
TEST_F(ParsedMessageValueTest, Kind) {
5444
ParsedMessageValue value = MakeParsedMessage<TestAllTypesProto3>();
5545
EXPECT_EQ(value.kind(), ParsedMessageValue::kKind);

0 commit comments

Comments
 (0)