66import abc
77import inspect
88from collections .abc import Awaitable , Callable
9- from typing import Any , Self , TypeVar , overload
9+ from typing import Any , Generic , Self , TypeVar , overload
1010
1111from grpc .aio import AioRpcError , Channel
1212
1313from .channel import ChannelOptions , parse_grpc_uri
1414from .exception import ApiClientError , ClientNotConnected
1515
16+ StubT = TypeVar ("StubT" )
17+ """The type of the gRPC stub."""
1618
17- class BaseApiClient (abc .ABC ):
19+
20+ class BaseApiClient (abc .ABC , Generic [StubT ]):
1821 """A base class for API clients.
1922
2023 This class provides a common interface for API clients that communicate with a API
2124 server. It is designed to be subclassed by specific API clients that provide a more
2225 specific interface.
2326
27+ Note:
28+ It is recommended to add a `stub` property to the subclass that returns the gRPC
29+ stub to use but using the *async stub* type instead of the *sync stub* type.
30+ This is because the gRPC library provides async stubs that have proper async
31+ type hints, but they only live in `.pyi` files, so they can be used in a very
32+ limited way (only as type hints). Because of this, a `type: ignore` comment is
33+ needed to cast the sync stub to the async stub.
34+
35+ Please see the example below for a recommended way to implement this property.
36+
2437 Some extra tools are provided to make it easier to write API clients:
2538
2639 - [call_stub_method()][frequenz.client.base.client.call_stub_method] is a function
@@ -29,31 +42,12 @@ class BaseApiClient(abc.ABC):
2942 a class that helps sending messages from a gRPC stream to
3043 a [Broadcast][frequenz.channels.Broadcast] channel.
3144
32- Note:
33- Because grpcio doesn't provide proper type hints, a hack is needed to have
34- propepr async type hints for the stubs generated by protoc. When using
35- `mypy-protobuf`, a `XxxAsyncStub` class is generated for each `XxxStub` class
36- but in the `.pyi` file, so the type can be used to specify type hints, but
37- **not** in any other context, as the class doesn't really exist for the Python
38- interpreter. This include generics, and because of this, this class can't be
39- even parametrized using the async class, so the instantiation of the stub can't
40- be done in the base class.
41-
42- Because of this, subclasses need to create the stubs by themselves, using the
43- real stub class and casting it to the `XxxAsyncStub` class, so `mypy` can use
44- the async version of the stubs.
45-
46- It is recommended to define a `stub` property that returns the async stub, so
47- this hack is completely hidden from clients, even if they need to access the
48- stub for more advanced uses.
49-
5045 Example:
5146 This example illustrates how to create a simple API client that connects to a
5247 gRPC server and calls a method on a stub.
5348
5449 ```python
5550 from collections.abc import AsyncIterable
56- from typing import cast
5751 from frequenz.client.base.client import BaseApiClient, call_stub_method
5852 from frequenz.client.base.streaming import GrpcStreamBroadcaster
5953 from frequenz.channels import Receiver
@@ -67,13 +61,13 @@ class ExampleResponse:
6761 float_value: float
6862
6963 class ExampleStub:
70- async def example_method(
64+ def example_method(
7165 self,
7266 request: ExampleRequest # pylint: disable=unused-argument
7367 ) -> ExampleResponse:
7468 ...
7569
76- def example_stream(self, _: ExampleRequest ) -> AsyncIterable[ExampleResponse]:
70+ def example_stream(self) -> AsyncIterable[ExampleResponse]:
7771 ...
7872
7973 class ExampleAsyncStub:
@@ -83,28 +77,27 @@ async def example_method(
8377 ) -> ExampleResponse:
8478 ...
8579
86- def example_stream(self, _: ExampleRequest ) -> AsyncIterable[ExampleResponse]:
80+ def example_stream(self) -> AsyncIterable[ExampleResponse]:
8781 ...
8882 # End of generated classes
8983
9084 class ExampleResponseWrapper:
91- def __init__(self, response: ExampleResponse) -> None :
85+ def __init__(self, response: ExampleResponse):
9286 self.transformed_value = f"{response.float_value:.2f}"
9387
9488 # Change defaults as needed
9589 DEFAULT_CHANNEL_OPTIONS = ChannelOptions()
9690
97- class MyApiClient(BaseApiClient):
91+ class MyApiClient(BaseApiClient[ExampleStub] ):
9892 def __init__(
9993 self,
10094 server_url: str,
10195 *,
10296 connect: bool = True,
10397 channel_defaults: ChannelOptions = DEFAULT_CHANNEL_OPTIONS,
10498 ) -> None:
105- super().__init__(server_url, connect=connect, channel_defaults=channel_defaults)
106- self._stub = cast(
107- ExampleAsyncStub, ExampleStub(self.channel)
99+ super().__init__(
100+ server_url, ExampleStub, connect=connect, channel_defaults=channel_defaults
108101 )
109102 self._broadcaster = GrpcStreamBroadcaster(
110103 "stream",
@@ -114,9 +107,13 @@ def __init__(
114107
115108 @property
116109 def stub(self) -> ExampleAsyncStub:
117- if self._channel is None:
110+ if self.channel is None or self._stub is None:
118111 raise ClientNotConnected(server_url=self.server_url, operation="stub")
119- return self._stub
112+ # This type: ignore is needed because we need to cast the sync stub to
113+ # the async stub, but we can't use cast because the async stub doesn't
114+ # actually exists to the eyes of the interpreter, it only exists for the
115+ # type-checker, so it can only be used for type hints.
116+ return self._stub # type: ignore
120117
121118 async def example_method(
122119 self, int_value: int, str_value: str
@@ -156,6 +153,7 @@ async def main():
156153 def __init__ (
157154 self ,
158155 server_url : str ,
156+ create_stub : Callable [[Channel ], StubT ],
159157 * ,
160158 connect : bool = True ,
161159 channel_defaults : ChannelOptions = ChannelOptions (),
@@ -164,6 +162,7 @@ def __init__(
164162
165163 Args:
166164 server_url: The URL of the server to connect to.
165+ create_stub: A function that creates a stub from a channel.
167166 connect: Whether to connect to the server as soon as a client instance is
168167 created. If `False`, the client will not connect to the server until
169168 [connect()][frequenz.client.base.client.BaseApiClient.connect] is
@@ -172,8 +171,10 @@ def __init__(
172171 the server URL.
173172 """
174173 self ._server_url : str = server_url
174+ self ._create_stub : Callable [[Channel ], StubT ] = create_stub
175175 self ._channel_defaults : ChannelOptions = channel_defaults
176176 self ._channel : Channel | None = None
177+ self ._stub : StubT | None = None
177178 if connect :
178179 self .connect (server_url )
179180
@@ -224,6 +225,7 @@ def connect(self, server_url: str | None = None) -> None:
224225 elif self .is_connected :
225226 return
226227 self ._channel = parse_grpc_uri (self ._server_url , self ._channel_defaults )
228+ self ._stub = self ._create_stub (self ._channel )
227229
228230 async def disconnect (self ) -> None :
229231 """Disconnect from the server.
@@ -248,6 +250,7 @@ async def __aexit__(
248250 return None
249251 result = await self ._channel .__aexit__ (_exc_type , _exc_val , _exc_tb )
250252 self ._channel = None
253+ self ._stub = None
251254 return result
252255
253256
@@ -260,7 +263,7 @@ async def __aexit__(
260263
261264@overload
262265async def call_stub_method (
263- client : BaseApiClient ,
266+ client : BaseApiClient [ StubT ] ,
264267 stub_method : Callable [[], Awaitable [StubOutT ]],
265268 * ,
266269 method_name : str | None = None ,
@@ -270,7 +273,7 @@ async def call_stub_method(
270273
271274@overload
272275async def call_stub_method (
273- client : BaseApiClient ,
276+ client : BaseApiClient [ StubT ] ,
274277 stub_method : Callable [[], Awaitable [StubOutT ]],
275278 * ,
276279 method_name : str | None = None ,
@@ -281,7 +284,7 @@ async def call_stub_method(
281284# We need the `noqa: DOC503` because `pydoclint` can't figure out that
282285# `ApiClientError.from_grpc_error()` returns a `GrpcError` instance.
283286async def call_stub_method ( # noqa: DOC503
284- client : BaseApiClient ,
287+ client : BaseApiClient [ StubT ] ,
285288 stub_method : Callable [[], Awaitable [StubOutT ]],
286289 * ,
287290 method_name : str | None = None ,
0 commit comments