Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests | Activate SplitPacket Tests #3061

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,18 @@ public static bool IsAADAuthorityURLSetup()

public static bool IsNotAzureServer()
{
return !AreConnStringsSetup() || !Utils.IsAzureSqlServer(new SqlConnectionStringBuilder((TCPConnectionString)).DataSource);
return !AreConnStringsSetup() || !Utils.IsAzureSqlServer(new SqlConnectionStringBuilder(TCPConnectionString).DataSource);
}

public static bool IsNotNamedInstance()
{
return !AreConnStringsSetup() || !new SqlConnectionStringBuilder(TCPConnectionString).DataSource.Contains(@"\");
}

public static bool IsLocalHost()
{
SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString);
return ParseDataSource(builder.DataSource, out string hostname, out _, out _) && string.Equals("localhost", hostname, StringComparison.OrdinalIgnoreCase);
}

// Synapse: Always Encrypted is not supported with Azure Synapse.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,79 +11,83 @@

namespace Microsoft.Data.SqlClient.ManualTesting.Tests
{
[ActiveIssue("5538")] // Only testable on localhost
public class SplitPacketTest
public class SplitPacketTest : IDisposable
{
private int Port = -1;
private int SplitPacketSize = 1;
private string BaseConnString;
private int _port = -1;
private int _splitPacketSize = 1;
private string _baseConnString;
private TcpListener _listener;
private CancellationTokenSource _cts = new CancellationTokenSource();

public SplitPacketTest()
{
string actualHost;
int actualPort;

SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(DataTestUtility.TCPConnectionString);
GetTcpInfoFromDataSource(builder.DataSource, out actualHost, out actualPort);
DataSourceBuilder dataSourceBuilder = new DataSourceBuilder(builder.DataSource);

Task.Factory.StartNew(() => { SetupProxy(actualHost, actualPort); });
Task.Factory.StartNew(() => { SetupProxy(dataSourceBuilder.ServerName, dataSourceBuilder.Port ?? 1433, _cts.Token); });

for (int i = 0; i < 10 && Port == -1; i++)
for (int i = 0; i < 10 && _port == -1; i++)
{
Thread.Sleep(500);
}
if (Port == -1)
if (_port == -1)
throw new InvalidOperationException("Proxy local port not defined!");

builder.DataSource = "tcp:127.0.0.1," + Port;
BaseConnString = builder.ConnectionString;
builder.DataSource = "tcp:127.0.0.1," + _port;
_baseConnString = builder.ConnectionString;
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsTCPConnStringSetup), nameof(DataTestUtility.IsLocalHost), nameof(DataTestUtility.IsNotNamedInstance))]
public void OneByteSplitTest()
{
SplitPacketSize = 1;
_splitPacketSize = 1;
OpenConnection();
Assert.True(true);
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsTCPConnStringSetup), nameof(DataTestUtility.IsLocalHost), nameof(DataTestUtility.IsNotNamedInstance))]
public void AlmostFullHeaderTest()
{
SplitPacketSize = 7;
_splitPacketSize = 7;
OpenConnection();
Assert.True(true);
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsTCPConnStringSetup), nameof(DataTestUtility.IsLocalHost), nameof(DataTestUtility.IsNotNamedInstance))]
public void FullHeaderTest()
{
SplitPacketSize = 8;
_splitPacketSize = 8;
OpenConnection();
Assert.True(true);
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsTCPConnStringSetup), nameof(DataTestUtility.IsLocalHost), nameof(DataTestUtility.IsNotNamedInstance))]
public void HeaderPlusOneTest()
{
SplitPacketSize = 9;
_splitPacketSize = 9;
OpenConnection();
Assert.True(true);
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsTCPConnStringSetup), nameof(DataTestUtility.IsLocalHost), nameof(DataTestUtility.IsNotNamedInstance))]
public void MARSSplitTest()
{
SplitPacketSize = 1;
_splitPacketSize = 1;
OpenMarsConnection("select * from Orders");
Assert.True(true);
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsTCPConnStringSetup), nameof(DataTestUtility.IsLocalHost), nameof(DataTestUtility.IsNotNamedInstance))]
public void MARSReplicateTest()
{
SplitPacketSize = 1;
_splitPacketSize = 1;
OpenMarsConnection("select REPLICATE('A', 10000)");
Assert.True(true);
}

private void OpenMarsConnection(string cmdText)
{
using (SqlConnection conn = new SqlConnection((new SqlConnectionStringBuilder(BaseConnString) { MultipleActiveResultSets = true }).ConnectionString))
using (SqlConnection conn = new SqlConnection((new SqlConnectionStringBuilder(_baseConnString) { MultipleActiveResultSets = true }).ConnectionString))
{
conn.Open();
using (SqlCommand cmd1 = new SqlCommand(cmdText, conn))
Expand All @@ -102,7 +106,7 @@ private void OpenMarsConnection(string cmdText)

private void OpenConnection()
{
using (SqlConnection conn = new SqlConnection(BaseConnString))
using (SqlConnection conn = new SqlConnection(_baseConnString))
{
conn.Open();
using (SqlCommand cmd = new SqlCommand("select * from Orders", conn))
Expand All @@ -114,23 +118,23 @@ private void OpenConnection()
}
}

private void SetupProxy(string actualHost, int actualPort)
private void SetupProxy(string actualHost, int actualPort, CancellationToken cancellationToken)
{
TcpListener listener = new TcpListener(IPAddress.Loopback, 0);
listener.Start();
Port = ((IPEndPoint)listener.LocalEndpoint).Port;
var client = listener.AcceptTcpClientAsync().GetAwaiter().GetResult();
_listener = new TcpListener(IPAddress.Loopback, 0);
_listener.Start();
_port = ((IPEndPoint)_listener.LocalEndpoint).Port;
var client = _listener.AcceptTcpClientAsync().GetAwaiter().GetResult();

var sqlClient = new TcpClient();
sqlClient.ConnectAsync(actualHost, actualPort).Wait();
sqlClient.ConnectAsync(actualHost, actualPort).Wait(cancellationToken);

Task.Factory.StartNew(() => { ForwardToSql(client, sqlClient); });
Task.Factory.StartNew(() => { ForwardToClient(client, sqlClient); });
Task.Factory.StartNew(() => { ForwardToSql(client, sqlClient, cancellationToken); }, cancellationToken);
Task.Factory.StartNew(() => { ForwardToClient(client, sqlClient, cancellationToken); }, cancellationToken);
}

private void ForwardToSql(TcpClient ourClient, TcpClient sqlClient)
private void ForwardToSql(TcpClient ourClient, TcpClient sqlClient, CancellationToken cancellationToken)
{
while (true)
while (!cancellationToken.IsCancellationRequested)
{
byte[] buffer = new byte[1024];
int bytesRead = ourClient.GetStream().Read(buffer, 0, buffer.Length);
Expand All @@ -139,11 +143,11 @@ private void ForwardToSql(TcpClient ourClient, TcpClient sqlClient)
}
}

private void ForwardToClient(TcpClient ourClient, TcpClient sqlClient)
private void ForwardToClient(TcpClient ourClient, TcpClient sqlClient, CancellationToken cancellationToken)
{
while (true)
while (!cancellationToken.IsCancellationRequested)
{
byte[] buffer = new byte[SplitPacketSize];
byte[] buffer = new byte[_splitPacketSize];
int bytesRead = sqlClient.GetStream().Read(buffer, 0, buffer.Length);

ourClient.GetStream().Write(buffer, 0, bytesRead);
Expand All @@ -155,22 +159,24 @@ private void ForwardToClient(TcpClient ourClient, TcpClient sqlClient)
}
}

private static void GetTcpInfoFromDataSource(string dataSource, out string hostName, out int port)
public void Dispose()
{
string[] dataSourceParts = dataSource.Split(',');
if (dataSourceParts.Length == 1)
{
hostName = dataSourceParts[0].Replace("tcp:", "");
port = 1433;
}
else if (dataSourceParts.Length == 2)
{
hostName = dataSourceParts[0].Replace("tcp:", "");
port = int.Parse(dataSourceParts[1]);
}
else
Dispose(true);
GC.SuppressFinalize(this);
}

protected virtual void Dispose(bool disposing)
{
if (disposing)
{
throw new InvalidOperationException("TCP Connection String not in correct format!");
_cts.Cancel();
_cts.Dispose();
_listener?.Server.Dispose();
#if NETFRAMEWORK
_listener?.Stop();
#else
_listener?.Dispose();
#endif
}
}
}
Expand Down
Loading