Skip to content

Commit 312bb45

Browse files
Copilotshueybubbles
andcommitted
Address PR feedback: fix CHANGELOG version and add test for pwd alias
Co-authored-by: shueybubbles <2224906+shueybubbles@users.noreply.github.com>
1 parent 016e455 commit 312bb45

3 files changed

Lines changed: 342 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Changelog
2-
## Next Release
2+
## 1.8.2
33

44
### Bug fixes
55

msdsn/conn_str_test.go.bak

Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
package msdsn
2+
3+
import (
4+
"crypto/tls"
5+
"encoding/hex"
6+
"io"
7+
"os"
8+
"reflect"
9+
"testing"
10+
"time"
11+
12+
"github.com/stretchr/testify/assert"
13+
)
14+
15+
func TestInvalidConnectionString(t *testing.T) {
16+
connStrings := []string{
17+
"log=invalid",
18+
"port=invalid",
19+
"packet size=invalid",
20+
"connection timeout=invalid",
21+
"dial timeout=invalid",
22+
"keepalive=invalid",
23+
"encrypt=invalid",
24+
"trustservercertificate=invalid",
25+
"failoverport=invalid",
26+
"applicationintent=ReadOnly",
27+
"disableretry=invalid",
28+
"multisubnetfailover=invalid",
29+
30+
// ODBC mode
31+
"odbc:password={",
32+
"odbc:password={somepass",
33+
"odbc:password={somepass}}",
34+
"odbc:password={some}pass",
35+
"odbc:=", // unexpected =
36+
"odbc: =",
37+
"odbc:password={some} a",
38+
39+
// URL mode
40+
"sqlserver://\x00",
41+
"sqlserver://host?key=value1&key=value2", // duplicate keys
42+
}
43+
for _, connStr := range connStrings {
44+
_, err := Parse(connStr)
45+
if err == nil {
46+
t.Errorf("Connection expected to fail for connection string %s but it didn't", connStr)
47+
continue
48+
} else {
49+
t.Logf("Connection failed for %s as expected with error %v", connStr, err)
50+
}
51+
}
52+
}
53+
54+
func TestValidConnectionString(t *testing.T) {
55+
type testStruct struct {
56+
connStr string
57+
check func(Config) bool
58+
}
59+
connStrings := []testStruct{
60+
{"server=server\\instance;database=testdb;user id=tester;password=pwd", func(p Config) bool {
61+
return p.Host == "server" && p.Instance == "instance" && p.User == "tester" && p.Password == "pwd"
62+
}},
63+
{"server=.", func(p Config) bool { return p.Host == "localhost" && !p.ColumnEncryption }},
64+
{"server=(local)", func(p Config) bool { return p.Host == "localhost" }},
65+
{"ServerSPN=serverspn;Workstation ID=workstid", func(p Config) bool { return p.ServerSPN == "serverspn" && p.Workstation == "workstid" }},
66+
{"failoverpartner=fopartner;failoverport=2000", func(p Config) bool { return p.FailOverPartner == "fopartner" && p.FailOverPort == 2000 }},
67+
{"app name=appname;applicationintent=ReadOnly;database=testdb", func(p Config) bool { return p.AppName == "appname" && p.ReadOnlyIntent }},
68+
{"encrypt=disable", func(p Config) bool { return p.Encryption == EncryptionDisabled }},
69+
{"encrypt=disable;tlsmin=1.1", func(p Config) bool { return p.Encryption == EncryptionDisabled && p.TLSConfig == nil }},
70+
{"encrypt=true", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 }},
71+
{"encrypt=mandatory", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 }},
72+
{"encrypt=true;tlsmin=1.0", func(p Config) bool {
73+
return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS10
74+
}},
75+
{"encrypt=false;tlsmin=1.0", func(p Config) bool {
76+
return p.Encryption == EncryptionOff && p.TLSConfig.MinVersion == tls.VersionTLS10
77+
}},
78+
{"encrypt=true;tlsmin=1.1;column encryption setting=enabled", func(p Config) bool {
79+
return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption
80+
}},
81+
{"encrypt=true;tlsmin=1.2", func(p Config) bool {
82+
return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS12
83+
}},
84+
{"encrypt=true;tlsmin=1.3", func(p Config) bool {
85+
return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS13
86+
}},
87+
{"encrypt=true;tlsmin=1.4", func(p Config) bool {
88+
return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0
89+
}},
90+
{"encrypt=false", func(p Config) bool { return p.Encryption == EncryptionOff }},
91+
{"encrypt=optional", func(p Config) bool { return p.Encryption == EncryptionOff }},
92+
{"encrypt=strict", func(p Config) bool { return p.Encryption == EncryptionStrict }},
93+
{"connection timeout=3;dial timeout=4;keepalive=5", func(p Config) bool {
94+
return p.ConnTimeout == 3*time.Second && p.DialTimeout == 4*time.Second && p.KeepAlive == 5*time.Second
95+
}},
96+
{"log=63", func(p Config) bool { return p.LogFlags == 63 && p.Port == 0 }},
97+
{"log=63;port=1000", func(p Config) bool { return p.LogFlags == 63 && p.Port == 1000 }},
98+
{"log=64", func(p Config) bool { return p.LogFlags == 64 }},
99+
{"log=64;packet size=0", func(p Config) bool { return p.LogFlags == 64 && p.PacketSize == 512 }},
100+
{"log=64;packet size=300", func(p Config) bool { return p.LogFlags == 64 && p.PacketSize == 512 }},
101+
{"log=64;packet size=8192", func(p Config) bool { return p.LogFlags == 64 && p.PacketSize == 8192 }},
102+
{"log=64;packet size=48000", func(p Config) bool { return p.LogFlags == 64 && p.PacketSize == 32767 }},
103+
{"disableretry=true", func(p Config) bool { return p.DisableRetry }},
104+
{"disableretry=false", func(p Config) bool { return !p.DisableRetry }},
105+
{"disableretry=1", func(p Config) bool { return p.DisableRetry }},
106+
{"disableretry=0", func(p Config) bool { return !p.DisableRetry }},
107+
{"", func(p Config) bool { return p.DisableRetry == disableRetryDefault }},
108+
{"MultiSubnetFailover=true;NoTraceID=true", func(p Config) bool { return p.MultiSubnetFailover && p.NoTraceID }},
109+
{"MultiSubnetFailover=false", func(p Config) bool { return !p.MultiSubnetFailover }},
110+
// those are supported currently, but maybe should not be
111+
{"someparam", func(p Config) bool { return true }},
112+
{";;=;", func(p Config) bool { return true }},
113+
114+
// ODBC mode
115+
{"odbc:server=somehost;user id=someuser;password=somepass", func(p Config) bool {
116+
return p.Host == "somehost" && p.User == "someuser" && p.Password == "somepass"
117+
}},
118+
{"odbc:server=somehost;user id=someuser;password=some{pass", func(p Config) bool {
119+
return p.Host == "somehost" && p.User == "someuser" && p.Password == "some{pass"
120+
}},
121+
{"odbc:server={somehost};user id={someuser};password={somepass}", func(p Config) bool {
122+
return p.Host == "somehost" && p.User == "someuser" && p.Password == "somepass"
123+
}},
124+
{"odbc:server={somehost};user id={someuser};password={some=pass}", func(p Config) bool {
125+
return p.Host == "somehost" && p.User == "someuser" && p.Password == "some=pass"
126+
}},
127+
{"odbc:server={somehost};user id={someuser};password={some;pass}", func(p Config) bool {
128+
return p.Host == "somehost" && p.User == "someuser" && p.Password == "some;pass"
129+
}},
130+
{"odbc:server={somehost};user id={someuser};password={some{pass}", func(p Config) bool {
131+
return p.Host == "somehost" && p.User == "someuser" && p.Password == "some{pass"
132+
}},
133+
{"odbc:server={somehost};user id={someuser};password={some}}pass}", func(p Config) bool {
134+
return p.Host == "somehost" && p.User == "someuser" && p.Password == "some}pass"
135+
}},
136+
{"odbc:server={somehost};user id={someuser};password={some{}}p=a;ss}", func(p Config) bool {
137+
return p.Host == "somehost" && p.User == "someuser" && p.Password == "some{}p=a;ss"
138+
}},
139+
{"odbc: server = somehost; user id = someuser ; password = {some pass } ;", func(p Config) bool {
140+
return p.Host == "somehost" && p.User == "someuser" && p.Password == "some pass "
141+
}},
142+
{"odbc:password", func(p Config) bool {
143+
return p.Password == ""
144+
}},
145+
{"odbc:", func(p Config) bool {
146+
return true
147+
}},
148+
{"odbc:password=", func(p Config) bool {
149+
return p.Password == ""
150+
}},
151+
{"odbc:password;", func(p Config) bool {
152+
return p.Password == ""
153+
}},
154+
{"odbc:password=;", func(p Config) bool {
155+
return p.Password == ""
156+
}},
157+
{"odbc:password={value} ", func(p Config) bool {
158+
return p.Password == "value"
159+
}},
160+
{"odbc:server=somehost;user id=someuser;password=somepass;disableretry=true", func(p Config) bool {
161+
return p.Host == "somehost" && p.User == "someuser" && p.Password == "somepass" && p.DisableRetry
162+
}},
163+
{"odbc:server=somehost;user id=someuser;password=somepass; disableretry = 1 ", func(p Config) bool {
164+
return p.Host == "somehost" && p.User == "someuser" && p.Password == "somepass" && p.DisableRetry
165+
}},
166+
167+
// URL mode
168+
{"sqlserver://somehost?connection+timeout=30", func(p Config) bool {
169+
return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.ConnTimeout == 30*time.Second
170+
}},
171+
{"sqlserver://someuser@somehost?connection+timeout=30", func(p Config) bool {
172+
return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second
173+
}},
174+
{"sqlserver://someuser:@somehost?connection+timeout=30", func(p Config) bool {
175+
return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second
176+
}},
177+
{"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost?connection+timeout=30", func(p Config) bool {
178+
return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second
179+
}},
180+
{"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost:1434?connection+timeout=30", func(p Config) bool {
181+
return p.Host == "somehost" && p.Port == 1434 && p.Instance == "" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second
182+
}},
183+
{"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost:1434/someinstance?connection+timeout=30", func(p Config) bool {
184+
return p.Host == "somehost" && p.Port == 1434 && p.Instance == "someinstance" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second
185+
}},
186+
{"sqlserver://someuser@somehost?disableretry=true", func(p Config) bool {
187+
return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.DisableRetry
188+
}},
189+
{"sqlserver://someuser@somehost?connection+timeout=30&disableretry=1", func(p Config) bool {
190+
return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second && p.DisableRetry && !p.ColumnEncryption
191+
}},
192+
{"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1", func(p Config) bool {
193+
return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption && !p.Encoding.GuidConversion
194+
}},
195+
{"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1&guid+conversion=true", func(p Config) bool {
196+
return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption && p.Encoding.GuidConversion
197+
}},
198+
}
199+
for _, ts := range connStrings {
200+
p, err := Parse(ts.connStr)
201+
if err == nil {
202+
t.Logf("Connection string was parsed successfully %s", ts.connStr)
203+
} else {
204+
t.Errorf("Connection string %s failed to parse with error %s", ts.connStr, err)
205+
continue
206+
}
207+
208+
if !ts.check(p) {
209+
t.Errorf("Check failed on conn str %s", ts.connStr)
210+
}
211+
}
212+
}
213+
214+
func TestSplitConnectionStringURL(t *testing.T) {
215+
_, err := splitConnectionStringURL("http://bad")
216+
if err == nil {
217+
t.Error("Connection string with invalid scheme should fail to parse but it didn't")
218+
}
219+
}
220+
221+
func TestConnParseRoundTripFixed(t *testing.T) {
222+
connStr := "sqlserver://sa:sa@localhost/sqlexpress?database=master&log=127&disableretry=true&dial+timeout=30"
223+
params, err := Parse(connStr)
224+
if err != nil {
225+
t.Fatal("Test URL is not valid", err)
226+
}
227+
rtParams, err := Parse(params.URL().String())
228+
if err != nil {
229+
t.Fatal("Params after roundtrip are not valid", err)
230+
}
231+
t.Log("params.URL " + params.URL().String())
232+
params.ActivityID = nil
233+
rtParams.ActivityID = nil
234+
if !reflect.DeepEqual(params, rtParams) {
235+
t.Fatal("Parameters do not match after roundtrip", params, rtParams)
236+
}
237+
}
238+
239+
func TestServerNameInTLSConfig(t *testing.T) {
240+
var tests = []struct {
241+
dsn string
242+
host string
243+
hasTLSConfig bool
244+
}{
245+
{"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=true", "somehost", true},
246+
{"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=false", "somehost", true},
247+
{"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=true&hostnameincertificate=someotherhost", "someotherhost", true},
248+
{"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false", "somehost", true},
249+
{"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=DISABLE", "", false},
250+
{"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=DISABLE&hostnameincertificate=someotherhost", "", false},
251+
{"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=false", "somehost", true},
252+
}
253+
for _, test := range tests {
254+
cfg, err := Parse(test.dsn)
255+
if err != nil {
256+
t.Errorf("Could not parse valid connection string %s: %v", test.dsn, err)
257+
}
258+
if !test.hasTLSConfig && cfg.TLSConfig != nil {
259+
t.Errorf("Expected empty TLS config, but got %v (cfg.Host was %s)", cfg.TLSConfig, cfg.Host)
260+
}
261+
if test.hasTLSConfig && cfg.TLSConfig.ServerName != test.host {
262+
t.Errorf("Expected somehost as TLS server, but got %s (cfg.Host was %s)", cfg.TLSConfig.ServerName, cfg.Host)
263+
}
264+
}
265+
}
266+
func TestAllKeysAreAvailableInParametersMap(t *testing.T) {
267+
keys := map[string]string{
268+
"user id": "1",
269+
"testparam": "testvalue",
270+
"password": "test",
271+
"thisisanunknownkey": "thisisthevalue",
272+
"server": "name",
273+
}
274+
275+
connString := ""
276+
for key, val := range keys {
277+
connString += key + "=" + val + ";"
278+
}
279+
280+
params, err := Parse(connString)
281+
if err != nil {
282+
t.Errorf("unexpected error while parsing, %v", err)
283+
}
284+
285+
if params.Parameters == nil {
286+
t.Error("Expected parameters map to be instanciated, found nil")
287+
return
288+
}
289+
290+
if len(params.Parameters) != len(keys) {
291+
t.Errorf("Expected parameters map to be same length as input map length, expected %v, found %v", len(keys), len(params.Parameters))
292+
return
293+
}
294+
295+
for key, val := range keys {
296+
if params.Parameters[key] != val {
297+
t.Errorf("Expected parameters map to contain key %v and value %v, found %v", key, val, params.Parameters[key])
298+
}
299+
}
300+
}
301+
302+
func TestReadCertificate(t *testing.T) {
303+
304+
//Setup dummy certificate
305+
hexCertificate := "3082031830820200a00302010202103608db21691eccba415f8624d34b66fe300d06092a864886f70d01010b050030143112301006035504030c096c6f63616c686f7374301e170d3233303830383133343233375a170d3234303830383134303233375a30143112301006035504030c096c6f63616c686f737430820122300d06092a864886f70d01010105000382010f003082010a0282010100e18cd4d2923c548ac6e4fd731de116716a09fd2447feb28213810a1b508c22c108928f61531d31439b7252808d6bc6a71d50e5bb00596bbc1633d65389b80bb36f22d1546cbff570881331285cb458b3a2ad1ad0fa83081bd000f2793d29460a6adc0128a2d979d34f5cd91d60d4fef5932f393e04fcb3730a33693f3c44b882384c529f7489e58e296b0c17ca391b02f2488c38f8fc3c3afa0c1be0d22329287f93cf57ee46836a12f74de82eb54b18a5ae0134266db52633c0e33177f8ac4532045f053ddc920f0659cafa84c54c2b3cc92f4010c8af93ae0fc92e461d47c0cf2da46421189b2ddcf2f6ae17cb5ef6f1eda94452af6f714d583dcb7bcd43e90203010001a3663064300e0603551d0f0101ff0404030205a0301d0603551d250416301406082b0601050507030206082b0601050507030130140603551d11040d300b82096c6f63616c686f7374301d0603551d0e0416041443e3d9f187e9474d73794c641d54ecb810342ec6300d06092a864886f70d01010b05000382010100a227e721ac80838e66ef75d8ba080185dd8f4a5c84d7373e8ed50534100a490b577e3c1af593597303bdad8bb900e32b5d6f69941c19cc87fd426f9e4a4134f34f2ade02748d64031bc4e9c7617206a45c1d9556bb0488994cd27126adb029216f7c57852c1663983b7be638f1bc5411ba2221ce3fde29bf4818e36bec8ac25e9a37bfc41c5a3812829a6358a66c467818448346be140639957077b924b22567b75c7dab4d9d6794b4d79596d17446641684cbd193ec20a6faa85fb6b72f5f30dc57e8cd662b22152429e5b43ccb450c6840ba006e1c8e38b002aa97d8dd07e100ef76eebd9c523d8710636f060865e6198da620fedbf1ae6ed75df997641621"
306+
derfile, _ := os.CreateTemp("", "*.der")
307+
defer os.Remove(derfile.Name())
308+
certInBytes, _ := hex.DecodeString(hexCertificate)
309+
_, _ = derfile.Write(certInBytes)
310+
311+
// Test with a valid certificate
312+
cert, err := readCertificate(derfile.Name())
313+
assert.Nil(t, err, "Expected no error while reading certificate, found %v", err)
314+
assert.NotNil(t, cert, "Expected certificate to be read, found nil")
315+
316+
pemfile, _ := os.CreateTemp("", "*.pem")
317+
_, _ = io.Copy(derfile, pemfile)
318+
defer os.Remove(pemfile.Name())
319+
cert, err = readCertificate(pemfile.Name())
320+
assert.Nil(t, err, "Expected no error while reading certificate, found %v", err)
321+
assert.NotNil(t, cert, "Expected certificate to be read, found nil")
322+
323+
// Test with an invalid certificate
324+
bakfile, _ := os.CreateTemp("", "*.bak")
325+
_, _ = io.Copy(derfile, bakfile)
326+
defer os.Remove(bakfile.Name())
327+
cert, err = readCertificate(bakfile.Name())
328+
assert.NotNil(t, err, "Expected error while reading certificate, found nil")
329+
assert.Nil(t, cert, "Expected certificate to be nil, found %v", cert)
330+
}

msdsn/pwd_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package msdsn
2+
3+
import "testing"
4+
5+
// TestPwdInConnStrSimple is a simple test for pwd as a password alias
6+
func TestPwdInConnStrSimple(t *testing.T) {
7+
// Test that pwd gets mapped to password in the adoSynonyms map
8+
if adoSynonyms["pwd"] != Password {
9+
t.Errorf("Expected adoSynonyms[\"pwd\"] to be %q, got %q", Password, adoSynonyms["pwd"])
10+
}
11+
}

0 commit comments

Comments
 (0)