From f06684b0b9e0ebbd2fb25f74380ac3c537a5d47f Mon Sep 17 00:00:00 2001 From: Youssef1313 Date: Wed, 13 Aug 2025 14:10:37 +0200 Subject: [PATCH 1/2] Reduce allocations from properties dictionary of TestContextImplementation --- .../Services/TestContextDictionary.cs | 182 ++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 src/Adapter/MSTestAdapter.PlatformServices/Services/TestContextDictionary.cs diff --git a/src/Adapter/MSTestAdapter.PlatformServices/Services/TestContextDictionary.cs b/src/Adapter/MSTestAdapter.PlatformServices/Services/TestContextDictionary.cs new file mode 100644 index 0000000000..18af976efe --- /dev/null +++ b/src/Adapter/MSTestAdapter.PlatformServices/Services/TestContextDictionary.cs @@ -0,0 +1,182 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Microsoft.VisualStudio.TestPlatform.MSTestAdapter.PlatformServices.Interface.ObjectModel; + +using TestContext = Microsoft.VisualStudio.TestTools.UnitTesting.TestContext; + +namespace Microsoft.VisualStudio.TestPlatform.MSTestAdapter.PlatformServices; + +internal sealed class TestContextDictionary : IDictionary +{ + private ITestMethod? _testMethod; + + private IDictionary _currentDictionary; + private bool _isOriginalDictionary; + + public TestContextDictionary(IDictionary originalDictionary, ITestMethod? testMethod) + { + // IMPORTANT: TestContextDictionary shouldn't mutate the original dictionary. + // We keep a flag to track if we are using the original dictionary or a copy. + // The idea here is to avoid always creating a copy dictionary if users don't end up mutating the dictionary (common scenario). + _currentDictionary = originalDictionary; + _isOriginalDictionary = true; + _testMethod = testMethod; + } + + public object? this[string key] + { + get + { + if (key == TestContext.FullyQualifiedTestClassNameLabel) + { + return _testMethod?.FullClassName; + } + else if (key == TestContext.ManagedTypeLabel) + { + return _testMethod?.ManagedTypeName; + } + else if (key == TestContext.ManagedMethodLabel) + { + return _testMethod?.ManagedMethodName; + } + else if (key == TestContext.TestNameLabel) + { + return _testMethod?.Name; + } + + return _currentDictionary[key]; + } + + set + { + if (key == TestContext.FullyQualifiedTestClassNameLabel || + key == TestContext.ManagedTypeLabel || + key == TestContext.ManagedMethodLabel || + key == TestContext.TestNameLabel) + { + throw new InvalidOperationException(); + } + + if (_isOriginalDictionary) + { + _currentDictionary = new Dictionary(_currentDictionary); + _isOriginalDictionary = false; + } + + _currentDictionary[key] = value; + } + } + + private sealed class TestContextDictionaryKeyCollection : ICollection + { + private readonly TestContextDictionary _testContextDictionary; + + public TestContextDictionaryKeyCollection(TestContextDictionary testContextDictionary) + => _testContextDictionary = testContextDictionary; + + public int Count => _testContextDictionary.Count; + + public bool IsReadOnly => true; + + public void Add(string item) => throw new NotSupportedException(); + + public void Clear() => throw new NotSupportedException(); + + public bool Contains(string item) => _testContextDictionary.ContainsKey(item); + + public void CopyTo(string[] array, int arrayIndex) + { + if (array == null) + { + throw new ArgumentNullException(nameof(array)); + } + + if (arrayIndex < 0 || arrayIndex > array.Length) + { + throw new ArgumentOutOfRangeException(nameof(arrayIndex)); + } + + if (array.Length - arrayIndex < _testContextDictionary.Count) + { + throw new ArgumentException(); + } + + // TODO: + } + + public IEnumerator GetEnumerator() => throw new NotImplementedException(); + + public bool Remove(string item) => throw new NotSupportedException(); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + + public ICollection Keys => new TestContextDictionaryKeyCollection(this); + + public ICollection Values => throw new NotImplementedException(); + + public int Count => _currentDictionary.Count + + (_testMethod?.FullClassName is null ? 0 : 1) + + (_testMethod?.ManagedTypeName is null ? 0 : 1) + + (_testMethod?.ManagedMethodName is null ? 0 : 1) + + (_testMethod?.Name is null ? 0 : 1); + + public bool IsReadOnly => _currentDictionary.IsReadOnly; + + public void Add(string key, object? value) => throw new NotImplementedException(); + + public void Add(KeyValuePair item) + => Add(item.Key, item.Value); + + public void Clear() + { + _testMethod = null; + if (_isOriginalDictionary) + { + _currentDictionary = new Dictionary(); + _isOriginalDictionary = false; + } + else + { + _currentDictionary.Clear(); + } + } + + public bool Contains(KeyValuePair item) + => _currentDictionary.TryGetValue(item.Key, out object? value) && EqualityComparer.Default.Equals(value, item.Value); + + public bool ContainsKey(string key) + { + if (key == TestContext.FullyQualifiedTestClassNameLabel) + { + return _testMethod?.FullClassName is not null; + } + else if (key == TestContext.ManagedTypeLabel) + { + return _testMethod?.ManagedTypeName is not null; + } + else if (key == TestContext.ManagedMethodLabel) + { + return _testMethod?.ManagedMethodName is not null; + } + else if (key == TestContext.TestNameLabel) + { + return _testMethod?.Name is not null; + } + + return _currentDictionary.ContainsKey(key); + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new NotImplementedException(); + + public IEnumerator> GetEnumerator() => throw new NotImplementedException(); + + public bool Remove(string key) => throw new NotImplementedException(); + + public bool Remove(KeyValuePair item) => throw new NotImplementedException(); + + public bool TryGetValue(string key, out object? value) => throw new NotImplementedException(); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); +} From c3a3e4792d42e9f6518e7454908aaa02cdf0e265 Mon Sep 17 00:00:00 2001 From: Youssef1313 Date: Wed, 13 Aug 2025 14:33:26 +0200 Subject: [PATCH 2/2] Progress --- .../Services/TestContextDictionary.cs | 84 ++++++++++++++----- 1 file changed, 64 insertions(+), 20 deletions(-) diff --git a/src/Adapter/MSTestAdapter.PlatformServices/Services/TestContextDictionary.cs b/src/Adapter/MSTestAdapter.PlatformServices/Services/TestContextDictionary.cs index 18af976efe..2bdea82929 100644 --- a/src/Adapter/MSTestAdapter.PlatformServices/Services/TestContextDictionary.cs +++ b/src/Adapter/MSTestAdapter.PlatformServices/Services/TestContextDictionary.cs @@ -30,19 +30,19 @@ public object? this[string key] { if (key == TestContext.FullyQualifiedTestClassNameLabel) { - return _testMethod?.FullClassName; + return _testMethod?.FullClassName ?? throw new KeyNotFoundException(); } else if (key == TestContext.ManagedTypeLabel) { - return _testMethod?.ManagedTypeName; + return _testMethod?.ManagedTypeName ?? throw new KeyNotFoundException(); } else if (key == TestContext.ManagedMethodLabel) { - return _testMethod?.ManagedMethodName; + return _testMethod?.ManagedMethodName ?? throw new KeyNotFoundException(); } else if (key == TestContext.TestNameLabel) { - return _testMethod?.Name; + return _testMethod?.Name ?? throw new KeyNotFoundException(); } return _currentDictionary[key]; @@ -50,19 +50,8 @@ public object? this[string key] set { - if (key == TestContext.FullyQualifiedTestClassNameLabel || - key == TestContext.ManagedTypeLabel || - key == TestContext.ManagedMethodLabel || - key == TestContext.TestNameLabel) - { - throw new InvalidOperationException(); - } - - if (_isOriginalDictionary) - { - _currentDictionary = new Dictionary(_currentDictionary); - _isOriginalDictionary = false; - } + ThrowIfKnownKey(key); + CloneDictionaryIfNeeded(); _currentDictionary[key] = value; } @@ -124,7 +113,13 @@ public void CopyTo(string[] array, int arrayIndex) public bool IsReadOnly => _currentDictionary.IsReadOnly; - public void Add(string key, object? value) => throw new NotImplementedException(); + public void Add(string key, object? value) + { + ThrowIfKnownKey(key); + CloneDictionaryIfNeeded(); + + _currentDictionary.Add(key, value); + } public void Add(KeyValuePair item) => Add(item.Key, item.Value); @@ -172,11 +167,60 @@ public bool ContainsKey(string key) public IEnumerator> GetEnumerator() => throw new NotImplementedException(); - public bool Remove(string key) => throw new NotImplementedException(); + public bool Remove(string key) + { + ThrowIfKnownKey(key); + CloneDictionaryIfNeeded(); + return _currentDictionary.Remove(key); + } public bool Remove(KeyValuePair item) => throw new NotImplementedException(); - public bool TryGetValue(string key, out object? value) => throw new NotImplementedException(); + public bool TryGetValue(string key, out object? value) + { + if (key == TestContext.FullyQualifiedTestClassNameLabel) + { + value = _testMethod?.FullClassName; + return value is not null; + } + else if (key == TestContext.ManagedTypeLabel) + { + value = _testMethod?.ManagedTypeName; + return value is not null; + } + else if (key == TestContext.ManagedMethodLabel) + { + value = _testMethod?.ManagedMethodName; + return value is not null; + } + else if (key == TestContext.TestNameLabel) + { + value = _testMethod?.Name; + return value is not null; + } + + return _currentDictionary.TryGetValue(key, out value); + } IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + private static void ThrowIfKnownKey(string key) + { + if (key == TestContext.FullyQualifiedTestClassNameLabel || + key == TestContext.ManagedTypeLabel || + key == TestContext.ManagedMethodLabel || + key == TestContext.TestNameLabel) + { + throw new InvalidOperationException(); + } + } + + private void CloneDictionaryIfNeeded() + { + if (_isOriginalDictionary) + { + _currentDictionary = new Dictionary(_currentDictionary); + _isOriginalDictionary = false; + } + } }