Skip to content

Commit cd6a441

Browse files
committed
Add more typing to transaction and base client
1 parent 0e2c7e9 commit cd6a441

File tree

2 files changed

+33
-24
lines changed

2 files changed

+33
-24
lines changed

pymodbus/client/base.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import asyncio
55
import socket
66
from abc import abstractmethod
7-
from collections.abc import Awaitable, Callable
7+
from collections.abc import Callable, Coroutine
88
from dataclasses import dataclass
99
from typing import Any, cast
1010

@@ -20,7 +20,7 @@
2020
from pymodbus.utilities import ModbusTransactionState
2121

2222

23-
class ModbusBaseClient(ModbusClientMixin[Awaitable[ModbusResponse]]):
23+
class ModbusBaseClient(ModbusClientMixin[Coroutine[Any, Any, ModbusResponse | None]]):
2424
"""**ModbusBaseClient**.
2525
2626
Fixed parameters:
@@ -141,7 +141,7 @@ def idle_time(self) -> float:
141141
return 0
142142
return self.last_frame_end + self.silent_interval
143143

144-
def execute(self, request: ModbusRequest):
144+
def execute(self, request: ModbusRequest) -> Coroutine[Any, Any, ModbusResponse | None]:
145145
"""Execute request and get response (call **sync/async**).
146146
147147
:param request: The request to process
@@ -155,7 +155,7 @@ def execute(self, request: ModbusRequest):
155155
# ----------------------------------------------------------------------- #
156156
# Merged client methods
157157
# ----------------------------------------------------------------------- #
158-
async def async_execute(self, request) -> ModbusResponse:
158+
async def async_execute(self, request) -> ModbusResponse | None:
159159
"""Execute requests asynchronously."""
160160
request.transaction_id = self.ctx.transaction.getNextTID()
161161
packet = self.ctx.framer.buildPacket(request)
@@ -183,9 +183,9 @@ async def async_execute(self, request) -> ModbusResponse:
183183
f"ERROR: No response received after {self.retries} retries"
184184
)
185185

186-
return resp # type: ignore[return-value]
186+
return resp
187187

188-
def build_response(self, request: ModbusRequest):
188+
def build_response(self, request: ModbusRequest) -> asyncio.Future[ModbusResponse]:
189189
"""Return a deferred response for the current request."""
190190
my_future: asyncio.Future = asyncio.Future()
191191
request.fut = my_future
@@ -222,7 +222,7 @@ def __str__(self):
222222
)
223223

224224

225-
class ModbusBaseSyncClient(ModbusClientMixin[ModbusResponse]):
225+
class ModbusBaseSyncClient(ModbusClientMixin[ModbusResponse | bytes | ModbusIOException]):
226226
"""**ModbusBaseClient**.
227227
228228
Fixed parameters:
@@ -336,7 +336,7 @@ def idle_time(self) -> float:
336336
return 0
337337
return self.last_frame_end + self.silent_interval
338338

339-
def execute(self, request: ModbusRequest) -> ModbusResponse:
339+
def execute(self, request: ModbusRequest) -> ModbusResponse | bytes | ModbusIOException:
340340
"""Execute request and get response (call **sync/async**).
341341
342342
:param request: The request to process

pymodbus/transaction.py

+25-16
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
ModbusTlsFramer,
3030
)
3131
from pymodbus.logging import Log
32-
from pymodbus.pdu import ModbusRequest
32+
from pymodbus.pdu import ModbusRequest, ModbusResponse
3333
from pymodbus.transport import CommType
3434
from pymodbus.utilities import ModbusTransactionState, hexlify_packets
3535

@@ -167,13 +167,13 @@ def _set_adu_size(self):
167167
else:
168168
self.base_adu_size = -1
169169

170-
def _calculate_response_length(self, expected_pdu_size):
170+
def _calculate_response_length(self, expected_pdu_size: int) -> int | None:
171171
"""Calculate response length."""
172172
if self.base_adu_size == -1:
173173
return None
174174
return self.base_adu_size + expected_pdu_size
175175

176-
def _calculate_exception_length(self):
176+
def _calculate_exception_length(self) -> int | None:
177177
"""Return the length of the Modbus Exception Response according to the type of Framer."""
178178
if isinstance(self.client.framer, (ModbusSocketFramer, ModbusTlsFramer)):
179179
return self.base_adu_size + 2 # Fcode(1), ExceptionCode(1)
@@ -183,7 +183,9 @@ def _calculate_exception_length(self):
183183
return self.base_adu_size + 2 # Fcode(1), ExceptionCode(1)
184184
return None
185185

186-
def _validate_response(self, request: ModbusRequest, response, exp_resp_len, is_udp=False):
186+
def _validate_response(
187+
self, request: ModbusRequest, response: bytes | int, exp_resp_len: int | None, is_udp=False
188+
) -> bool:
187189
"""Validate Incoming response against request.
188190
189191
:param request: Request sent
@@ -208,7 +210,7 @@ def _validate_response(self, request: ModbusRequest, response, exp_resp_len, is_
208210
return mbap.get("length") == exp_resp_len
209211
return True
210212

211-
def execute(self, request: ModbusRequest): # noqa: C901
213+
def execute(self, request: ModbusRequest) -> ModbusResponse | bytes | ModbusIOException: # noqa: C901
212214
"""Start the producer to send the next request to consumer.write(Frame(request))."""
213215
with self._transaction_lock:
214216
try:
@@ -333,7 +335,9 @@ def execute(self, request: ModbusRequest): # noqa: C901
333335
self.client.close()
334336
return exc
335337

336-
def _retry_transaction(self, retries, reason, packet, response_length, full=False):
338+
def _retry_transaction(
339+
self, retries: int, reason: str, request: ModbusRequest, response_length: int | None, full=False
340+
) -> tuple[bytes, str | Exception | None]:
337341
"""Retry transaction."""
338342
Log.debug("Retry on {} response - {}", reason, retries)
339343
Log.debug('Changing transaction state from "WAITING_FOR_REPLY" to "RETRYING"')
@@ -350,9 +354,11 @@ def _retry_transaction(self, retries, reason, packet, response_length, full=Fals
350354
if response_length == in_waiting:
351355
result = self._recv(response_length, full)
352356
return result, None
353-
return self._transact(packet, response_length, full=full)
357+
return self._transact(request, response_length, full=full)
354358

355-
def _transact(self, request: ModbusRequest, response_length, full=False, broadcast=False):
359+
def _transact(
360+
self, request: ModbusRequest, response_length: int | None, full=False, broadcast=False
361+
) -> tuple[bytes, str | Exception | None]:
356362
"""Do a Write and Read transaction.
357363
358364
:param packet: packet to be sent
@@ -368,16 +374,13 @@ def _transact(self, request: ModbusRequest, response_length, full=False, broadca
368374
packet = self.client.framer.buildPacket(request)
369375
Log.debug("SEND: {}", packet, ":hex")
370376
size = self._send(packet)
371-
if (
372-
isinstance(size, bytes)
373-
and self.client.state == ModbusTransactionState.RETRYING
374-
):
377+
if self.client.state == ModbusTransactionState.RETRYING:
375378
Log.debug(
376379
"Changing transaction state from "
377380
'"RETRYING" to "PROCESSING REPLY"'
378381
)
379382
self.client.state = ModbusTransactionState.PROCESSING_REPLY
380-
return size, None
383+
return b"", None
381384
if self.client.comm_params.handle_local_echo is True:
382385
if self._recv(size, full) != packet:
383386
return b"", "Wrong local echo"
@@ -405,11 +408,11 @@ def _transact(self, request: ModbusRequest, response_length, full=False, broadca
405408
result = b""
406409
return result, last_exception
407410

408-
def _send(self, packet: bytes, _retrying=False):
411+
def _send(self, packet: bytes, _retrying=False) -> int:
409412
"""Send."""
410413
return self.client.framer.sendPacket(packet)
411414

412-
def _recv(self, expected_response_length, full) -> bytes: # noqa: C901
415+
def _recv(self, expected_response_length: int | None, full: bool) -> bytes: # noqa: C901
413416
"""Receive."""
414417
total = None
415418
if not full:
@@ -420,8 +423,10 @@ def _recv(self, expected_response_length, full) -> bytes: # noqa: C901
420423
min_size = 4
421424
elif isinstance(self.client.framer, ModbusAsciiFramer):
422425
min_size = 5
423-
else:
426+
elif expected_response_length:
424427
min_size = expected_response_length
428+
else:
429+
min_size = 0
425430

426431
read_min = self.client.framer.recvPacket(min_size)
427432
if len(read_min) != min_size:
@@ -463,13 +468,17 @@ def _recv(self, expected_response_length, full) -> bytes: # noqa: C901
463468
expected_response_length -= min_size
464469
total = expected_response_length + min_size
465470
else:
471+
if exception_length is None:
472+
exception_length = 0
466473
expected_response_length = exception_length - min_size
467474
total = expected_response_length + min_size
468475
else:
469476
total = expected_response_length
470477
else:
471478
read_min = b""
472479
total = expected_response_length
480+
if expected_response_length is None:
481+
expected_response_length = 0
473482
result = self.client.framer.recvPacket(expected_response_length)
474483
result = read_min + result
475484
actual = len(result)

0 commit comments

Comments
 (0)