Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from .grpcClient import GrpcClientFactory
from .grpc_utils import get_grpc_uri, get_grpc_max_message_length, parse_grpc_args
from .utils import convert_to_bytestring


class SettlementError(Exception):
Expand Down Expand Up @@ -80,12 +81,12 @@ def complete(self, message) -> None:
raise SettlementError("complete",
f"Failed to complete message {locktoken}", e)

def abandon(self, message) -> None:
def abandon(self, message, properties_to_modify: Optional[dict] = {}) -> None:
try:
locktoken = self._validate_lock_token(message)
request = AbandonRequest()
request.locktoken = str(locktoken)
request.propertiesToModify = b""
request.propertiesToModify = convert_to_bytestring(properties_to_modify)
self._client.Abandon(request)
except Exception as e:
raise SettlementError("abandon",
Expand All @@ -95,12 +96,13 @@ def deadletter(
self,
message,
deadletter_reason: Optional[str] = None,
deadletter_error_description: Optional[str] = None) -> None:
deadletter_error_description: Optional[str] = None,
properties_to_modify: Optional[dict] = {}) -> None:
try:
locktoken = self._validate_lock_token(message)
request = DeadletterRequest()
request.locktoken = str(locktoken)
request.propertiesToModify = b""
request.propertiesToModify = convert_to_bytestring(properties_to_modify)

if deadletter_reason:
request.deadletterReason.CopyFrom(StringValue(value=deadletter_reason))
Expand All @@ -113,12 +115,12 @@ def deadletter(
raise SettlementError("deadletter",
f"Failed to deadletter message {locktoken}", e)

def defer(self, message) -> None:
def defer(self, message, properties_to_modify: Optional[dict] = {}) -> None:
try:
locktoken = self._validate_lock_token(message)
request = DeferRequest()
request.locktoken = str(locktoken)
request.propertiesToModify = b""
request.propertiesToModify = convert_to_bytestring(properties_to_modify)

self._client.Defer(request)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import datetime
import struct
import uamqp
import uuid


_X_OPT_LOCK_TOKEN = b"x-opt-lock-token"

# AMQP format codes (subset)
FMT_NULL = 0x40
FMT_BOOL_TRUE = 0x41
FMT_BOOL_FALSE = 0x42
FMT_UINT = 0x70
FMT_INT = 0x71
FMT_LONG = 0x81
FMT_DOUBLE = 0x82
FMT_UTF8_SMALL = 0xA1
FMT_UTF8_LARGE = 0xB1
FMT_UUID = 0x98
FMT_MAP8 = 0xC1
FMT_MAP32 = 0xD1


def get_lock_token(message: bytes, index: int) -> str:
# Get the lock token from the message
Expand Down Expand Up @@ -52,3 +68,61 @@ def get_decoded_message(content: bytes):
except Exception as e:
raise ValueError(f"Failed to decode ServiceBus content: {e}") from e
return None


def encode_amqp_value(value):
if value is None:
return bytes([FMT_NULL])
elif isinstance(value, bool):
return bytes([FMT_BOOL_TRUE if value else FMT_BOOL_FALSE])
elif isinstance(value, int):
# encode as int32 or int64 depending on value
if -2**31 <= value < 2**31:
return bytes([FMT_INT]) + struct.pack(">i", value)
else:
return bytes([FMT_LONG]) + struct.pack(">q", value)
elif isinstance(value, float):
return bytes([FMT_DOUBLE]) + struct.pack(">d", value)
elif isinstance(value, str):
utf8 = value.encode("utf-8")
if len(utf8) < 256:
return bytes([FMT_UTF8_SMALL, len(utf8)]) + utf8
else:
return bytes([FMT_UTF8_LARGE]) + struct.pack(">I", len(utf8)) + utf8
elif isinstance(value, uuid.UUID):
return bytes([FMT_UUID]) + value.bytes
elif isinstance(value, datetime.timedelta):
ticks = int(value.total_seconds() * 10_000_000)
return encode_amqp_value(ticks)
elif isinstance(value, datetime.datetime):
# UTC ticks since 1970-01-01
if value.tzinfo is None:
value = value.replace(tzinfo=datetime.timezone.utc)
epoch = datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)
ms = int((value - epoch).total_seconds() * 1000)
return encode_amqp_value(ms)
else:
raise TypeError(f"Unsupported type: {type(value)}")


# Encode map
def encode_amqp_map(dct):
if not dct:
return b""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the dct is empty, shouldn't we still encode it with size= 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this change, I had request.propertiesToModify = b"". I return just b"" here to keep the behavior consistent with what we had before

items_bytes = b"".join(
encode_amqp_value(k) + encode_amqp_value(v) for k, v in dct.items()
)
size = len(items_bytes) + 1 # 1 byte for count
count = len(dct) * 2
if size < 256:
return bytes([FMT_MAP8, size, count]) + items_bytes
else:
return (bytes([FMT_MAP32])
+ struct.pack(">I", size)
+ struct.pack(">I", count)
+ items_bytes)


# Main conversion function
def convert_to_bytestring(properties_to_modify: dict) -> bytes:
return encode_amqp_map(properties_to_modify)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from azurefunctions.extensions.bindings.servicebus import ServiceBusMessageActions
from azurefunctions.extensions.bindings.servicebus.serviceBusMessageActions import SettlementError # noqa
from azurefunctions.extensions.bindings.servicebus.utils import convert_to_bytestring
from azurefunctions.extensions.bindings.protos import settlement_pb2 as pb2


Expand Down Expand Up @@ -82,6 +83,28 @@ def test_abandon_calls_grpc(self):
self.assertEqual(called_req.locktoken, "lock123")
self.assertEqual(called_req.propertiesToModify, b"")

def test_abandon_with_properties(self):
msg = DummyMessage("lock123")
props = {"status": "done", "attempt": 3}
self.actions.abandon(msg, props)
# Verify gRPC call happened
self.mock_client.Abandon.assert_called_once()
called_req = self.mock_client.Abandon.call_args[0][0]
self.assertEqual(called_req.locktoken, "lock123")
self.assertEqual(called_req.propertiesToModify, convert_to_bytestring(props))

def test_deadletter(self):
msg = DummyMessage("lock123")
self.actions.deadletter(msg)

self.mock_client.Deadletter.assert_called_once()
called_req = self.mock_client.Deadletter.call_args[0][0]
self.assertIsInstance(called_req, pb2.DeadletterRequest)
self.assertEqual(called_req.locktoken, "lock123")
self.assertEqual(called_req.propertiesToModify, b"")
self.assertEqual(called_req.deadletterReason.value, "")
self.assertEqual(called_req.deadletterErrorDescription.value, "")

def test_deadletter_with_reasons(self):
msg = DummyMessage("lock123")
self.actions.deadletter(
Expand All @@ -98,7 +121,21 @@ def test_deadletter_with_reasons(self):
self.assertEqual(called_req.deadletterReason.value, "reason")
self.assertEqual(called_req.deadletterErrorDescription.value, "desc")

def test_defer_calls_grpc(self):
def test_deadletter_with_properties_and_reason(self):
msg = DummyMessage("lock123")
props = {"errorCode": 500}
self.actions.deadletter(msg, deadletter_reason="bad data",
deadletter_error_description="validation failed",
properties_to_modify=props)
self.mock_client.Deadletter.assert_called_once()
called_req = self.mock_client.Deadletter.call_args[0][0]
self.assertEqual(called_req.propertiesToModify, convert_to_bytestring(props))
# Check reason was set
self.assertEqual(called_req.deadletterReason.value, "bad data")
self.assertEqual(called_req.deadletterErrorDescription.value,
"validation failed")

def test_defer(self):
msg = DummyMessage("lock123")
self.actions.defer(msg)

Expand All @@ -108,6 +145,14 @@ def test_defer_calls_grpc(self):
self.assertEqual(called_req.locktoken, "lock123")
self.assertEqual(called_req.propertiesToModify, b"")

def test_defer_with_properties(self):
msg = DummyMessage("lock123")
props = {"deferFlag": True}
self.actions.defer(msg, props)
self.mock_client.Defer.assert_called_once()
called_req = self.mock_client.Defer.call_args[0][0]
self.assertEqual(called_req.propertiesToModify, convert_to_bytestring(props))

def test_renew_message_lock_calls_grpc(self):
msg = DummyMessage("lock123")
self.actions.renew_message_lock(msg)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import unittest
import datetime
import uuid
from azurefunctions.extensions.bindings.servicebus.utils import (encode_amqp_value,
encode_amqp_map,
convert_to_bytestring)


class TestAmqpEncoding(unittest.TestCase):

def test_encode_none(self):
result = encode_amqp_value(None)
self.assertEqual(result, bytes([0x40])) # FMT_NULL

def test_encode_bool(self):
self.assertEqual(encode_amqp_value(True), bytes([0x41])) # FMT_BOOL_TRUE
self.assertEqual(encode_amqp_value(False), bytes([0x42])) # FMT_BOOL_FALSE

def test_encode_int(self):
# Small int
small_int = 123
result = encode_amqp_value(small_int)
self.assertEqual(result[0], 0x71) # FMT_INT
self.assertEqual(int.from_bytes(result[1:], "big", signed=True), small_int)

# Large int
large_int = 2**40
result = encode_amqp_value(large_int)
self.assertEqual(result[0], 0x81) # FMT_LONG
self.assertEqual(int.from_bytes(result[1:], "big", signed=True), large_int)

def test_encode_float(self):
val = 3.1415
result = encode_amqp_value(val)
self.assertEqual(result[0], 0x82) # FMT_DOUBLE

def test_encode_str(self):
s = "hello"
result = encode_amqp_value(s)
self.assertIn(result[0], (0xA1, 0xB1)) # FMT_UTF8_SMALL or LARGE

def test_encode_uuid(self):
u = uuid.uuid4()
result = encode_amqp_value(u)
self.assertEqual(result[0], 0x98) # FMT_UUID
self.assertEqual(result[1:], u.bytes)

def test_encode_timedelta(self):
td = datetime.timedelta(seconds=5)
result = encode_amqp_value(td)
# Should encode as int ticks
ticks = int(td.total_seconds() * 10_000_000)
encoded_ticks = int.from_bytes(result[1:], "big", signed=True)
self.assertEqual(encoded_ticks, ticks)

def test_encode_datetime(self):
dt = datetime.datetime(1970, 1, 2, tzinfo=datetime.timezone.utc)
result = encode_amqp_value(dt)
# 1 day in ms = 86400000
ms = int((dt - datetime.datetime(
1970,
1,
1,
tzinfo=datetime.timezone.utc)).total_seconds() * 1000)
encoded_ms = int.from_bytes(result[1:], "big", signed=True)
self.assertEqual(encoded_ms, ms)

def test_encode_unsupported_type(self):
with self.assertRaises(TypeError):
encode_amqp_value(object())

def test_encode_amqp_map_empty(self):
result = encode_amqp_map({})
self.assertEqual(result, b"")

def test_encode_amqp_map_scalars(self):
data = {
"a": 1,
"b": True,
"c": "hi"
}
result = convert_to_bytestring(data)
self.assertIsInstance(result, bytes)
self.assertGreater(len(result), 0)

def test_encode_application_properties_empty(self):
data = {}
result = convert_to_bytestring(data)
self.assertIsInstance(result, bytes)
self.assertEqual(len(result), 0)