Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder c
_httpClientFactory = httpClientFactory;
_payloadContentBuilder = contentBuilder;
}

// TODO: Test only, will remove later
internal RestClient(IHttpClientFactory httpClientFactory) : this(httpClientFactory, new JsonPayloadContentBuilder(new JsonObjectSerializer()))
{
Expand Down Expand Up @@ -78,6 +78,17 @@ public Task SendMessageWithRetryAsync(
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, AsAsync(handleExpectedResponse), cancellationToken);
}

public Task SendMessageWithRetryAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
string methodName,
object?[] args,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponseAsync = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, handleExpectedResponseAsync, cancellationToken);
}

public Task SendStreamMessageWithRetryAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Text.Json.Serialization;

namespace Microsoft.Azure.SignalR.Management.ClientInvocation;
#nullable enable
sealed class InvocationResponse<T>
{
[JsonPropertyName("result")]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may not work if the object serializer is backed by Newtonsoft.Json.

public T? Result { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using System.ComponentModel;
using System.Linq;
using System.Net.Http;

using Azure.Core.Serialization;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;
Expand Down Expand Up @@ -40,7 +40,8 @@ public IServiceHubLifetimeManager<THub> Create<THub>(string hubName) where THub
var httpClientFactory = _serviceProvider.GetRequiredService<IHttpClientFactory>();
var serviceEndpoint = _serviceProvider.GetRequiredService<IServiceEndpointManager>().Endpoints.First().Key;
var restClient = new RestClient(httpClientFactory, payloadBuilderResolver.GetPayloadContentBuilder());
return new RestHubLifetimeManager<THub>(hubName, serviceEndpoint, _options.ApplicationName, restClient);
var objectSerializer = _serviceProvider.GetRequiredService<IOptions<ServiceManagerOptions>>().Value.ObjectSerializer ?? new JsonObjectSerializer();
return new RestHubLifetimeManager<THub>(hubName, serviceEndpoint, _options.ApplicationName, restClient, objectSerializer);
}
default: throw new InvalidEnumArgumentException(nameof(ServiceManagerOptions.ServiceTransportType), (int)_options.ServiceTransportType, typeof(ServiceTransportType));
}
Expand Down
5 changes: 5 additions & 0 deletions src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,9 @@ private RestApiEndpoint GenerateRestApiEndpoint(string appName, string hubName,
: $"{pathAfterHub}?application={Uri.EscapeDataString(appName.ToLowerInvariant())}&api-version={Version}";
return new RestApiEndpoint($"{requestPrefixWithHub}{pathAfterHub}") { Query = queries };
}

public RestApiEndpoint SendClientInvocation(string appName, string hubName, string connectionId)
{
return GenerateRestApiEndpoint(appName, hubName, $"/connections/{Uri.EscapeDataString(connectionId)}/:invoke");
}
}
77 changes: 75 additions & 2 deletions src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@
using System.Linq;
using System.Net;
using System.Net.Http;
#if NET7_0_OR_GREATER
using System.IO;
#endif
using System.Threading;
using System.Threading.Tasks;

using Azure;

using Azure.Core.Serialization;
using Microsoft.AspNetCore.SignalR;
#if NET7_0_OR_GREATER
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Azure.SignalR.Management.ClientInvocation;
#endif
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Primitives;

Expand All @@ -29,13 +36,15 @@ internal class RestHubLifetimeManager<THub> : HubLifetimeManager<THub>, IService
private readonly RestApiProvider _restApiProvider;
private readonly string _hubName;
private readonly string _appName;
private readonly ObjectSerializer _objectSerializer;

public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string appName, RestClient restClient)
public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string appName, RestClient restClient, ObjectSerializer objectSerializer)
{
_restApiProvider = new RestApiProvider(endpoint);
_appName = appName;
_hubName = hubName;
_restClient = restClient;
_objectSerializer = objectSerializer;
}

public override async Task AddToGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -351,6 +360,70 @@ public async Task SendStreamCompletionAsync(string connectionId, string streamId
await _restClient.SendWithRetryAsync(api, HttpMethod.Post, cancellationToken: cancellationToken);
}

#if NET7_0_OR_GREATER
#nullable enable
public override async Task<T> InvokeConnectionAsync<T>(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default)
{
// Validate input parameters
if (string.IsNullOrEmpty(methodName))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName));
}
if (string.IsNullOrEmpty(connectionId))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId));
}

// Get API endpoint and prepare for the request
var api = _restApiProvider.SendClientInvocation(_appName, _hubName, connectionId);
string? responseContent = null;
var isSuccessStatusCode = false;
// Send request and capture the response
await _restClient.SendMessageWithRetryAsync(
api,
HttpMethod.Post,
methodName,
args,
async response =>
{
responseContent = await response.Content.ReadAsStringAsync(cancellationToken);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use DeserializeAsync here directly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And RestClient already contains an ObjectSerializer, you can reuse that one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated by exposing the ObjectSerializer in the RestClient and using that.

isSuccessStatusCode = response.IsSuccessStatusCode;
return response.IsSuccessStatusCode || response.StatusCode == HttpStatusCode.BadRequest;
},
cancellationToken: cancellationToken);

// Ensure we have a response
if (responseContent is null)
{
throw new HubException("Response content is null or empty");
}

if (!isSuccessStatusCode)
{
throw new HubException(responseContent);
}

using var contentStream = new MemoryStream(System.Text.Encoding.UTF8.GetBytes(responseContent));

var wrapperObj = await _objectSerializer.DeserializeAsync(
contentStream,
typeof(InvocationResponse<T>),
cancellationToken);

var wrapper = wrapperObj as InvocationResponse<T>
?? throw new HubException("Failed to deserialize response");

return wrapper.Result ?? throw new HubException("Result not found in response");
}

public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result)
{
// This method won't get trigger because in transient we will wait for the returned completion message.
// this is to honor the interface
throw new NotImplementedException();
}
#endif

private static bool FilterExpectedResponse(HttpResponseMessage response, string expectedErrorCode) =>
response.IsSuccessStatusCode
|| (response.StatusCode == HttpStatusCode.NotFound && response.Headers.TryGetValues(Headers.MicrosoftErrorCode, out var errorCodes) && errorCodes.First().Equals(expectedErrorCode, StringComparison.OrdinalIgnoreCase));
Expand Down
Loading
Loading