diff --git a/scripts/performance/common.py b/scripts/performance/common.py
index df58c05570b..ed125e5ab7e 100644
--- a/scripts/performance/common.py
+++ b/scripts/performance/common.py
@@ -18,6 +18,7 @@
import os
import sys
import time
+import base64
from typing import Callable, List, Optional, Tuple, Type, TypeVar
@@ -139,6 +140,10 @@ def get_packages_directory() -> str:
'''
return os.path.join(get_artifacts_directory(), 'packages')
+def base64_to_bytes(base64_string: str) -> bytes:
+ byte_data = base64.b64decode(base64_string)
+ return byte_data
+
@contextmanager
def push_dir(path: Optional[str] = None):
'''
@@ -234,6 +239,7 @@ def __init__(
cmdline: List[str],
success_exit_codes: Optional[List[int]] = None,
verbose: bool = False,
+ echo: bool = True,
retry: int = 0):
if cmdline is None:
raise TypeError('Unspecified command line to be executed.')
@@ -243,6 +249,7 @@ def __init__(
self.__cmdline = cmdline
self.__verbose = verbose
self.__retry = retry
+ self.__echo = echo
if success_exit_codes is None:
self.__success_exit_codes = [0]
@@ -262,6 +269,11 @@ def success_exit_codes(self) -> List[int]:
'''
return self.__success_exit_codes
+ @property
+ def echo(self) -> bool:
+ '''Enables/Disables echoing of STDOUT'''
+ return self.__echo
+
@property
def verbose(self) -> bool:
'''Enables/Disables verbosity.'''
@@ -297,7 +309,8 @@ def __runinternal(self, working_directory: Optional[str] = None) -> Tuple[int, s
line = raw_line.decode('utf-8', errors='backslashreplace')
self.__stdout.write(line)
line = line.rstrip()
- getLogger().info(line)
+ if self.echo:
+ getLogger().info(line)
proc.wait()
return (proc.returncode, quoted_cmdline)
diff --git a/scripts/performance/constants.py b/scripts/performance/constants.py
index 02b63119988..bd29d394404 100644
--- a/scripts/performance/constants.py
+++ b/scripts/performance/constants.py
@@ -5,4 +5,5 @@
UPLOAD_STORAGE_URI = 'https://pvscmdupload.{}.core.windows.net'
UPLOAD_QUEUE = 'resultsqueue'
TENANT_ID = '72f988bf-86f1-41af-91ab-2d7cd011db47'
-CLIENT_ID = 'a231f733-103b-46e9-b58a-9416edde0eb4'
+ARC_CLIENT_ID = 'a231f733-103b-46e9-b58a-9416edde0eb4'
+CERT_CLIENT_ID = '8c4b65ef-5a73-4d5a-a298-962d4a4ef7bc'
\ No newline at end of file
diff --git a/scripts/run_performance_job.py b/scripts/run_performance_job.py
index 9ff3684825f..2d90bc48719 100644
--- a/scripts/run_performance_job.py
+++ b/scripts/run_performance_job.py
@@ -774,6 +774,32 @@ def run_performance_job(args: RunPerformanceJobArgs):
getLogger().info("Copying global.json to payload directory")
shutil.copy(os.path.join(args.performance_repo_dir, 'global.json'), os.path.join(performance_payload_dir, 'global.json'))
+ # Building CertHelper needs to happen here as we need it on every run. This also means that we will need to move the calculation
+ # of the parameters needed outside of the if block
+
+ framework = os.environ["PERFLAB_Framework"]
+ os.environ["PERFLAB_TARGET_FRAMEWORKS"] = framework
+ if args.os_group == "windows":
+ runtime_id = f"win-{args.architecture}"
+ elif args.os_group == "osx":
+ runtime_id = f"osx-{args.architecture}"
+ else:
+ runtime_id = f"linux-{args.architecture}"
+
+ dotnet_executable_path = os.path.join(ci_setup_arguments.install_dir, "dotnet")
+
+ RunCommand([
+ dotnet_executable_path, "publish",
+ "-c", "Release",
+ "-o", os.path.join(payload_dir, "certhelper"),
+ "-f", framework,
+ "-r", runtime_id,
+ "--self-contained",
+ os.path.join(args.performance_repo_dir, "src", "tools", "CertHelper", "CertHelper.csproj"),
+ f"/bl:{os.path.join(args.performance_repo_dir, 'artifacts', 'log', build_config, 'CertHelper.binlog')}",
+ "-p:DisableTransitiveFrameworkReferenceDownloads=true"],
+ verbose=True).run()
+
if args.is_scenario:
set_environment_variable("DOTNET_ROOT", ci_setup_arguments.install_dir, save_to_pipeline=True)
getLogger().info(f"Set DOTNET_ROOT to {ci_setup_arguments.install_dir}")
@@ -782,17 +808,6 @@ def run_performance_job(args: RunPerformanceJobArgs):
set_environment_variable("PATH", new_path, save_to_pipeline=True)
getLogger().info(f"Set PATH to {new_path}")
- framework = os.environ["PERFLAB_Framework"]
- os.environ["PERFLAB_TARGET_FRAMEWORKS"] = framework
- if args.os_group == "windows":
- runtime_id = f"win-{args.architecture}"
- elif args.os_group == "osx":
- runtime_id = f"osx-{args.architecture}"
- else:
- runtime_id = f"linux-{args.architecture}"
-
- dotnet_executable_path = os.path.join(ci_setup_arguments.install_dir, "dotnet")
-
os.environ["MSBUILDDISABLENODEREUSE"] = "1" # without this, MSbuild will be kept alive
# build Startup
diff --git a/scripts/upload.py b/scripts/upload.py
index 40d83e3156b..36bd40f57eb 100644
--- a/scripts/upload.py
+++ b/scripts/upload.py
@@ -3,11 +3,11 @@
from azure.storage.blob import BlobClient, ContentSettings
from azure.storage.queue import QueueClient, TextBase64EncodePolicy
from azure.core.exceptions import ResourceExistsError, ClientAuthenticationError
-from azure.identity import DefaultAzureCredential, ClientAssertionCredential
+from azure.identity import DefaultAzureCredential, ClientAssertionCredential, CertificateCredential
from traceback import format_exc
from glob import glob
-from performance.common import retry_on_exception
-from performance.constants import TENANT_ID, CLIENT_ID
+from performance.common import retry_on_exception, RunCommand, helixpayload, base64_to_bytes, extension
+from performance.constants import TENANT_ID, ARC_CLIENT_ID, CERT_CLIENT_ID
import os
import json
@@ -32,14 +32,25 @@ def upload(globpath: str, container: str, queue: str, sas_token_env: str, storag
credential = None
try:
dac = DefaultAzureCredential()
- credential = ClientAssertionCredential(TENANT_ID, CLIENT_ID, lambda: dac.get_token("api://AzureADTokenExchange/.default").token)
+ credential = ClientAssertionCredential(TENANT_ID, ARC_CLIENT_ID, lambda: dac.get_token("api://AzureADTokenExchange/.default").token)
credential.get_token("https://storage.azure.com/.default")
except ClientAuthenticationError as ex:
- getLogger().info("Unable to use managed identity. Falling back to environment variable.")
- credential = os.getenv(sas_token_env)
+ credential = None
+ getLogger().info("Unable to use managed identity. Falling back to certificate.")
+ cmd_line = [(os.path.join(str(helixpayload()), 'certhelper', "CertHelper%s" % extension()))]
+ cert_helper = RunCommand(cmd_line, None, True, False, 0)
+ cert_helper.run()
+ for cert in cert_helper.stdout.splitlines():
+ credential = CertificateCredential(TENANT_ID, CERT_CLIENT_ID, certificate_data=base64_to_bytes(cert))
+ try:
+ credential.get_token("https://storage.azure.com/.default")
+ except ClientAuthenticationError as ex:
+ credential = None
+ continue
if credential is None:
- getLogger().error("Sas token environment variable {} was not defined.".format(sas_token_env))
- return 1
+ getLogger().error("Unable to authenticate with managed identity or certificates.")
+ getLogger().info("Falling back to environment variable.")
+ credential = os.getenv(sas_token_env)
files = glob(globpath, recursive=True)
any_upload_or_queue_failed = False
diff --git a/src/tools/CertHelper/AssemblyInfo.cs b/src/tools/CertHelper/AssemblyInfo.cs
new file mode 100644
index 00000000000..edd1d789f59
--- /dev/null
+++ b/src/tools/CertHelper/AssemblyInfo.cs
@@ -0,0 +1,12 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Text;
+using System.Threading.Tasks;
+
+[assembly: InternalsVisibleTo("CertHelperTests")]
+namespace CertHelper;
+internal class AssemblyInfo
+{
+}
diff --git a/src/tools/CertHelper/CertHelper.csproj b/src/tools/CertHelper/CertHelper.csproj
new file mode 100644
index 00000000000..005a2ceb892
--- /dev/null
+++ b/src/tools/CertHelper/CertHelper.csproj
@@ -0,0 +1,20 @@
+
+
+
+ Exe
+ $(PERFLAB_TARGET_FRAMEWORKS)
+
+ net9.0
+ enable
+ enable
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/tools/CertHelper/CertHelper.sln b/src/tools/CertHelper/CertHelper.sln
new file mode 100644
index 00000000000..df485270020
--- /dev/null
+++ b/src/tools/CertHelper/CertHelper.sln
@@ -0,0 +1,28 @@
+
+Microsoft Visual Studio Solution File, Format Version 12.00
+# Visual Studio Version 17
+VisualStudioVersion = 17.12.35514.174 d17.12
+MinimumVisualStudioVersion = 10.0.40219.1
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "CertHelper", "CertHelper.csproj", "{165A37BD-2E9E-4D0A-8402-BB58C29A0BF4}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "CertRotatorTests", "..\CertHelperTests\CertRotatorTests.csproj", "{AEA0F93B-EC9B-4438-991E-A80C0C82B3D1}"
+EndProject
+Global
+ GlobalSection(SolutionConfigurationPlatforms) = preSolution
+ Debug|Any CPU = Debug|Any CPU
+ Release|Any CPU = Release|Any CPU
+ EndGlobalSection
+ GlobalSection(ProjectConfigurationPlatforms) = postSolution
+ {165A37BD-2E9E-4D0A-8402-BB58C29A0BF4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {165A37BD-2E9E-4D0A-8402-BB58C29A0BF4}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {165A37BD-2E9E-4D0A-8402-BB58C29A0BF4}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {165A37BD-2E9E-4D0A-8402-BB58C29A0BF4}.Release|Any CPU.Build.0 = Release|Any CPU
+ {AEA0F93B-EC9B-4438-991E-A80C0C82B3D1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {AEA0F93B-EC9B-4438-991E-A80C0C82B3D1}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {AEA0F93B-EC9B-4438-991E-A80C0C82B3D1}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {AEA0F93B-EC9B-4438-991E-A80C0C82B3D1}.Release|Any CPU.Build.0 = Release|Any CPU
+ EndGlobalSection
+ GlobalSection(SolutionProperties) = preSolution
+ HideSolutionNode = FALSE
+ EndGlobalSection
+EndGlobal
diff --git a/src/tools/CertHelper/Constants.cs b/src/tools/CertHelper/Constants.cs
new file mode 100644
index 00000000000..3ddad191905
--- /dev/null
+++ b/src/tools/CertHelper/Constants.cs
@@ -0,0 +1,14 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace CertHelper;
+public class Constants
+{
+ public static readonly string Cert1Name = "LabCert1";
+ public static readonly string Cert2Name = "LabCert2";
+ public static readonly Uri Cert1Id = new Uri("https://test.vault.azure.net/certificates/LabCert1/07a7d98bf4884e5c40e690e02b96b3b4");
+ public static readonly Uri Cert2Id = new Uri("https://test.vault.azure.net/certificates/LabCert2/07a7d98bf4884e5c41e690e02b96b3b4");
+}
diff --git a/src/tools/CertHelper/IX509Store.cs b/src/tools/CertHelper/IX509Store.cs
new file mode 100644
index 00000000000..2fd187d86dc
--- /dev/null
+++ b/src/tools/CertHelper/IX509Store.cs
@@ -0,0 +1,35 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Security.Cryptography.X509Certificates;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace CertHelper;
+public interface IX509Store
+{
+ X509Certificate2Collection Certificates { get; }
+ string? Name { get; }
+ StoreLocation Location { get; }
+ X509Store GetX509Store();
+}
+
+public class TestableX509Store : IX509Store
+{
+ public X509Certificate2Collection Certificates { get => store.Certificates; }
+
+ public string? Name => store.Name;
+
+ public StoreLocation Location => store.Location;
+
+ private X509Store store;
+ public TestableX509Store(OpenFlags flags = OpenFlags.ReadOnly)
+ {
+ store = new X509Store(StoreName.My, StoreLocation.CurrentUser, flags);
+ }
+
+ public X509Store GetX509Store()
+ {
+ return store;
+ }
+}
diff --git a/src/tools/CertHelper/KeyVaultCert.cs b/src/tools/CertHelper/KeyVaultCert.cs
new file mode 100644
index 00000000000..067e9025e94
--- /dev/null
+++ b/src/tools/CertHelper/KeyVaultCert.cs
@@ -0,0 +1,110 @@
+using Azure;
+using Azure.Core;
+using Azure.Identity;
+using Azure.Security.KeyVault.Certificates;
+using Azure.Security.KeyVault.Secrets;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Security.Cryptography.X509Certificates;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace CertHelper;
+
+public class KeyVaultCert
+{
+ private readonly string _keyVaultUrl = "https://dotnetperfkeyvault.vault.azure.net/";
+ private readonly string _tenantId = "72f988bf-86f1-41af-91ab-2d7cd011db47";
+ private readonly string _clientId = "8c4b65ef-5a73-4d5a-a298-962d4a4ef7bc";
+
+ public X509Certificate2Collection KeyVaultCertificates { get; set; }
+ public ILocalCert LocalCerts { get; set; }
+ private TokenCredential _credential { get; set; }
+ private CertificateClient _certClient { get; set; }
+ private SecretClient _secretClient { get; set; }
+
+ public KeyVaultCert(TokenCredential? cred = null, CertificateClient? certClient = null, SecretClient? secretClient = null, ILocalCert? localCerts = null)
+ {
+ LocalCerts = localCerts ?? new LocalCert();
+ _credential = cred ?? GetCertifcateCredentialAsync(_tenantId, _clientId, LocalCerts.Certificates).Result;
+ _certClient = certClient ?? new CertificateClient(new Uri(_keyVaultUrl), _credential);
+ _secretClient = secretClient ?? new SecretClient(new Uri(_keyVaultUrl), _credential);
+ KeyVaultCertificates = new X509Certificate2Collection();
+ }
+
+ public async Task LoadKeyVaultCertsAsync()
+ {
+ KeyVaultCertificates.Add(await FindCertificateInKeyVaultAsync(Constants.Cert1Name));
+ KeyVaultCertificates.Add(await FindCertificateInKeyVaultAsync(Constants.Cert2Name));
+
+ if (KeyVaultCertificates.Where(c => c == null).Count() > 0)
+ {
+ throw new Exception("One or more certificates not found");
+ }
+ }
+
+ private async Task GetCertifcateCredentialAsync(string tenantId, string clientId, X509Certificate2Collection certCollection)
+ {
+ ClientCertificateCredential? ccc = null;
+ Exception? exception = null;
+ foreach (var cert in certCollection)
+ {
+ try
+ {
+ ccc = new ClientCertificateCredential(tenantId, clientId, cert);
+ await ccc.GetTokenAsync(new TokenRequestContext(new string[] { "https://vault.azure.net/.default" }));
+ break;
+ }
+ catch (Exception ex)
+ {
+ ccc = null;
+ exception = ex;
+ }
+ }
+ if(ccc == null)
+ {
+ throw new Exception("Both certificates failed to authenticate", exception);
+ }
+ return ccc;
+ }
+
+ private async Task FindCertificateInKeyVaultAsync(string certName)
+ {
+ var keyVaultCert = await _certClient.GetCertificateAsync(certName);
+ if(keyVaultCert.Value == null)
+ {
+ throw new Exception("Certificate not found in Key Vault");
+ }
+ var secret = await _secretClient.GetSecretAsync(keyVaultCert.Value.Name, keyVaultCert.Value.SecretId.Segments.Last());
+ if(secret.Value == null)
+ {
+ throw new Exception("Certificate secret not found in Key Vault");
+ }
+ var certBytes = Convert.FromBase64String(secret.Value.Value);
+#if NET9_0_OR_GREATER
+ var cert = X509CertificateLoader.LoadPkcs12(certBytes, "", X509KeyStorageFlags.Exportable);
+#else
+ var cert = new X509Certificate2(certBytes, "", X509KeyStorageFlags.Exportable);
+#endif
+ return cert;
+ }
+
+ public bool ShouldRotateCerts()
+ {
+ var keyVaultThumbprints = new HashSet();
+ foreach (var cert in KeyVaultCertificates)
+ {
+ keyVaultThumbprints.Add(cert.Thumbprint);
+ }
+ foreach(var cert in LocalCerts.Certificates)
+ {
+ if (!keyVaultThumbprints.Contains(cert.Thumbprint))
+ {
+ return true;
+ }
+ }
+ return false;
+ }
+}
diff --git a/src/tools/CertHelper/LocalCert.cs b/src/tools/CertHelper/LocalCert.cs
new file mode 100644
index 00000000000..3459e7509df
--- /dev/null
+++ b/src/tools/CertHelper/LocalCert.cs
@@ -0,0 +1,44 @@
+using Azure.Security.KeyVault.Certificates;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Security.Cryptography.X509Certificates;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace CertHelper;
+
+public class LocalCert : ILocalCert
+{
+ public X509Certificate2Collection Certificates { get; set; }
+ internal IX509Store LocalMachineCerts { get; set; }
+
+ public LocalCert(IX509Store? store = null)
+ {
+ LocalMachineCerts = store ?? new TestableX509Store();
+ Certificates = new X509Certificate2Collection();
+ GetLocalCerts();
+ }
+
+ private void GetLocalCerts()
+ {
+ foreach (var cert in LocalMachineCerts.Certificates.Find(X509FindType.FindBySubjectName, "dotnetperf.microsoft.com", false))
+ {
+ if (cert.Subject == "CN=dotnetperf.microsoft.com")
+ {
+ Certificates.Add(cert);
+ }
+ }
+
+ if (Certificates.Count < 2 || Certificates.Where(c => c == null).Count() > 0)
+ {
+ throw new Exception("One or more certificates not found");
+ }
+ }
+}
+
+public interface ILocalCert
+{
+ X509Certificate2Collection Certificates { get; set; }
+}
diff --git a/src/tools/CertHelper/Program.cs b/src/tools/CertHelper/Program.cs
new file mode 100644
index 00000000000..71563a205e0
--- /dev/null
+++ b/src/tools/CertHelper/Program.cs
@@ -0,0 +1,67 @@
+using Azure.Identity;
+using Azure.Storage.Blobs;
+using Azure.Storage.Blobs.Specialized;
+using System.Security.Cryptography.X509Certificates;
+using System.Text;
+
+namespace CertHelper;
+
+internal class Program
+{
+ static readonly string TENANT_ID = "72f988bf-86f1-41af-91ab-2d7cd011db47";
+ static readonly string CERT_CLIENT_ID = "8c4b65ef-5a73-4d5a-a298-962d4a4ef7bc";
+ static async Task Main(string[] args)
+ {
+ try
+ {
+ var kvc = new KeyVaultCert();
+ await kvc.LoadKeyVaultCertsAsync();
+ if (kvc.ShouldRotateCerts())
+ {
+ using (var localMachineCerts = new X509Store(StoreName.My, StoreLocation.CurrentUser))
+ {
+ localMachineCerts.Open(OpenFlags.ReadWrite);
+ localMachineCerts.RemoveRange(kvc.LocalCerts.Certificates);
+ localMachineCerts.AddRange(kvc.KeyVaultCertificates);
+ }
+ }
+ var bcc = new BlobContainerClient(new Uri("https://pvscmdupload.blob.core.windows.net/certstatus"),
+ new ClientCertificateCredential(TENANT_ID, CERT_CLIENT_ID, kvc.KeyVaultCertificates.First()));
+ var currentKeyValutCertThumbprints = "";
+ foreach(var cert in kvc.KeyVaultCertificates)
+ {
+ currentKeyValutCertThumbprints += $"[{DateTimeOffset.UtcNow}] {cert.Thumbprint}{Environment.NewLine}";
+ }
+ var blob = bcc.GetBlobClient(System.Environment.MachineName);
+ if (blob.Exists())
+ {
+ var result = blob.DownloadContent();
+ var currentBlob = result.Value.Content.ToString();
+ currentBlob = currentBlob + currentKeyValutCertThumbprints;
+ blob.Upload(new MemoryStream(Encoding.UTF8.GetBytes(currentBlob)), overwrite: true);
+ }
+ else
+ {
+ blob.Upload(new MemoryStream(Encoding.UTF8.GetBytes(currentKeyValutCertThumbprints)), overwrite: false);
+ }
+
+ }
+ catch (Exception ex)
+ {
+ Console.WriteLine("Failed to rotate certificates");
+ Console.WriteLine(ex.Message);
+ Console.WriteLine(ex.StackTrace);
+ }
+
+
+
+ using (var store = new X509Store(StoreName.My, StoreLocation.CurrentUser, OpenFlags.ReadWrite))
+ {
+ foreach(var cert in store.Certificates.Find(X509FindType.FindBySubjectName, "dotnetperf.microsoft.com", false))
+ {
+ Console.WriteLine(Convert.ToBase64String(cert.Export(X509ContentType.Pfx)));
+ }
+ }
+ return 0;
+ }
+}
diff --git a/src/tools/CertHelperTests/CertRotatorTests.csproj b/src/tools/CertHelperTests/CertRotatorTests.csproj
new file mode 100644
index 00000000000..0ea831f1d18
--- /dev/null
+++ b/src/tools/CertHelperTests/CertRotatorTests.csproj
@@ -0,0 +1,30 @@
+
+
+
+ net9.0
+ enable
+ enable
+
+ false
+ true
+
+
+
+
+ all
+ runtime; build; native; contentfiles; analyzers; buildtransitive
+
+
+
+
+
+ all
+ runtime; build; native; contentfiles; analyzers; buildtransitive
+
+
+
+
+
+
+
+
diff --git a/src/tools/CertHelperTests/KeyVaultCertTests.cs b/src/tools/CertHelperTests/KeyVaultCertTests.cs
new file mode 100644
index 00000000000..7f9621352b4
--- /dev/null
+++ b/src/tools/CertHelperTests/KeyVaultCertTests.cs
@@ -0,0 +1,163 @@
+using System;
+using System.Linq;
+using System.Reflection;
+using System.Security.Cryptography;
+using System.Security.Cryptography.X509Certificates;
+using System.Threading.Tasks;
+using Azure;
+using Azure.Core;
+using Azure.Identity;
+using Azure.Security.KeyVault.Certificates;
+using Azure.Security.KeyVault.Secrets;
+using Moq;
+using Xunit;
+using Xunit.Sdk;
+
+namespace CertHelper.Tests;
+
+public class KeyVaultCertTests
+{
+ [Fact]
+ public async Task LoadKeyVaultCertsAsync_ShouldAddCertificatesToCollection()
+ {
+ // Arrange
+ Mock mockTokenCred;
+ Mock mockCertClient;
+ Mock mockSecretClient;
+ Mock mockLocalCert;
+ CertStoreSetup(out mockTokenCred, out mockCertClient, out mockSecretClient, out mockLocalCert, false);
+
+ var keyVaultCert = new KeyVaultCert(mockTokenCred.Object, mockCertClient.Object, mockSecretClient.Object, mockLocalCert.Object);
+
+ // Act
+ await keyVaultCert.LoadKeyVaultCertsAsync();
+
+ // Assert
+ Assert.Equal(2, keyVaultCert.KeyVaultCertificates.Count);
+ }
+
+ private static void CertStoreSetup(out Mock mockTokenCred, out Mock mockCertClient, out Mock mockSecretClient, out Mock mockLocalCert, bool missingKeyVaultCerts = false, bool localAndKeyVaultDifferent = false)
+ {
+ mockTokenCred = new Mock();
+ mockTokenCred.Setup(tc => tc.GetTokenAsync(It.IsAny(), default)).ReturnsAsync(new AccessToken("token", DateTimeOffset.Now));
+
+ mockCertClient = new Mock(new Uri("https://dotnetperfkeyvault.vault.azure.net/"), mockTokenCred.Object);
+ mockSecretClient = new Mock(new Uri("https://dotnetperfkeyvault.vault.azure.net/"), mockTokenCred.Object);
+ KeyVaultCertificateWithPolicy? mockCert1, mockCert2;
+ X509Certificate2 cert1, cert2, cert3;
+ MakeCerts(out mockCert1, out mockCert2, out cert1, out cert2, out cert3, localAndKeyVaultDifferent);
+
+ var certCollection = new X509Certificate2Collection { cert1, cert2 };
+
+ mockCertClient.Setup(c => c.GetCertificateAsync(Constants.Cert1Name, default)).ReturnsAsync(Response.FromValue(mockCert1, null!));
+ if (missingKeyVaultCerts)
+ {
+ mockCertClient.Setup(c => c.GetCertificateAsync(Constants.Cert2Name, default)).ReturnsAsync(Response.FromValue(null!, null!));
+ }
+ else
+ {
+ mockCertClient.Setup(c => c.GetCertificateAsync(Constants.Cert2Name, default)).ReturnsAsync(Response.FromValue(mockCert2, null!));
+ }
+
+ KeyVaultSecret secret1;
+ if(localAndKeyVaultDifferent)
+ {
+ secret1 = new KeyVaultSecret(Constants.Cert1Name, Convert.ToBase64String(cert3.Export(X509ContentType.Pfx)));
+ }
+ else
+ {
+ secret1 = new KeyVaultSecret(Constants.Cert1Name, Convert.ToBase64String(cert1.Export(X509ContentType.Pfx)));
+ }
+ var secret2 = new KeyVaultSecret(Constants.Cert2Name, Convert.ToBase64String(cert2.Export(X509ContentType.Pfx)));
+
+ mockSecretClient.Setup(s => s.GetSecretAsync(Constants.Cert1Name, mockCert1!.SecretId.Segments.Last(), default)).ReturnsAsync(Response.FromValue(secret1, null!));
+ mockSecretClient.Setup(s => s.GetSecretAsync(Constants.Cert2Name, mockCert2!.SecretId.Segments.Last(), default)).ReturnsAsync(Response.FromValue(secret2, null!));
+
+ mockLocalCert = new Mock();
+ mockLocalCert.Setup(lc => lc.Certificates).Returns(certCollection);
+ }
+
+ private static void MakeCerts(out KeyVaultCertificateWithPolicy? mockCert1, out KeyVaultCertificateWithPolicy? mockCert2, out X509Certificate2 cert1, out X509Certificate2 cert2, out X509Certificate2 cert3, bool localAndKeyVaultDifferent = false)
+ {
+ using var rsa1 = RSA.Create(); // generate asymmetric key pair
+ var req1 = new CertificateRequest("cn=perflabtest", rsa1, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
+ var tmpCert1 = req1.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5));
+ using var rsa2 = RSA.Create(); // generate asymmetric key pair
+ var req2 = new CertificateRequest("cn=perflabtest", rsa2, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
+ var tmpCert2 = req2.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5));
+ cert1 = new X509Certificate2(tmpCert1);
+ cert2 = new X509Certificate2(tmpCert2);
+
+ if(localAndKeyVaultDifferent)
+ {
+ using var rsa3 = RSA.Create(); // generate asymmetric key pair
+ var req = new CertificateRequest("cn=perflabtest", rsa3, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
+ tmpCert1 = req.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5));
+ }
+ cert3 = new X509Certificate2(tmpCert1);
+
+ mockCert1 = CertificateModelFactory.KeyVaultCertificateWithPolicy(CertificateModelFactory.CertificateProperties(Constants.Cert1Id, Constants.Cert1Name, x509thumbprint: Convert.FromHexString(tmpCert1.Thumbprint)),
+ Constants.Cert1Id, Constants.Cert1Id, tmpCert1.GetRawCertData());
+ mockCert2 = CertificateModelFactory.KeyVaultCertificateWithPolicy(CertificateModelFactory.CertificateProperties(Constants.Cert2Id, Constants.Cert2Name, x509thumbprint: Convert.FromHexString(tmpCert2.Thumbprint)),
+ Constants.Cert2Id, Constants.Cert2Id, tmpCert2.GetRawCertData());
+ }
+
+ [Fact]
+ public async Task LoadKeyVaultCertsAsync_ShouldThrowException_WhenCertificatesNotFound()
+ {
+ // Arrange
+ Mock mockTokenCred;
+ Mock mockCertClient;
+ Mock mockSecretClient;
+ Mock mockLocalCert;
+ CertStoreSetup(out mockTokenCred, out mockCertClient, out mockSecretClient, out mockLocalCert, missingKeyVaultCerts: true);
+
+ var keyVaultCert = new KeyVaultCert(mockTokenCred.Object, mockCertClient.Object, mockSecretClient.Object, mockLocalCert.Object);
+
+ // Act & Assert
+ await Assert.ThrowsAsync(() => keyVaultCert.LoadKeyVaultCertsAsync());
+ }
+
+ [Fact]
+ public async Task ShouldRotateCerts_ShouldReturnTrue_WhenThumbprintsDoNotMatch()
+ {
+ // Arrange
+ Mock mockTokenCred;
+ Mock mockCertClient;
+ Mock mockSecretClient;
+ Mock mockLocalCert;
+ CertStoreSetup(out mockTokenCred, out mockCertClient, out mockSecretClient, out mockLocalCert, localAndKeyVaultDifferent: true);
+
+ var keyVaultCert = new KeyVaultCert(mockTokenCred.Object, mockCertClient.Object, mockSecretClient.Object, mockLocalCert.Object);
+
+ // Act
+ await keyVaultCert.LoadKeyVaultCertsAsync();
+
+ var result = keyVaultCert.ShouldRotateCerts();
+
+ // Assert
+ Assert.True(result);
+ }
+
+ [Fact]
+ public async Task ShouldRotateCerts_ShouldReturnFalse_WhenThumbprintsMatch()
+ {
+ // Arrange
+ Mock mockTokenCred;
+ Mock mockCertClient;
+ Mock mockSecretClient;
+ Mock mockLocalCert;
+ CertStoreSetup(out mockTokenCred, out mockCertClient, out mockSecretClient, out mockLocalCert);
+
+ var keyVaultCert = new KeyVaultCert(mockTokenCred.Object, mockCertClient.Object, mockSecretClient.Object, mockLocalCert.Object);
+
+ // Act
+ await keyVaultCert.LoadKeyVaultCertsAsync();
+
+ var result = keyVaultCert.ShouldRotateCerts();
+
+ // Assert
+ Assert.False(result);
+ }
+}
+
diff --git a/src/tools/CertHelperTests/LocalCertTests.cs b/src/tools/CertHelperTests/LocalCertTests.cs
new file mode 100644
index 00000000000..fcc49ad596e
--- /dev/null
+++ b/src/tools/CertHelperTests/LocalCertTests.cs
@@ -0,0 +1,112 @@
+using System.Security.Cryptography.X509Certificates;
+using Xunit;
+using Moq;
+using Microsoft.VisualBasic;
+using System.Security.Cryptography;
+
+namespace CertHelper.Tests;
+
+public class LocalCertTests
+{
+ [Fact]
+ public void GetLocalCerts_ShouldAddCertificatesToCollection()
+ {
+ // Arrange
+ var mockStore = new Mock();
+ var ecdsa1 = ECDsa.Create(); // generate asymmetric key pair
+ var req1 = new CertificateRequest("CN=dotnetperf.microsoft.com", ecdsa1, HashAlgorithmName.SHA256);
+ var cert1 = req1.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5));
+ var ecdsa2 = ECDsa.Create(); // generate asymmetric key pair
+ var req2 = new CertificateRequest("CN=dotnetperf.microsoft.com", ecdsa2, HashAlgorithmName.SHA256);
+ var cert2 = req2.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5));
+
+ var mockCert1 = cert1;
+ var mockCert2 = cert2;
+
+ var certCollection = new X509Certificate2Collection { mockCert1, mockCert2 };
+ mockStore.Setup(s => s.Certificates).Returns(certCollection);
+
+
+
+ // Act
+ var localCert = new LocalCert(mockStore.Object);
+
+ // Assert
+ Assert.Equal(2, localCert.Certificates.Count);
+ Assert.Contains(mockCert1, localCert.Certificates);
+ Assert.Contains(mockCert2, localCert.Certificates);
+ }
+
+ [Fact]
+ public void GetLocalCerts_ShouldAddCertificatesToCollection_WhenSubjectMatches()
+ {
+ // Arrange
+ var mockStore = new Mock();
+ var ecdsa1 = ECDsa.Create(); // generate asymmetric key pair
+ var req1 = new CertificateRequest("CN=dotnetperf.microsoft.com", ecdsa1, HashAlgorithmName.SHA256);
+ var cert1 = req1.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5));
+ var ecdsa2 = ECDsa.Create(); // generate asymmetric key pair
+ var req2 = new CertificateRequest("CN=dotnetperf.microsoft.com", ecdsa2, HashAlgorithmName.SHA256);
+ var cert2 = req2.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5));
+
+ var mockCert1 = cert1;
+ var mockCert2 = cert2;
+
+ var certCollection = new X509Certificate2Collection { mockCert1, mockCert2 };
+ mockStore.Setup(s => s.Certificates).Returns(certCollection);
+
+ // Act
+ var localCert = new LocalCert(mockStore.Object);
+
+ // Assert
+ Assert.Equal(2, localCert.Certificates.Count);
+ Assert.Contains(mockCert1, localCert.Certificates);
+ Assert.Contains(mockCert2, localCert.Certificates);
+ }
+
+ [Fact]
+ public void GetLocalCerts_ShouldThrowException_WhenOneCertificateFound()
+ {
+ // Arrange
+ var mockStore = new Mock();
+ var ecdsa1 = ECDsa.Create(); // generate asymmetric key pair
+ var req1 = new CertificateRequest("CN=dotnetperf.microsoft.com", ecdsa1, HashAlgorithmName.SHA256);
+ var cert1 = req1.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5));
+ var certCollection = new X509Certificate2Collection { cert1 };
+ mockStore.Setup(s => s.Certificates).Returns(certCollection);
+
+ // Act & Assert
+ Assert.Throws(() => new LocalCert(mockStore.Object));
+ }
+
+ [Fact]
+ public void GetLocalCerts_ShouldThrowException_WhenCertificatesHaveWrongSubject()
+ {
+ // Arrange
+ var mockStore = new Mock();
+ var ecdsa1 = ECDsa.Create(); // generate asymmetric key pair
+ var req1 = new CertificateRequest("CN=dotnetperf.microsoft.co", ecdsa1, HashAlgorithmName.SHA256);
+ var cert1 = req1.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5));
+ var ecdsa2 = ECDsa.Create(); // generate asymmetric key pair
+ var req2 = new CertificateRequest("CN=dotnetperf.microsoft.co", ecdsa1, HashAlgorithmName.SHA256);
+ var cert2 = req1.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5));
+
+ var certCollection = new X509Certificate2Collection { cert1, cert2 };
+ mockStore.Setup(s => s.Certificates).Returns(certCollection);
+
+ // Act & Assert
+ Assert.Throws(() => new LocalCert(mockStore.Object));
+ }
+
+ [Fact]
+ public void GetLocalCerts_ShouldThrowException_WhenCertificatesNotFound()
+ {
+ // Arrange
+ var mockStore = new Mock();
+ var certCollection = new X509Certificate2Collection();
+ mockStore.Setup(s => s.Certificates).Returns(certCollection);
+
+ // Act & Assert
+ Assert.Throws(() => new LocalCert(mockStore.Object));
+ }
+}