diff --git a/scripts/performance/common.py b/scripts/performance/common.py index df58c05570b..ed125e5ab7e 100644 --- a/scripts/performance/common.py +++ b/scripts/performance/common.py @@ -18,6 +18,7 @@ import os import sys import time +import base64 from typing import Callable, List, Optional, Tuple, Type, TypeVar @@ -139,6 +140,10 @@ def get_packages_directory() -> str: ''' return os.path.join(get_artifacts_directory(), 'packages') +def base64_to_bytes(base64_string: str) -> bytes: + byte_data = base64.b64decode(base64_string) + return byte_data + @contextmanager def push_dir(path: Optional[str] = None): ''' @@ -234,6 +239,7 @@ def __init__( cmdline: List[str], success_exit_codes: Optional[List[int]] = None, verbose: bool = False, + echo: bool = True, retry: int = 0): if cmdline is None: raise TypeError('Unspecified command line to be executed.') @@ -243,6 +249,7 @@ def __init__( self.__cmdline = cmdline self.__verbose = verbose self.__retry = retry + self.__echo = echo if success_exit_codes is None: self.__success_exit_codes = [0] @@ -262,6 +269,11 @@ def success_exit_codes(self) -> List[int]: ''' return self.__success_exit_codes + @property + def echo(self) -> bool: + '''Enables/Disables echoing of STDOUT''' + return self.__echo + @property def verbose(self) -> bool: '''Enables/Disables verbosity.''' @@ -297,7 +309,8 @@ def __runinternal(self, working_directory: Optional[str] = None) -> Tuple[int, s line = raw_line.decode('utf-8', errors='backslashreplace') self.__stdout.write(line) line = line.rstrip() - getLogger().info(line) + if self.echo: + getLogger().info(line) proc.wait() return (proc.returncode, quoted_cmdline) diff --git a/scripts/performance/constants.py b/scripts/performance/constants.py index 02b63119988..bd29d394404 100644 --- a/scripts/performance/constants.py +++ b/scripts/performance/constants.py @@ -5,4 +5,5 @@ UPLOAD_STORAGE_URI = 'https://pvscmdupload.{}.core.windows.net' UPLOAD_QUEUE = 'resultsqueue' TENANT_ID = '72f988bf-86f1-41af-91ab-2d7cd011db47' -CLIENT_ID = 'a231f733-103b-46e9-b58a-9416edde0eb4' +ARC_CLIENT_ID = 'a231f733-103b-46e9-b58a-9416edde0eb4' +CERT_CLIENT_ID = '8c4b65ef-5a73-4d5a-a298-962d4a4ef7bc' \ No newline at end of file diff --git a/scripts/run_performance_job.py b/scripts/run_performance_job.py index 9ff3684825f..2d90bc48719 100644 --- a/scripts/run_performance_job.py +++ b/scripts/run_performance_job.py @@ -774,6 +774,32 @@ def run_performance_job(args: RunPerformanceJobArgs): getLogger().info("Copying global.json to payload directory") shutil.copy(os.path.join(args.performance_repo_dir, 'global.json'), os.path.join(performance_payload_dir, 'global.json')) + # Building CertHelper needs to happen here as we need it on every run. This also means that we will need to move the calculation + # of the parameters needed outside of the if block + + framework = os.environ["PERFLAB_Framework"] + os.environ["PERFLAB_TARGET_FRAMEWORKS"] = framework + if args.os_group == "windows": + runtime_id = f"win-{args.architecture}" + elif args.os_group == "osx": + runtime_id = f"osx-{args.architecture}" + else: + runtime_id = f"linux-{args.architecture}" + + dotnet_executable_path = os.path.join(ci_setup_arguments.install_dir, "dotnet") + + RunCommand([ + dotnet_executable_path, "publish", + "-c", "Release", + "-o", os.path.join(payload_dir, "certhelper"), + "-f", framework, + "-r", runtime_id, + "--self-contained", + os.path.join(args.performance_repo_dir, "src", "tools", "CertHelper", "CertHelper.csproj"), + f"/bl:{os.path.join(args.performance_repo_dir, 'artifacts', 'log', build_config, 'CertHelper.binlog')}", + "-p:DisableTransitiveFrameworkReferenceDownloads=true"], + verbose=True).run() + if args.is_scenario: set_environment_variable("DOTNET_ROOT", ci_setup_arguments.install_dir, save_to_pipeline=True) getLogger().info(f"Set DOTNET_ROOT to {ci_setup_arguments.install_dir}") @@ -782,17 +808,6 @@ def run_performance_job(args: RunPerformanceJobArgs): set_environment_variable("PATH", new_path, save_to_pipeline=True) getLogger().info(f"Set PATH to {new_path}") - framework = os.environ["PERFLAB_Framework"] - os.environ["PERFLAB_TARGET_FRAMEWORKS"] = framework - if args.os_group == "windows": - runtime_id = f"win-{args.architecture}" - elif args.os_group == "osx": - runtime_id = f"osx-{args.architecture}" - else: - runtime_id = f"linux-{args.architecture}" - - dotnet_executable_path = os.path.join(ci_setup_arguments.install_dir, "dotnet") - os.environ["MSBUILDDISABLENODEREUSE"] = "1" # without this, MSbuild will be kept alive # build Startup diff --git a/scripts/upload.py b/scripts/upload.py index 40d83e3156b..36bd40f57eb 100644 --- a/scripts/upload.py +++ b/scripts/upload.py @@ -3,11 +3,11 @@ from azure.storage.blob import BlobClient, ContentSettings from azure.storage.queue import QueueClient, TextBase64EncodePolicy from azure.core.exceptions import ResourceExistsError, ClientAuthenticationError -from azure.identity import DefaultAzureCredential, ClientAssertionCredential +from azure.identity import DefaultAzureCredential, ClientAssertionCredential, CertificateCredential from traceback import format_exc from glob import glob -from performance.common import retry_on_exception -from performance.constants import TENANT_ID, CLIENT_ID +from performance.common import retry_on_exception, RunCommand, helixpayload, base64_to_bytes, extension +from performance.constants import TENANT_ID, ARC_CLIENT_ID, CERT_CLIENT_ID import os import json @@ -32,14 +32,25 @@ def upload(globpath: str, container: str, queue: str, sas_token_env: str, storag credential = None try: dac = DefaultAzureCredential() - credential = ClientAssertionCredential(TENANT_ID, CLIENT_ID, lambda: dac.get_token("api://AzureADTokenExchange/.default").token) + credential = ClientAssertionCredential(TENANT_ID, ARC_CLIENT_ID, lambda: dac.get_token("api://AzureADTokenExchange/.default").token) credential.get_token("https://storage.azure.com/.default") except ClientAuthenticationError as ex: - getLogger().info("Unable to use managed identity. Falling back to environment variable.") - credential = os.getenv(sas_token_env) + credential = None + getLogger().info("Unable to use managed identity. Falling back to certificate.") + cmd_line = [(os.path.join(str(helixpayload()), 'certhelper', "CertHelper%s" % extension()))] + cert_helper = RunCommand(cmd_line, None, True, False, 0) + cert_helper.run() + for cert in cert_helper.stdout.splitlines(): + credential = CertificateCredential(TENANT_ID, CERT_CLIENT_ID, certificate_data=base64_to_bytes(cert)) + try: + credential.get_token("https://storage.azure.com/.default") + except ClientAuthenticationError as ex: + credential = None + continue if credential is None: - getLogger().error("Sas token environment variable {} was not defined.".format(sas_token_env)) - return 1 + getLogger().error("Unable to authenticate with managed identity or certificates.") + getLogger().info("Falling back to environment variable.") + credential = os.getenv(sas_token_env) files = glob(globpath, recursive=True) any_upload_or_queue_failed = False diff --git a/src/tools/CertHelper/AssemblyInfo.cs b/src/tools/CertHelper/AssemblyInfo.cs new file mode 100644 index 00000000000..edd1d789f59 --- /dev/null +++ b/src/tools/CertHelper/AssemblyInfo.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Tasks; + +[assembly: InternalsVisibleTo("CertHelperTests")] +namespace CertHelper; +internal class AssemblyInfo +{ +} diff --git a/src/tools/CertHelper/CertHelper.csproj b/src/tools/CertHelper/CertHelper.csproj new file mode 100644 index 00000000000..005a2ceb892 --- /dev/null +++ b/src/tools/CertHelper/CertHelper.csproj @@ -0,0 +1,20 @@ + + + + Exe + $(PERFLAB_TARGET_FRAMEWORKS) + + net9.0 + enable + enable + + + + + + + + + + + diff --git a/src/tools/CertHelper/CertHelper.sln b/src/tools/CertHelper/CertHelper.sln new file mode 100644 index 00000000000..df485270020 --- /dev/null +++ b/src/tools/CertHelper/CertHelper.sln @@ -0,0 +1,28 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.12.35514.174 d17.12 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "CertHelper", "CertHelper.csproj", "{165A37BD-2E9E-4D0A-8402-BB58C29A0BF4}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "CertRotatorTests", "..\CertHelperTests\CertRotatorTests.csproj", "{AEA0F93B-EC9B-4438-991E-A80C0C82B3D1}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {165A37BD-2E9E-4D0A-8402-BB58C29A0BF4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {165A37BD-2E9E-4D0A-8402-BB58C29A0BF4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {165A37BD-2E9E-4D0A-8402-BB58C29A0BF4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {165A37BD-2E9E-4D0A-8402-BB58C29A0BF4}.Release|Any CPU.Build.0 = Release|Any CPU + {AEA0F93B-EC9B-4438-991E-A80C0C82B3D1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {AEA0F93B-EC9B-4438-991E-A80C0C82B3D1}.Debug|Any CPU.Build.0 = Debug|Any CPU + {AEA0F93B-EC9B-4438-991E-A80C0C82B3D1}.Release|Any CPU.ActiveCfg = Release|Any CPU + {AEA0F93B-EC9B-4438-991E-A80C0C82B3D1}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection +EndGlobal diff --git a/src/tools/CertHelper/Constants.cs b/src/tools/CertHelper/Constants.cs new file mode 100644 index 00000000000..3ddad191905 --- /dev/null +++ b/src/tools/CertHelper/Constants.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace CertHelper; +public class Constants +{ + public static readonly string Cert1Name = "LabCert1"; + public static readonly string Cert2Name = "LabCert2"; + public static readonly Uri Cert1Id = new Uri("https://test.vault.azure.net/certificates/LabCert1/07a7d98bf4884e5c40e690e02b96b3b4"); + public static readonly Uri Cert2Id = new Uri("https://test.vault.azure.net/certificates/LabCert2/07a7d98bf4884e5c41e690e02b96b3b4"); +} diff --git a/src/tools/CertHelper/IX509Store.cs b/src/tools/CertHelper/IX509Store.cs new file mode 100644 index 00000000000..2fd187d86dc --- /dev/null +++ b/src/tools/CertHelper/IX509Store.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Threading.Tasks; + +namespace CertHelper; +public interface IX509Store +{ + X509Certificate2Collection Certificates { get; } + string? Name { get; } + StoreLocation Location { get; } + X509Store GetX509Store(); +} + +public class TestableX509Store : IX509Store +{ + public X509Certificate2Collection Certificates { get => store.Certificates; } + + public string? Name => store.Name; + + public StoreLocation Location => store.Location; + + private X509Store store; + public TestableX509Store(OpenFlags flags = OpenFlags.ReadOnly) + { + store = new X509Store(StoreName.My, StoreLocation.CurrentUser, flags); + } + + public X509Store GetX509Store() + { + return store; + } +} diff --git a/src/tools/CertHelper/KeyVaultCert.cs b/src/tools/CertHelper/KeyVaultCert.cs new file mode 100644 index 00000000000..067e9025e94 --- /dev/null +++ b/src/tools/CertHelper/KeyVaultCert.cs @@ -0,0 +1,110 @@ +using Azure; +using Azure.Core; +using Azure.Identity; +using Azure.Security.KeyVault.Certificates; +using Azure.Security.KeyVault.Secrets; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Threading.Tasks; + +namespace CertHelper; + +public class KeyVaultCert +{ + private readonly string _keyVaultUrl = "https://dotnetperfkeyvault.vault.azure.net/"; + private readonly string _tenantId = "72f988bf-86f1-41af-91ab-2d7cd011db47"; + private readonly string _clientId = "8c4b65ef-5a73-4d5a-a298-962d4a4ef7bc"; + + public X509Certificate2Collection KeyVaultCertificates { get; set; } + public ILocalCert LocalCerts { get; set; } + private TokenCredential _credential { get; set; } + private CertificateClient _certClient { get; set; } + private SecretClient _secretClient { get; set; } + + public KeyVaultCert(TokenCredential? cred = null, CertificateClient? certClient = null, SecretClient? secretClient = null, ILocalCert? localCerts = null) + { + LocalCerts = localCerts ?? new LocalCert(); + _credential = cred ?? GetCertifcateCredentialAsync(_tenantId, _clientId, LocalCerts.Certificates).Result; + _certClient = certClient ?? new CertificateClient(new Uri(_keyVaultUrl), _credential); + _secretClient = secretClient ?? new SecretClient(new Uri(_keyVaultUrl), _credential); + KeyVaultCertificates = new X509Certificate2Collection(); + } + + public async Task LoadKeyVaultCertsAsync() + { + KeyVaultCertificates.Add(await FindCertificateInKeyVaultAsync(Constants.Cert1Name)); + KeyVaultCertificates.Add(await FindCertificateInKeyVaultAsync(Constants.Cert2Name)); + + if (KeyVaultCertificates.Where(c => c == null).Count() > 0) + { + throw new Exception("One or more certificates not found"); + } + } + + private async Task GetCertifcateCredentialAsync(string tenantId, string clientId, X509Certificate2Collection certCollection) + { + ClientCertificateCredential? ccc = null; + Exception? exception = null; + foreach (var cert in certCollection) + { + try + { + ccc = new ClientCertificateCredential(tenantId, clientId, cert); + await ccc.GetTokenAsync(new TokenRequestContext(new string[] { "https://vault.azure.net/.default" })); + break; + } + catch (Exception ex) + { + ccc = null; + exception = ex; + } + } + if(ccc == null) + { + throw new Exception("Both certificates failed to authenticate", exception); + } + return ccc; + } + + private async Task FindCertificateInKeyVaultAsync(string certName) + { + var keyVaultCert = await _certClient.GetCertificateAsync(certName); + if(keyVaultCert.Value == null) + { + throw new Exception("Certificate not found in Key Vault"); + } + var secret = await _secretClient.GetSecretAsync(keyVaultCert.Value.Name, keyVaultCert.Value.SecretId.Segments.Last()); + if(secret.Value == null) + { + throw new Exception("Certificate secret not found in Key Vault"); + } + var certBytes = Convert.FromBase64String(secret.Value.Value); +#if NET9_0_OR_GREATER + var cert = X509CertificateLoader.LoadPkcs12(certBytes, "", X509KeyStorageFlags.Exportable); +#else + var cert = new X509Certificate2(certBytes, "", X509KeyStorageFlags.Exportable); +#endif + return cert; + } + + public bool ShouldRotateCerts() + { + var keyVaultThumbprints = new HashSet(); + foreach (var cert in KeyVaultCertificates) + { + keyVaultThumbprints.Add(cert.Thumbprint); + } + foreach(var cert in LocalCerts.Certificates) + { + if (!keyVaultThumbprints.Contains(cert.Thumbprint)) + { + return true; + } + } + return false; + } +} diff --git a/src/tools/CertHelper/LocalCert.cs b/src/tools/CertHelper/LocalCert.cs new file mode 100644 index 00000000000..3459e7509df --- /dev/null +++ b/src/tools/CertHelper/LocalCert.cs @@ -0,0 +1,44 @@ +using Azure.Security.KeyVault.Certificates; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Threading.Tasks; + +namespace CertHelper; + +public class LocalCert : ILocalCert +{ + public X509Certificate2Collection Certificates { get; set; } + internal IX509Store LocalMachineCerts { get; set; } + + public LocalCert(IX509Store? store = null) + { + LocalMachineCerts = store ?? new TestableX509Store(); + Certificates = new X509Certificate2Collection(); + GetLocalCerts(); + } + + private void GetLocalCerts() + { + foreach (var cert in LocalMachineCerts.Certificates.Find(X509FindType.FindBySubjectName, "dotnetperf.microsoft.com", false)) + { + if (cert.Subject == "CN=dotnetperf.microsoft.com") + { + Certificates.Add(cert); + } + } + + if (Certificates.Count < 2 || Certificates.Where(c => c == null).Count() > 0) + { + throw new Exception("One or more certificates not found"); + } + } +} + +public interface ILocalCert +{ + X509Certificate2Collection Certificates { get; set; } +} diff --git a/src/tools/CertHelper/Program.cs b/src/tools/CertHelper/Program.cs new file mode 100644 index 00000000000..71563a205e0 --- /dev/null +++ b/src/tools/CertHelper/Program.cs @@ -0,0 +1,67 @@ +using Azure.Identity; +using Azure.Storage.Blobs; +using Azure.Storage.Blobs.Specialized; +using System.Security.Cryptography.X509Certificates; +using System.Text; + +namespace CertHelper; + +internal class Program +{ + static readonly string TENANT_ID = "72f988bf-86f1-41af-91ab-2d7cd011db47"; + static readonly string CERT_CLIENT_ID = "8c4b65ef-5a73-4d5a-a298-962d4a4ef7bc"; + static async Task Main(string[] args) + { + try + { + var kvc = new KeyVaultCert(); + await kvc.LoadKeyVaultCertsAsync(); + if (kvc.ShouldRotateCerts()) + { + using (var localMachineCerts = new X509Store(StoreName.My, StoreLocation.CurrentUser)) + { + localMachineCerts.Open(OpenFlags.ReadWrite); + localMachineCerts.RemoveRange(kvc.LocalCerts.Certificates); + localMachineCerts.AddRange(kvc.KeyVaultCertificates); + } + } + var bcc = new BlobContainerClient(new Uri("https://pvscmdupload.blob.core.windows.net/certstatus"), + new ClientCertificateCredential(TENANT_ID, CERT_CLIENT_ID, kvc.KeyVaultCertificates.First())); + var currentKeyValutCertThumbprints = ""; + foreach(var cert in kvc.KeyVaultCertificates) + { + currentKeyValutCertThumbprints += $"[{DateTimeOffset.UtcNow}] {cert.Thumbprint}{Environment.NewLine}"; + } + var blob = bcc.GetBlobClient(System.Environment.MachineName); + if (blob.Exists()) + { + var result = blob.DownloadContent(); + var currentBlob = result.Value.Content.ToString(); + currentBlob = currentBlob + currentKeyValutCertThumbprints; + blob.Upload(new MemoryStream(Encoding.UTF8.GetBytes(currentBlob)), overwrite: true); + } + else + { + blob.Upload(new MemoryStream(Encoding.UTF8.GetBytes(currentKeyValutCertThumbprints)), overwrite: false); + } + + } + catch (Exception ex) + { + Console.WriteLine("Failed to rotate certificates"); + Console.WriteLine(ex.Message); + Console.WriteLine(ex.StackTrace); + } + + + + using (var store = new X509Store(StoreName.My, StoreLocation.CurrentUser, OpenFlags.ReadWrite)) + { + foreach(var cert in store.Certificates.Find(X509FindType.FindBySubjectName, "dotnetperf.microsoft.com", false)) + { + Console.WriteLine(Convert.ToBase64String(cert.Export(X509ContentType.Pfx))); + } + } + return 0; + } +} diff --git a/src/tools/CertHelperTests/CertRotatorTests.csproj b/src/tools/CertHelperTests/CertRotatorTests.csproj new file mode 100644 index 00000000000..0ea831f1d18 --- /dev/null +++ b/src/tools/CertHelperTests/CertRotatorTests.csproj @@ -0,0 +1,30 @@ + + + + net9.0 + enable + enable + + false + true + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + diff --git a/src/tools/CertHelperTests/KeyVaultCertTests.cs b/src/tools/CertHelperTests/KeyVaultCertTests.cs new file mode 100644 index 00000000000..7f9621352b4 --- /dev/null +++ b/src/tools/CertHelperTests/KeyVaultCertTests.cs @@ -0,0 +1,163 @@ +using System; +using System.Linq; +using System.Reflection; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Threading.Tasks; +using Azure; +using Azure.Core; +using Azure.Identity; +using Azure.Security.KeyVault.Certificates; +using Azure.Security.KeyVault.Secrets; +using Moq; +using Xunit; +using Xunit.Sdk; + +namespace CertHelper.Tests; + +public class KeyVaultCertTests +{ + [Fact] + public async Task LoadKeyVaultCertsAsync_ShouldAddCertificatesToCollection() + { + // Arrange + Mock mockTokenCred; + Mock mockCertClient; + Mock mockSecretClient; + Mock mockLocalCert; + CertStoreSetup(out mockTokenCred, out mockCertClient, out mockSecretClient, out mockLocalCert, false); + + var keyVaultCert = new KeyVaultCert(mockTokenCred.Object, mockCertClient.Object, mockSecretClient.Object, mockLocalCert.Object); + + // Act + await keyVaultCert.LoadKeyVaultCertsAsync(); + + // Assert + Assert.Equal(2, keyVaultCert.KeyVaultCertificates.Count); + } + + private static void CertStoreSetup(out Mock mockTokenCred, out Mock mockCertClient, out Mock mockSecretClient, out Mock mockLocalCert, bool missingKeyVaultCerts = false, bool localAndKeyVaultDifferent = false) + { + mockTokenCred = new Mock(); + mockTokenCred.Setup(tc => tc.GetTokenAsync(It.IsAny(), default)).ReturnsAsync(new AccessToken("token", DateTimeOffset.Now)); + + mockCertClient = new Mock(new Uri("https://dotnetperfkeyvault.vault.azure.net/"), mockTokenCred.Object); + mockSecretClient = new Mock(new Uri("https://dotnetperfkeyvault.vault.azure.net/"), mockTokenCred.Object); + KeyVaultCertificateWithPolicy? mockCert1, mockCert2; + X509Certificate2 cert1, cert2, cert3; + MakeCerts(out mockCert1, out mockCert2, out cert1, out cert2, out cert3, localAndKeyVaultDifferent); + + var certCollection = new X509Certificate2Collection { cert1, cert2 }; + + mockCertClient.Setup(c => c.GetCertificateAsync(Constants.Cert1Name, default)).ReturnsAsync(Response.FromValue(mockCert1, null!)); + if (missingKeyVaultCerts) + { + mockCertClient.Setup(c => c.GetCertificateAsync(Constants.Cert2Name, default)).ReturnsAsync(Response.FromValue(null!, null!)); + } + else + { + mockCertClient.Setup(c => c.GetCertificateAsync(Constants.Cert2Name, default)).ReturnsAsync(Response.FromValue(mockCert2, null!)); + } + + KeyVaultSecret secret1; + if(localAndKeyVaultDifferent) + { + secret1 = new KeyVaultSecret(Constants.Cert1Name, Convert.ToBase64String(cert3.Export(X509ContentType.Pfx))); + } + else + { + secret1 = new KeyVaultSecret(Constants.Cert1Name, Convert.ToBase64String(cert1.Export(X509ContentType.Pfx))); + } + var secret2 = new KeyVaultSecret(Constants.Cert2Name, Convert.ToBase64String(cert2.Export(X509ContentType.Pfx))); + + mockSecretClient.Setup(s => s.GetSecretAsync(Constants.Cert1Name, mockCert1!.SecretId.Segments.Last(), default)).ReturnsAsync(Response.FromValue(secret1, null!)); + mockSecretClient.Setup(s => s.GetSecretAsync(Constants.Cert2Name, mockCert2!.SecretId.Segments.Last(), default)).ReturnsAsync(Response.FromValue(secret2, null!)); + + mockLocalCert = new Mock(); + mockLocalCert.Setup(lc => lc.Certificates).Returns(certCollection); + } + + private static void MakeCerts(out KeyVaultCertificateWithPolicy? mockCert1, out KeyVaultCertificateWithPolicy? mockCert2, out X509Certificate2 cert1, out X509Certificate2 cert2, out X509Certificate2 cert3, bool localAndKeyVaultDifferent = false) + { + using var rsa1 = RSA.Create(); // generate asymmetric key pair + var req1 = new CertificateRequest("cn=perflabtest", rsa1, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + var tmpCert1 = req1.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5)); + using var rsa2 = RSA.Create(); // generate asymmetric key pair + var req2 = new CertificateRequest("cn=perflabtest", rsa2, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + var tmpCert2 = req2.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5)); + cert1 = new X509Certificate2(tmpCert1); + cert2 = new X509Certificate2(tmpCert2); + + if(localAndKeyVaultDifferent) + { + using var rsa3 = RSA.Create(); // generate asymmetric key pair + var req = new CertificateRequest("cn=perflabtest", rsa3, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + tmpCert1 = req.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5)); + } + cert3 = new X509Certificate2(tmpCert1); + + mockCert1 = CertificateModelFactory.KeyVaultCertificateWithPolicy(CertificateModelFactory.CertificateProperties(Constants.Cert1Id, Constants.Cert1Name, x509thumbprint: Convert.FromHexString(tmpCert1.Thumbprint)), + Constants.Cert1Id, Constants.Cert1Id, tmpCert1.GetRawCertData()); + mockCert2 = CertificateModelFactory.KeyVaultCertificateWithPolicy(CertificateModelFactory.CertificateProperties(Constants.Cert2Id, Constants.Cert2Name, x509thumbprint: Convert.FromHexString(tmpCert2.Thumbprint)), + Constants.Cert2Id, Constants.Cert2Id, tmpCert2.GetRawCertData()); + } + + [Fact] + public async Task LoadKeyVaultCertsAsync_ShouldThrowException_WhenCertificatesNotFound() + { + // Arrange + Mock mockTokenCred; + Mock mockCertClient; + Mock mockSecretClient; + Mock mockLocalCert; + CertStoreSetup(out mockTokenCred, out mockCertClient, out mockSecretClient, out mockLocalCert, missingKeyVaultCerts: true); + + var keyVaultCert = new KeyVaultCert(mockTokenCred.Object, mockCertClient.Object, mockSecretClient.Object, mockLocalCert.Object); + + // Act & Assert + await Assert.ThrowsAsync(() => keyVaultCert.LoadKeyVaultCertsAsync()); + } + + [Fact] + public async Task ShouldRotateCerts_ShouldReturnTrue_WhenThumbprintsDoNotMatch() + { + // Arrange + Mock mockTokenCred; + Mock mockCertClient; + Mock mockSecretClient; + Mock mockLocalCert; + CertStoreSetup(out mockTokenCred, out mockCertClient, out mockSecretClient, out mockLocalCert, localAndKeyVaultDifferent: true); + + var keyVaultCert = new KeyVaultCert(mockTokenCred.Object, mockCertClient.Object, mockSecretClient.Object, mockLocalCert.Object); + + // Act + await keyVaultCert.LoadKeyVaultCertsAsync(); + + var result = keyVaultCert.ShouldRotateCerts(); + + // Assert + Assert.True(result); + } + + [Fact] + public async Task ShouldRotateCerts_ShouldReturnFalse_WhenThumbprintsMatch() + { + // Arrange + Mock mockTokenCred; + Mock mockCertClient; + Mock mockSecretClient; + Mock mockLocalCert; + CertStoreSetup(out mockTokenCred, out mockCertClient, out mockSecretClient, out mockLocalCert); + + var keyVaultCert = new KeyVaultCert(mockTokenCred.Object, mockCertClient.Object, mockSecretClient.Object, mockLocalCert.Object); + + // Act + await keyVaultCert.LoadKeyVaultCertsAsync(); + + var result = keyVaultCert.ShouldRotateCerts(); + + // Assert + Assert.False(result); + } +} + diff --git a/src/tools/CertHelperTests/LocalCertTests.cs b/src/tools/CertHelperTests/LocalCertTests.cs new file mode 100644 index 00000000000..fcc49ad596e --- /dev/null +++ b/src/tools/CertHelperTests/LocalCertTests.cs @@ -0,0 +1,112 @@ +using System.Security.Cryptography.X509Certificates; +using Xunit; +using Moq; +using Microsoft.VisualBasic; +using System.Security.Cryptography; + +namespace CertHelper.Tests; + +public class LocalCertTests +{ + [Fact] + public void GetLocalCerts_ShouldAddCertificatesToCollection() + { + // Arrange + var mockStore = new Mock(); + var ecdsa1 = ECDsa.Create(); // generate asymmetric key pair + var req1 = new CertificateRequest("CN=dotnetperf.microsoft.com", ecdsa1, HashAlgorithmName.SHA256); + var cert1 = req1.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5)); + var ecdsa2 = ECDsa.Create(); // generate asymmetric key pair + var req2 = new CertificateRequest("CN=dotnetperf.microsoft.com", ecdsa2, HashAlgorithmName.SHA256); + var cert2 = req2.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5)); + + var mockCert1 = cert1; + var mockCert2 = cert2; + + var certCollection = new X509Certificate2Collection { mockCert1, mockCert2 }; + mockStore.Setup(s => s.Certificates).Returns(certCollection); + + + + // Act + var localCert = new LocalCert(mockStore.Object); + + // Assert + Assert.Equal(2, localCert.Certificates.Count); + Assert.Contains(mockCert1, localCert.Certificates); + Assert.Contains(mockCert2, localCert.Certificates); + } + + [Fact] + public void GetLocalCerts_ShouldAddCertificatesToCollection_WhenSubjectMatches() + { + // Arrange + var mockStore = new Mock(); + var ecdsa1 = ECDsa.Create(); // generate asymmetric key pair + var req1 = new CertificateRequest("CN=dotnetperf.microsoft.com", ecdsa1, HashAlgorithmName.SHA256); + var cert1 = req1.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5)); + var ecdsa2 = ECDsa.Create(); // generate asymmetric key pair + var req2 = new CertificateRequest("CN=dotnetperf.microsoft.com", ecdsa2, HashAlgorithmName.SHA256); + var cert2 = req2.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5)); + + var mockCert1 = cert1; + var mockCert2 = cert2; + + var certCollection = new X509Certificate2Collection { mockCert1, mockCert2 }; + mockStore.Setup(s => s.Certificates).Returns(certCollection); + + // Act + var localCert = new LocalCert(mockStore.Object); + + // Assert + Assert.Equal(2, localCert.Certificates.Count); + Assert.Contains(mockCert1, localCert.Certificates); + Assert.Contains(mockCert2, localCert.Certificates); + } + + [Fact] + public void GetLocalCerts_ShouldThrowException_WhenOneCertificateFound() + { + // Arrange + var mockStore = new Mock(); + var ecdsa1 = ECDsa.Create(); // generate asymmetric key pair + var req1 = new CertificateRequest("CN=dotnetperf.microsoft.com", ecdsa1, HashAlgorithmName.SHA256); + var cert1 = req1.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5)); + var certCollection = new X509Certificate2Collection { cert1 }; + mockStore.Setup(s => s.Certificates).Returns(certCollection); + + // Act & Assert + Assert.Throws(() => new LocalCert(mockStore.Object)); + } + + [Fact] + public void GetLocalCerts_ShouldThrowException_WhenCertificatesHaveWrongSubject() + { + // Arrange + var mockStore = new Mock(); + var ecdsa1 = ECDsa.Create(); // generate asymmetric key pair + var req1 = new CertificateRequest("CN=dotnetperf.microsoft.co", ecdsa1, HashAlgorithmName.SHA256); + var cert1 = req1.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5)); + var ecdsa2 = ECDsa.Create(); // generate asymmetric key pair + var req2 = new CertificateRequest("CN=dotnetperf.microsoft.co", ecdsa1, HashAlgorithmName.SHA256); + var cert2 = req1.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5)); + + var certCollection = new X509Certificate2Collection { cert1, cert2 }; + mockStore.Setup(s => s.Certificates).Returns(certCollection); + + // Act & Assert + Assert.Throws(() => new LocalCert(mockStore.Object)); + } + + [Fact] + public void GetLocalCerts_ShouldThrowException_WhenCertificatesNotFound() + { + // Arrange + var mockStore = new Mock(); + var certCollection = new X509Certificate2Collection(); + mockStore.Setup(s => s.Certificates).Returns(certCollection); + + // Act & Assert + Assert.Throws(() => new LocalCert(mockStore.Object)); + } +}