Skip to content

Commit

Permalink
Introduce RetryAttribute for test methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Youssef1313 committed Jan 14, 2025
1 parent 63ec778 commit c45c457
Show file tree
Hide file tree
Showing 25 changed files with 432 additions and 188 deletions.
64 changes: 42 additions & 22 deletions src/Adapter/MSTest.TestAdapter/Execution/TestExecutionManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public void SendMessage(TestMessageLevel testMessageLevel, string message)
/// </summary>
private readonly IDictionary<string, object> _sessionParameters;
private readonly IEnvironment _environment;
private readonly Func<Action, Task> _taskFactory;
private readonly Func<Func<Task>, Task> _taskFactory;

/// <summary>
/// Specifies whether the test run is canceled or not.
Expand All @@ -51,15 +51,15 @@ public TestExecutionManager()
{
}

internal TestExecutionManager(IEnvironment environment, Func<Action, Task>? taskFactory = null)
internal TestExecutionManager(IEnvironment environment, Func<Func<Task>, Task>? taskFactory = null)
{
_testMethodFilter = new TestMethodFilter();
_sessionParameters = new Dictionary<string, object>();
_environment = environment;
_taskFactory = taskFactory ?? DefaultFactoryAsync;
}

private static Task DefaultFactoryAsync(Action action)
private static Task DefaultFactoryAsync(Func<Task> taskGetter)
{
if (MSTestSettings.RunConfigurationSettings.ExecutionApartmentState == ApartmentState.STA
&& RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
Expand All @@ -69,7 +69,9 @@ private static Task DefaultFactoryAsync(Action action)
{
try
{
action();
// This is best we can do to execute in STA thread.
Task task = taskGetter();
task.GetAwaiter().GetResult();
tcs.SetResult(0);
}
catch (Exception ex)
Expand All @@ -84,7 +86,7 @@ private static Task DefaultFactoryAsync(Action action)
}
else
{
return Task.Run(action);
return taskGetter();
}
}

Expand Down Expand Up @@ -121,7 +123,9 @@ public void RunTests(IEnumerable<TestCase> tests, IRunContext? runContext, IFram
CacheSessionParameters(runContext, frameworkHandle);

// Execute the tests
ExecuteTests(tests, runContext, frameworkHandle, isDeploymentDone);
// This is a public API, so we can't change it to be async.
// Consider not using this API internally, and introduce an async version, and mark this as obsolete.
ExecuteTestsAsync(tests, runContext, frameworkHandle, isDeploymentDone).GetAwaiter().GetResult();

if (!_hasAnyTestFailed)
{
Expand Down Expand Up @@ -159,7 +163,9 @@ public void RunTests(IEnumerable<string> sources, IRunContext? runContext, IFram
CacheSessionParameters(runContext, frameworkHandle);

// Run tests.
ExecuteTests(tests, runContext, frameworkHandle, isDeploymentDone);
// This is a public API, so we can't change it to be async.
// Consider not using this API internally, and introduce an async version, and mark this as obsolete.
ExecuteTestsAsync(tests, runContext, frameworkHandle, isDeploymentDone).GetAwaiter().GetResult();

if (!_hasAnyTestFailed)
{
Expand All @@ -174,7 +180,7 @@ public void RunTests(IEnumerable<string> sources, IRunContext? runContext, IFram
/// <param name="runContext">The run context.</param>
/// <param name="frameworkHandle">Handle to record test start/end/results.</param>
/// <param name="isDeploymentDone">Indicates if deployment is done.</param>
internal virtual void ExecuteTests(IEnumerable<TestCase> tests, IRunContext? runContext, IFrameworkHandle frameworkHandle, bool isDeploymentDone)
internal virtual async Task ExecuteTestsAsync(IEnumerable<TestCase> tests, IRunContext? runContext, IFrameworkHandle frameworkHandle, bool isDeploymentDone)
{
var testsBySource = from test in tests
group test by test.Source into testGroup
Expand All @@ -183,7 +189,7 @@ group test by test.Source into testGroup
foreach (var group in testsBySource)
{
_testRunCancellationToken?.ThrowIfCancellationRequested();
ExecuteTestsInSource(group.Tests, runContext, frameworkHandle, group.Source, isDeploymentDone);
await ExecuteTestsInSourceAsync(group.Tests, runContext, frameworkHandle, group.Source, isDeploymentDone);
}
}

Expand Down Expand Up @@ -257,7 +263,7 @@ private static bool MatchTestFilter(ITestCaseFilterExpression? filterExpression,
/// <param name="frameworkHandle">Handle to record test start/end/results.</param>
/// <param name="source">The test container for the tests.</param>
/// <param name="isDeploymentDone">Indicates if deployment is done.</param>
private void ExecuteTestsInSource(IEnumerable<TestCase> tests, IRunContext? runContext, IFrameworkHandle frameworkHandle, string source, bool isDeploymentDone)
private async Task ExecuteTestsInSourceAsync(IEnumerable<TestCase> tests, IRunContext? runContext, IFrameworkHandle frameworkHandle, string source, bool isDeploymentDone)
{
DebugEx.Assert(!StringEx.IsNullOrEmpty(source), "Source cannot be empty");

Expand All @@ -267,7 +273,7 @@ private void ExecuteTestsInSource(IEnumerable<TestCase> tests, IRunContext? runC
}

using MSTestAdapter.PlatformServices.Interface.ITestSourceHost isolationHost = PlatformServiceProvider.Instance.CreateTestSourceHost(source, runContext?.RunSettings, frameworkHandle);

bool usesAppDomains = isolationHost is MSTestAdapter.PlatformServices.TestSourceHost { UsesAppDomain: true };
PlatformServiceProvider.Instance.AdapterTraceLogger.LogInfo("Created unit-test runner {0}", source);

// Default test set is filtered tests based on user provided filter criteria
Expand Down Expand Up @@ -363,7 +369,7 @@ private void ExecuteTestsInSource(IEnumerable<TestCase> tests, IRunContext? runC
{
_testRunCancellationToken?.ThrowIfCancellationRequested();

tasks.Add(_taskFactory(() =>
tasks.Add(_taskFactory(async () =>
{
try
{
Expand All @@ -373,7 +379,7 @@ private void ExecuteTestsInSource(IEnumerable<TestCase> tests, IRunContext? runC

if (queue.TryDequeue(out IEnumerable<TestCase>? testSet))
{
ExecuteTestsWithTestRunner(testSet, frameworkHandle, source, sourceLevelParameters, testRunner);
await ExecuteTestsWithTestRunnerAsync(testSet, frameworkHandle, source, sourceLevelParameters, testRunner, usesAppDomains);
}
}
}
Expand All @@ -385,7 +391,7 @@ private void ExecuteTestsInSource(IEnumerable<TestCase> tests, IRunContext? runC

try
{
Task.WaitAll(tasks.ToArray());
await Task.WhenAll(tasks);
}
catch (Exception ex)
{
Expand All @@ -399,12 +405,12 @@ private void ExecuteTestsInSource(IEnumerable<TestCase> tests, IRunContext? runC
// Queue the non parallel set
if (nonParallelizableTestSet != null)
{
ExecuteTestsWithTestRunner(nonParallelizableTestSet, frameworkHandle, source, sourceLevelParameters, testRunner);
await ExecuteTestsWithTestRunnerAsync(nonParallelizableTestSet, frameworkHandle, source, sourceLevelParameters, testRunner, usesAppDomains);
}
}
else
{
ExecuteTestsWithTestRunner(testsToRun, frameworkHandle, source, sourceLevelParameters, testRunner);
await ExecuteTestsWithTestRunnerAsync(testsToRun, frameworkHandle, source, sourceLevelParameters, testRunner, usesAppDomains);
}

if (PlatformServiceProvider.Instance.IsGracefulStopRequested)
Expand All @@ -415,12 +421,13 @@ private void ExecuteTestsInSource(IEnumerable<TestCase> tests, IRunContext? runC
PlatformServiceProvider.Instance.AdapterTraceLogger.LogInfo("Executed tests belonging to source {0}", source);
}

private void ExecuteTestsWithTestRunner(
private async Task ExecuteTestsWithTestRunnerAsync(
IEnumerable<TestCase> tests,
ITestExecutionRecorder testExecutionRecorder,
string source,
IDictionary<string, object> sourceLevelParameters,
UnitTestRunner testRunner)
UnitTestRunner testRunner,
bool usesAppDomains)
{
bool hasAnyRunnableTests = false;
var fixtureTests = new List<TestCase>();
Expand All @@ -429,7 +436,11 @@ private void ExecuteTestsWithTestRunner(
? tests.OrderBy(t => t.GetManagedType()).ThenBy(t => t.GetManagedMethod())
: tests;

var remotingMessageLogger = new RemotingMessageLogger(testExecutionRecorder);
// If testRunner is in a different AppDomain, we cannot pass the testExecutionRecorder directly.
// Instead, we pass a proxy (remoting object) that is marshallable by ref.
IMessageLogger remotingMessageLogger = usesAppDomains
? new RemotingMessageLogger(testExecutionRecorder)
: testExecutionRecorder;

foreach (TestCase currentTest in orderedTests)
{
Expand Down Expand Up @@ -460,9 +471,18 @@ private void ExecuteTestsWithTestRunner(
IDictionary<TestProperty, object?> tcmProperties = TcmTestPropertiesProvider.GetTcmProperties(currentTest);
Dictionary<string, object?> testContextProperties = GetTestContextProperties(tcmProperties, sourceLevelParameters);

// testRunner could be in a different AppDomain. We cannot pass the testExecutionRecorder directly.
// Instead, we pass a proxy (remoting object) that is marshallable by ref.
UnitTestResult[] unitTestResult = testRunner.RunSingleTest(unitTestElement.TestMethod, testContextProperties, remotingMessageLogger);
UnitTestResult[] unitTestResult;
if (usesAppDomains)
{
#pragma warning disable VSTHRD103 // Call async methods when in an async method - We cannot do right now because we are crossing app domains.
// TODO: When app domains support is dropped, we can finally always be calling the async version.
unitTestResult = testRunner.RunSingleTest(unitTestElement.TestMethod, testContextProperties, remotingMessageLogger);
#pragma warning restore VSTHRD103 // Call async methods when in an async method
}
else
{
unitTestResult = await testRunner.RunSingleTestAsync(unitTestElement.TestMethod, testContextProperties, remotingMessageLogger);
}

PlatformServiceProvider.Instance.AdapterTraceLogger.LogInfo("Executed test {0}", unitTestElement.TestMethod.Name);

Expand Down
13 changes: 13 additions & 0 deletions src/Adapter/MSTest.TestAdapter/Execution/TestMethodInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ internal TestMethodInfo(
Parent = parent;
TestMethodOptions = testMethodOptions;
ExpectedException = ResolveExpectedException();
RetryAttribute = GetRetryAttribute();
}

/// <summary>
Expand Down Expand Up @@ -92,6 +93,8 @@ internal TestMethodInfo(

internal ExpectedExceptionBaseAttribute? ExpectedException { get; set; /*set for testing only*/ }

internal RetryAttribute? RetryAttribute { get; }

public Attribute[]? GetAllAttributes(bool inherit) => ReflectHelper.Instance.GetDerivedAttributes<Attribute>(TestMethod, inherit).ToArray();

public TAttributeType[] GetAttributes<TAttributeType>(bool inherit)
Expand Down Expand Up @@ -260,6 +263,16 @@ public virtual TestResult Invoke(object?[]? arguments)
return expectedExceptions.FirstOrDefault();
}

/// <summary>
/// Gets the number of retries this test method should make in case of failure.
/// </summary>
/// <returns>
/// The number of retries, which is always greater than or equal to 1.
/// If RetryAttribute is not present, returns 1.
/// </returns>
private RetryAttribute? GetRetryAttribute()
=> ReflectHelper.Instance.GetFirstDerivedAttributeOrDefault<RetryAttribute>(TestMethod, inherit: true);

/// <summary>
/// Execute test without timeout.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ internal List<TestResult> RunTestMethod()
{
if (_test.TestDataSourceIgnoreMessage is not null)
{
_testContext.SetOutcome(UTF.UnitTestOutcome.Ignored);
return [new() { Outcome = UTF.UnitTestOutcome.Ignored, IgnoreReason = _test.TestDataSourceIgnoreMessage }];
}

Expand Down
26 changes: 24 additions & 2 deletions src/Adapter/MSTest.TestAdapter/Execution/UnitTestRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.VisualStudio.TestPlatform.MSTest.TestAdapter.Extensions;
using Microsoft.VisualStudio.TestPlatform.MSTest.TestAdapter.Helpers;
using Microsoft.VisualStudio.TestPlatform.MSTest.TestAdapter.ObjectModel;
using Microsoft.VisualStudio.TestPlatform.MSTestAdapter;
using Microsoft.VisualStudio.TestPlatform.MSTestAdapter.PlatformServices;
using Microsoft.VisualStudio.TestPlatform.MSTestAdapter.PlatformServices.Interface;
using Microsoft.VisualStudio.TestPlatform.ObjectModel.Logging;
Expand Down Expand Up @@ -122,13 +123,19 @@ internal FixtureTestResult GetFixtureTestResult(TestMethod testMethod, string fi
static UnitTestOutcome GetOutcome(Exception? exception) => exception == null ? UnitTestOutcome.Passed : UnitTestOutcome.Failed;
}

// Task cannot cross app domains.
// For now, TestExecutionManager will call this sync method which is hacky.
// If we removed AppDomains in v4, we should use the async method and remove this one.
internal UnitTestResult[] RunSingleTest(TestMethod testMethod, IDictionary<string, object?> testContextProperties, IMessageLogger messageLogger)
=> RunSingleTestAsync(testMethod, testContextProperties, messageLogger).GetAwaiter().GetResult();

/// <summary>
/// Runs a single test.
/// </summary>
/// <param name="testMethod"> The test Method. </param>
/// <param name="testContextProperties"> The test context properties. </param>
/// <returns> The <see cref="UnitTestResult"/>. </returns>
internal UnitTestResult[] RunSingleTest(TestMethod testMethod, IDictionary<string, object?> testContextProperties, IMessageLogger messageLogger)
internal async Task<UnitTestResult[]> RunSingleTestAsync(TestMethod testMethod, IDictionary<string, object?> testContextProperties, IMessageLogger messageLogger)
{
Guard.NotNull(testMethod);

Expand Down Expand Up @@ -179,8 +186,23 @@ internal UnitTestResult[] RunSingleTest(TestMethod testMethod, IDictionary<strin
{
// Run the test method
testContextForTestExecution.SetOutcome(testContextForClassInit.Context.CurrentTestOutcome);
RetryAttribute? retryAttribute = testMethodInfo.RetryAttribute;
var testMethodRunner = new TestMethodRunner(testMethodInfo, testMethod, testContextForTestExecution);
result = testMethodRunner.Execute(classInitializeResult.StandardOut!, classInitializeResult.StandardError!, classInitializeResult.DebugTrace!, classInitializeResult.TestContextMessages!).ToUnitTestResults();
List<TestResult> firstRunResult = testMethodRunner.Execute(classInitializeResult.StandardOut!, classInitializeResult.StandardError!, classInitializeResult.DebugTrace!, classInitializeResult.TestContextMessages!);
result = firstRunResult.ToUnitTestResults();
if (retryAttribute is not null && !RetryAttribute.IsAcceptableResultForRetry(firstRunResult))
{
RetryResult retryResult = await retryAttribute.ExecuteAsync(
new RetryContext(
() => Task.FromResult(
testMethodRunner.Execute(
classInitializeResult.StandardOut!,
classInitializeResult.StandardError!,
classInitializeResult.DebugTrace!,
classInitializeResult.TestContextMessages!).ToArray())));

result = retryResult.TryGetLast()?.ToUnitTestResults() ?? throw ApplicationStateGuard.Unreachable();
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ internal TestSourceHost(string sourceFileName, IRunSettings? runSettings, IFrame
internal AppDomain? AppDomain { get; private set; }
#endif

#pragma warning disable CA1822 // Mark members as static - accesses instance data under .NET Framework
internal bool UsesAppDomain =>
#if NETFRAMEWORK
!_isAppDomainCreationDisabled;
#else
false;
#endif

/// <summary>
/// Setup the isolation host.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Microsoft.VisualStudio.TestTools.UnitTesting;

/// <summary>
/// Specifies a backoff type for the delay between retries.
/// </summary>
public enum DelayBackoffType
{
/// <summary>
/// Specifies a constant backoff type. Meaning the delay between retries is constant.
/// </summary>
Constant,

/// <summary>
/// Specifies an exponential backoff type.
/// The delay is calculated as the base delay * 2^(n-1) where n is the retry attempt.
/// For example, if the base delay is 1000ms, the delays will be 1000ms, 2000ms, 4000ms, 8000ms, etc.
/// </summary>
Exponential,
}
Loading

0 comments on commit c45c457

Please sign in to comment.