diff --git a/samples/cs/Directory.Packages.props b/samples/cs/Directory.Packages.props index 77b68c4cc..1132933a1 100644 --- a/samples/cs/Directory.Packages.props +++ b/samples/cs/Directory.Packages.props @@ -4,8 +4,8 @@ true - - + + diff --git a/samples/cs/README.md b/samples/cs/README.md index ad10a3c65..9207fe400 100644 --- a/samples/cs/README.md +++ b/samples/cs/README.md @@ -18,6 +18,7 @@ Both packages provide the same APIs, so the same source code works on all platfo | [tool-calling-foundry-local-sdk](tool-calling-foundry-local-sdk/) | Use tool calling with native chat completions. | | [tool-calling-foundry-local-web-server](tool-calling-foundry-local-web-server/) | Use tool calling with the local web server. | | [model-management-example](model-management-example/) | Manage models, variant selection, and updates. | +| [private-catalog](private-catalog/) | Register a private MDS-backed catalog with `AddCatalogAsync`, list public + private models, and chat with one. | | [tutorial-chat-assistant](tutorial-chat-assistant/) | Build an interactive chat assistant (tutorial). | | [tutorial-document-summarizer](tutorial-document-summarizer/) | Summarize documents with AI (tutorial). | | [tutorial-tool-calling](tutorial-tool-calling/) | Create a tool-calling assistant (tutorial). | diff --git a/samples/cs/nuget.config b/samples/cs/nuget.config index 63954b2fb..294478a7c 100644 --- a/samples/cs/nuget.config +++ b/samples/cs/nuget.config @@ -3,17 +3,6 @@ - - + - - - - - - - - - - \ No newline at end of file + diff --git a/samples/cs/private-catalog/PrivateCatalog.csproj b/samples/cs/private-catalog/PrivateCatalog.csproj new file mode 100644 index 000000000..27be80d6f --- /dev/null +++ b/samples/cs/private-catalog/PrivateCatalog.csproj @@ -0,0 +1,62 @@ + + + + Exe + enable + enable + + + $(NoWarn);NU1604;NU1701 + + + + + net9.0-windows10.0.26100 + false + ARM64;x64 + None + false + + + + + net9.0 + + + + $(NETCoreSdkRuntimeIdentifier) + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/samples/cs/private-catalog/PrivateCatalog.sln b/samples/cs/private-catalog/PrivateCatalog.sln new file mode 100644 index 000000000..6d66e4fa5 --- /dev/null +++ b/samples/cs/private-catalog/PrivateCatalog.sln @@ -0,0 +1,34 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.0.31903.59 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "PrivateCatalog", "PrivateCatalog.csproj", "{B1C23D45-6789-4ABC-DEF0-123456789ABC}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Debug|x64 = Debug|x64 + Debug|x86 = Debug|x86 + Release|Any CPU = Release|Any CPU + Release|x64 = Release|x64 + Release|x86 = Release|x86 + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {B1C23D45-6789-4ABC-DEF0-123456789ABC}.Debug|Any CPU.ActiveCfg = Debug|ARM64 + {B1C23D45-6789-4ABC-DEF0-123456789ABC}.Debug|Any CPU.Build.0 = Debug|ARM64 + {B1C23D45-6789-4ABC-DEF0-123456789ABC}.Debug|x64.ActiveCfg = Debug|x64 + {B1C23D45-6789-4ABC-DEF0-123456789ABC}.Debug|x64.Build.0 = Debug|x64 + {B1C23D45-6789-4ABC-DEF0-123456789ABC}.Debug|x86.ActiveCfg = Debug|ARM64 + {B1C23D45-6789-4ABC-DEF0-123456789ABC}.Debug|x86.Build.0 = Debug|ARM64 + {B1C23D45-6789-4ABC-DEF0-123456789ABC}.Release|Any CPU.ActiveCfg = Release|ARM64 + {B1C23D45-6789-4ABC-DEF0-123456789ABC}.Release|Any CPU.Build.0 = Release|ARM64 + {B1C23D45-6789-4ABC-DEF0-123456789ABC}.Release|x64.ActiveCfg = Release|x64 + {B1C23D45-6789-4ABC-DEF0-123456789ABC}.Release|x64.Build.0 = Release|x64 + {B1C23D45-6789-4ABC-DEF0-123456789ABC}.Release|x86.ActiveCfg = Release|ARM64 + {B1C23D45-6789-4ABC-DEF0-123456789ABC}.Release|x86.Build.0 = Release|ARM64 + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection +EndGlobal diff --git a/samples/cs/private-catalog/Program.cs b/samples/cs/private-catalog/Program.cs new file mode 100644 index 000000000..b7af32955 --- /dev/null +++ b/samples/cs/private-catalog/Program.cs @@ -0,0 +1,274 @@ +using Microsoft.AI.Foundry.Local; +using Betalgo.Ranul.OpenAI.ObjectModels.RequestModels; +using System.Security.Cryptography; +using System.Text; +using System.Text.Json; + +// --------------------------------------------------------------------------- +// Private Catalog sample — register a customer MDS catalog with a self-signed +// JWT, list public + private models, and run a streaming chat completion. +// +// Required: +// --customer Customer name (or env MDS_CUSTOMER) +// --key-dir Directory with -key.pem (or env MDS_KEY_DIR) +// +// Optional: +// --model Alias or variant id (otherwise interactive picker) +// --prompt Prompt (default "Why is the sky blue?") +// --list List models and exit +// --no-private Skip private-catalog registration (public only) +// --show-uri Print variant URIs alongside the listing +// --------------------------------------------------------------------------- + +string? mdsCustomer = Environment.GetEnvironmentVariable("MDS_CUSTOMER"); +string? mdsKeyDir = Environment.GetEnvironmentVariable("MDS_KEY_DIR"); +string? cliModel = null; +string cliPrompt = "Why is the sky blue?"; +bool listOnly = false; +bool noPrivate = false; +bool showUri = false; + +for (int i = 0; i < args.Length; i++) +{ + string Next() + { + if (i + 1 >= args.Length) + { + Console.WriteLine($"Error: {args[i]} requires a value."); + Environment.Exit(1); + } + return args[++i]; + } + + switch (args[i]) + { + case "-c": case "--customer": mdsCustomer = Next(); break; + case "--key-dir": mdsKeyDir = Next(); break; + case "-m": case "--model": cliModel = Next(); break; + case "-p": case "--prompt": cliPrompt = Next(); break; + case "-l": case "--list": listOnly = true; break; + case "--no-private": noPrivate = true; break; + case "--show-uri": showUri = true; break; + case "-h": case "--help": PrintUsage(); return; + default: + Console.WriteLine($"Unknown argument: {args[i]}"); + PrintUsage(); + return; + } +} + +if (string.IsNullOrWhiteSpace(mdsCustomer)) +{ + Console.WriteLine("Error: --customer (or MDS_CUSTOMER) is required."); + PrintUsage(); + return; +} +if (string.IsNullOrWhiteSpace(mdsKeyDir)) +{ + Console.WriteLine("Error: --key-dir (or MDS_KEY_DIR) is required."); + PrintUsage(); + return; +} + +// --- Derive customer resources (same convention as mds/scripts/onboard.py) --- +var customerLower = mdsCustomer.ToLowerInvariant(); +var safeName = customerLower.Replace(" ", "").Replace("-", ""); +var registryName = $"mds-{customerLower}-registry"; +var issuer = $"https://mds{safeName}jwks.blob.core.windows.net/jwks"; +var kid = $"mds-{customerLower}-key-1"; +var keyPath = Path.Combine(mdsKeyDir, $"{customerLower}-key.pem"); + +if (!File.Exists(keyPath)) +{ + Console.WriteLine($"Error: Private key not found at {keyPath}"); + Console.WriteLine("Run: python mds/scripts/onboard.py --customer --test-keys"); + return; +} + +var jwt = SignJwt(keyPath, kid, issuer, registryName); +Console.WriteLine($"Signed JWT for '{mdsCustomer}' (registry={registryName})"); + +// --- Init Foundry Local --- +await FoundryLocalManager.CreateAsync( + new Configuration + { + AppName = "private_catalog_sample", + LogLevel = Microsoft.AI.Foundry.Local.LogLevel.Information, + }, + Utils.GetAppLogger()); +var mgr = FoundryLocalManager.Instance; + +// --- Register private catalog (falls back to public-only on failure) --- +var catalog = await mgr.GetCatalogAsync(); +bool privateRegistered = false; + +if (noPrivate) +{ + Console.WriteLine("\n[--no-private] Skipping AddCatalogAsync."); +} +else +{ + Console.WriteLine("\nRegistering private catalog..."); + try + { + await catalog.AddCatalogAsync("private", new PrivateCatalogOptions + { + BearerToken = jwt, + Audience = "model-distribution-service", + }); + privateRegistered = true; + Console.WriteLine("Private catalog registered."); + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + Console.WriteLine($"Warning: could not register private catalog ({ex.Message})."); + Console.WriteLine("Continuing with public catalog only."); + } +} + +// --- List models grouped by origin --- +var allModels = await catalog.ListModelsAsync(); +var allVariants = allModels.SelectMany(m => m.Variants).ToList(); + +bool IsPrivate(IModel v) => + v.Info.Uri?.Contains(registryName, StringComparison.OrdinalIgnoreCase) == true; + +var publicVariants = allVariants.Where(v => !IsPrivate(v)).ToList(); +var privateVariants = allVariants.Where(IsPrivate).ToList(); +allVariants = publicVariants.Concat(privateVariants).ToList(); + +int idx = 0; +Console.WriteLine($"\n=== Public Models ({publicVariants.Count}) ==="); +foreach (var v in publicVariants) +{ + Console.WriteLine($" [{++idx}] {v.Alias} ({v.Id})"); + if (showUri) Console.WriteLine($" uri: {v.Info.Uri}"); +} +if (privateRegistered) +{ + Console.WriteLine($"\n=== Private Models ({privateVariants.Count}) ==="); + if (privateVariants.Count == 0) Console.WriteLine(" (none)"); + foreach (var v in privateVariants) + { + Console.WriteLine($" [{++idx}] {v.Alias} ({v.Id})"); + if (showUri) Console.WriteLine($" uri: {v.Info.Uri}"); + } +} + +if (listOnly) return; + +// --- Resolve a model (CLI or interactive) --- +string? input = cliModel; +if (string.IsNullOrWhiteSpace(input)) +{ + Console.Write("\nEnter model number, alias, or variant id (q to quit): "); + input = Console.ReadLine()?.Trim(); + if (string.IsNullOrEmpty(input) || input.Equals("q", StringComparison.OrdinalIgnoreCase)) return; + if (int.TryParse(input, out int n) && n >= 1 && n <= allVariants.Count) + input = allVariants[n - 1].Id; +} + +var model = await ResolveModel(catalog, allVariants, input!); +if (model == null) +{ + Console.WriteLine($"\nModel '{input}' not found."); + return; +} +Console.WriteLine($"\nSelected: {model.Id}"); + +// --- Download / load / chat --- +await model.DownloadAsync(p => +{ + Console.Write($"\rDownloading: {p:F1}%"); + if (p >= 100f) Console.WriteLine(); +}); + +Console.Write($"Loading {model.Id}..."); +await model.LoadAsync(); +Console.WriteLine(" done."); + +var chat = await model.GetChatClientAsync(); +var messages = new List { new() { Role = "user", Content = cliPrompt } }; + +Console.WriteLine("Chat completion:"); +await foreach (var chunk in chat.CompleteChatStreamingAsync(messages, CancellationToken.None)) +{ + Console.Write(chunk.Choices[0].Message.Content); + Console.Out.Flush(); +} +Console.WriteLine(); + +await model.UnloadAsync(); + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +static void PrintUsage() +{ + Console.WriteLine("Usage: PrivateCatalog --customer --key-dir [options]"); + Console.WriteLine(" -c, --customer Customer name (env MDS_CUSTOMER)"); + Console.WriteLine(" --key-dir Dir with -key.pem (env MDS_KEY_DIR)"); + Console.WriteLine(" --host MDS host (env MDS_HOST)"); + Console.WriteLine(" -m, --model Model alias or variant id"); + Console.WriteLine(" -p, --prompt Prompt (default \"Why is the sky blue?\")"); + Console.WriteLine(" -l, --list List models and exit"); + Console.WriteLine(" --no-private Skip private-catalog registration"); + Console.WriteLine(" --show-uri Print variant URIs in the listing"); +} + +static async Task ResolveModel( + ICatalog catalog, List allVariants, string input) +{ + // Exact variant id + var model = await catalog.GetModelVariantAsync(input); + if (model != null) return model; + + // Alias (prefer generic-cpu) + var resolved = await catalog.GetModelAsync(input); + if (resolved != null) + { + var pick = resolved.Variants.FirstOrDefault(v => + v.Id.Contains("generic-cpu", StringComparison.OrdinalIgnoreCase)) + ?? resolved.Variants[0]; + return await catalog.GetModelVariantAsync(pick.Id); + } + + // Substring match + var match = allVariants.FirstOrDefault(v => + v.Id.Contains(input, StringComparison.OrdinalIgnoreCase) || + v.Alias.Contains(input, StringComparison.OrdinalIgnoreCase)); + return match != null ? await catalog.GetModelVariantAsync(match.Id) : null; +} + +static string SignJwt(string pemPath, string kid, string issuer, string registryName) +{ + using var rsa = RSA.Create(); + rsa.ImportFromPem(File.ReadAllText(pemPath)); + + var now = DateTimeOffset.UtcNow; + var header = JsonSerializer.Serialize(new { alg = "RS256", typ = "JWT", kid }); + var payload = JsonSerializer.Serialize(new Dictionary + { + ["iss"] = issuer, + ["sub"] = "foundry-local-sample", + ["aud"] = "model-distribution-service", + ["iat"] = now.ToUnixTimeSeconds(), + ["exp"] = now.AddHours(1).ToUnixTimeSeconds(), + ["registry_name"] = registryName, + ["entitlements"] = new Dictionary + { + ["models"] = new[] { "*" }, + ["versions"] = new[] { "*" }, + }, + }); + + var h = B64Url(Encoding.UTF8.GetBytes(header)); + var p = B64Url(Encoding.UTF8.GetBytes(payload)); + var sig = rsa.SignData(Encoding.UTF8.GetBytes($"{h}.{p}"), + HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + return $"{h}.{p}.{B64Url(sig)}"; +} + +static string B64Url(byte[] data) => + Convert.ToBase64String(data).TrimEnd('=').Replace('+', '-').Replace('/', '_'); diff --git a/samples/cs/private-catalog/README.md b/samples/cs/private-catalog/README.md new file mode 100644 index 000000000..9c1c31330 --- /dev/null +++ b/samples/cs/private-catalog/README.md @@ -0,0 +1,76 @@ +# Private Catalog (C#) + +End-to-end sample: register a customer MDS catalog with Foundry Local using a +self-signed RS256 JWT, list public + private models, download one, and run a +streaming chat completion. + +## Prerequisites + +- .NET 9 SDK +- Windows x64 +- A customer provisioned in MDS (registry + JWKS). Run + `python scripts/onboard.py --customer --subscription --test-keys` + from the [MDS repo](https://github.com/coreai-microsoft/MDS) — this creates + the resources and writes `-key.pem` into `MDS/scripts/`. + The matching JWKS must already be published at + `https://mdsjwks.blob.core.windows.net/jwks/.well-known/jwks.json`. +- At least one model uploaded for that customer + (`python scripts/upload_model.py --customer --name --path ...`). +- A running Foundry Local (`neutron`) that supports `AddCatalogAsync`. + If it doesn't, the sample falls back to the public catalog only. + +## Configure + +Edit [appsettings.json](appsettings.json) with your own values: + +```json +{ + "MdsHost": "https://mds-web-app.azurewebsites.net", + "MdsCustomer": "", + "MdsKeyDir": "C:/path/to/MDS/scripts" +} +``` + +- `MdsHost` — MDS endpoint. Prod is + `https://mds-web-app.azurewebsites.net`; use + `https://mds-web-app-staging.azurewebsites.net` for staging. +- `MdsCustomer` — the same name you passed to `onboard.py`. Used to derive + the registry (`mds--registry`), JWKS URL, and key file name. +- `MdsKeyDir` — folder containing `-key.pem` (typically + `MDS/scripts/`). + +## Run + +From `samples/cs/private-catalog`: + +```powershell +dotnet run +``` + +## What it does + +1. Loads `appsettings.json` and derives the customer's registry, issuer, and + key path. +2. Signs an RS256 JWT with claims: + `iss`, `sub`, `aud=model-distribution-service`, `iat`, `exp`, + `registry_name`, `entitlements={models:["*"], versions:["*"]}`. +3. Initializes Foundry Local (CPU execution provider only — no + `DownloadAndRegisterEpsAsync` call, so the sample doesn't pull GPU/NPU EPs). +4. Calls `catalog.AddCatalogAsync("private", mdsHost, { BearerToken, Audience })`. + If it fails (e.g. older neutron without this API), falls back to public-only. +5. Lists all models, partitioned by `Uri`: + - **Public**: built-in Azure ML registry + - **Private**: `azureml://registries/mds--registry/...` +6. Prompts you to pick one, downloads it, loads it, and streams a chat + completion. + +## Troubleshooting + +| Symptom | Likely cause | Fix | +|---|---|---| +| `Private key not found at ...` | `MdsKeyDir` or customer name wrong | Check [appsettings.json](appsettings.json); ensure `-key.pem` exists | +| `Warning: could not register private catalog (Unknown command)` | Neutron build predates `AddCatalogAsync` | Use a newer neutron; sample continues with public-only | +| `401 Invalid token issuer` | JWKS not yet published, or wrong issuer URL | Verify `https://mdsjwks.blob.core.windows.net/jwks/.well-known/jwks.json` returns your key | +| Private model appears in **Public** section | Model's registry Uri is `local://...` | Re-upload with `mds/scripts/upload_model.py` so registry stores proper blob info | +| `Failed to download model` | Same as above, or SAS generation error | Check MDS logs; confirm `blob_prefix` tag on the registry entry | +| Build fails with `MSB3027` / file locked on `*.dll` | A previous `PrivateCatalog` process is still running | Close it (or `Stop-Process -Name PrivateCatalog`) and re-run `dotnet run` | diff --git a/sdk/cs/src/Catalog.cs b/sdk/cs/src/Catalog.cs index f33dcaff5..954317d28 100644 --- a/sdk/cs/src/Catalog.cs +++ b/sdk/cs/src/Catalog.cs @@ -58,6 +58,20 @@ public async Task> ListModelsAsync(CancellationToken? ct = null) "Error listing models.", _logger).ConfigureAwait(false); } + public async Task> ListModelsAsync(string catalogRegistryName, CancellationToken? ct = null) + { + if (string.IsNullOrWhiteSpace(catalogRegistryName)) + { + throw new ArgumentException("Catalog registry name must be a non-empty string.", nameof(catalogRegistryName)); + } + var all = await ListModelsAsync(ct).ConfigureAwait(false); + // Match Model itself or any variant: ModelInfo merges variants per alias, + // so the variant check is needed for public+private alias collisions. + return all.Where(m => m.Info.IsFromCatalogRegistry(catalogRegistryName) || + m.Variants.Any(v => v.Info.IsFromCatalogRegistry(catalogRegistryName))) + .ToList(); + } + public async Task> GetCachedModelsAsync(CancellationToken? ct = null) { return await Utils.CallWithExceptionHandling(() => GetCachedModelsImplAsync(ct), @@ -77,6 +91,30 @@ public async Task> GetLoadedModelsAsync(CancellationToken? ct = nul .ConfigureAwait(false); } + public async Task GetModelAsync(string modelAlias, + string? preferCatalogRegistryName, + CancellationToken? ct = null) + { + if (string.IsNullOrWhiteSpace(preferCatalogRegistryName)) + { + return await GetModelAsync(modelAlias, ct).ConfigureAwait(false); + } + + var model = await GetModelAsync(modelAlias, ct).ConfigureAwait(false); + if (model == null || model.Info.IsFromCatalogRegistry(preferCatalogRegistryName)) + { + return model; + } + + // Prefer a variant from the named registry; pin it via GetModelVariantAsync + // so callers get the single-variant IModel contract. Fall back to the + // unfiltered model so the alias still resolves — preference is best-effort. + var preferred = model.Variants.FirstOrDefault(v => v.Info.IsFromCatalogRegistry(preferCatalogRegistryName)); + return preferred != null + ? await GetModelVariantAsync(preferred.Id, ct).ConfigureAwait(false) ?? model + : model; + } + public async Task GetModelVariantAsync(string modelId, CancellationToken? ct = null) { return await Utils.CallWithExceptionHandling(() => GetModelVariantImplAsync(modelId, ct), @@ -245,8 +283,190 @@ internal void InvalidateCache() _lastFetch = DateTime.MinValue; } + // Year 3000 ≈ 32503680000 Unix seconds. A larger 'exp'/'iat' almost + // certainly means the caller passed ToUnixTimeMilliseconds() by mistake; + // MDS would later reject it with an opaque 401. + private static void RejectMillisecondJwt(string bearer) + { + var t = bearer.Trim(); + if (t.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase)) { t = t.Substring(7).Trim(); } + var parts = t.Split('.'); + if (parts.Length < 2) { return; } + + string b64 = parts[1].Replace('-', '+').Replace('_', '/'); + b64 += new string('=', (4 - b64.Length % 4) % 4); + byte[] payload; + try { payload = Convert.FromBase64String(b64); } catch (FormatException) { return; } + + try + { + using var doc = JsonDocument.Parse(payload); + foreach (var claim in new[] { "exp", "iat" }) + { + if (doc.RootElement.TryGetProperty(claim, out var v) && + v.ValueKind == JsonValueKind.Number && + v.TryGetInt64(out var n) && n > 32503680000L) + { + throw new ArgumentException( + $"JWT '{claim}' claim ({n}) looks like milliseconds since epoch. " + + "Use DateTimeOffset.UtcNow.ToUnixTimeSeconds(), not ToUnixTimeMilliseconds().", + "options"); + } + } + } + catch (JsonException) { } + } + + private FoundryLocalException ClassifyCatalogError(string cmd, string err, CancellationToken? ct, string context) + { + var p = err.ToLowerInvariant(); + string? reason = + p.Contains("expired") ? "ExpiredToken" : + p.Contains("audience") || p.Contains("\"aud\"") ? "InvalidAudience" : + p.Contains("registry_name") ? "MissingRegistryName" : + p.Contains("signature") ? "InvalidSignature" : + p.Contains("401") || p.Contains("unauthorized") ? "Unauthorized" : + p.Contains("403") || p.Contains("forbidden") ? "Forbidden" : + null; + if (reason != null) + { + _logger.LogError("Catalog auth failure ({Reason}): {Err}", reason, err); + return new CatalogAuthException($"{context}: {err} (reason: {reason})", reason); + } + return Utils.FromNativeError(cmd, err, ct, _logger, context: context); + } + public void Dispose() { _lock.Dispose(); } + + public Task AddCatalogAsync(string name, + Dictionary? options = null, + CancellationToken? ct = null) + => AddOrUpdateCatalogAsync(name, options, ct); + + public Task AddCatalogAsync(string name, PrivateCatalogOptions options, + CancellationToken? ct = null) + { + if (options is null) { throw new ArgumentNullException(nameof(options)); } + var d = new Dictionary(); + if (!string.IsNullOrEmpty(options.BearerToken)) { d["BearerToken"] = options.BearerToken!; } + if (!string.IsNullOrEmpty(options.Audience)) { d["Audience"] = options.Audience!; } + return AddOrUpdateCatalogAsync(name, d, ct); + } + + public async Task AddOrUpdateCatalogAsync(string name, + Dictionary? options = null, + CancellationToken? ct = null) + { +#if NET7_0_OR_GREATER + ArgumentException.ThrowIfNullOrWhiteSpace(name); +#else + if (string.IsNullOrWhiteSpace(name)) + { + throw new ArgumentException("Catalog name must be a non-empty, non-whitespace string.", nameof(name)); + } +#endif + + if (options != null && options.TryGetValue("TokenEndpoint", out var tokenEndpoint) && tokenEndpoint != null) + { + if (!Uri.TryCreate(tokenEndpoint, UriKind.Absolute, out var parsedEndpoint)) + { + throw new ArgumentException($"Token endpoint is not a valid URL: '{tokenEndpoint}'."); + } + if (parsedEndpoint.Scheme != "https" && parsedEndpoint.Scheme != "http") + { + throw new ArgumentException($"Token endpoint must use http or https scheme, got '{parsedEndpoint.Scheme}'."); + } + } + + // Fail fast on the common 'exp/iat in milliseconds' JWT mistake. + if (options != null && options.TryGetValue("BearerToken", out var bearer) && !string.IsNullOrEmpty(bearer)) + { + RejectMillisecondJwt(bearer!); + } + + await Utils.CallWithExceptionHandling(async () => + { + // Caller-supplied options first; Name/Type overlaid so they can't + // be silently overridden. Default Type to AzurePrivate; honour an + // explicit "Type" in options. + var p = new Dictionary(options ?? new Dictionary()) + { + ["Name"] = name, + }; + if (!p.TryGetValue("Type", out var typeValue) || string.IsNullOrEmpty(typeValue)) + { + p["Type"] = "AzurePrivate"; + } + + // Idempotent: if a catalog with this name already exists, remove it + // first so re-registration (e.g. token refresh) is a no-op for the + // happy path. Errors from remove are swallowed at debug \u2014 add_catalog + // will surface the real problem. + try + { + var rm = await _coreInterop.ExecuteCommandAsync( + "remove_catalog", + new CoreInteropRequest { Params = new Dictionary { ["Name"] = name } }, + ct).ConfigureAwait(false); + if (rm.Error != null) + { + _logger.LogDebug("remove_catalog('{Name}') before add returned: {Err}", name, rm.Error); + } + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + _logger.LogDebug(ex, "remove_catalog('{Name}') before add threw (ignored).", name); + } + + var add = await _coreInterop.ExecuteCommandAsync( + "add_catalog", new CoreInteropRequest { Params = p }, ct).ConfigureAwait(false); + if (add.Error != null) + { + throw ClassifyCatalogError("add_catalog", add.Error, ct, $"Error adding catalog '{name}'"); + } + + InvalidateCache(); + await UpdateModels(ct).ConfigureAwait(false); + }, $"Error adding catalog '{name}'.", _logger).ConfigureAwait(false); + } + + public async Task RemoveCatalogAsync(string name, CancellationToken? ct = null) + { + if (string.IsNullOrWhiteSpace(name)) + { + throw new ArgumentException("Catalog name must be a non-empty, non-whitespace string.", nameof(name)); + } + + await Utils.CallWithExceptionHandling(async () => + { + var request = new CoreInteropRequest { Params = new Dictionary { ["Name"] = name } }; + var result = await _coreInterop.ExecuteCommandAsync("remove_catalog", request, ct).ConfigureAwait(false); + if (result.Error != null) + { + throw Utils.FromNativeError("remove_catalog", result.Error, ct, _logger, + context: $"Error removing catalog '{name}'"); + } + InvalidateCache(); + await UpdateModels(ct).ConfigureAwait(false); + }, $"Error removing catalog '{name}'.", _logger).ConfigureAwait(false); + } + + public async Task> GetCatalogNamesAsync(CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(async () => + { + CoreInteropRequest? input = null; + var result = await _coreInterop.ExecuteCommandAsync("get_catalog_names", input, ct) + .ConfigureAwait(false); + if (result.Error != null) + { + throw new FoundryLocalException($"Error getting catalog names: {result.Error}", _logger); + } + + return JsonSerializer.Deserialize(result.Data ?? "[]", JsonSerializationContext.Default.ListString) ?? []; + }, "Error getting catalog names.", _logger).ConfigureAwait(false); + } } diff --git a/sdk/cs/src/Detail/CoreInterop.cs b/sdk/cs/src/Detail/CoreInterop.cs index 7239a48e4..034bcb53c 100644 --- a/sdk/cs/src/Detail/CoreInterop.cs +++ b/sdk/cs/src/Detail/CoreInterop.cs @@ -243,7 +243,7 @@ public Response ExecuteCommandImpl(string commandName, string? commandInput, } catch (Exception ex) when (ex is not OperationCanceledException) { - var msg = $"Error executing command '{commandName}' with input {commandInput ?? "null"}"; + var msg = $"Error executing command '{commandName}'"; throw new FoundryLocalException(msg, ex, _logger); } } diff --git a/sdk/cs/src/Detail/JsonSerializationContext.cs b/sdk/cs/src/Detail/JsonSerializationContext.cs index 0fe5e6770..d0180d5a5 100644 --- a/sdk/cs/src/Detail/JsonSerializationContext.cs +++ b/sdk/cs/src/Detail/JsonSerializationContext.cs @@ -1,4 +1,4 @@ -// -------------------------------------------------------------------------------------------------------------------- +// -------------------------------------------------------------------------------------------------------------------- // // Copyright (c) Microsoft. All rights reserved. // @@ -41,6 +41,7 @@ namespace Microsoft.AI.Foundry.Local.Detail; // which has AOT-incompatible JsonConverters, so we only register the raw deserialization type) --- [JsonSerializable(typeof(LiveAudioTranscriptionRaw))] [JsonSerializable(typeof(CoreErrorResponse))] +[JsonSerializable(typeof(List))] // catalog names [JsonSourceGenerationOptions(DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, WriteIndented = false)] internal partial class JsonSerializationContext : JsonSerializerContext diff --git a/sdk/cs/src/Detail/ModelLoadManager.cs b/sdk/cs/src/Detail/ModelLoadManager.cs index 76b48539a..e323bc96c 100644 --- a/sdk/cs/src/Detail/ModelLoadManager.cs +++ b/sdk/cs/src/Detail/ModelLoadManager.cs @@ -53,7 +53,8 @@ public async Task LoadAsync(string modelId, CancellationToken? ct = null) var result = await _coreInterop.ExecuteCommandAsync("load_model", request, ct).ConfigureAwait(false); if (result.Error != null) { - throw new FoundryLocalException($"Error loading model {modelId}: {result.Error}"); + throw Utils.FromNativeError("load_model", result.Error, ct, _logger, + context: $"Error loading model {modelId}"); } // currently just a 'model loaded successfully' message @@ -72,7 +73,8 @@ public async Task UnloadAsync(string modelId, CancellationToken? ct = null) var result = await _coreInterop.ExecuteCommandAsync("unload_model", request, ct).ConfigureAwait(false); if (result.Error != null) { - throw new FoundryLocalException($"Error unloading model {modelId}: {result.Error}"); + throw Utils.FromNativeError("unload_model", result.Error, ct, _logger, + context: $"Error unloading model {modelId}"); } _logger.LogInformation("Model {ModelId} unloaded successfully: {Message}", modelId, result.Data); diff --git a/sdk/cs/src/Detail/ModelVariant.cs b/sdk/cs/src/Detail/ModelVariant.cs index 250c601a2..4c800555e 100644 --- a/sdk/cs/src/Detail/ModelVariant.cs +++ b/sdk/cs/src/Detail/ModelVariant.cs @@ -127,22 +127,22 @@ private async Task GetPathImplAsync(CancellationToken? ct = null) var result = await _coreInterop.ExecuteCommandAsync("get_model_path", request, ct).ConfigureAwait(false); if (result.Error != null) { - throw new FoundryLocalException( - $"Error getting path for model {Id}: {result.Error}. Has it been downloaded?"); + throw Utils.FromNativeError("get_model_path", result.Error, ct, _logger, + context: $"Error getting path for model {Id} (has it been downloaded?)"); } var path = result.Data!; return path; } + // Re-fire the user's progress callback every 5s when the native layer + // goes quiet, so spinners/UIs don't look hung during slow blob reads. + private static readonly TimeSpan DownloadHeartbeatInterval = TimeSpan.FromSeconds(5); + private async Task DownloadImplAsync(Action? downloadProgress = null, CancellationToken? ct = null) { - var request = new CoreInteropRequest - { - Params = new() { { "Model", Id } } - }; - + var request = new CoreInteropRequest { Params = new() { { "Model", Id } } }; ICoreInterop.Response? response; if (downloadProgress == null) @@ -151,21 +151,56 @@ private async Task DownloadImplAsync(Action? downloadProgress = null, } else { - var callback = new ICoreInterop.CallbackFn(progressString => + float lastProgress = 0f; + var lastUtc = DateTime.UtcNow; + var sync = new object(); + + var callback = new ICoreInterop.CallbackFn(s => { - if (float.TryParse(progressString, out var progress)) + if (float.TryParse(s, out var p)) { - downloadProgress(progress); + lock (sync) { lastProgress = p; lastUtc = DateTime.UtcNow; } + downloadProgress(p); } }); - response = await _coreInterop.ExecuteCommandWithCallbackAsync("download_model", request, - callback, ct).ConfigureAwait(false); + using var hbCts = new CancellationTokenSource(); + var hb = Task.Run(async () => + { + try + { + while (!hbCts.IsCancellationRequested) + { + await Task.Delay(DownloadHeartbeatInterval, hbCts.Token).ConfigureAwait(false); + float p; TimeSpan idle; + lock (sync) { p = lastProgress; idle = DateTime.UtcNow - lastUtc; } + if (idle >= DownloadHeartbeatInterval) + { + _logger.LogDebug("Download heartbeat {ModelId}: {Pct:F1}% (idle {Idle:F0}s)", + Id, p, idle.TotalSeconds); + try { downloadProgress(p); } catch { /* don't let a buggy callback abort the download */ } + } + } + } + catch (OperationCanceledException) { } + }, hbCts.Token); + + try + { + response = await _coreInterop.ExecuteCommandWithCallbackAsync("download_model", request, + callback, ct).ConfigureAwait(false); + } + finally + { + hbCts.Cancel(); + try { await hb.ConfigureAwait(false); } catch { } + } } if (response.Error != null) { - throw new FoundryLocalException($"Error downloading model {Id}: {response.Error}"); + throw Utils.FromNativeError("download_model", response.Error, ct, _logger, + context: $"Error downloading model {Id}"); } } @@ -176,7 +211,8 @@ private async Task RemoveFromCacheImplAsync(CancellationToken? ct = null) var result = await _coreInterop.ExecuteCommandAsync("remove_cached_model", request, ct).ConfigureAwait(false); if (result.Error != null) { - throw new FoundryLocalException($"Error removing model {Id} from cache: {result.Error}"); + throw Utils.FromNativeError("remove_cached_model", result.Error, ct, _logger, + context: $"Error removing model {Id} from cache"); } } diff --git a/sdk/cs/src/FoundryLocalException.cs b/sdk/cs/src/FoundryLocalException.cs index dae5ef042..8432976f6 100644 --- a/sdk/cs/src/FoundryLocalException.cs +++ b/sdk/cs/src/FoundryLocalException.cs @@ -30,3 +30,13 @@ internal FoundryLocalException(string message, Exception innerException, ILogger logger.LogError(innerException, message); } } + +/// +/// Thrown when a private catalog operation fails authentication (bad/expired +/// token, wrong aud, missing registry_name, etc.). +/// +public class CatalogAuthException : FoundryLocalException +{ + public string Reason { get; } + public CatalogAuthException(string message, string reason) : base(message) { Reason = reason; } +} diff --git a/sdk/cs/src/FoundryLocalManager.cs b/sdk/cs/src/FoundryLocalManager.cs index b014850f6..687ca58f3 100644 --- a/sdk/cs/src/FoundryLocalManager.cs +++ b/sdk/cs/src/FoundryLocalManager.cs @@ -11,6 +11,7 @@ namespace Microsoft.AI.Foundry.Local; using Microsoft.AI.Foundry.Local.Detail; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; public class FoundryLocalManager : IDisposable { @@ -46,15 +47,17 @@ public class FoundryLocalManager : IDisposable /// Create the singleton instance. /// /// Configuration to use. - /// Application logger to use. - /// Use Microsoft.Extensions.Logging.NullLogger.Instance if you wish to ignore log output from the SDK. + /// Application logger to use. If null, log output from the SDK is + /// discarded (equivalent to passing ). /// /// Optional cancellation token for the initialization. /// Task creating the instance. /// - public static async Task CreateAsync(Configuration configuration, ILogger logger, + public static async Task CreateAsync(Configuration configuration, ILogger? logger = null, CancellationToken? ct = null) { + logger ??= NullLogger.Instance; + using var disposable = await asyncLock.LockAsync().ConfigureAwait(false); if (instance != null) diff --git a/sdk/cs/src/FoundryModelInfo.cs b/sdk/cs/src/FoundryModelInfo.cs index 2d1327cc9..17c2f5c8f 100644 --- a/sdk/cs/src/FoundryModelInfo.cs +++ b/sdk/cs/src/FoundryModelInfo.cs @@ -131,4 +131,31 @@ public record ModelInfo [JsonPropertyName("capabilities")] public string? Capabilities { get; init; } + + /// + /// Registry name parsed from when it matches + /// azureml://registries/<name>/... (or any URI containing + /// registries/<name>/). Null otherwise. Used by + /// to filter private vs public models. + /// + [JsonIgnore] + public string? RegistryName + { + get + { + const string marker = "registries/"; + if (string.IsNullOrEmpty(Uri)) return null; + var idx = Uri.IndexOf(marker, System.StringComparison.OrdinalIgnoreCase); + if (idx < 0) return null; + var start = idx + marker.Length; + var end = Uri.IndexOf('/', start); + if (end < 0) end = Uri.Length; + return end > start ? Uri.Substring(start, end - start) : null; + } + } + + /// Case-insensitive match against . + public bool IsFromCatalogRegistry(string? registryName) => + !string.IsNullOrEmpty(registryName) && + string.Equals(RegistryName, registryName, System.StringComparison.OrdinalIgnoreCase); } diff --git a/sdk/cs/src/ICatalog.cs b/sdk/cs/src/ICatalog.cs index 4dca8e7d9..5aadcd8e0 100644 --- a/sdk/cs/src/ICatalog.cs +++ b/sdk/cs/src/ICatalog.cs @@ -1,4 +1,4 @@ -// -------------------------------------------------------------------------------------------------------------------- +// -------------------------------------------------------------------------------------------------------------------- // // Copyright (c) Microsoft. All rights reserved. // @@ -7,6 +7,17 @@ namespace Microsoft.AI.Foundry.Local; using System.Collections.Generic; +/// +/// Strongly-typed options for . +/// +public sealed class PrivateCatalogOptions +{ + /// JWT bearer token. exp/iat MUST be Unix seconds, not milliseconds. + public string? BearerToken { get; set; } + /// JWT aud claim MDS expects (e.g. "model-distribution-service"). + public string? Audience { get; set; } +} + public interface ICatalog { /// @@ -21,6 +32,23 @@ public interface ICatalog /// List of IModel instances. Task> ListModelsAsync(CancellationToken? ct = null); + /// + /// List the available models filtered to those whose + /// matches (case-insensitive). + /// + /// + /// Registry name to filter by (e.g. mds-acme-registry for an MDS-hosted + /// private catalog, or the name of the Azure ML public registry). Required. + /// + /// Optional CancellationToken. + /// List of IModel instances whose URI is rooted at the given registry. + /// + /// This is a client-side filter on the result of ; + /// no extra round-trips to the broker. Useful for samples / UIs that want to + /// group "public catalog" vs "private catalog" models. + /// + Task> ListModelsAsync(string catalogRegistryName, CancellationToken? ct = null); + /// /// Lookup a model by its alias. /// @@ -29,6 +57,23 @@ public interface ICatalog /// The matching IModel, or null if no model with the given alias exists. Task GetModelAsync(string modelAlias, CancellationToken? ct = null); + /// + /// Lookup a model by its alias, preferring variants that originate from + /// the named catalog registry. If no variants match the preferred registry, + /// falls back to the unfiltered result (same as ). + /// + /// Model alias. + /// + /// Catalog registry name to prefer (e.g. mds-acme-registry). When the + /// same alias is published by both the public and a private catalog this + /// disambiguates which one the caller wants. Null or empty disables the + /// preference and behaves like the single-argument overload. + /// + /// Optional CancellationToken. + Task GetModelAsync(string modelAlias, + string? preferCatalogRegistryName, + CancellationToken? ct = null); + /// /// Lookup a model variant by its unique model id. /// NOTE: This will return an IModel with a single variant. Use GetModelAsync to get an IModel with all available @@ -61,4 +106,46 @@ public interface ICatalog /// Optional CancellationToken. /// The latest version of the model. Will match the input if it is the latest version. Task GetLatestVersionAsync(IModel model, CancellationToken? ct = null); + + /// + /// Add a private model catalog. The endpoint is fixed (the SDK targets the + /// Model Distribution Service); only and credentials + /// are caller-provided. Idempotent: calling with the same + /// replaces the existing registration (use this to + /// rotate an expired BearerToken). The model list is refreshed + /// before returning. + /// + /// + /// Recognised keys: BearerToken, Audience, TokenEndpoint, + /// ClientId, ClientSecret, Type (default + /// "AzurePrivate"). If BearerToken is a JWT, its exp/iat + /// MUST be Unix seconds (not milliseconds); the SDK rejects ms-shaped values. + /// Prefer the overload for IntelliSense. + /// + /// Bad name, or JWT exp/iat in milliseconds. + /// MDS rejected the bearer token. + Task AddCatalogAsync(string name, Dictionary? options = null, + CancellationToken? ct = null); + + /// Strongly-typed overload of . + Task AddCatalogAsync(string name, PrivateCatalogOptions options, + CancellationToken? ct = null); + + /// Alias for ; same idempotent behavior. + Task AddOrUpdateCatalogAsync(string name, Dictionary? options = null, + CancellationToken? ct = null); + + /// + /// Remove a previously-added private catalog by name. No-op for the built-in + /// public catalog. After removal the model list is refreshed so cached + /// models from the removed catalog no longer appear in . + /// + Task RemoveCatalogAsync(string name, CancellationToken? ct = null); + + /// + /// Get the names of all registered catalogs. + /// + /// Optional CancellationToken. + /// List of catalog name strings. + Task> GetCatalogNamesAsync(CancellationToken? ct = null); } diff --git a/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj b/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj index 384b44151..7e7242a2e 100644 --- a/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj +++ b/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj @@ -15,6 +15,7 @@ git net8.0;netstandard2.0 + win-x64;win-arm64;osx-arm64;linux-x64 true enable True diff --git a/sdk/cs/src/Utils.cs b/sdk/cs/src/Utils.cs index 09338497b..458572f96 100644 --- a/sdk/cs/src/Utils.cs +++ b/sdk/cs/src/Utils.cs @@ -47,4 +47,36 @@ internal static T CallWithExceptionHandling(Func func, string errorMsg, IL throw new FoundryLocalException(errorMsg, ex, logger); } } + + /// + /// Build a from a native broker error string. + /// Adds a hint when the native layer reports "cancelled" but the caller did not + /// actually cancel — almost always a server/auth/network failure mis-reported as + /// cancellation. Stashes raw error + command on . + /// + internal static FoundryLocalException FromNativeError( + string commandName, string nativeError, CancellationToken? ct, ILogger logger, string? context = null) + { + var prefix = context ?? $"Error executing '{commandName}'"; + var userCancelled = ct.HasValue && ct.Value.IsCancellationRequested; +#if NETSTANDARD2_0 + var looksCancel = nativeError.IndexOf("cancel", System.StringComparison.OrdinalIgnoreCase) >= 0; +#else + var looksCancel = nativeError.Contains("cancel", System.StringComparison.OrdinalIgnoreCase); +#endif + var message = (looksCancel && !userCancelled) + ? $"{prefix}: {nativeError}. Caller did not cancel — likely a server-side failure " + + "(authentication, expired credentials, or network error)." + : $"{prefix}: {nativeError}"; + + logger.LogError("Native command '{Command}' failed: {NativeError} (caller cancelled: {UserCancelled})", + commandName, nativeError, userCancelled); + + var ex = new FoundryLocalException(message); + ex.Data["Command"] = commandName; + ex.Data["NativeError"] = nativeError; + ex.Data["UserCancelled"] = userCancelled; + return ex; + } } + diff --git a/sdk/cs/test/FoundryLocal.Tests/CatalogManagementTests.cs b/sdk/cs/test/FoundryLocal.Tests/CatalogManagementTests.cs new file mode 100644 index 000000000..7c6a7ca48 --- /dev/null +++ b/sdk/cs/test/FoundryLocal.Tests/CatalogManagementTests.cs @@ -0,0 +1,59 @@ +// -------------------------------------------------------------------------------------------------------------------- +// +// Copyright (c) Microsoft. All rights reserved. +// +// -------------------------------------------------------------------------------------------------------------------- + +namespace Microsoft.AI.Foundry.Local.Tests; + +using System.Text.Json; +using Microsoft.AI.Foundry.Local.Detail; +using Moq; + +public class CatalogManagementTests +{ + private static async Task CreateCatalogWithIntercepts( + List extra) + { + var logger = Utils.CreateCapturingLoggerMock([]); + var lm = new Mock(); + lm.Setup(m => m.ListLoadedModelsAsync(It.IsAny())).ReturnsAsync(Array.Empty()); + + List intercepts = + [ + new() { CommandName = "get_catalog_name", ResponseData = "Test" }, + new() { CommandName = "get_model_list", + ResponseData = JsonSerializer.Serialize(Utils.TestCatalog.TestCatalog, + JsonSerializationContext.Default.ListModelInfo) }, + new() { CommandName = "get_cached_models", ResponseData = "[]" }, + .. extra + ]; + + var ci = Utils.CreateCoreInteropWithIntercept(Utils.CoreInterop, intercepts); + return await Catalog.CreateAsync(lm.Object, ci.Object, logger.Object); + } + + [Test] + public async Task Test_AddCatalog() + { + using var catalog = await CreateCatalogWithIntercepts( + [ + new() { CommandName = "add_catalog", ResponseData = "OK" } + ]); + + await catalog.AddCatalogAsync("priv", + new Dictionary { ["ClientId"] = "id", ["ClientSecret"] = "secret" }); + await Assert.That(catalog).IsNotNull(); + } + + [Test] + public async Task Test_GetCatalogNames() + { + using var catalog = await CreateCatalogWithIntercepts( + [new() { CommandName = "get_catalog_names", ResponseData = "[\"public\",\"private\"]" }]); + + var names = await catalog.GetCatalogNamesAsync(); + await Assert.That(names.Count).IsEqualTo(2); + await Assert.That(names).Contains("private"); + } +}