Skip to content

Commit

Permalink
Fix | Throttling of token requests by calling AcquireTokenSilent (#1925
Browse files Browse the repository at this point in the history
…) (#1995)
  • Loading branch information
DavoudEshtehari committed Apr 17, 2023
1 parent 86aac95 commit 5b01cc1
Showing 1 changed file with 131 additions and 70 deletions.
Expand Up @@ -4,7 +4,11 @@

using System;
using System.Collections.Concurrent;
using System.Linq;
using System.Runtime.Caching;
using System.Security;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
Expand All @@ -24,6 +28,8 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
/// </summary>
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
private static readonly MemoryCache s_accountPwCache = new(nameof(ActiveDirectoryAuthenticationProvider));
private static readonly int s_accountPwCacheTtlInHours = 2;
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
private static readonly string s_defaultScopeSuffix = "/.default";
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
Expand Down Expand Up @@ -171,7 +177,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
}

AuthenticationResult result;
AuthenticationResult result = null;
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal)
{
AccessToken accessToken = await new ClientSecretCredential(audience, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
Expand Down Expand Up @@ -207,86 +213,87 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti

if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
if (!string.IsNullOrEmpty(parameters.UserId))
{
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.WithUsername(parameters.UserId)
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
}
else
{
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
}
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn);
}
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword)
{
SecureString password = new SecureString();
foreach (char c in parameters.Password)
password.AppendChar(c);
password.MakeReadOnly();

result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn);
}
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
{
// Fetch available accounts from 'app' instance
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);

IAccount account = default;
if (accounts.MoveNext())
if (null == result)
{
if (!string.IsNullOrEmpty(parameters.UserId))
{
do
{
IAccount currentVal = accounts.Current;
if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0)
{
account = currentVal;
break;
}
}
while (accounts.MoveNext());
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.WithUsername(parameters.UserId)
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
}
else
{
account = accounts.Current;
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
}
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn);
}
}
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword)
{
string pwCacheKey = GetAccountPwCacheKey(parameters);
object previousPw = s_accountPwCache.Get(pwCacheKey);
byte[] currPwHash = GetHash(parameters.Password);

if (null != previousPw &&
previousPw is byte[] previousPwBytes &&
// Only get the cached token if the current password hash matches the previously used password hash
currPwHash.SequenceEqual(previousPwBytes))
{
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
}

if (null != account)
if (null == result)
{
try
SecureString password = new SecureString();
foreach (char c in parameters.Password)
password.AppendChar(c);
password.MakeReadOnly();

result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);

// We cache the password hash to ensure future connection requests include a validated password
// when we check for a cached MSAL account. Otherwise, a connection request with the same username
// against the same tenant could succeed with an invalid password when we re-use the cached token.
if (!s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours)))
{
// If 'account' is available in 'app', we use the same to acquire token silently.
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
}
catch (MsalUiRequiredException)
{
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
// or the user needs to perform two factor authentication.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
s_accountPwCache.Remove(pwCacheKey);
s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours));
}

SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn);
}
else
}
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
{
try
{
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
}
catch (MsalUiRequiredException)
{
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
// or the user needs to perform two factor authentication.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
}

if (null == result)
{
// If no existing 'account' is found, we request user to sign in interactively.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false);
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
}
}
Expand All @@ -299,8 +306,49 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn);
}

private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId,
SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts)
private static async Task<AuthenticationResult> TryAcquireTokenSilent(IPublicClientApplication app, SqlAuthenticationParameters parameters,
string[] scopes, CancellationTokenSource cts)
{
AuthenticationResult result = null;

// Fetch available accounts from 'app' instance
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();

IAccount account = default;
if (accounts.MoveNext())
{
if (!string.IsNullOrEmpty(parameters.UserId))
{
do
{
IAccount currentVal = accounts.Current;
if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0)
{
account = currentVal;
break;
}
}
while (accounts.MoveNext());
}
else
{
account = accounts.Current;
}
}

if (null != account)
{
// If 'account' is available in 'app', we use the same to acquire token silently.
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
}

return result;
}

private static async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId,
SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts, ICustomWebUi customWebUI, Func<DeviceCodeResult, Task> deviceCodeFlowCallback)
{
try
{
Expand All @@ -319,11 +367,11 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
*/
ctsInteractive.CancelAfter(180000);
#endif
if (_customWebUI != null)
if (customWebUI != null)
{
return await app.AcquireTokenInteractive(scopes)
.WithCorrelationId(connectionId)
.WithCustomWebUi(_customWebUI)
.WithCustomWebUi(customWebUI)
.WithLoginHint(userId)
.ExecuteAsync(ctsInteractive.Token)
.ConfigureAwait(false);
Expand Down Expand Up @@ -357,7 +405,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
else
{
AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes,
deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult))
deviceCodeResult => deviceCodeFlowCallback(deviceCodeResult))
.WithCorrelationId(connectionId)
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
Expand Down Expand Up @@ -410,6 +458,19 @@ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey p
return clientApplicationInstance;
}

private static string GetAccountPwCacheKey(SqlAuthenticationParameters parameters)
{
return parameters.Authority + "+" + parameters.UserId;
}

private static byte[] GetHash(string input)
{
byte[] unhashedBytes = Encoding.Unicode.GetBytes(input);
SHA256 sha256 = SHA256.Create();
byte[] hashedBytes = sha256.ComputeHash(unhashedBytes);
return hashedBytes;
}

private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey)
{
IPublicClientApplication publicClientApplication;
Expand Down

0 comments on commit 5b01cc1

Please sign in to comment.