Skip to content

Commit

Permalink
Merge pull request #156 from WalletConnect/feat/default-session-saving
Browse files Browse the repository at this point in the history
default session saving completed & tested
  • Loading branch information
skibitsky authored Jan 19, 2024
2 parents 4af8225 + 758b561 commit ec81740
Show file tree
Hide file tree
Showing 43 changed files with 810 additions and 342 deletions.
4 changes: 3 additions & 1 deletion Core Modules/WalletConnectSharp.Crypto/KeyChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ private async Task<Dictionary<string, string>> GetKeyChain()

private async Task SaveKeyChain()
{
await Storage.SetItem(StorageKey, this._keyChain);
// We need to copy the contents, otherwise Dispose()
// may clear the reference stored inside InMemoryStorage
await Storage.SetItem(StorageKey, new Dictionary<string, string>(this._keyChain));
}

public void Dispose()
Expand Down
14 changes: 14 additions & 0 deletions Core Modules/WalletConnectSharp.Network/JsonRpcProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,19 @@ protected void RegisterEventListeners()
_hasRegisteredEventListeners = true;
}

protected void UnregisterEventListeners()
{
if (!_hasRegisteredEventListeners) return;

WCLogger.Log(
$"[JsonRpcProvider] Unregistering event listeners on connection object with context {_connection.ToString()} inside {Context}");
_connection.PayloadReceived -= OnPayload;
_connection.Closed -= OnConnectionDisconnected;
_connection.ErrorReceived -= OnConnectionError;

_hasRegisteredEventListeners = false;
}

private void OnConnectionError(object sender, Exception e)
{
this.ErrorReceived?.Invoke(this, e);
Expand Down Expand Up @@ -313,6 +326,7 @@ protected virtual void Dispose(bool disposing)

if (disposing)
{
UnregisterEventListeners();
_connection?.Dispose();
}

Expand Down
71 changes: 53 additions & 18 deletions Core Modules/WalletConnectSharp.Storage/FileSystemStorage.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Collections.Concurrent;
using System.Text;
using Newtonsoft.Json;
using WalletConnectSharp.Common.Logging;
Expand Down Expand Up @@ -41,7 +42,11 @@ public FileSystemStorage(string filePath = null)
/// <returns></returns>
public override async Task Init()
{
if (Initialized)
return;

_semaphoreSlim = new SemaphoreSlim(1, 1);

await Task.WhenAll(
Load(), base.Init()
);
Expand Down Expand Up @@ -89,38 +94,68 @@ private async Task Save()
Directory.CreateDirectory(path);
}

var json = JsonConvert.SerializeObject(Entries,
string json;
json = JsonConvert.SerializeObject(Entries,
new JsonSerializerSettings() { TypeNameHandling = TypeNameHandling.All });

await _semaphoreSlim.WaitAsync();
await File.WriteAllTextAsync(FilePath, json, Encoding.UTF8);
_semaphoreSlim.Release();
try
{
if (!Disposed)
await _semaphoreSlim.WaitAsync();
int count = 5;
IOException lastException;
do
{
try
{
await File.WriteAllTextAsync(FilePath, json, Encoding.UTF8);
return;
}
catch (IOException e)
{
WCLogger.LogError($"Got error saving storage file: retries left {count}");
await Task.Delay(100);
count--;
lastException = e;
}
} while (count > 0);

throw lastException;
}
finally
{
if (!Disposed)
_semaphoreSlim.Release();
}
}

private async Task Load()
{
if (!File.Exists(FilePath))
return;

await _semaphoreSlim.WaitAsync();
var json = await File.ReadAllTextAsync(FilePath, Encoding.UTF8);
_semaphoreSlim.Release();
string json;
try
{
await _semaphoreSlim.WaitAsync();
json = await File.ReadAllTextAsync(FilePath, Encoding.UTF8);
}
finally
{
_semaphoreSlim.Release();
}

// Hard fail here if the storage file is bad, unless it's serialized as a Dictionary (for backwards compatibility)
var jsonSerializerSettings = new JsonSerializerSettings { TypeNameHandling = TypeNameHandling.Auto };
try
{
Entries = JsonConvert.DeserializeObject<Dictionary<string, object>>(json,
new JsonSerializerSettings() { TypeNameHandling = TypeNameHandling.Auto });
Entries = JsonConvert.DeserializeObject<ConcurrentDictionary<string, object>>(json,
jsonSerializerSettings);
}
catch (JsonSerializationException e)
catch (JsonSerializationException)
{
// Move the file to a .unsupported file
// and start fresh
WCLogger.LogError(e);
WCLogger.LogError("Cannot load JSON file, moving data to .unsupported file to force continue");
if (File.Exists(FilePath + ".unsupported"))
File.Move(FilePath + ".unsupported", FilePath + "." + Guid.NewGuid() + ".unsupported");
File.Move(FilePath, FilePath + ".unsupported");
Entries = new Dictionary<string, object>();
var dict = JsonConvert.DeserializeObject<Dictionary<string, object>>(json, jsonSerializerSettings);
Entries = new ConcurrentDictionary<string, object>(dict);
}
}

Expand Down
18 changes: 13 additions & 5 deletions Core Modules/WalletConnectSharp.Storage/InMemoryStorage.cs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
using System.Collections.Concurrent;
using WalletConnectSharp.Common.Model.Errors;
using WalletConnectSharp.Storage.Interfaces;

namespace WalletConnectSharp.Storage
{
public class InMemoryStorage : IKeyValueStorage
{
protected Dictionary<string, object> Entries = new Dictionary<string, object>();
private bool _initialized = false;
protected ConcurrentDictionary<string, object> Entries = new ConcurrentDictionary<string, object>();
protected bool Initialized = false;
protected bool Disposed;

public virtual Task Init()
{
_initialized = true;
if (Initialized)
return Task.CompletedTask;

Initialized = true;
return Task.CompletedTask;
}

Expand All @@ -24,6 +28,7 @@ public virtual Task<string[]> GetKeys()
public virtual async Task<T[]> GetEntriesOfType<T>()
{
IsInitialized();
// GetEntries is thread-safe
return (await GetEntries()).OfType<T>().ToArray();
}

Expand All @@ -43,13 +48,15 @@ public virtual Task SetItem<T>(string key, T value)
{
IsInitialized();
Entries[key] = value;

return Task.CompletedTask;
}

public virtual Task RemoveItem(string key)
{
IsInitialized();
Entries.Remove(key);
Entries.Remove(key, out _);

return Task.CompletedTask;
}

Expand All @@ -63,12 +70,13 @@ public virtual Task Clear()
{
IsInitialized();
Entries.Clear();

return Task.CompletedTask;
}

protected void IsInitialized()
{
if (!_initialized)
if (!Initialized)
{
throw WalletConnectException.FromType(ErrorType.NOT_INITIALIZED, "Storage");
}
Expand Down
49 changes: 26 additions & 23 deletions Tests/WalletConnectSharp.Auth.Tests/AuthClientTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using WalletConnectSharp.Storage;
using WalletConnectSharp.Tests.Common;
using Xunit;
using Xunit.Abstractions;
using ErrorResponse = WalletConnectSharp.Auth.Models.ErrorResponse;

namespace WalletConnectSharp.Auth.Tests
Expand All @@ -26,6 +27,7 @@ public class AuthClientTests : IClassFixture<CryptoWalletFixture>, IAsyncLifetim
};

private readonly CryptoWalletFixture _cryptoWalletFixture;
private readonly ITestOutputHelper _testOutputHelper;

private IAuthClient PeerA;
public IAuthClient PeerB;
Expand Down Expand Up @@ -54,13 +56,14 @@ public string WalletAddress
}
}

public AuthClientTests(CryptoWalletFixture cryptoFixture)
public AuthClientTests(CryptoWalletFixture cryptoFixture, ITestOutputHelper testOutputHelper)
{
this._cryptoWalletFixture = cryptoFixture;
_testOutputHelper = testOutputHelper;
}

[Fact, Trait("Category", "unit")]
public async void TestInit()
public async Task TestInit()
{
Assert.NotNull(PeerA);
Assert.NotNull(PeerB);
Expand All @@ -77,7 +80,7 @@ public async void TestInit()
}

[Fact, Trait("Category", "unit")]
public async void TestPairs()
public async Task TestPairs()
{
var ogPairSize = PeerA.Core.Pairing.Pairings.Length;

Expand Down Expand Up @@ -110,7 +113,7 @@ public async void TestPairs()
}

[Fact, Trait("Category", "unit")]
public async void TestKnownPairings()
public async Task TestKnownPairings()
{
var ogSizeA = PeerA.Core.Pairing.Pairings.Length;
var history = await PeerA.AuthHistory();
Expand All @@ -121,7 +124,7 @@ public async void TestKnownPairings()
var ogHistorySizeB = historyB.Keys.Length;

List<TopicMessage> responses = new List<TopicMessage>();
TaskCompletionSource<TopicMessage> responseTask = new TaskCompletionSource<TopicMessage>();
TaskCompletionSource<TopicMessage> knownPairingTask = new TaskCompletionSource<TopicMessage>();

async void OnPeerBOnAuthRequested(object sender, AuthRequest request)
{
Expand All @@ -145,9 +148,9 @@ void OnPeerAOnAuthResponded(object sender, AuthResponse args)
var sessionTopic = args.Topic;
var cacao = args.Response.Result;
var signature = cacao.Signature;
Console.WriteLine($"{sessionTopic}: {signature}");
_testOutputHelper.WriteLine($"{sessionTopic}: {signature}");
responses.Add(args);
responseTask.SetResult(args);
knownPairingTask.SetResult(args);
}

PeerA.AuthResponded += OnPeerAOnAuthResponded;
Expand All @@ -156,9 +159,9 @@ void OnPeerAOnAuthError(object sender, AuthErrorResponse args)
{
var sessionTopic = args.Topic;
var error = args.Error;
Console.WriteLine($"{sessionTopic}: {error}");
_testOutputHelper.WriteLine($"{sessionTopic}: {error}");
responses.Add(args);
responseTask.SetResult(args);
knownPairingTask.SetResult(args);
}

PeerA.AuthError += OnPeerAOnAuthError;
Expand All @@ -167,18 +170,18 @@ void OnPeerAOnAuthError(object sender, AuthErrorResponse args)

await PeerB.Core.Pairing.Pair(requestData.Uri);

await responseTask.Task;

await knownPairingTask.Task;
// Reset
responseTask = new TaskCompletionSource<TopicMessage>();
knownPairingTask = new TaskCompletionSource<TopicMessage>();

// Get last pairing, that is the one we just made
var knownPairing = PeerA.Core.Pairing.Pairings[^1];

var requestData2 = await PeerA.Request(DefaultRequestParams, knownPairing.Topic);

await responseTask.Task;

await knownPairingTask.Task;
Assert.Null(requestData2.Uri);

Assert.Equal(ogSizeA + 1, PeerA.Core.Pairing.Pairings.Length);
Expand All @@ -195,7 +198,7 @@ void OnPeerAOnAuthError(object sender, AuthErrorResponse args)
}

[Fact, Trait("Category", "unit")]
public async void HandlesAuthRequests()
public async Task HandlesAuthRequests()
{
var ogSize = PeerB.Requests.Length;

Expand All @@ -218,7 +221,7 @@ public async void HandlesAuthRequests()
}

[Fact, Trait("Category", "unit")]
public async void TestErrorResponses()
public async Task TestErrorResponses()
{
var ogPSize = PeerA.Core.Pairing.Pairings.Length;

Expand Down Expand Up @@ -263,7 +266,7 @@ void OnPeerAOnAuthResponded(object sender, AuthResponse response)
}

[Fact, Trait("Category", "unit")]
public async void HandlesSuccessfulResponse()
public async Task HandlesSuccessfulResponse()
{
var ogPSize = PeerA.Core.Pairing.Pairings.Length;

Expand Down Expand Up @@ -313,7 +316,7 @@ void OnPeerAOnAuthResponded(object sender, AuthResponse response) =>
}

[Fact, Trait("Category", "unit")]
public async void TestCustomRequestExpiry()
public async Task TestCustomRequestExpiry()
{
var uri = "";
var expiry = 1000;
Expand Down Expand Up @@ -360,7 +363,7 @@ await PeerB.Respond(
}

[Fact, Trait("Category", "unit")]
public async void TestGetPendingPairings()
public async Task TestGetPendingPairings()
{
var ogCount = PeerB.PendingRequests.Count;

Expand All @@ -386,7 +389,7 @@ public async void TestGetPendingPairings()
}

[Fact, Trait("Category", "unit")]
public async void TestGetPairings()
public async Task TestGetPairings()
{
var peerAOgSize = PeerA.Core.Pairing.Pairings.Length;
var peerBOgSize = PeerB.Core.Pairing.Pairings.Length;
Expand Down Expand Up @@ -414,7 +417,7 @@ public async void TestGetPairings()
}

[Fact, Trait("Category", "unit")]
public async void TestPing()
public async Task TestPing()
{
TaskCompletionSource<bool> receivedAuthRequest = new TaskCompletionSource<bool>();
TaskCompletionSource<bool> receivedClientPing = new TaskCompletionSource<bool>();
Expand Down Expand Up @@ -453,7 +456,7 @@ public async void TestPing()
}

[Fact, Trait("Category", "unit")]
public async void TestDisconnectedPairing()
public async Task TestDisconnectedPairing()
{
var peerAOgSize = PeerA.Core.Pairing.Pairings.Length;
var peerBOgSize = PeerB.Core.Pairing.Pairings.Length;
Expand Down Expand Up @@ -493,7 +496,7 @@ public async void TestDisconnectedPairing()
}

[Fact, Trait("Category", "unit")]
public async void TestReceivesMetadata()
public async Task TestReceivesMetadata()
{
var receivedMetadataName = "";
var ogPairingSize = PeerA.Core.Pairing.Pairings.Length;
Expand Down
Loading

0 comments on commit ec81740

Please sign in to comment.