Skip to content

Commit 3db64bc

Browse files
committed
Automated rollback of commit b0ab1f3
PiperOrigin-RevId: 660923516
1 parent 947086d commit 3db64bc

File tree

7 files changed

+132
-44
lines changed

7 files changed

+132
-44
lines changed

build/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ sh_binary(
2020
name = "gen_proto",
2121
srcs = ["gen_proto.sh"],
2222
data = [
23+
"//tfx/dsl/component/experimental:annotations_test_proto_pb2.py",
2324
"//tfx/examples/custom_components/presto_example_gen/proto:presto_config_pb2.py",
2425
"//tfx/extensions/experimental/kfp_compatibility/proto:kfp_component_spec_pb2.py",
2526
"//tfx/extensions/google_cloud_big_query/experimental/elwc_example_gen/proto:elwc_config_pb2.py",
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
load("//tfx:tfx.bzl", "tfx_py_proto_library")
2+
3+
# Copyright 2024 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
package(default_visibility = ["//visibility:public"])
17+
18+
licenses(["notice"]) # Apache 2.0
19+
20+
exports_files(["LICENSE"])
21+
22+
tfx_py_proto_library(
23+
name = "annotations_test_proto_py_pb2",
24+
srcs = ["annotations_test_proto.proto"],
25+
)

tfx/dsl/component/experimental/annotations.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from tfx.types import artifact
2424
from tfx.utils import deprecation_utils
2525

26+
from google.protobuf import message
27+
2628
try:
2729
import apache_beam as beam # pytype: disable=import-error # pylint: disable=g-import-not-at-top
2830

@@ -107,31 +109,43 @@ def __repr__(self):
107109
return '%s[%s]' % (self.__class__.__name__, self.type)
108110

109111

110-
class _PrimitiveTypeGenericMeta(type):
112+
class _PrimitiveAndProtoTypeGenericMeta(type):
111113
"""Metaclass for _PrimitiveTypeGeneric, to enable primitive type indexing."""
112114

113115
def __getitem__(
114-
cls: Type['_PrimitiveTypeGeneric'],
115-
params: Type[Union[int, float, str, bool, List[Any], Dict[Any, Any]]],
116+
cls: Type['_PrimitiveAndProtoTypeGeneric'],
117+
params: Type[
118+
Union[
119+
int,
120+
float,
121+
str,
122+
bool,
123+
List[Any],
124+
Dict[Any, Any],
125+
message.Message,
126+
],
127+
],
116128
):
117129
"""Metaclass method allowing indexing class (`_PrimitiveTypeGeneric[T]`)."""
118130
return cls._generic_getitem(params) # pytype: disable=attribute-error
119131

120132

121-
class _PrimitiveTypeGeneric(metaclass=_PrimitiveTypeGenericMeta):
133+
class _PrimitiveAndProtoTypeGeneric(
134+
metaclass=_PrimitiveAndProtoTypeGenericMeta
135+
):
122136
"""A generic that takes a primitive type as its single argument."""
123137

124138
def __init__( # pylint: disable=invalid-name
125139
self,
126-
artifact_type: Type[Union[int, float, str, bool]],
140+
artifact_type: Type[Union[int, float, str, bool, message.Message]],
127141
_init_via_getitem=False,
128142
):
129143
if not _init_via_getitem:
130144
class_name = self.__class__.__name__
131145
raise ValueError(
132146
(
133147
'%s should be instantiated via the syntax `%s[T]`, where T is '
134-
'`int`, `float`, `str`, or `bool`.'
148+
'`int`, `float`, `str`, `bool` or proto type.'
135149
)
136150
% (class_name, class_name)
137151
)
@@ -143,17 +157,20 @@ def _generic_getitem(cls, params):
143157
# Check that the given parameter is a primitive type.
144158
if (
145159
inspect.isclass(params)
146-
and params in (int, float, str, bool)
160+
and (
161+
params in (int, float, str, bool)
162+
or issubclass(params, message.Message)
163+
)
147164
or json_compat.is_json_compatible(params)
148165
):
149166
return cls(params, _init_via_getitem=True)
150167
else:
151168
class_name = cls.__name__
152169
raise ValueError(
153170
(
154-
'Generic type `%s[T]` expects the single parameter T to be '
155-
'`int`, `float`, `str`, `bool` or JSON-compatible types '
156-
'(Dict[str, T], List[T]) (got %r instead).'
171+
'Generic type `%s[T]` expects the single parameter T to be `int`,'
172+
' `float`, `str`, `bool`, JSON-compatible types (Dict[str, T],'
173+
' List[T]) or a proto type. (got %r instead).'
157174
)
158175
% (class_name, params)
159176
)
@@ -252,7 +269,7 @@ class AsyncOutputArtifact(Generic[T]):
252269
"""Intermediate artifact object type annotation."""
253270

254271

255-
class Parameter(_PrimitiveTypeGeneric):
272+
class Parameter(_PrimitiveAndProtoTypeGeneric):
256273
"""Component parameter type annotation."""
257274

258275

tfx/dsl/component/experimental/annotations_test.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import apache_beam as beam
1919
import tensorflow as tf
2020
from tfx.dsl.component.experimental import annotations
21+
from tfx.dsl.component.experimental import annotations_test_proto_pb2
2122
from tfx.types import artifact
2223
from tfx.types import standard_artifacts
2324
from tfx.types import value_artifact
@@ -27,18 +28,21 @@ class AnnotationsTest(tf.test.TestCase):
2728

2829
def testArtifactGenericAnnotation(self):
2930
# Error: type hint whose parameter is not an Artifact subclass.
30-
with self.assertRaisesRegex(ValueError,
31-
'expects .* a concrete subclass of'):
31+
with self.assertRaisesRegex(
32+
ValueError, 'expects .* a concrete subclass of'
33+
):
3234
_ = annotations._ArtifactGeneric[int] # pytype: disable=unsupported-operands
3335

3436
# Error: type hint with abstract Artifact subclass.
35-
with self.assertRaisesRegex(ValueError,
36-
'expects .* a concrete subclass of'):
37+
with self.assertRaisesRegex(
38+
ValueError, 'expects .* a concrete subclass of'
39+
):
3740
_ = annotations._ArtifactGeneric[artifact.Artifact]
3841

3942
# Error: type hint with abstract Artifact subclass.
40-
with self.assertRaisesRegex(ValueError,
41-
'expects .* a concrete subclass of'):
43+
with self.assertRaisesRegex(
44+
ValueError, 'expects .* a concrete subclass of'
45+
):
4246
_ = annotations._ArtifactGeneric[value_artifact.ValueArtifact]
4347

4448
# OK.
@@ -49,56 +53,55 @@ def testArtifactAnnotationUsage(self):
4953
_ = annotations.OutputArtifact[standard_artifacts.Examples]
5054
_ = annotations.AsyncOutputArtifact[standard_artifacts.Model]
5155

52-
def testPrimitiveTypeGenericAnnotation(self):
53-
# Error: type hint whose parameter is not a primitive type
56+
def testPrimitivAndProtoTypeGenericAnnotation(self):
57+
# Error: type hint whose parameter is not a primitive or a proto type
5458
# pytype: disable=unsupported-operands
5559
with self.assertRaisesRegex(
5660
ValueError, 'T to be `int`, `float`, `str`, `bool`'
5761
):
58-
_ = annotations._PrimitiveTypeGeneric[artifact.Artifact]
62+
_ = annotations._PrimitiveAndProtoTypeGeneric[artifact.Artifact]
5963
with self.assertRaisesRegex(
6064
ValueError, 'T to be `int`, `float`, `str`, `bool`'
6165
):
62-
_ = annotations._PrimitiveTypeGeneric[object]
66+
_ = annotations._PrimitiveAndProtoTypeGeneric[object]
6367
with self.assertRaisesRegex(
6468
ValueError, 'T to be `int`, `float`, `str`, `bool`'
6569
):
66-
_ = annotations._PrimitiveTypeGeneric[123]
70+
_ = annotations._PrimitiveAndProtoTypeGeneric[123]
6771
with self.assertRaisesRegex(
6872
ValueError, 'T to be `int`, `float`, `str`, `bool`'
6973
):
70-
_ = annotations._PrimitiveTypeGeneric['string']
74+
_ = annotations._PrimitiveAndProtoTypeGeneric['string']
7175
with self.assertRaisesRegex(
7276
ValueError, 'T to be `int`, `float`, `str`, `bool`'
7377
):
74-
_ = annotations._PrimitiveTypeGeneric[Dict[int, int]]
78+
_ = annotations._PrimitiveAndProtoTypeGeneric[Dict[int, int]]
7579
with self.assertRaisesRegex(
7680
ValueError, 'T to be `int`, `float`, `str`, `bool`'
7781
):
78-
_ = annotations._PrimitiveTypeGeneric[bytes]
82+
_ = annotations._PrimitiveAndProtoTypeGeneric[bytes]
7983
# pytype: enable=unsupported-operands
8084
# OK.
81-
_ = annotations._PrimitiveTypeGeneric[int]
82-
_ = annotations._PrimitiveTypeGeneric[float]
83-
_ = annotations._PrimitiveTypeGeneric[str]
84-
_ = annotations._PrimitiveTypeGeneric[bool]
85-
_ = annotations._PrimitiveTypeGeneric[Dict[str, float]]
86-
_ = annotations._PrimitiveTypeGeneric[bool]
85+
_ = annotations._PrimitiveAndProtoTypeGeneric[int]
86+
_ = annotations._PrimitiveAndProtoTypeGeneric[float]
87+
_ = annotations._PrimitiveAndProtoTypeGeneric[str]
88+
_ = annotations._PrimitiveAndProtoTypeGeneric[bool]
89+
_ = annotations._PrimitiveAndProtoTypeGeneric[Dict[str, float]]
90+
_ = annotations._PrimitiveAndProtoTypeGeneric[bool]
91+
_ = annotations._PrimitiveAndProtoTypeGeneric[
92+
annotations_test_proto_pb2.TestMessage
93+
]
8794

8895
def testPipelineTypeGenericAnnotation(self):
8996
# Error: type hint whose parameter is not a primitive type
90-
with self.assertRaisesRegex(
91-
ValueError, 'T to be `beam.Pipeline`'):
97+
with self.assertRaisesRegex(ValueError, 'T to be `beam.Pipeline`'):
9298
_ = annotations._PipelineTypeGeneric[artifact.Artifact]
93-
with self.assertRaisesRegex(
94-
ValueError, 'T to be `beam.Pipeline`'):
99+
with self.assertRaisesRegex(ValueError, 'T to be `beam.Pipeline`'):
95100
_ = annotations._PipelineTypeGeneric[object]
96101
# pytype: disable=unsupported-operands
97-
with self.assertRaisesRegex(
98-
ValueError, 'T to be `beam.Pipeline`'):
102+
with self.assertRaisesRegex(ValueError, 'T to be `beam.Pipeline`'):
99103
_ = annotations._PipelineTypeGeneric[123]
100-
with self.assertRaisesRegex(
101-
ValueError, 'T to be `beam.Pipeline`'):
104+
with self.assertRaisesRegex(ValueError, 'T to be `beam.Pipeline`'):
102105
_ = annotations._PipelineTypeGeneric['string']
103106
# pytype: enable=unsupported-operands
104107

@@ -110,6 +113,7 @@ def testParameterUsage(self):
110113
_ = annotations.Parameter[float]
111114
_ = annotations.Parameter[str]
112115
_ = annotations.Parameter[bool]
116+
_ = annotations.Parameter[annotations_test_proto_pb2.TestMessage]
113117

114118

115119
if __name__ == '__main__':
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Copyright 2024 Google LLC. All Rights Reserved.
2+
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
syntax = "proto3";
15+
16+
package tfx.dsl.component.experimental;
17+
18+
message TestMessage {
19+
int32 number = 1;
20+
string name = 2;
21+
}

tfx/dsl/component/experimental/utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tfx.types import artifact
2626
from tfx.types import component_spec
2727
from tfx.types import system_executions
28+
from google.protobuf import message
2829

2930

3031
class ArgFormats(enum.Enum):
@@ -224,10 +225,17 @@ def _create_component_spec_class(
224225
json_compatible_outputs[key],
225226
)
226227
if parameters:
227-
for key, primitive_type in parameters.items():
228-
spec_parameters[key] = component_spec.ExecutionParameter(
229-
type=primitive_type, optional=(key in arg_defaults)
230-
)
228+
for key, param_type in parameters.items():
229+
if inspect.isclass(param_type) and issubclass(
230+
param_type, message.Message
231+
):
232+
spec_parameters[key] = component_spec.ExecutionParameter(
233+
type=param_type, optional=(key in arg_defaults), use_proto=True
234+
)
235+
else:
236+
spec_parameters[key] = component_spec.ExecutionParameter(
237+
type=param_type, optional=(key in arg_defaults)
238+
)
231239
component_spec_class = type(
232240
'%s_Spec' % func.__name__,
233241
(tfx_types.ComponentSpec,),

tfx/dsl/component/experimental/utils_test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Dict, List
1919
import tensorflow as tf
2020
from tfx.dsl.component.experimental import annotations
21+
from tfx.dsl.component.experimental import annotations_test_proto_pb2
2122
from tfx.dsl.component.experimental import decorators
2223
from tfx.dsl.component.experimental import function_parser
2324
from tfx.dsl.component.experimental import utils
@@ -106,6 +107,9 @@ def func_with_primitive_parameter(
106107
float_param: annotations.Parameter[float],
107108
str_param: annotations.Parameter[str],
108109
bool_param: annotations.Parameter[bool],
110+
proto_param: annotations.Parameter[
111+
annotations_test_proto_pb2.TestMessage
112+
],
109113
dict_int_param: annotations.Parameter[Dict[str, int]],
110114
list_bool_param: annotations.Parameter[List[bool]],
111115
dict_list_bool_param: annotations.Parameter[Dict[str, List[bool]]],
@@ -124,6 +128,7 @@ def func_with_primitive_parameter(
124128
'float_param': float,
125129
'str_param': str,
126130
'bool_param': bool,
131+
'proto_param': annotations_test_proto_pb2.TestMessage,
127132
'dict_int_param': Dict[str, int],
128133
'list_bool_param': List[bool],
129134
'dict_list_bool_param': Dict[str, List[bool]],
@@ -193,6 +198,9 @@ def func(
193198
standard_artifacts.Examples
194199
],
195200
int_param: annotations.Parameter[int],
201+
proto_param: annotations.Parameter[
202+
annotations_test_proto_pb2.TestMessage
203+
],
196204
json_compat_param: annotations.Parameter[Dict[str, int]],
197205
str_param: annotations.Parameter[str] = 'foo',
198206
) -> annotations.OutputDict(
@@ -257,11 +265,15 @@ def func(
257265
spec_outputs['map_str_float_output'].type, standard_artifacts.JsonValue
258266
)
259267
spec_parameter = actual_spec_class.PARAMETERS
260-
self.assertLen(spec_parameter, 3)
268+
self.assertLen(spec_parameter, 4)
261269
self.assertEqual(spec_parameter['int_param'].type, int)
262270
self.assertEqual(spec_parameter['int_param'].optional, False)
263271
self.assertEqual(spec_parameter['str_param'].type, str)
264272
self.assertEqual(spec_parameter['str_param'].optional, True)
273+
self.assertEqual(
274+
spec_parameter['proto_param'].type,
275+
annotations_test_proto_pb2.TestMessage,
276+
)
265277
self.assertEqual(spec_parameter['json_compat_param'].type, Dict[str, int])
266278
self.assertEqual(spec_parameter['json_compat_param'].optional, False)
267279
self.assertEqual(actual_spec_class.TYPE_ANNOTATION, type_annotation)

0 commit comments

Comments
 (0)