-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmaskargs.py
166 lines (134 loc) · 6.98 KB
/
maskargs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import functools
from typing import Callable, Sequence, Any, TypeAlias, Literal
import inspect
class ArgAdaptor:
ValueMap: TypeAlias = dict[str, str | int]
ArgumentMap: TypeAlias = dict[str, ValueMap]
HintMap: TypeAlias = dict[str, TypeAlias]
__args__: ArgumentMap = {}
def __init_subclass__(cls, **kwargs):
""" Build Literals for the subclass and formats the __args__ dictionary to lowercase """
# Lowercase the keys for case insensitivity
cls.__args__ = {k.lower(): v for k, v in cls.__args__.items()}
# Build the literals for the subclass
cls.__hints__: ArgAdaptor.HintMap = {
parameter: Literal[*list(options.keys())]
for parameter, options in cls.__args__.items()
}
@classmethod
def maskargs(adaptor: 'ArgAdaptor', masked_function: Callable) -> Callable:
# Grab signature from the masked function
masked_function_signature = inspect.signature(masked_function)
# Grab the parameters from the masked function
function_parameters = masked_function_signature.parameters
# Alias the adaptors for clearer code
adaptors = adaptor.__args__
@functools.wraps(masked_function)
def adapt_arguments(*args, **kwargs):
"""Convert a function call with named string arguments to a function call with named arguments from the adaptor"""
# Map the arguments to the function parameters
adapted_arguments: dict[str, Any] = {
parameter.name: value
# zip the positional arguments (not strict, so optional args are dropped if they are in kwargs)
for parameter, value in zip(function_parameters.values(), args)
if parameter.kind in (
# strict positional arguments
inspect.Parameter.POSITIONAL_ONLY,
# optional positional arguments
inspect.Parameter.POSITIONAL_OR_KEYWORD,
# *args
inspect.Parameter.VAR_POSITIONAL
)
}
# Add the keyword arguments to the mapping
# This will shadow any optional positional arguments
# If they were provided as keyword arguments
adapted_arguments.update(kwargs)
invalid_args: list[str] = []
for argument, arg_value in adapted_arguments.items():
# Skip non-adapted arguments
if argument not in adaptors:
continue
# Handle string arguments
if isinstance(arg_value, str):
# Lowercase the argument value for case insensitivity
arg_value = arg_value.lower()
# Add invalid arguments to the error list
if arg_value not in adaptors[argument]:
invalid_args.append(
f'Invalid value for `{argument}`: '
f"'{arg_value}' "
f'(choices are {list(adaptors[argument].keys())})'
)
continue
# Update the argument value if it is in the adaptor
adapted_arguments[argument] = adaptors[argument][arg_value]
# Handle sequence arguments
elif isinstance(arg_value, Sequence):
# Change variable name for clarity
arg_values: list[str] = arg_value
# Get invalid arguments in argument sequence
invalid_values = [
arg_val
for arg_val in arg_values
if arg_val not in adaptors[argument]
]
# Add invalid arguments to the error list
if invalid_values:
invalid_args.append(
f'Invalid value{"s"*(len(invalid_values)>1)} for {argument}:'
f'{", ".join(map(str, invalid_values))}'
f'(choices are {list(adaptors[argument].keys())})'
)
continue
# Update the argument value if it is in the adaptor
adapted_arguments[argument] = [
adaptors[argument][arg_val.lower()]
for arg_val in arg_values
]
if invalid_args:
raise ValueError('\n'.join(invalid_args))
# Pass the adapted arguments to the masked function
return masked_function(**adapted_arguments)
# Rebuild attribtues for the adapted function using the masked function
for att in ('__doc__', '__annotations__', '__esri_toolinfo__'):
setattr(
adapt_arguments,
att,
(
# Get attribute from the adapted function first
getattr(adapt_arguments, att, None) or
# Get attribute from the masked function if it is not in the adapted function
getattr(masked_function, att, None) or
# Build __esri_toolinfo__ if it is in neither the adapted or masked function
[
f"String::"
f"{'|'.join(adaptors[argument].keys())}:"
for argument in function_parameters.keys()
if argument in adaptors
]
if att == '__esri_toolinfo__'
# Default to None
else None
)
)
# Apply literal type hints to the adapted function
adapt_arguments.__annotations__.update(adaptor.__hints__)
# Allow manual annotations to override the literal type hints
adapt_arguments.__annotations__.update(masked_function.__annotations__)
return adapt_arguments
@classmethod
def maskmethods(adaptor: 'ArgAdaptor', other: type) -> None:
# Grab all non dunder/private methods
methods_to_mask = {
method_name: method_object
for method_name in dir(other)
# Check if the method is callable and not private
# Use the walrus operator to store the method object
if callable(method_object := getattr(other, method_name))
and not method_name.startswith("_")
}
# Mask the methods using the specified adaptor
for method_name, method_object in methods_to_mask.items():
setattr(other, method_name, adaptor.maskargs(method_object))
print(f'Masked method: {method_name}')