Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce RetryAttribute for test methods #4586

Merged
merged 9 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 43 additions & 22 deletions src/Adapter/MSTest.TestAdapter/Execution/TestExecutionManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public void SendMessage(TestMessageLevel testMessageLevel, string message)
/// </summary>
private readonly IDictionary<string, object> _sessionParameters;
private readonly IEnvironment _environment;
private readonly Func<Action, Task> _taskFactory;
private readonly Func<Func<Task>, Task> _taskFactory;

/// <summary>
/// Specifies whether the test run is canceled or not.
Expand All @@ -51,15 +51,15 @@ public TestExecutionManager()
{
}

internal TestExecutionManager(IEnvironment environment, Func<Action, Task>? taskFactory = null)
internal TestExecutionManager(IEnvironment environment, Func<Func<Task>, Task>? taskFactory = null)
{
_testMethodFilter = new TestMethodFilter();
_sessionParameters = new Dictionary<string, object>();
_environment = environment;
_taskFactory = taskFactory ?? DefaultFactoryAsync;
}

private static Task DefaultFactoryAsync(Action action)
private static Task DefaultFactoryAsync(Func<Task> taskGetter)
{
if (MSTestSettings.RunConfigurationSettings.ExecutionApartmentState == ApartmentState.STA
&& RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
Expand All @@ -69,7 +69,9 @@ private static Task DefaultFactoryAsync(Action action)
{
try
{
action();
// This is best we can do to execute in STA thread.
Task task = taskGetter();
task.GetAwaiter().GetResult();
tcs.SetResult(0);
}
catch (Exception ex)
Expand All @@ -84,7 +86,8 @@ private static Task DefaultFactoryAsync(Action action)
}
else
{
return Task.Run(action);
// NOTE: If you replace this with `return taskGetter()`, you will break parallel tests.
return Task.Run(taskGetter);
}
}

Expand Down Expand Up @@ -121,7 +124,9 @@ public void RunTests(IEnumerable<TestCase> tests, IRunContext? runContext, IFram
CacheSessionParameters(runContext, frameworkHandle);

// Execute the tests
ExecuteTests(tests, runContext, frameworkHandle, isDeploymentDone);
// This is a public API, so we can't change it to be async.
// Consider not using this API internally, and introduce an async version, and mark this as obsolete.
ExecuteTestsAsync(tests, runContext, frameworkHandle, isDeploymentDone).GetAwaiter().GetResult();
Comment on lines +127 to +129
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I need to address that in a follow-up


if (!_hasAnyTestFailed)
{
Expand Down Expand Up @@ -159,7 +164,9 @@ public void RunTests(IEnumerable<string> sources, IRunContext? runContext, IFram
CacheSessionParameters(runContext, frameworkHandle);

// Run tests.
ExecuteTests(tests, runContext, frameworkHandle, isDeploymentDone);
// This is a public API, so we can't change it to be async.
// Consider not using this API internally, and introduce an async version, and mark this as obsolete.
ExecuteTestsAsync(tests, runContext, frameworkHandle, isDeploymentDone).GetAwaiter().GetResult();

if (!_hasAnyTestFailed)
{
Expand All @@ -174,7 +181,7 @@ public void RunTests(IEnumerable<string> sources, IRunContext? runContext, IFram
/// <param name="runContext">The run context.</param>
/// <param name="frameworkHandle">Handle to record test start/end/results.</param>
/// <param name="isDeploymentDone">Indicates if deployment is done.</param>
internal virtual void ExecuteTests(IEnumerable<TestCase> tests, IRunContext? runContext, IFrameworkHandle frameworkHandle, bool isDeploymentDone)
internal virtual async Task ExecuteTestsAsync(IEnumerable<TestCase> tests, IRunContext? runContext, IFrameworkHandle frameworkHandle, bool isDeploymentDone)
{
var testsBySource = from test in tests
group test by test.Source into testGroup
Expand All @@ -183,7 +190,7 @@ group test by test.Source into testGroup
foreach (var group in testsBySource)
{
_testRunCancellationToken?.ThrowIfCancellationRequested();
ExecuteTestsInSource(group.Tests, runContext, frameworkHandle, group.Source, isDeploymentDone);
await ExecuteTestsInSourceAsync(group.Tests, runContext, frameworkHandle, group.Source, isDeploymentDone);
}
}

Expand Down Expand Up @@ -257,7 +264,7 @@ private static bool MatchTestFilter(ITestCaseFilterExpression? filterExpression,
/// <param name="frameworkHandle">Handle to record test start/end/results.</param>
/// <param name="source">The test container for the tests.</param>
/// <param name="isDeploymentDone">Indicates if deployment is done.</param>
private void ExecuteTestsInSource(IEnumerable<TestCase> tests, IRunContext? runContext, IFrameworkHandle frameworkHandle, string source, bool isDeploymentDone)
private async Task ExecuteTestsInSourceAsync(IEnumerable<TestCase> tests, IRunContext? runContext, IFrameworkHandle frameworkHandle, string source, bool isDeploymentDone)
{
DebugEx.Assert(!StringEx.IsNullOrEmpty(source), "Source cannot be empty");

Expand All @@ -267,7 +274,7 @@ private void ExecuteTestsInSource(IEnumerable<TestCase> tests, IRunContext? runC
}

using MSTestAdapter.PlatformServices.Interface.ITestSourceHost isolationHost = PlatformServiceProvider.Instance.CreateTestSourceHost(source, runContext?.RunSettings, frameworkHandle);

bool usesAppDomains = isolationHost is MSTestAdapter.PlatformServices.TestSourceHost { UsesAppDomain: true };
PlatformServiceProvider.Instance.AdapterTraceLogger.LogInfo("Created unit-test runner {0}", source);

// Default test set is filtered tests based on user provided filter criteria
Expand Down Expand Up @@ -363,7 +370,7 @@ private void ExecuteTestsInSource(IEnumerable<TestCase> tests, IRunContext? runC
{
_testRunCancellationToken?.ThrowIfCancellationRequested();

tasks.Add(_taskFactory(() =>
tasks.Add(_taskFactory(async () =>
{
try
{
Expand All @@ -373,7 +380,7 @@ private void ExecuteTestsInSource(IEnumerable<TestCase> tests, IRunContext? runC

if (queue.TryDequeue(out IEnumerable<TestCase>? testSet))
{
ExecuteTestsWithTestRunner(testSet, frameworkHandle, source, sourceLevelParameters, testRunner);
await ExecuteTestsWithTestRunnerAsync(testSet, frameworkHandle, source, sourceLevelParameters, testRunner, usesAppDomains);
}
}
}
Expand All @@ -385,7 +392,7 @@ private void ExecuteTestsInSource(IEnumerable<TestCase> tests, IRunContext? runC

try
{
Task.WaitAll(tasks.ToArray());
await Task.WhenAll(tasks);
}
catch (Exception ex)
{
Expand All @@ -399,12 +406,12 @@ private void ExecuteTestsInSource(IEnumerable<TestCase> tests, IRunContext? runC
// Queue the non parallel set
if (nonParallelizableTestSet != null)
{
ExecuteTestsWithTestRunner(nonParallelizableTestSet, frameworkHandle, source, sourceLevelParameters, testRunner);
await ExecuteTestsWithTestRunnerAsync(nonParallelizableTestSet, frameworkHandle, source, sourceLevelParameters, testRunner, usesAppDomains);
}
}
else
{
ExecuteTestsWithTestRunner(testsToRun, frameworkHandle, source, sourceLevelParameters, testRunner);
await ExecuteTestsWithTestRunnerAsync(testsToRun, frameworkHandle, source, sourceLevelParameters, testRunner, usesAppDomains);
}

if (PlatformServiceProvider.Instance.IsGracefulStopRequested)
Expand All @@ -415,12 +422,13 @@ private void ExecuteTestsInSource(IEnumerable<TestCase> tests, IRunContext? runC
PlatformServiceProvider.Instance.AdapterTraceLogger.LogInfo("Executed tests belonging to source {0}", source);
}

private void ExecuteTestsWithTestRunner(
private async Task ExecuteTestsWithTestRunnerAsync(
IEnumerable<TestCase> tests,
ITestExecutionRecorder testExecutionRecorder,
string source,
IDictionary<string, object> sourceLevelParameters,
UnitTestRunner testRunner)
UnitTestRunner testRunner,
bool usesAppDomains)
{
bool hasAnyRunnableTests = false;
var fixtureTests = new List<TestCase>();
Expand All @@ -429,7 +437,11 @@ private void ExecuteTestsWithTestRunner(
? tests.OrderBy(t => t.GetManagedType()).ThenBy(t => t.GetManagedMethod())
: tests;

var remotingMessageLogger = new RemotingMessageLogger(testExecutionRecorder);
// If testRunner is in a different AppDomain, we cannot pass the testExecutionRecorder directly.
// Instead, we pass a proxy (remoting object) that is marshallable by ref.
IMessageLogger remotingMessageLogger = usesAppDomains
? new RemotingMessageLogger(testExecutionRecorder)
: testExecutionRecorder;

foreach (TestCase currentTest in orderedTests)
{
Expand Down Expand Up @@ -460,9 +472,18 @@ private void ExecuteTestsWithTestRunner(
IDictionary<TestProperty, object?> tcmProperties = TcmTestPropertiesProvider.GetTcmProperties(currentTest);
Dictionary<string, object?> testContextProperties = GetTestContextProperties(tcmProperties, sourceLevelParameters);

// testRunner could be in a different AppDomain. We cannot pass the testExecutionRecorder directly.
// Instead, we pass a proxy (remoting object) that is marshallable by ref.
UnitTestResult[] unitTestResult = testRunner.RunSingleTest(unitTestElement.TestMethod, testContextProperties, remotingMessageLogger);
UnitTestResult[] unitTestResult;
if (usesAppDomains)
{
#pragma warning disable VSTHRD103 // Call async methods when in an async method - We cannot do right now because we are crossing app domains.
// TODO: When app domains support is dropped, we can finally always be calling the async version.
unitTestResult = testRunner.RunSingleTest(unitTestElement.TestMethod, testContextProperties, remotingMessageLogger);
#pragma warning restore VSTHRD103 // Call async methods when in an async method
}
else
{
unitTestResult = await testRunner.RunSingleTestAsync(unitTestElement.TestMethod, testContextProperties, remotingMessageLogger);
}

PlatformServiceProvider.Instance.AdapterTraceLogger.LogInfo("Executed test {0}", unitTestElement.TestMethod.Name);

Expand Down
13 changes: 13 additions & 0 deletions src/Adapter/MSTest.TestAdapter/Execution/TestMethodInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ internal TestMethodInfo(
Parent = parent;
TestMethodOptions = testMethodOptions;
ExpectedException = ResolveExpectedException();
RetryAttribute = GetRetryAttribute();
}

/// <summary>
Expand Down Expand Up @@ -92,6 +93,8 @@ internal TestMethodInfo(

internal ExpectedExceptionBaseAttribute? ExpectedException { get; set; /*set for testing only*/ }

internal RetryAttribute? RetryAttribute { get; }

public Attribute[]? GetAllAttributes(bool inherit) => ReflectHelper.Instance.GetDerivedAttributes<Attribute>(TestMethod, inherit).ToArray();

public TAttributeType[] GetAttributes<TAttributeType>(bool inherit)
Expand Down Expand Up @@ -260,6 +263,16 @@ public virtual TestResult Invoke(object?[]? arguments)
return expectedExceptions.FirstOrDefault();
}

/// <summary>
/// Gets the number of retries this test method should make in case of failure.
/// </summary>
/// <returns>
/// The number of retries, which is always greater than or equal to 1.
/// If RetryAttribute is not present, returns 1.
Evangelink marked this conversation as resolved.
Show resolved Hide resolved
/// </returns>
private RetryAttribute? GetRetryAttribute()
=> ReflectHelper.Instance.GetFirstDerivedAttributeOrDefault<RetryAttribute>(TestMethod, inherit: true);
Youssef1313 marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Execute test without timeout.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ internal List<TestResult> RunTestMethod()
{
if (_test.TestDataSourceIgnoreMessage is not null)
{
_testContext.SetOutcome(UTF.UnitTestOutcome.Ignored);
return [new() { Outcome = UTF.UnitTestOutcome.Ignored, IgnoreReason = _test.TestDataSourceIgnoreMessage }];
}

Expand Down
26 changes: 24 additions & 2 deletions src/Adapter/MSTest.TestAdapter/Execution/UnitTestRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.VisualStudio.TestPlatform.MSTest.TestAdapter.Extensions;
using Microsoft.VisualStudio.TestPlatform.MSTest.TestAdapter.Helpers;
using Microsoft.VisualStudio.TestPlatform.MSTest.TestAdapter.ObjectModel;
using Microsoft.VisualStudio.TestPlatform.MSTestAdapter;
using Microsoft.VisualStudio.TestPlatform.MSTestAdapter.PlatformServices;
using Microsoft.VisualStudio.TestPlatform.MSTestAdapter.PlatformServices.Interface;
using Microsoft.VisualStudio.TestPlatform.ObjectModel.Logging;
Expand Down Expand Up @@ -122,13 +123,19 @@ internal FixtureTestResult GetFixtureTestResult(TestMethod testMethod, string fi
static UnitTestOutcome GetOutcome(Exception? exception) => exception == null ? UnitTestOutcome.Passed : UnitTestOutcome.Failed;
}

// Task cannot cross app domains.
// For now, TestExecutionManager will call this sync method which is hacky.
// If we removed AppDomains in v4, we should use the async method and remove this one.
internal UnitTestResult[] RunSingleTest(TestMethod testMethod, IDictionary<string, object?> testContextProperties, IMessageLogger messageLogger)
=> RunSingleTestAsync(testMethod, testContextProperties, messageLogger).GetAwaiter().GetResult();
Youssef1313 marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Runs a single test.
/// </summary>
/// <param name="testMethod"> The test Method. </param>
/// <param name="testContextProperties"> The test context properties. </param>
/// <returns> The <see cref="UnitTestResult"/>. </returns>
internal UnitTestResult[] RunSingleTest(TestMethod testMethod, IDictionary<string, object?> testContextProperties, IMessageLogger messageLogger)
internal async Task<UnitTestResult[]> RunSingleTestAsync(TestMethod testMethod, IDictionary<string, object?> testContextProperties, IMessageLogger messageLogger)
{
Guard.NotNull(testMethod);

Expand Down Expand Up @@ -179,8 +186,23 @@ internal UnitTestResult[] RunSingleTest(TestMethod testMethod, IDictionary<strin
{
// Run the test method
testContextForTestExecution.SetOutcome(testContextForClassInit.Context.CurrentTestOutcome);
RetryAttribute? retryAttribute = testMethodInfo.RetryAttribute;
var testMethodRunner = new TestMethodRunner(testMethodInfo, testMethod, testContextForTestExecution);
result = testMethodRunner.Execute(classInitializeResult.StandardOut!, classInitializeResult.StandardError!, classInitializeResult.DebugTrace!, classInitializeResult.TestContextMessages!).ToUnitTestResults();
List<TestResult> firstRunResult = testMethodRunner.Execute(classInitializeResult.StandardOut!, classInitializeResult.StandardError!, classInitializeResult.DebugTrace!, classInitializeResult.TestContextMessages!);
result = firstRunResult.ToUnitTestResults();
if (retryAttribute is not null && !RetryAttribute.IsAcceptableResultForRetry(firstRunResult))
{
RetryResult retryResult = await retryAttribute.ExecuteAsync(
new RetryContext(
() => Task.FromResult(
testMethodRunner.Execute(
classInitializeResult.StandardOut!,
classInitializeResult.StandardError!,
classInitializeResult.DebugTrace!,
classInitializeResult.TestContextMessages!).ToArray())));

result = retryResult.TryGetLast()?.ToUnitTestResults() ?? throw ApplicationStateGuard.Unreachable();
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ internal TestSourceHost(string sourceFileName, IRunSettings? runSettings, IFrame
internal AppDomain? AppDomain { get; private set; }
#endif

#pragma warning disable CA1822 // Mark members as static - accesses instance data under .NET Framework
internal bool UsesAppDomain =>
#if NETFRAMEWORK
!_isAppDomainCreationDisabled;
#else
false;
#endif

/// <summary>
/// Setup the isolation host.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Microsoft.VisualStudio.TestTools.UnitTesting;

/// <summary>
/// Specifies a backoff type for the delay between retries.
/// </summary>
public enum DelayBackoffType
{
/// <summary>
/// Specifies a constant backoff type. Meaning the delay between retries is constant.
/// </summary>
Constant,

/// <summary>
/// Specifies an exponential backoff type.
/// The delay is calculated as the base delay * 2^(n-1) where n is the retry attempt.
/// For example, if the base delay is 1000ms, the delays will be 1000ms, 2000ms, 4000ms, 8000ms, etc.
/// </summary>
Exponential,
}
Loading
Loading