Skip to content

Commit

Permalink
update connect timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
DavoudEshtehari committed Jul 28, 2023
1 parent dcccdd0 commit 1670faa
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Diagnostics;
using System.Net;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
{
Expand Down Expand Up @@ -194,7 +197,7 @@ internal static bool ValidateSslServerCertificate(string targetServerName, X509C
return true;
}
}

/// <summary>
/// We validate the provided certificate provided by the client with the one from the server to see if it matches.
/// Certificate validation and chain trust validations are done by SSLStream class [System.Net.Security.SecureChannel.VerifyRemoteCertificate method]
Expand Down Expand Up @@ -239,6 +242,23 @@ internal static bool ValidateSslServerCertificate(X509Certificate clientCert, X5
}
}

internal static IPAddress[] GetDnsIpAddresses(string serverName, ref TimeSpan timeout)
{
using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses)))
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Getting DNS host entries for serverName {0} with {1} timeout.", args0: serverName, args1: timeout);
using CancellationTokenSource cts = new CancellationTokenSource(timeout);
Stopwatch stopwatch = Stopwatch.StartNew();
// using this overload to support netstandard
Task<IPAddress[]> task = Dns.GetHostAddressesAsync(serverName);
task.ConfigureAwait(false);
task.Wait(cts.Token);
timeout -= stopwatch.Elapsed;
stopwatch.Stop();
return task.Result;
}
}

internal static IPAddress[] GetDnsIpAddresses(string serverName)
{
using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ public override int ProtocolVersion
ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts;
}

Stopwatch stopwatch = Stopwatch.StartNew();

bool reportError = true;

SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Connecting to serverName {1} and port {2}", args0: _connectionId, args1: serverName, args2: port);
Expand All @@ -183,6 +185,11 @@ public override int ProtocolVersion
}
catch (Exception ex)
{
TimeSpan timeLeft = ts - stopwatch.Elapsed;
if (!isInfiniteTimeOut && timeLeft <= TimeSpan.Zero)
{
throw;
}
// Retry with cached IP address
if (ex is SocketException || ex is ArgumentException || ex is AggregateException)
{
Expand Down Expand Up @@ -214,26 +221,31 @@ public override int ProtocolVersion
{
if (parallel)
{
_socket = TryConnectParallel(firstCachedIP, portRetry, ts, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo);
_socket = TryConnectParallel(firstCachedIP, portRetry, timeLeft, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo);
}
else
{
_socket = Connect(firstCachedIP, portRetry, ts, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo);
_socket = Connect(firstCachedIP, portRetry, timeLeft, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo);
}
}
catch (Exception exRetry)
{
timeLeft = ts - stopwatch.Elapsed;
if (!isInfiniteTimeOut && timeLeft <= TimeSpan.Zero)
{
throw;
}
if (exRetry is SocketException || exRetry is ArgumentNullException
|| exRetry is ArgumentException || exRetry is ArgumentOutOfRangeException || exRetry is AggregateException)
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Retrying exception {1}", args0: _connectionId, args1: exRetry?.Message);
if (parallel)
{
_socket = TryConnectParallel(secondCachedIP, portRetry, ts, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo);
_socket = TryConnectParallel(secondCachedIP, portRetry, timeLeft, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo);
}
else
{
_socket = Connect(secondCachedIP, portRetry, ts, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo);
_socket = Connect(secondCachedIP, portRetry, timeLeft, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo);
}
}
else
Expand All @@ -249,6 +261,10 @@ public override int ProtocolVersion
throw;
}
}
finally
{
stopwatch.Stop();
}

if (_socket == null || !_socket.Connected)
{
Expand Down Expand Up @@ -304,8 +320,11 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i
{
Socket availableSocket = null;
Task<Socket> connectTask;
TimeSpan timeout = ts;

IPAddress[] serverAddresses = SNICommon.GetDnsIpAddresses(hostName);
IPAddress[] serverAddresses = isInfiniteTimeOut
? SNICommon.GetDnsIpAddresses(hostName)
: SNICommon.GetDnsIpAddresses(hostName, ref timeout);

if (serverAddresses.Length > MaxParallelIpAddresses)
{
Expand Down Expand Up @@ -338,7 +357,7 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i

connectTask = ParallelConnectAsync(serverAddresses, port);

if (!(isInfiniteTimeOut ? connectTask.Wait(-1) : connectTask.Wait(ts)))
if (!(isInfiniteTimeOut ? connectTask.Wait(-1) : connectTask.Wait(timeout)))
{
callerReportError = false;
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "Connection Id {0} Connection timed out, Exception: {1}", args0: _connectionId, args1: Strings.SNI_ERROR_40);
Expand All @@ -349,7 +368,7 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i
availableSocket = connectTask.Result;
return availableSocket;
}

/// <summary>
/// Returns array of IP addresses for the given server name, sorted according to the given preference.
/// </summary>
Expand Down Expand Up @@ -389,7 +408,7 @@ private static IEnumerable<IPAddress> GetHostAddressesSortedByPreference(string
}
}
}

// Connect to server with hostName and port.
// The IP information will be collected temporarily as the pendingDNSInfo but is not stored in the DNS cache at this point.
// Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server.
Expand Down Expand Up @@ -422,26 +441,44 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
port,
ipAddress.AddressFamily,
isInfiniteTimeout);

bool isConnected;
try // catching SocketException with SocketErrorCode == WouldBlock to run Socket.Select
{
socket.Connect(ipAddress, port);
if (!isInfiniteTimeout)
if (isInfiniteTimeout)
{
socket.Connect(ipAddress, port);
}
else
{
TimeSpan timeLeft = timeout - timeTaken.Elapsed;
if (timeLeft <= TimeSpan.Zero)
{
return null;
}
// Socket.Connect does not support infinite timeouts, so we use Task to simulate it
Task socketConnectTask = new Task(() => socket.Connect(ipAddress, port));
socketConnectTask.ConfigureAwait(false);
socketConnectTask.Start();
if (!socketConnectTask.Wait(timeLeft))
{
throw ADP.TimeoutException($"The socket couldn't connect during the expected {timeLeft} remaining time to connect.");
}
throw SQL.SocketDidNotThrow();
}

isConnected = true;
}
catch (SocketException socketException) when (!isInfiniteTimeout &&
socketException.SocketErrorCode ==
SocketError.WouldBlock)
catch (AggregateException aggregateException) when (!isInfiniteTimeout
&& aggregateException.InnerException is SocketException socketException
&& socketException.SocketErrorCode == SocketError.WouldBlock)
{
// https://github.com/dotnet/SqlClient/issues/826#issuecomment-736224118
// Socket.Select is used because it supports timeouts, while Socket.Connect does not

List<Socket> checkReadLst; List<Socket> checkWriteLst; List<Socket> checkErrorLst;
List<Socket> checkReadLst;
List<Socket> checkWriteLst;
List<Socket> checkErrorLst;

// Repeating Socket.Select several times if our timeout is greater
// than int.MaxValue microseconds because of
Expand All @@ -450,9 +487,10 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
do
{
TimeSpan timeLeft = timeout - timeTaken.Elapsed;

if (timeLeft <= TimeSpan.Zero)
if (!isInfiniteTimeout && timeLeft <= TimeSpan.Zero)
{
return null;
}

int socketSelectTimeout =
checked((int)(Math.Min(timeLeft.TotalMilliseconds, int.MaxValue / 1000) * 1000));
Expand Down Expand Up @@ -487,11 +525,15 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
return socket;
}
}
catch (SocketException e)
catch (AggregateException aggregateException) when (aggregateException.InnerException is SocketException socketException)
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: e?.Message);
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: socketException?.Message);
SqlClientEventSource.Log.TryAdvancedTraceEvent(
$"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {e}");
$"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {socketException}");
}
catch (AggregateException aggregateException) when (aggregateException.InnerException is TimeoutException timeoutException)
{
Console.WriteLine(timeoutException); // temporary for testing
}
finally
{
Expand Down Expand Up @@ -675,7 +717,7 @@ private bool ValidateServerCertificate(object sender, X509Certificate serverCert
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Certificate will not be validated.", args0: _connectionId);
return true;
}

string serverNameToValidate;
if (!string.IsNullOrEmpty(_hostNameInCertificate))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,17 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re
TimeSpan ts = default;
// In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count
// The infinite Timeout is a function of ConnectionString Timeout=0
if (long.MaxValue != timerExpire)
bool isInfiniteTimeout = long.MaxValue == timerExpire;
if (!isInfiniteTimeout)
{
ts = DateTime.FromFileTime(timerExpire) - DateTime.Now;
ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts;
}

IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(browserHostname);
IPAddress[] ipAddresses = isInfiniteTimeout
? SNICommon.GetDnsIpAddresses(browserHostname)
: SNICommon.GetDnsIpAddresses(browserHostname, ref ts);

Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve");
IPAddress[] ipv4Addresses = null;
IPAddress[] ipv6Addresses = null;
Expand Down

0 comments on commit 1670faa

Please sign in to comment.