diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs index 54058a9218..00057e3486 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs @@ -454,7 +454,18 @@ public static bool IsAADAuthorityURLSetup() public static bool IsNotAzureServer() { - return !AreConnStringsSetup() || !Utils.IsAzureSqlServer(new SqlConnectionStringBuilder((TCPConnectionString)).DataSource); + return !AreConnStringsSetup() || !Utils.IsAzureSqlServer(new SqlConnectionStringBuilder(TCPConnectionString).DataSource); + } + + public static bool IsNotNamedInstance() + { + return !AreConnStringsSetup() || !new SqlConnectionStringBuilder(TCPConnectionString).DataSource.Contains(@"\"); + } + + public static bool IsLocalHost() + { + SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString); + return ParseDataSource(builder.DataSource, out string hostname, out _, out _) && string.Equals("localhost", hostname, StringComparison.OrdinalIgnoreCase); } // Synapse: Always Encrypted is not supported with Azure Synapse. diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SplitPacketTest/SplitPacketTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SplitPacketTest/SplitPacketTest.cs index 1147e704ca..4cfd148d72 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SplitPacketTest/SplitPacketTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SplitPacketTest/SplitPacketTest.cs @@ -11,79 +11,83 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests { - [ActiveIssue("5538")] // Only testable on localhost - public class SplitPacketTest + public class SplitPacketTest : IDisposable { - private int Port = -1; - private int SplitPacketSize = 1; - private string BaseConnString; + private int _port = -1; + private int _splitPacketSize = 1; + private string _baseConnString; + private TcpListener _listener; + private CancellationTokenSource _cts = new CancellationTokenSource(); public SplitPacketTest() { - string actualHost; - int actualPort; - SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(DataTestUtility.TCPConnectionString); - GetTcpInfoFromDataSource(builder.DataSource, out actualHost, out actualPort); + DataSourceBuilder dataSourceBuilder = new DataSourceBuilder(builder.DataSource); - Task.Factory.StartNew(() => { SetupProxy(actualHost, actualPort); }); + Task.Factory.StartNew(() => { SetupProxy(dataSourceBuilder.ServerName, dataSourceBuilder.Port ?? 1433, _cts.Token); }); - for (int i = 0; i < 10 && Port == -1; i++) + for (int i = 0; i < 10 && _port == -1; i++) { Thread.Sleep(500); } - if (Port == -1) + if (_port == -1) throw new InvalidOperationException("Proxy local port not defined!"); - builder.DataSource = "tcp:127.0.0.1," + Port; - BaseConnString = builder.ConnectionString; + builder.DataSource = "tcp:127.0.0.1," + _port; + _baseConnString = builder.ConnectionString; } - [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsTCPConnStringSetup), nameof(DataTestUtility.IsLocalHost), nameof(DataTestUtility.IsNotNamedInstance))] public void OneByteSplitTest() { - SplitPacketSize = 1; + _splitPacketSize = 1; OpenConnection(); + Assert.True(true); } - [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsTCPConnStringSetup), nameof(DataTestUtility.IsLocalHost), nameof(DataTestUtility.IsNotNamedInstance))] public void AlmostFullHeaderTest() { - SplitPacketSize = 7; + _splitPacketSize = 7; OpenConnection(); + Assert.True(true); } - [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsTCPConnStringSetup), nameof(DataTestUtility.IsLocalHost), nameof(DataTestUtility.IsNotNamedInstance))] public void FullHeaderTest() { - SplitPacketSize = 8; + _splitPacketSize = 8; OpenConnection(); + Assert.True(true); } - [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsTCPConnStringSetup), nameof(DataTestUtility.IsLocalHost), nameof(DataTestUtility.IsNotNamedInstance))] public void HeaderPlusOneTest() { - SplitPacketSize = 9; + _splitPacketSize = 9; OpenConnection(); + Assert.True(true); } - [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsTCPConnStringSetup), nameof(DataTestUtility.IsLocalHost), nameof(DataTestUtility.IsNotNamedInstance))] public void MARSSplitTest() { - SplitPacketSize = 1; + _splitPacketSize = 1; OpenMarsConnection("select * from Orders"); + Assert.True(true); } - [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsTCPConnStringSetup), nameof(DataTestUtility.IsLocalHost), nameof(DataTestUtility.IsNotNamedInstance))] public void MARSReplicateTest() { - SplitPacketSize = 1; + _splitPacketSize = 1; OpenMarsConnection("select REPLICATE('A', 10000)"); + Assert.True(true); } private void OpenMarsConnection(string cmdText) { - using (SqlConnection conn = new SqlConnection((new SqlConnectionStringBuilder(BaseConnString) { MultipleActiveResultSets = true }).ConnectionString)) + using (SqlConnection conn = new SqlConnection((new SqlConnectionStringBuilder(_baseConnString) { MultipleActiveResultSets = true }).ConnectionString)) { conn.Open(); using (SqlCommand cmd1 = new SqlCommand(cmdText, conn)) @@ -102,7 +106,7 @@ private void OpenMarsConnection(string cmdText) private void OpenConnection() { - using (SqlConnection conn = new SqlConnection(BaseConnString)) + using (SqlConnection conn = new SqlConnection(_baseConnString)) { conn.Open(); using (SqlCommand cmd = new SqlCommand("select * from Orders", conn)) @@ -114,23 +118,23 @@ private void OpenConnection() } } - private void SetupProxy(string actualHost, int actualPort) + private void SetupProxy(string actualHost, int actualPort, CancellationToken cancellationToken) { - TcpListener listener = new TcpListener(IPAddress.Loopback, 0); - listener.Start(); - Port = ((IPEndPoint)listener.LocalEndpoint).Port; - var client = listener.AcceptTcpClientAsync().GetAwaiter().GetResult(); + _listener = new TcpListener(IPAddress.Loopback, 0); + _listener.Start(); + _port = ((IPEndPoint)_listener.LocalEndpoint).Port; + var client = _listener.AcceptTcpClientAsync().GetAwaiter().GetResult(); var sqlClient = new TcpClient(); - sqlClient.ConnectAsync(actualHost, actualPort).Wait(); + sqlClient.ConnectAsync(actualHost, actualPort).Wait(cancellationToken); - Task.Factory.StartNew(() => { ForwardToSql(client, sqlClient); }); - Task.Factory.StartNew(() => { ForwardToClient(client, sqlClient); }); + Task.Factory.StartNew(() => { ForwardToSql(client, sqlClient, cancellationToken); }, cancellationToken); + Task.Factory.StartNew(() => { ForwardToClient(client, sqlClient, cancellationToken); }, cancellationToken); } - private void ForwardToSql(TcpClient ourClient, TcpClient sqlClient) + private void ForwardToSql(TcpClient ourClient, TcpClient sqlClient, CancellationToken cancellationToken) { - while (true) + while (!cancellationToken.IsCancellationRequested) { byte[] buffer = new byte[1024]; int bytesRead = ourClient.GetStream().Read(buffer, 0, buffer.Length); @@ -139,11 +143,11 @@ private void ForwardToSql(TcpClient ourClient, TcpClient sqlClient) } } - private void ForwardToClient(TcpClient ourClient, TcpClient sqlClient) + private void ForwardToClient(TcpClient ourClient, TcpClient sqlClient, CancellationToken cancellationToken) { - while (true) + while (!cancellationToken.IsCancellationRequested) { - byte[] buffer = new byte[SplitPacketSize]; + byte[] buffer = new byte[_splitPacketSize]; int bytesRead = sqlClient.GetStream().Read(buffer, 0, buffer.Length); ourClient.GetStream().Write(buffer, 0, bytesRead); @@ -155,22 +159,24 @@ private void ForwardToClient(TcpClient ourClient, TcpClient sqlClient) } } - private static void GetTcpInfoFromDataSource(string dataSource, out string hostName, out int port) + public void Dispose() { - string[] dataSourceParts = dataSource.Split(','); - if (dataSourceParts.Length == 1) - { - hostName = dataSourceParts[0].Replace("tcp:", ""); - port = 1433; - } - else if (dataSourceParts.Length == 2) - { - hostName = dataSourceParts[0].Replace("tcp:", ""); - port = int.Parse(dataSourceParts[1]); - } - else + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposing) { - throw new InvalidOperationException("TCP Connection String not in correct format!"); + _cts.Cancel(); + _cts.Dispose(); + _listener?.Server.Dispose(); +#if NETFRAMEWORK + _listener?.Stop(); +#else + _listener?.Dispose(); +#endif } } }