|
| 1 | +package msdsn |
| 2 | + |
| 3 | +import ( |
| 4 | + "crypto/tls" |
| 5 | + "encoding/hex" |
| 6 | + "io" |
| 7 | + "os" |
| 8 | + "reflect" |
| 9 | + "testing" |
| 10 | + "time" |
| 11 | + |
| 12 | + "github.com/stretchr/testify/assert" |
| 13 | +) |
| 14 | + |
| 15 | +func TestInvalidConnectionString(t *testing.T) { |
| 16 | + connStrings := []string{ |
| 17 | + "log=invalid", |
| 18 | + "port=invalid", |
| 19 | + "packet size=invalid", |
| 20 | + "connection timeout=invalid", |
| 21 | + "dial timeout=invalid", |
| 22 | + "keepalive=invalid", |
| 23 | + "encrypt=invalid", |
| 24 | + "trustservercertificate=invalid", |
| 25 | + "failoverport=invalid", |
| 26 | + "applicationintent=ReadOnly", |
| 27 | + "disableretry=invalid", |
| 28 | + "multisubnetfailover=invalid", |
| 29 | + |
| 30 | + // ODBC mode |
| 31 | + "odbc:password={", |
| 32 | + "odbc:password={somepass", |
| 33 | + "odbc:password={somepass}}", |
| 34 | + "odbc:password={some}pass", |
| 35 | + "odbc:=", // unexpected = |
| 36 | + "odbc: =", |
| 37 | + "odbc:password={some} a", |
| 38 | + |
| 39 | + // URL mode |
| 40 | + "sqlserver://\x00", |
| 41 | + "sqlserver://host?key=value1&key=value2", // duplicate keys |
| 42 | + } |
| 43 | + for _, connStr := range connStrings { |
| 44 | + _, err := Parse(connStr) |
| 45 | + if err == nil { |
| 46 | + t.Errorf("Connection expected to fail for connection string %s but it didn't", connStr) |
| 47 | + continue |
| 48 | + } else { |
| 49 | + t.Logf("Connection failed for %s as expected with error %v", connStr, err) |
| 50 | + } |
| 51 | + } |
| 52 | +} |
| 53 | + |
| 54 | +func TestValidConnectionString(t *testing.T) { |
| 55 | + type testStruct struct { |
| 56 | + connStr string |
| 57 | + check func(Config) bool |
| 58 | + } |
| 59 | + connStrings := []testStruct{ |
| 60 | + {"server=server\\instance;database=testdb;user id=tester;password=pwd", func(p Config) bool { |
| 61 | + return p.Host == "server" && p.Instance == "instance" && p.User == "tester" && p.Password == "pwd" |
| 62 | + }}, |
| 63 | + {"server=.", func(p Config) bool { return p.Host == "localhost" && !p.ColumnEncryption }}, |
| 64 | + {"server=(local)", func(p Config) bool { return p.Host == "localhost" }}, |
| 65 | + {"ServerSPN=serverspn;Workstation ID=workstid", func(p Config) bool { return p.ServerSPN == "serverspn" && p.Workstation == "workstid" }}, |
| 66 | + {"failoverpartner=fopartner;failoverport=2000", func(p Config) bool { return p.FailOverPartner == "fopartner" && p.FailOverPort == 2000 }}, |
| 67 | + {"app name=appname;applicationintent=ReadOnly;database=testdb", func(p Config) bool { return p.AppName == "appname" && p.ReadOnlyIntent }}, |
| 68 | + {"encrypt=disable", func(p Config) bool { return p.Encryption == EncryptionDisabled }}, |
| 69 | + {"encrypt=disable;tlsmin=1.1", func(p Config) bool { return p.Encryption == EncryptionDisabled && p.TLSConfig == nil }}, |
| 70 | + {"encrypt=true", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 }}, |
| 71 | + {"encrypt=mandatory", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 }}, |
| 72 | + {"encrypt=true;tlsmin=1.0", func(p Config) bool { |
| 73 | + return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS10 |
| 74 | + }}, |
| 75 | + {"encrypt=false;tlsmin=1.0", func(p Config) bool { |
| 76 | + return p.Encryption == EncryptionOff && p.TLSConfig.MinVersion == tls.VersionTLS10 |
| 77 | + }}, |
| 78 | + {"encrypt=true;tlsmin=1.1;column encryption setting=enabled", func(p Config) bool { |
| 79 | + return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption |
| 80 | + }}, |
| 81 | + {"encrypt=true;tlsmin=1.2", func(p Config) bool { |
| 82 | + return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS12 |
| 83 | + }}, |
| 84 | + {"encrypt=true;tlsmin=1.3", func(p Config) bool { |
| 85 | + return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS13 |
| 86 | + }}, |
| 87 | + {"encrypt=true;tlsmin=1.4", func(p Config) bool { |
| 88 | + return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 |
| 89 | + }}, |
| 90 | + {"encrypt=false", func(p Config) bool { return p.Encryption == EncryptionOff }}, |
| 91 | + {"encrypt=optional", func(p Config) bool { return p.Encryption == EncryptionOff }}, |
| 92 | + {"encrypt=strict", func(p Config) bool { return p.Encryption == EncryptionStrict }}, |
| 93 | + {"connection timeout=3;dial timeout=4;keepalive=5", func(p Config) bool { |
| 94 | + return p.ConnTimeout == 3*time.Second && p.DialTimeout == 4*time.Second && p.KeepAlive == 5*time.Second |
| 95 | + }}, |
| 96 | + {"log=63", func(p Config) bool { return p.LogFlags == 63 && p.Port == 0 }}, |
| 97 | + {"log=63;port=1000", func(p Config) bool { return p.LogFlags == 63 && p.Port == 1000 }}, |
| 98 | + {"log=64", func(p Config) bool { return p.LogFlags == 64 }}, |
| 99 | + {"log=64;packet size=0", func(p Config) bool { return p.LogFlags == 64 && p.PacketSize == 512 }}, |
| 100 | + {"log=64;packet size=300", func(p Config) bool { return p.LogFlags == 64 && p.PacketSize == 512 }}, |
| 101 | + {"log=64;packet size=8192", func(p Config) bool { return p.LogFlags == 64 && p.PacketSize == 8192 }}, |
| 102 | + {"log=64;packet size=48000", func(p Config) bool { return p.LogFlags == 64 && p.PacketSize == 32767 }}, |
| 103 | + {"disableretry=true", func(p Config) bool { return p.DisableRetry }}, |
| 104 | + {"disableretry=false", func(p Config) bool { return !p.DisableRetry }}, |
| 105 | + {"disableretry=1", func(p Config) bool { return p.DisableRetry }}, |
| 106 | + {"disableretry=0", func(p Config) bool { return !p.DisableRetry }}, |
| 107 | + {"", func(p Config) bool { return p.DisableRetry == disableRetryDefault }}, |
| 108 | + {"MultiSubnetFailover=true;NoTraceID=true", func(p Config) bool { return p.MultiSubnetFailover && p.NoTraceID }}, |
| 109 | + {"MultiSubnetFailover=false", func(p Config) bool { return !p.MultiSubnetFailover }}, |
| 110 | + // those are supported currently, but maybe should not be |
| 111 | + {"someparam", func(p Config) bool { return true }}, |
| 112 | + {";;=;", func(p Config) bool { return true }}, |
| 113 | + |
| 114 | + // ODBC mode |
| 115 | + {"odbc:server=somehost;user id=someuser;password=somepass", func(p Config) bool { |
| 116 | + return p.Host == "somehost" && p.User == "someuser" && p.Password == "somepass" |
| 117 | + }}, |
| 118 | + {"odbc:server=somehost;user id=someuser;password=some{pass", func(p Config) bool { |
| 119 | + return p.Host == "somehost" && p.User == "someuser" && p.Password == "some{pass" |
| 120 | + }}, |
| 121 | + {"odbc:server={somehost};user id={someuser};password={somepass}", func(p Config) bool { |
| 122 | + return p.Host == "somehost" && p.User == "someuser" && p.Password == "somepass" |
| 123 | + }}, |
| 124 | + {"odbc:server={somehost};user id={someuser};password={some=pass}", func(p Config) bool { |
| 125 | + return p.Host == "somehost" && p.User == "someuser" && p.Password == "some=pass" |
| 126 | + }}, |
| 127 | + {"odbc:server={somehost};user id={someuser};password={some;pass}", func(p Config) bool { |
| 128 | + return p.Host == "somehost" && p.User == "someuser" && p.Password == "some;pass" |
| 129 | + }}, |
| 130 | + {"odbc:server={somehost};user id={someuser};password={some{pass}", func(p Config) bool { |
| 131 | + return p.Host == "somehost" && p.User == "someuser" && p.Password == "some{pass" |
| 132 | + }}, |
| 133 | + {"odbc:server={somehost};user id={someuser};password={some}}pass}", func(p Config) bool { |
| 134 | + return p.Host == "somehost" && p.User == "someuser" && p.Password == "some}pass" |
| 135 | + }}, |
| 136 | + {"odbc:server={somehost};user id={someuser};password={some{}}p=a;ss}", func(p Config) bool { |
| 137 | + return p.Host == "somehost" && p.User == "someuser" && p.Password == "some{}p=a;ss" |
| 138 | + }}, |
| 139 | + {"odbc: server = somehost; user id = someuser ; password = {some pass } ;", func(p Config) bool { |
| 140 | + return p.Host == "somehost" && p.User == "someuser" && p.Password == "some pass " |
| 141 | + }}, |
| 142 | + {"odbc:password", func(p Config) bool { |
| 143 | + return p.Password == "" |
| 144 | + }}, |
| 145 | + {"odbc:", func(p Config) bool { |
| 146 | + return true |
| 147 | + }}, |
| 148 | + {"odbc:password=", func(p Config) bool { |
| 149 | + return p.Password == "" |
| 150 | + }}, |
| 151 | + {"odbc:password;", func(p Config) bool { |
| 152 | + return p.Password == "" |
| 153 | + }}, |
| 154 | + {"odbc:password=;", func(p Config) bool { |
| 155 | + return p.Password == "" |
| 156 | + }}, |
| 157 | + {"odbc:password={value} ", func(p Config) bool { |
| 158 | + return p.Password == "value" |
| 159 | + }}, |
| 160 | + {"odbc:server=somehost;user id=someuser;password=somepass;disableretry=true", func(p Config) bool { |
| 161 | + return p.Host == "somehost" && p.User == "someuser" && p.Password == "somepass" && p.DisableRetry |
| 162 | + }}, |
| 163 | + {"odbc:server=somehost;user id=someuser;password=somepass; disableretry = 1 ", func(p Config) bool { |
| 164 | + return p.Host == "somehost" && p.User == "someuser" && p.Password == "somepass" && p.DisableRetry |
| 165 | + }}, |
| 166 | + |
| 167 | + // URL mode |
| 168 | + {"sqlserver://somehost?connection+timeout=30", func(p Config) bool { |
| 169 | + return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.ConnTimeout == 30*time.Second |
| 170 | + }}, |
| 171 | + {"sqlserver://someuser@somehost?connection+timeout=30", func(p Config) bool { |
| 172 | + return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second |
| 173 | + }}, |
| 174 | + {"sqlserver://someuser:@somehost?connection+timeout=30", func(p Config) bool { |
| 175 | + return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second |
| 176 | + }}, |
| 177 | + {"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost?connection+timeout=30", func(p Config) bool { |
| 178 | + return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second |
| 179 | + }}, |
| 180 | + {"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost:1434?connection+timeout=30", func(p Config) bool { |
| 181 | + return p.Host == "somehost" && p.Port == 1434 && p.Instance == "" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second |
| 182 | + }}, |
| 183 | + {"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost:1434/someinstance?connection+timeout=30", func(p Config) bool { |
| 184 | + return p.Host == "somehost" && p.Port == 1434 && p.Instance == "someinstance" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second |
| 185 | + }}, |
| 186 | + {"sqlserver://someuser@somehost?disableretry=true", func(p Config) bool { |
| 187 | + return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.DisableRetry |
| 188 | + }}, |
| 189 | + {"sqlserver://someuser@somehost?connection+timeout=30&disableretry=1", func(p Config) bool { |
| 190 | + return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second && p.DisableRetry && !p.ColumnEncryption |
| 191 | + }}, |
| 192 | + {"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1", func(p Config) bool { |
| 193 | + return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption && !p.Encoding.GuidConversion |
| 194 | + }}, |
| 195 | + {"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1&guid+conversion=true", func(p Config) bool { |
| 196 | + return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption && p.Encoding.GuidConversion |
| 197 | + }}, |
| 198 | + } |
| 199 | + for _, ts := range connStrings { |
| 200 | + p, err := Parse(ts.connStr) |
| 201 | + if err == nil { |
| 202 | + t.Logf("Connection string was parsed successfully %s", ts.connStr) |
| 203 | + } else { |
| 204 | + t.Errorf("Connection string %s failed to parse with error %s", ts.connStr, err) |
| 205 | + continue |
| 206 | + } |
| 207 | + |
| 208 | + if !ts.check(p) { |
| 209 | + t.Errorf("Check failed on conn str %s", ts.connStr) |
| 210 | + } |
| 211 | + } |
| 212 | +} |
| 213 | + |
| 214 | +func TestSplitConnectionStringURL(t *testing.T) { |
| 215 | + _, err := splitConnectionStringURL("http://bad") |
| 216 | + if err == nil { |
| 217 | + t.Error("Connection string with invalid scheme should fail to parse but it didn't") |
| 218 | + } |
| 219 | +} |
| 220 | + |
| 221 | +func TestConnParseRoundTripFixed(t *testing.T) { |
| 222 | + connStr := "sqlserver://sa:sa@localhost/sqlexpress?database=master&log=127&disableretry=true&dial+timeout=30" |
| 223 | + params, err := Parse(connStr) |
| 224 | + if err != nil { |
| 225 | + t.Fatal("Test URL is not valid", err) |
| 226 | + } |
| 227 | + rtParams, err := Parse(params.URL().String()) |
| 228 | + if err != nil { |
| 229 | + t.Fatal("Params after roundtrip are not valid", err) |
| 230 | + } |
| 231 | + t.Log("params.URL " + params.URL().String()) |
| 232 | + params.ActivityID = nil |
| 233 | + rtParams.ActivityID = nil |
| 234 | + if !reflect.DeepEqual(params, rtParams) { |
| 235 | + t.Fatal("Parameters do not match after roundtrip", params, rtParams) |
| 236 | + } |
| 237 | +} |
| 238 | + |
| 239 | +func TestServerNameInTLSConfig(t *testing.T) { |
| 240 | + var tests = []struct { |
| 241 | + dsn string |
| 242 | + host string |
| 243 | + hasTLSConfig bool |
| 244 | + }{ |
| 245 | + {"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=true", "somehost", true}, |
| 246 | + {"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=false", "somehost", true}, |
| 247 | + {"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=true&hostnameincertificate=someotherhost", "someotherhost", true}, |
| 248 | + {"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false", "somehost", true}, |
| 249 | + {"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=DISABLE", "", false}, |
| 250 | + {"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=DISABLE&hostnameincertificate=someotherhost", "", false}, |
| 251 | + {"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=false", "somehost", true}, |
| 252 | + } |
| 253 | + for _, test := range tests { |
| 254 | + cfg, err := Parse(test.dsn) |
| 255 | + if err != nil { |
| 256 | + t.Errorf("Could not parse valid connection string %s: %v", test.dsn, err) |
| 257 | + } |
| 258 | + if !test.hasTLSConfig && cfg.TLSConfig != nil { |
| 259 | + t.Errorf("Expected empty TLS config, but got %v (cfg.Host was %s)", cfg.TLSConfig, cfg.Host) |
| 260 | + } |
| 261 | + if test.hasTLSConfig && cfg.TLSConfig.ServerName != test.host { |
| 262 | + t.Errorf("Expected somehost as TLS server, but got %s (cfg.Host was %s)", cfg.TLSConfig.ServerName, cfg.Host) |
| 263 | + } |
| 264 | + } |
| 265 | +} |
| 266 | +func TestAllKeysAreAvailableInParametersMap(t *testing.T) { |
| 267 | + keys := map[string]string{ |
| 268 | + "user id": "1", |
| 269 | + "testparam": "testvalue", |
| 270 | + "password": "test", |
| 271 | + "thisisanunknownkey": "thisisthevalue", |
| 272 | + "server": "name", |
| 273 | + } |
| 274 | + |
| 275 | + connString := "" |
| 276 | + for key, val := range keys { |
| 277 | + connString += key + "=" + val + ";" |
| 278 | + } |
| 279 | + |
| 280 | + params, err := Parse(connString) |
| 281 | + if err != nil { |
| 282 | + t.Errorf("unexpected error while parsing, %v", err) |
| 283 | + } |
| 284 | + |
| 285 | + if params.Parameters == nil { |
| 286 | + t.Error("Expected parameters map to be instanciated, found nil") |
| 287 | + return |
| 288 | + } |
| 289 | + |
| 290 | + if len(params.Parameters) != len(keys) { |
| 291 | + t.Errorf("Expected parameters map to be same length as input map length, expected %v, found %v", len(keys), len(params.Parameters)) |
| 292 | + return |
| 293 | + } |
| 294 | + |
| 295 | + for key, val := range keys { |
| 296 | + if params.Parameters[key] != val { |
| 297 | + t.Errorf("Expected parameters map to contain key %v and value %v, found %v", key, val, params.Parameters[key]) |
| 298 | + } |
| 299 | + } |
| 300 | +} |
| 301 | + |
| 302 | +func TestReadCertificate(t *testing.T) { |
| 303 | + |
| 304 | + //Setup dummy certificate |
| 305 | + hexCertificate := "3082031830820200a00302010202103608db21691eccba415f8624d34b66fe300d06092a864886f70d01010b050030143112301006035504030c096c6f63616c686f7374301e170d3233303830383133343233375a170d3234303830383134303233375a30143112301006035504030c096c6f63616c686f737430820122300d06092a864886f70d01010105000382010f003082010a0282010100e18cd4d2923c548ac6e4fd731de116716a09fd2447feb28213810a1b508c22c108928f61531d31439b7252808d6bc6a71d50e5bb00596bbc1633d65389b80bb36f22d1546cbff570881331285cb458b3a2ad1ad0fa83081bd000f2793d29460a6adc0128a2d979d34f5cd91d60d4fef5932f393e04fcb3730a33693f3c44b882384c529f7489e58e296b0c17ca391b02f2488c38f8fc3c3afa0c1be0d22329287f93cf57ee46836a12f74de82eb54b18a5ae0134266db52633c0e33177f8ac4532045f053ddc920f0659cafa84c54c2b3cc92f4010c8af93ae0fc92e461d47c0cf2da46421189b2ddcf2f6ae17cb5ef6f1eda94452af6f714d583dcb7bcd43e90203010001a3663064300e0603551d0f0101ff0404030205a0301d0603551d250416301406082b0601050507030206082b0601050507030130140603551d11040d300b82096c6f63616c686f7374301d0603551d0e0416041443e3d9f187e9474d73794c641d54ecb810342ec6300d06092a864886f70d01010b05000382010100a227e721ac80838e66ef75d8ba080185dd8f4a5c84d7373e8ed50534100a490b577e3c1af593597303bdad8bb900e32b5d6f69941c19cc87fd426f9e4a4134f34f2ade02748d64031bc4e9c7617206a45c1d9556bb0488994cd27126adb029216f7c57852c1663983b7be638f1bc5411ba2221ce3fde29bf4818e36bec8ac25e9a37bfc41c5a3812829a6358a66c467818448346be140639957077b924b22567b75c7dab4d9d6794b4d79596d17446641684cbd193ec20a6faa85fb6b72f5f30dc57e8cd662b22152429e5b43ccb450c6840ba006e1c8e38b002aa97d8dd07e100ef76eebd9c523d8710636f060865e6198da620fedbf1ae6ed75df997641621" |
| 306 | + derfile, _ := os.CreateTemp("", "*.der") |
| 307 | + defer os.Remove(derfile.Name()) |
| 308 | + certInBytes, _ := hex.DecodeString(hexCertificate) |
| 309 | + _, _ = derfile.Write(certInBytes) |
| 310 | + |
| 311 | + // Test with a valid certificate |
| 312 | + cert, err := readCertificate(derfile.Name()) |
| 313 | + assert.Nil(t, err, "Expected no error while reading certificate, found %v", err) |
| 314 | + assert.NotNil(t, cert, "Expected certificate to be read, found nil") |
| 315 | + |
| 316 | + pemfile, _ := os.CreateTemp("", "*.pem") |
| 317 | + _, _ = io.Copy(derfile, pemfile) |
| 318 | + defer os.Remove(pemfile.Name()) |
| 319 | + cert, err = readCertificate(pemfile.Name()) |
| 320 | + assert.Nil(t, err, "Expected no error while reading certificate, found %v", err) |
| 321 | + assert.NotNil(t, cert, "Expected certificate to be read, found nil") |
| 322 | + |
| 323 | + // Test with an invalid certificate |
| 324 | + bakfile, _ := os.CreateTemp("", "*.bak") |
| 325 | + _, _ = io.Copy(derfile, bakfile) |
| 326 | + defer os.Remove(bakfile.Name()) |
| 327 | + cert, err = readCertificate(bakfile.Name()) |
| 328 | + assert.NotNil(t, err, "Expected error while reading certificate, found nil") |
| 329 | + assert.Nil(t, cert, "Expected certificate to be nil, found %v", cert) |
| 330 | +} |
0 commit comments