Skip to content

Commit 3d0c065

Browse files
committed
Implement a new parameter to generate only async stubs
This commit adds a new parameter `async_only` to the plugin, that when present, will only generate the async stubs. The current {}AsyncStub classes are "fake", they are not present in the .py file, so they can only be used to write type hints. Using them also only works if the user manually casts the sync stub to the async stub. Since most users will only want to use either the sync or the async stub, this new parameter will allow to generate only the stubs that are suited for the user's use case, and allow them to have less hacky code. Signed-off-by: Leandro Lucarella <[email protected]>
1 parent d91d1df commit 3d0c065

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

mypy_protobuf/main.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,7 @@ def write_grpc_services(
781781
self,
782782
services: Iterable[d.ServiceDescriptorProto],
783783
scl_prefix: SourceCodeLocation,
784+
async_only: bool,
784785
) -> None:
785786
wl = self._write_line
786787
for i, service in enumerate(services):
@@ -797,21 +798,25 @@ def write_grpc_services(
797798
with self._indent():
798799
if self._write_comments(scl):
799800
wl("")
800-
# To support casting into FooAsyncStub, allow both Channel and aio.Channel here.
801-
channel = f"{self._import('typing', 'Union')}[{self._import('grpc', 'Channel')}, {self._import('grpc.aio', 'Channel')}]"
801+
# To support casting into FooAsyncStub, allow both Channel and aio.Channel here,
802+
# but only if we are generating both sync and async stubs.
803+
channel = self._import("grpc.aio", "Channel")
804+
if not async_only:
805+
channel = f"{self._import('typing', 'Union')}[{self._import('grpc', 'Channel')}, {channel}]"
802806
wl("def __init__(self, channel: {}) -> None: ...", channel)
803-
self.write_grpc_stub_methods(service, scl)
807+
self.write_grpc_stub_methods(service, scl, is_async=async_only)
804808

805-
# The (fake) async stub client
806-
wl(
807-
"class {}AsyncStub:",
808-
service.name,
809-
)
810-
with self._indent():
811-
if self._write_comments(scl):
812-
wl("")
813-
# No __init__ since this isn't a real class (yet), and requires manual casting to work.
814-
self.write_grpc_stub_methods(service, scl, is_async=True)
809+
if not async_only:
810+
# The (fake) async stub client
811+
wl(
812+
"class {}AsyncStub:",
813+
service.name,
814+
)
815+
with self._indent():
816+
if self._write_comments(scl):
817+
wl("")
818+
# No __init__ since this isn't a real class (yet), and requires manual casting to work.
819+
self.write_grpc_stub_methods(service, scl, is_async=True)
815820

816821
# The service definition interface
817822
wl(
@@ -1009,6 +1014,7 @@ def generate_mypy_stubs(
10091014
def generate_mypy_grpc_stubs(
10101015
descriptors: Descriptors,
10111016
response: plugin_pb2.CodeGeneratorResponse,
1017+
async_only: bool,
10121018
quiet: bool,
10131019
readable_stubs: bool,
10141020
relax_strict_optional_primitives: bool,
@@ -1022,7 +1028,7 @@ def generate_mypy_grpc_stubs(
10221028
grpc=True,
10231029
)
10241030
pkg_writer.write_grpc_async_hacks()
1025-
pkg_writer.write_grpc_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER])
1031+
pkg_writer.write_grpc_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER], async_only)
10261032

10271033
assert name == fd.name
10281034
assert fd.name.endswith(".proto")
@@ -1079,6 +1085,7 @@ def grpc() -> None:
10791085
generate_mypy_grpc_stubs(
10801086
Descriptors(request),
10811087
response,
1088+
"async_only" in request.parameter,
10821089
"quiet" in request.parameter,
10831090
"readable_stubs" in request.parameter,
10841091
"relax_strict_optional_primitives" in request.parameter,

0 commit comments

Comments
 (0)