Skip to content

Commit 75466f4

Browse files
committed
Updates for custom records and method updates
1 parent 3ba4073 commit 75466f4

11 files changed

+970
-195
lines changed

asyncpg/cluster.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,15 @@ def get_status(self) -> str:
130130

131131
async def connect(self,
132132
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
133-
**kwargs: typing.Any) -> 'connection.Connection':
133+
**kwargs: typing.Any) \
134+
-> 'connection.Connection[typing.Any]':
134135
conn_info = typing.cast(typing.Dict[str, typing.Any],
135136
self.get_connection_spec())
136137
conn_info.update(kwargs)
137-
return await asyncpg.connect(loop=loop, **conn_info)
138+
return typing.cast(
139+
'connection.Connection[typing.Any]',
140+
await asyncpg.connect(loop=loop, **conn_info)
141+
)
138142

139143
def init(self, **settings: str) -> str:
140144
"""Initialize cluster."""
@@ -521,7 +525,7 @@ def _test_connection(self, timeout: int = 60) -> str:
521525

522526
try:
523527
con = loop.run_until_complete(
524-
asyncpg.connect(database='postgres', # type: ignore[misc] # noqa: E501
528+
asyncpg.connect(database='postgres', # type: ignore[arg-type] # noqa: E501
525529
user='postgres',
526530
timeout=5, loop=loop,
527531
**self._connection_addr))

asyncpg/connect_utils.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
_Connection = typing.TypeVar('_Connection')
3131
_Protocol = typing.TypeVar('_Protocol', bound=asyncio.Protocol)
32+
_Record = typing.TypeVar('_Record', bound=protocol.Record)
3233

3334
_TPTupleType = typing.Tuple[asyncio.WriteTransport, _Protocol]
3435
AddrType = typing.Union[typing.Tuple[str, int], str]
@@ -654,7 +655,7 @@ async def _connect_addr(
654655
params: _ConnectionParameters,
655656
config: _ClientConfiguration,
656657
connection_class: typing.Type[_Connection],
657-
record_class: typing.Any
658+
record_class: typing.Type[_Record]
658659
) -> _Connection:
659660
assert loop is not None
660661

@@ -680,7 +681,7 @@ async def _connect_addr(
680681
assert not params.ssl
681682
connector = typing.cast(
682683
typing.Coroutine[typing.Any, None,
683-
_TPTupleType[protocol.Protocol]],
684+
_TPTupleType['protocol.Protocol[_Record]']],
684685
loop.create_unix_connection(proto_factory, addr))
685686
elif params.ssl:
686687
connector = _create_ssl_connection(
@@ -689,7 +690,7 @@ async def _connect_addr(
689690
else:
690691
connector = typing.cast(
691692
typing.Coroutine[typing.Any, None,
692-
_TPTupleType[protocol.Protocol]],
693+
_TPTupleType['protocol.Protocol[_Record]']],
693694
loop.create_connection(proto_factory, *addr))
694695

695696
connector_future = asyncio.ensure_future(connector)
@@ -716,7 +717,7 @@ async def _connect(
716717
loop: typing.Optional[asyncio.AbstractEventLoop],
717718
timeout: float,
718719
connection_class: typing.Type[_Connection],
719-
record_class: typing.Any,
720+
record_class: typing.Type[_Record],
720721
**kwargs: typing.Any
721722
) -> _Connection:
722723
if loop is None:

0 commit comments

Comments
 (0)