Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev | Improved connection timeout #2098

Merged
merged 6 commits into from Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -3,9 +3,13 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Diagnostics;
using System.Net;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Data.ProviderBase;

namespace Microsoft.Data.SqlClient.SNI
{
Expand Down Expand Up @@ -194,7 +198,7 @@ internal static bool ValidateSslServerCertificate(string targetServerName, X509C
return true;
}
}

/// <summary>
/// We validate the provided certificate provided by the client with the one from the server to see if it matches.
/// Certificate validation and chain trust validations are done by SSLStream class [System.Net.Security.SecureChannel.VerifyRemoteCertificate method]
Expand Down Expand Up @@ -239,6 +243,24 @@ internal static bool ValidateSslServerCertificate(X509Certificate clientCert, X5
}
}

internal static IPAddress[] GetDnsIpAddresses(string serverName, TimeoutTimer timeout)
{
using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses)))
{
int remainingTimeout = timeout.MillisecondsRemainingInt;
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO,
"Getting DNS host entries for serverName {0} within {1} milliseconds.",
args0: serverName,
args1: remainingTimeout);
using CancellationTokenSource cts = new CancellationTokenSource(remainingTimeout);
// using this overload to support netstandard
Task<IPAddress[]> task = Dns.GetHostAddressesAsync(serverName);
task.ConfigureAwait(false);
task.Wait(cts.Token);
return task.Result;
}
}

internal static IPAddress[] GetDnsIpAddresses(string serverName)
{
using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses)))
Expand Down
Expand Up @@ -10,6 +10,7 @@
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using Microsoft.Data.ProviderBase;

namespace Microsoft.Data.SqlClient.SNI
{
Expand Down Expand Up @@ -37,7 +38,7 @@ internal sealed class SNINpHandle : SNIPhysicalHandle
private int _bufferSize = TdsEnums.DEFAULT_LOGIN_PACKET_SIZE;
private readonly Guid _connectionId = Guid.NewGuid();

public SNINpHandle(string serverName, string pipeName, long timerExpire, bool tlsFirst)
public SNINpHandle(string serverName, string pipeName, TimeoutTimer timeout, bool tlsFirst)
{
using (TrySNIEventScope.Create(nameof(SNINpHandle)))
{
Expand All @@ -54,17 +55,25 @@ public SNINpHandle(string serverName, string pipeName, long timerExpire, bool tl
PipeDirection.InOut,
PipeOptions.Asynchronous | PipeOptions.WriteThrough);

bool isInfiniteTimeOut = long.MaxValue == timerExpire;
if (isInfiniteTimeOut)
if (timeout.IsInfinite)
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNINpHandle), EventType.INFO,
"Connection Id {0}, Setting server name = {1}, pipe name = {2}. Connecting with infinite timeout.",
args0: _connectionId,
args1: serverName,
args2: pipeName);
_pipeStream.Connect(Timeout.Infinite);
}
else
{
TimeSpan ts = DateTime.FromFileTime(timerExpire) - DateTime.Now;
ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts;

_pipeStream.Connect((int)ts.TotalMilliseconds);
int timeoutMilliseconds = timeout.MillisecondsRemainingInt;
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNINpHandle), EventType.INFO,
"Connection Id {0}, Setting server name = {1}, pipe name = {2}. Connecting within the {3} sepecified milliseconds.",
args0: _connectionId,
args1: serverName,
args2: pipeName,
args3: timeoutMilliseconds);
_pipeStream.Connect(timeoutMilliseconds);
}
}
catch (TimeoutException te)
Expand Down
Expand Up @@ -9,6 +9,7 @@
using System.Net.Security;
using System.Net.Sockets;
using System.Text;
using Microsoft.Data.ProviderBase;

namespace Microsoft.Data.SqlClient.SNI
{
Expand Down Expand Up @@ -130,7 +131,7 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode)
/// Create a SNI connection handle
/// </summary>
/// <param name="fullServerName">Full server name from connection string</param>
/// <param name="timerExpire">Timer expiration</param>
/// <param name="timeout">Timer expiration</param>
/// <param name="instanceName">Instance name</param>
/// <param name="spnBuffer">SPN</param>
/// <param name="serverSPN">pre-defined SPN</param>
Expand All @@ -147,7 +148,7 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode)
/// <returns>SNI handle</returns>
internal static SNIHandle CreateConnectionHandle(
string fullServerName,
long timerExpire,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
string serverSPN,
Expand Down Expand Up @@ -186,11 +187,11 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode)
case DataSource.Protocol.Admin:
case DataSource.Protocol.None: // default to using tcp if no protocol is provided
case DataSource.Protocol.TCP:
sniHandle = CreateTcpHandle(details, timerExpire, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo,
sniHandle = CreateTcpHandle(details, timeout, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo,
tlsFirst, hostNameInCertificate, serverCertificateFilename);
break;
case DataSource.Protocol.NP:
sniHandle = CreateNpHandle(details, timerExpire, parallel, tlsFirst);
sniHandle = CreateNpHandle(details, timeout, parallel, tlsFirst);
break;
default:
Debug.Fail($"Unexpected connection protocol: {details._connectionProtocol}");
Expand Down Expand Up @@ -279,7 +280,7 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
/// Creates an SNITCPHandle object
/// </summary>
/// <param name="details">Data source</param>
/// <param name="timerExpire">Timer expiration</param>
/// <param name="timeout">Timer expiration</param>
/// <param name="parallel">Should MultiSubnetFailover be used</param>
/// <param name="ipPreference">IP address preference</param>
/// <param name="cachedFQDN">Key for DNS Cache</param>
Expand All @@ -290,7 +291,7 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
/// <returns>SNITCPHandle</returns>
private static SNITCPHandle CreateTcpHandle(
DataSource details,
long timerExpire,
TimeoutTimer timeout,
bool parallel,
SqlConnectionIPAddressPreference ipPreference,
string cachedFQDN,
Expand All @@ -317,8 +318,8 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
try
{
port = isAdminConnection ?
SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timerExpire, parallel, ipPreference) :
SSRP.GetPortByInstanceName(hostName, details.InstanceName, timerExpire, parallel, ipPreference);
SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference) :
SSRP.GetPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference);
}
catch (SocketException se)
{
Expand All @@ -335,27 +336,27 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
port = isAdminConnection ? DefaultSqlServerDacPort : DefaultSqlServerPort;
}

return new SNITCPHandle(hostName, port, timerExpire, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo,
return new SNITCPHandle(hostName, port, timeout, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo,
tlsFirst, hostNameInCertificate, serverCertificateFilename);
}

/// <summary>
/// Creates an SNINpHandle object
/// </summary>
/// <param name="details">Data source</param>
/// <param name="timerExpire">Timer expiration</param>
/// <param name="timeout">Timer expiration</param>
/// <param name="parallel">Should MultiSubnetFailover be used. Only returns an error for named pipes.</param>
/// <param name="tlsFirst"></param>
/// <returns>SNINpHandle</returns>
private static SNINpHandle CreateNpHandle(DataSource details, long timerExpire, bool parallel, bool tlsFirst)
private static SNINpHandle CreateNpHandle(DataSource details, TimeoutTimer timeout, bool parallel, bool tlsFirst)
{
if (parallel)
{
// Connecting to a SQL Server instance using the MultiSubnetFailover connection option is only supported when using the TCP protocol
SNICommon.ReportSNIError(SNIProviders.NP_PROV, 0, SNICommon.MultiSubnetFailoverWithNonTcpProtocol, Strings.SNI_ERROR_49);
return null;
}
return new SNINpHandle(details.PipeHostName, details.PipeName, timerExpire, tlsFirst);
return new SNINpHandle(details.PipeHostName, details.PipeName, timeout, tlsFirst);
}

/// <summary>
Expand Down