Skip to content

Commit

Permalink
Tested init code paths. Fixed bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
ralph-msft committed Mar 30, 2024
1 parent e4c1c1c commit ce4f9b7
Show file tree
Hide file tree
Showing 19 changed files with 324 additions and 259 deletions.
11 changes: 6 additions & 5 deletions src/ai/Program_AI.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,13 @@ public AiProgramData()
() => TelemetryHelpers.InstantiateFromConfig(this),
System.Threading.LazyThreadSafetyMode.ExecutionAndPublication);

LoginManager = new AzConsoleLoginManager(TelemetryUserAgent);
var azCliClient = new AzCliClient(
LoginManager,
() => Values?.GetOrDefault("init.service.interactive", true) ?? true,
TelemetryUserAgent);
var environmentVariables = LegacyAzCli.GetUserAgentEnv();

LoginManager = new AzConsoleLoginManager(
() => Values?.GetOrDefault("init.service.interactive", true) ?? true,
environmentVariables);

var azCliClient = new AzCliClient(environmentVariables);
SubscriptionsClient = azCliClient;
CognitiveServicesClient = azCliClient;
SearchClient = azCliClient;
Expand Down
11 changes: 6 additions & 5 deletions src/ai/commands/init_command.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,18 +187,19 @@ private async Task DoInitRootVerifyConfigFileAsync(bool interactive, string file

ConsoleHelpers.WriteLineWithHighlight("\n `ATTACHED SERVICES AND RESOURCES`\n");

var message = " Validating...";
Console.Write(message);
string indent = " ";
using var cw = new ConsoleTempWriter($"{indent}Validating...");

var (hubName, openai, search) = await AiSdkConsoleGui.VerifyResourceConnections(_values, validated?.Id, groupName, projectName);
if (openai != null && search != null)
{
Console.Write($"\r{new string(' ', message.Length)}\r");
return (hubName, openai, search);
}
else
{
ConsoleHelpers.WriteLineWithHighlight($"\r{message} `#e_;WARNING: Configuration could not be validated!`");
cw.Clear();
Console.Write(indent);
ConsoleHelpers.WriteLineWithHighlight($"`#e_;WARNING: Configuration could not be validated!`");
Console.WriteLine();
return (null, null, null);
}
Expand Down Expand Up @@ -536,7 +537,7 @@ private async Task DoInitOpenAi(bool interactive, bool skipChat = false, bool al
var regionFilter = _values.GetOrDefault("init.service.resource.region.name", "");
var groupFilter = _values.GetOrDefault("init.service.resource.group.name", "");
var resourceFilter = _values.GetOrDefault("init.service.cognitiveservices.resource.name", "");
var kind = _values.GetOrDefault("init.service.cognitiveservices.resource.kind", "OpenAI;AIServices");
var kind = _values.GetOrDefault("init.service.cognitiveservices.resource.kind", "AIServices;OpenAI");
var sku = _values.GetOrDefault("init.service.cognitiveservices.resource.sku", Program.CognitiveServiceResourceSku);
var yes = _values.GetOrDefault("init.service.cognitiveservices.terms.agree", false);

Expand Down
139 changes: 63 additions & 76 deletions src/clients.cli/AzCliClient.cs

Large diffs are not rendered by default.

28 changes: 22 additions & 6 deletions src/clients.cli/AzConsoleLoginManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,35 @@ namespace Azure.AI.CLI.Clients.AzPython
/// </summary>
public class AzConsoleLoginManager : ILoginManager
{
private readonly Func<bool> _getAllowInteractive;
private readonly IDictionary<string, string> _cliEnv;
private readonly TimeSpan _minValidity;
private AuthenticationToken? _authToken;

public AzConsoleLoginManager(string userAgent, TimeSpan? minValidity = null)
/// <summary>
/// Creates a new instance of <see cref="AzConsoleLoginManager"/>
/// </summary>
/// <param name="getAllowInteractiveLogin">A function that returns whether interactive login is allowed</param>
/// <param name="environmentVariables">Environment variables to use when running the AZ CLI</param>
/// <param name="minValidity">The minimum validity an auth token should have before renewing it</param>
public AzConsoleLoginManager(Func<bool> getAllowInteractiveLogin, IDictionary<string, string>? environmentVariables = null, TimeSpan? minValidity = null)
{
_cliEnv = new Dictionary<string, string>()
{
{ "AZURE_HTTP_USER_AGENT", userAgent ?? throw new ArgumentNullException(nameof(userAgent)) }
};

_getAllowInteractive = getAllowInteractiveLogin ?? throw new ArgumentNullException(nameof(getAllowInteractiveLogin));
_cliEnv = environmentVariables ?? new Dictionary<string, string>();
_minValidity = minValidity ?? TimeSpan.FromMinutes(5);
}

/// <inheritdoc />
public bool CanAttemptLogin
{
get
{
try { return _getAllowInteractive(); }
catch (Exception) { return false; }
}
}

/// <inheritdoc />
public async Task<ClientResult> LoginAsync(LoginOptions options, CancellationToken token)
{
try
Expand Down Expand Up @@ -66,6 +81,7 @@ public async Task<ClientResult> LoginAsync(LoginOptions options, CancellationTok
}
}

/// <inheritdoc />
public async Task<ClientResult<AuthenticationToken>> GetOrRenewAuthToken(CancellationToken token)
{
try
Expand Down
97 changes: 0 additions & 97 deletions src/clients.cli/login_helpers.cs

This file was deleted.

28 changes: 28 additions & 0 deletions src/common/Clients/ILoginManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,44 @@ public enum LoginMode
UseDeviceCode,
}

/// <summary>
/// Options for logging in
/// </summary>
public struct LoginOptions
{
/// <summary>
/// The type of login to attempt
/// </summary>
public LoginMode Mode { get; set; }

//public string? TenantId { get; set; }
//public string? UserName { get; set; }
}

/// <summary>
/// Manages the login process for a client
/// </summary>
public interface ILoginManager : IDisposable
{
/// <summary>
/// Whether or not we can attempt to log in
/// </summary>
/// <returns>True or false</returns>
public bool CanAttemptLogin { get; }

/// <summary>
/// Logs the client in
/// </summary>
/// <param name="options">The options to use for the login</param>
/// <param name="token">The cancellation token to use</param>
/// <returns>Asynchronous task that completes once the login has completed (or failed)</returns>
Task<ClientResult> LoginAsync(LoginOptions options, CancellationToken token);

/// <summary>
/// Gets an authentication token to use for sending HTTP REST requests
/// </summary>
/// <param name="token">The cancellation token to use</param>
/// <returns>Asynchronous task that returns the authentication token</returns>
Task<ClientResult<AuthenticationToken>> GetOrRenewAuthToken(CancellationToken token);
}
}
65 changes: 42 additions & 23 deletions src/common/Clients/Models/ClientResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -163,27 +163,24 @@ internal static bool CheckIsError(ClientOutcome Outcome, string? ErrorDetails, E
|| Exception != null
|| !string.IsNullOrWhiteSpace(ErrorDetails);

internal static Exception GenerateException(ClientOutcome outcome, string? message = null, string? errorDetails = null, Exception? ex = null)
=> outcome switch
{
ClientOutcome.Canceled => new OperationCanceledException(CreateExceptionMessage("CANCELED", message, errorDetails), ex),
ClientOutcome.Failed or ClientOutcome.Unknown => new ApplicationException(CreateExceptionMessage("FAILED", message, errorDetails), ex),
ClientOutcome.LoginNeeded => new ApplicationException(CreateExceptionMessage("LOGIN REQUIRED", message, errorDetails), ex),
ClientOutcome.TimedOut => new TimeoutException(CreateExceptionMessage("TIMED OUT", message, errorDetails), ex),
_ => new ApplicationException(CreateExceptionMessage("FAILED", message, errorDetails), ex),
};

internal static void ThrowOnFail(ClientOutcome outcome, string? message, string? errorDetails, Exception? ex)
{
switch (outcome)
if (ClientOutcome.Success == outcome)
{
case ClientOutcome.Canceled:
throw new OperationCanceledException(CreateExceptionMessage("CANCELED", message, errorDetails), ex);

case ClientOutcome.Failed:
case ClientOutcome.Unknown:
throw new ApplicationException(CreateExceptionMessage("FAILED", message, errorDetails), ex);

case ClientOutcome.LoginNeeded:
throw new ApplicationException(CreateExceptionMessage("LOGIN REQUIRED", message, errorDetails), ex);

case ClientOutcome.Success:
// nothing to do
return;

case ClientOutcome.TimedOut:
throw new TimeoutException(CreateExceptionMessage("TIMED OUT", message, errorDetails), ex);
return;
}

throw GenerateException(outcome, message, errorDetails, ex);
}

private static string CreateExceptionMessage(string typeStr, string? messageStr, string? detailsStr)
Expand Down Expand Up @@ -239,27 +236,45 @@ internal static string ToString<T>(T result) where T : IClientResult

internal static ClientResult From(ProcessOutput output)
{
bool hasError = output.HasError;
ClientOutcome outcome = ClientOutcome.Success;
if (HasLoginError(output.StdError))
{
outcome = ClientOutcome.LoginNeeded;
}
else if (output.HasError)
{
outcome = ClientOutcome.Failed;
}

return new ClientResult
{
ErrorDetails = string.IsNullOrWhiteSpace(output.StdError) ? null : output.StdError,
Exception = hasError
Exception = outcome != ClientOutcome.Success
? new ApplicationException($"Process failed with exit code {output.ExitCode}. Details: {output.StdError}")
: null,
Outcome = hasError ? ClientOutcome.Failed : ClientOutcome.Success,
Outcome = outcome,
};
}

internal static ClientResult<TValue> From<TValue>(ProcessOutput output, TValue value)
{
bool hasError = output.HasError;
ClientOutcome outcome = ClientOutcome.Success;
if (HasLoginError(output.StdError))
{
outcome = ClientOutcome.LoginNeeded;
}
else if (output.HasError)
{
outcome = ClientOutcome.Failed;
}

return new ClientResult<TValue>
{
ErrorDetails = string.IsNullOrWhiteSpace(output.StdError) ? null : output.StdError,
Exception = hasError
Exception = outcome != ClientOutcome.Success
? new ApplicationException($"Process failed with exit code {output.ExitCode}. Details: {output.StdError}")
: null,
Outcome = hasError ? ClientOutcome.Failed : ClientOutcome.Success,
Outcome = outcome,
Value = value
};
}
Expand Down Expand Up @@ -290,6 +305,10 @@ internal static ClientResult<TValue> FromException<TValue>(Exception ex, TValue
ErrorDetails = ex.Message,
Value = result
};

private static bool HasLoginError(string? errorMessage)
=> errorMessage != null
&& (errorMessage.Split('\'', '"').Contains("az login") || errorMessage.Contains("refresh token"));
}

#endregion
Expand Down
Loading

0 comments on commit ce4f9b7

Please sign in to comment.