18
18
import apache_beam as beam
19
19
import tensorflow as tf
20
20
from tfx .dsl .component .experimental import annotations
21
+ from tfx .dsl .component .experimental import annotations_test_proto_pb2
21
22
from tfx .types import artifact
22
23
from tfx .types import standard_artifacts
23
24
from tfx .types import value_artifact
@@ -27,18 +28,21 @@ class AnnotationsTest(tf.test.TestCase):
27
28
28
29
def testArtifactGenericAnnotation (self ):
29
30
# Error: type hint whose parameter is not an Artifact subclass.
30
- with self .assertRaisesRegex (ValueError ,
31
- 'expects .* a concrete subclass of' ):
31
+ with self .assertRaisesRegex (
32
+ ValueError , 'expects .* a concrete subclass of'
33
+ ):
32
34
_ = annotations ._ArtifactGeneric [int ] # pytype: disable=unsupported-operands
33
35
34
36
# Error: type hint with abstract Artifact subclass.
35
- with self .assertRaisesRegex (ValueError ,
36
- 'expects .* a concrete subclass of' ):
37
+ with self .assertRaisesRegex (
38
+ ValueError , 'expects .* a concrete subclass of'
39
+ ):
37
40
_ = annotations ._ArtifactGeneric [artifact .Artifact ]
38
41
39
42
# Error: type hint with abstract Artifact subclass.
40
- with self .assertRaisesRegex (ValueError ,
41
- 'expects .* a concrete subclass of' ):
43
+ with self .assertRaisesRegex (
44
+ ValueError , 'expects .* a concrete subclass of'
45
+ ):
42
46
_ = annotations ._ArtifactGeneric [value_artifact .ValueArtifact ]
43
47
44
48
# OK.
@@ -49,56 +53,55 @@ def testArtifactAnnotationUsage(self):
49
53
_ = annotations .OutputArtifact [standard_artifacts .Examples ]
50
54
_ = annotations .AsyncOutputArtifact [standard_artifacts .Model ]
51
55
52
- def testPrimitiveTypeGenericAnnotation (self ):
53
- # Error: type hint whose parameter is not a primitive type
56
+ def testPrimitivAndProtoTypeGenericAnnotation (self ):
57
+ # Error: type hint whose parameter is not a primitive or a proto type
54
58
# pytype: disable=unsupported-operands
55
59
with self .assertRaisesRegex (
56
60
ValueError , 'T to be `int`, `float`, `str`, `bool`'
57
61
):
58
- _ = annotations ._PrimitiveTypeGeneric [artifact .Artifact ]
62
+ _ = annotations ._PrimitiveAndProtoTypeGeneric [artifact .Artifact ]
59
63
with self .assertRaisesRegex (
60
64
ValueError , 'T to be `int`, `float`, `str`, `bool`'
61
65
):
62
- _ = annotations ._PrimitiveTypeGeneric [object ]
66
+ _ = annotations ._PrimitiveAndProtoTypeGeneric [object ]
63
67
with self .assertRaisesRegex (
64
68
ValueError , 'T to be `int`, `float`, `str`, `bool`'
65
69
):
66
- _ = annotations ._PrimitiveTypeGeneric [123 ]
70
+ _ = annotations ._PrimitiveAndProtoTypeGeneric [123 ]
67
71
with self .assertRaisesRegex (
68
72
ValueError , 'T to be `int`, `float`, `str`, `bool`'
69
73
):
70
- _ = annotations ._PrimitiveTypeGeneric ['string' ]
74
+ _ = annotations ._PrimitiveAndProtoTypeGeneric ['string' ]
71
75
with self .assertRaisesRegex (
72
76
ValueError , 'T to be `int`, `float`, `str`, `bool`'
73
77
):
74
- _ = annotations ._PrimitiveTypeGeneric [Dict [int , int ]]
78
+ _ = annotations ._PrimitiveAndProtoTypeGeneric [Dict [int , int ]]
75
79
with self .assertRaisesRegex (
76
80
ValueError , 'T to be `int`, `float`, `str`, `bool`'
77
81
):
78
- _ = annotations ._PrimitiveTypeGeneric [bytes ]
82
+ _ = annotations ._PrimitiveAndProtoTypeGeneric [bytes ]
79
83
# pytype: enable=unsupported-operands
80
84
# OK.
81
- _ = annotations ._PrimitiveTypeGeneric [int ]
82
- _ = annotations ._PrimitiveTypeGeneric [float ]
83
- _ = annotations ._PrimitiveTypeGeneric [str ]
84
- _ = annotations ._PrimitiveTypeGeneric [bool ]
85
- _ = annotations ._PrimitiveTypeGeneric [Dict [str , float ]]
86
- _ = annotations ._PrimitiveTypeGeneric [bool ]
85
+ _ = annotations ._PrimitiveAndProtoTypeGeneric [int ]
86
+ _ = annotations ._PrimitiveAndProtoTypeGeneric [float ]
87
+ _ = annotations ._PrimitiveAndProtoTypeGeneric [str ]
88
+ _ = annotations ._PrimitiveAndProtoTypeGeneric [bool ]
89
+ _ = annotations ._PrimitiveAndProtoTypeGeneric [Dict [str , float ]]
90
+ _ = annotations ._PrimitiveAndProtoTypeGeneric [bool ]
91
+ _ = annotations ._PrimitiveAndProtoTypeGeneric [
92
+ annotations_test_proto_pb2 .TestMessage
93
+ ]
87
94
88
95
def testPipelineTypeGenericAnnotation (self ):
89
96
# Error: type hint whose parameter is not a primitive type
90
- with self .assertRaisesRegex (
91
- ValueError , 'T to be `beam.Pipeline`' ):
97
+ with self .assertRaisesRegex (ValueError , 'T to be `beam.Pipeline`' ):
92
98
_ = annotations ._PipelineTypeGeneric [artifact .Artifact ]
93
- with self .assertRaisesRegex (
94
- ValueError , 'T to be `beam.Pipeline`' ):
99
+ with self .assertRaisesRegex (ValueError , 'T to be `beam.Pipeline`' ):
95
100
_ = annotations ._PipelineTypeGeneric [object ]
96
101
# pytype: disable=unsupported-operands
97
- with self .assertRaisesRegex (
98
- ValueError , 'T to be `beam.Pipeline`' ):
102
+ with self .assertRaisesRegex (ValueError , 'T to be `beam.Pipeline`' ):
99
103
_ = annotations ._PipelineTypeGeneric [123 ]
100
- with self .assertRaisesRegex (
101
- ValueError , 'T to be `beam.Pipeline`' ):
104
+ with self .assertRaisesRegex (ValueError , 'T to be `beam.Pipeline`' ):
102
105
_ = annotations ._PipelineTypeGeneric ['string' ]
103
106
# pytype: enable=unsupported-operands
104
107
@@ -110,6 +113,7 @@ def testParameterUsage(self):
110
113
_ = annotations .Parameter [float ]
111
114
_ = annotations .Parameter [str ]
112
115
_ = annotations .Parameter [bool ]
116
+ _ = annotations .Parameter [annotations_test_proto_pb2 .TestMessage ]
113
117
114
118
115
119
if __name__ == '__main__' :
0 commit comments