diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlCommandTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlCommandTest.cs index c5e3b000dc..2e3d3dcb7d 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlCommandTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlCommandTest.cs @@ -295,16 +295,18 @@ public void CommandTimeout_Value_Negative() Assert.Equal("CommandTimeout", ex.ParamName); } - [Fact] - public void CommandType_Value_Invalid() + [Theory] + [InlineData((CommandType)666)] + [InlineData(CommandType.TableDirect)] + public void CommandType_Value_Invalid(CommandType commandType) { SqlCommand cmd = new SqlCommand(); - ArgumentOutOfRangeException ex = Assert.Throws(() => cmd.CommandType = (CommandType)(666)); + ArgumentOutOfRangeException ex = Assert.Throws(() => cmd.CommandType = commandType); // The CommandType enumeration value, 666, is invalid Assert.Null(ex.InnerException); Assert.NotNull(ex.Message); - Assert.True(ex.Message.IndexOf("666", StringComparison.Ordinal) != -1); + Assert.True(ex.Message.IndexOf(((int)commandType).ToString(), StringComparison.Ordinal) != -1); Assert.Equal("CommandType", ex.ParamName); } diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlDataRecordTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlDataRecordTest.cs index 0e748c5d74..2e9de8e7ea 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlDataRecordTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlDataRecordTest.cs @@ -7,6 +7,7 @@ using System.Collections.Generic; using System.Data; using System.Data.SqlTypes; +using System.Text; using Microsoft.Data.SqlClient.Server; using Microsoft.SqlServer.Types; using Xunit; @@ -15,133 +16,34 @@ namespace Microsoft.Data.SqlClient.Tests { public class SqlDataRecordTest { - [Fact] - public void SqlRecordFillTest() + public void GetDataTypeName_ReturnsMetaDataTypeIfUdtType() { SqlMetaData[] metaData = new SqlMetaData[] { - new SqlMetaData("col1", SqlDbType.Bit), - new SqlMetaData("col2", SqlDbType.TinyInt), - new SqlMetaData("col3", SqlDbType.VarBinary, 1000), - new SqlMetaData("col4", SqlDbType.NVarChar, 1000), - new SqlMetaData("col5", SqlDbType.DateTime), - new SqlMetaData("col6", SqlDbType.Float), - new SqlMetaData("col7", SqlDbType.UniqueIdentifier), - new SqlMetaData("col8", SqlDbType.SmallInt), - new SqlMetaData("col9", SqlDbType.Int), - new SqlMetaData("col10", SqlDbType.BigInt), - new SqlMetaData("col11", SqlDbType.Real), - new SqlMetaData("col12", SqlDbType.Decimal), - new SqlMetaData("col13", SqlDbType.Money), - new SqlMetaData("col14", SqlDbType.Variant) + new SqlMetaData("col1", SqlDbType.Udt, typeof(TestUdt), "sql_TestUdt") }; SqlDataRecord record = new SqlDataRecord(metaData); - for (int i = 0; i < record.FieldCount; i++) - { - Assert.Equal($"col{i + 1}", record.GetName(i)); - } - - record.SetBoolean(0, true); - Assert.True(record.GetBoolean(0)); - - record.SetByte(1, 1); - Assert.Equal(1, record.GetByte(1)); - - byte[] bb1 = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 }; - byte[] bb2 = new byte[5]; - record.SetSqlBinary(2, new SqlBinary(new byte[0])); - record.SetBytes(2, 0, bb1, 0, 3); - record.SetBytes(2, 2, bb1, 6, 3); - - // Verify the length of the byte array - Assert.Equal(5, record.GetBytes(2, 0, bb2, 0, 5)); - - Assert.Equal(5, record.GetBytes(2, 0, null, 0, 0)); - - byte[] expected = new byte[] { 1, 2, 7, 8, 9 }; - Assert.Equal(expected, bb2); - - char[] cb1 = new char[] { 'a', 'b', 'c', 'd', 'e', 'f', 'g' }; - char[] cb2 = new char[5]; - record.SetChars(3, 0, cb1, 0, 3); - record.SetChars(3, 2, cb1, 4, 3); - - char[] expectedValue = new char[] { 'a', 'b', 'e', 'f', 'g' }; - Assert.Equal(expectedValue.Length, record.GetChars(3, 0, cb2, 0, 5)); - Assert.Equal(expectedValue, new string(cb2, 0, (int)record.GetChars(3, 0, null, 0, 0)).ToCharArray()); - - record.SetString(3, ""); - string xyz = "xyz"; - record.SetString(3, "xyz"); - Assert.Equal(xyz, record.GetString(3)); - Assert.Equal(xyz.Length, record.GetChars(3, 0, cb2, 0, 5)); - Assert.Equal(xyz, new string(cb2, 0, (int)record.GetChars(3, 0, null, 0, 0))); - - record.SetChars(3, 2, cb1, 4, 3); - Assert.Equal(5, record.GetChars(3, 0, cb2, 0, 5)); - - string interleavedResult = "xyefg"; - Assert.Equal(interleavedResult, new string(cb2, 0, (int)record.GetChars(3, 0, null, 0, 0))); - Assert.Equal(interleavedResult, record.GetString(3)); - - record.SetSqlDateTime(4, SqlDateTime.MaxValue); - Assert.Equal(SqlDateTime.MaxValue, record.GetSqlDateTime(4)); - - record.SetSqlDouble(5, SqlDouble.MaxValue); - Assert.Equal(SqlDouble.MaxValue, record.GetSqlDouble(5)); - - SqlGuid guid = new SqlGuid("F9168C5E-CEB2-4faa-B6BF-329BF39FA1E4"); - record.SetSqlGuid(6, guid); - Assert.Equal(guid, record.GetSqlGuid(6)); - - record.SetSqlInt16(7, SqlInt16.MaxValue); - Assert.Equal(SqlInt16.MaxValue, record.GetSqlInt16(7)); - - record.SetSqlInt32(8, SqlInt32.MaxValue); - Assert.Equal(SqlInt32.MaxValue, record.GetSqlInt32(8)); - - record.SetSqlInt64(9, SqlInt64.MaxValue); - Assert.Equal(SqlInt64.MaxValue, record.GetSqlInt64(9)); - - record.SetSqlSingle(10, SqlSingle.MinValue); - Assert.Equal(SqlSingle.MinValue, record.GetSqlSingle(10)); - - record.SetSqlDecimal(11, SqlDecimal.Null); - record.SetSqlDecimal(11, SqlDecimal.MaxValue); - Assert.Equal(SqlDecimal.MaxValue, record.GetSqlDecimal(11)); - - record.SetSqlMoney(12, SqlMoney.MaxValue); - Assert.Equal(SqlMoney.MaxValue, record.GetSqlMoney(12)); - + Assert.Equal("Microsoft.Data.SqlClient.Tests.TestUdt", record.GetDataTypeName(0)); + } - // Try adding different values to SqlVariant type - for (int i = 0; i < record.FieldCount - 1; ++i) + [Fact] + public void GetDataTypeName_ReturnsTypeFromMetaTypeIfNotUdt() + { + SqlMetaData[] metaData = new SqlMetaData[] { - object valueToSet = record.GetSqlValue(i); - record.SetValue(record.FieldCount - 1, valueToSet); - object o = record.GetSqlValue(record.FieldCount - 1); - - if (o is SqlBinary) - { - Assert.Equal(((SqlBinary)valueToSet).Value, ((SqlBinary)o).Value); - } - else - { - Assert.Equal(valueToSet, o); - } + new SqlMetaData("col1", SqlDbType.NVarChar, 50) + }; - record.SetDBNull(record.FieldCount - 1); - Assert.Equal(DBNull.Value, record.GetSqlValue(record.FieldCount - 1)); + SqlDataRecord record = new SqlDataRecord(metaData); - record.SetDBNull(i); - Assert.Equal(DBNull.Value, record.GetValue(i)); - } + Assert.Equal("nvarchar", record.GetDataTypeName(0)); } + [Fact] - public void GetDataTypeName_ReturnsMetaDataTypeIfUdtType() + public void GetFieldType_ReturnMetaDataTypeIfUdtType() { SqlMetaData[] metaData = new SqlMetaData[] { @@ -150,23 +52,43 @@ public void GetDataTypeName_ReturnsMetaDataTypeIfUdtType() SqlDataRecord record = new SqlDataRecord(metaData); - Assert.Equal("Microsoft.Data.SqlClient.Tests.TestUdt", record.GetDataTypeName(0)); +#if NET + Assert.Equal(typeof(object), record.GetFieldType(0)); +#else + Assert.Equal(typeof(TestUdt), record.GetFieldType(0)); +#endif + Assert.Equal(typeof(object), record.GetSqlFieldType(0)); } - [Fact] - public void GetDataTypeName_ReturnsTypeFromMetaTypeIfNotUdt() + [Theory] + [ClassData(typeof(DbTypeData))] + public void GetFieldType_ReturnMetaTypeClassTypeIfNotUdt(SqlDbType dbType, int? length, Type expectedClrType, Type expectedSqlType) { SqlMetaData[] metaData = new SqlMetaData[] { - new SqlMetaData("col1", SqlDbType.NVarChar, 50) + length == null ? new SqlMetaData("col1", dbType) : new SqlMetaData("col1", dbType, length.Value) }; SqlDataRecord record = new SqlDataRecord(metaData); - Assert.Equal("nvarchar", record.GetDataTypeName(0)); + Assert.Equal(expectedClrType, record.GetFieldType(0)); + Assert.Equal(expectedSqlType, record.GetSqlFieldType(0)); + } + + [Fact] + public void Ctor_ThrowsIfNullMetadata() + { + SqlMetaData[] metaData = new SqlMetaData[] + { + null + }; + + Assert.Throws(() => new SqlDataRecord(null)); + Assert.Throws(() => new SqlDataRecord(metaData)); } + [Fact] - public void GetFieldType_ReturnMetaTypeClassType() + public void IDataRecord_GetData_ThrowsNotSupported() { SqlMetaData[] metaData = new SqlMetaData[] { @@ -175,7 +97,7 @@ public void GetFieldType_ReturnMetaTypeClassType() SqlDataRecord record = new SqlDataRecord(metaData); - Assert.Equal(typeof(string), record.GetFieldType(0)); + Assert.Throws(() => ((IDataRecord)record).GetData(0)); } [Fact] @@ -189,6 +111,7 @@ public void GetValues_ThrowsIfNull() SqlDataRecord record = new SqlDataRecord(metaData); Assert.Throws(() => record.GetValues(null)); + Assert.Throws(() => record.GetSqlValues(null)); } [Fact] @@ -211,6 +134,15 @@ public void GetValues_IfValuesBiggerThanColumnCount_LastArrayItemKeptEmpty() Assert.Null(values[i]); } Assert.Equal(2, columnCount); + + values = new object[5]; + columnCount = record.GetSqlValues(values); + + for (int i = 2; i < 5; i++) + { + Assert.Null(values[i]); + } + Assert.Equal(2, columnCount); } [Fact] @@ -230,6 +162,11 @@ public void GetValues_IfValuesShorterThanColumnCount_FillOnlyFirstColumn() Assert.Equal("test", values[0]); Assert.Equal(1, columnCount); + + columnCount = record.GetSqlValues(values); + + Assert.Equal(new SqlString("test"), values[0]); + Assert.Equal(1, columnCount); } [Fact] @@ -250,6 +187,108 @@ public void GetValues_FillsArrayAndRespectColumnOrder() Assert.Equal("test", values[0]); Assert.Equal(2, values[1]); Assert.Equal(2, columnCount); + + columnCount = record.GetSqlValues(values); + + Assert.Equal(new SqlString("test"), values[0]); + Assert.Equal(new SqlInt32(2), values[1]); + Assert.Equal(2, columnCount); + } + + [Fact] + public void SetValues_ThrowsIfNull() + { + SqlMetaData[] metaData = new SqlMetaData[] + { + new SqlMetaData("col1", SqlDbType.NVarChar, 50) + }; + + SqlDataRecord record = new SqlDataRecord(metaData); + + Assert.Throws(() => record.SetValues(null)); + } + + [Fact] + public void SetValues_ThrowsIfTypeMismatch() + { + SqlMetaData[] metaData = new SqlMetaData[] + { + new SqlMetaData("col1", SqlDbType.NVarChar, 50), + new SqlMetaData("col2", SqlDbType.Int), + new SqlMetaData("col3", SqlDbType.NVarChar, 50) + }; + + SqlDataRecord record = new SqlDataRecord(metaData); + object[] values = new object[3] { "one", "2", "three" }; + + Assert.Throws(() => record.SetValues(values)); + Assert.True(record.IsDBNull(0)); + } + + [Fact] + public void SetValues_IfValuesBiggerThanColumnCount_LastArrayItemIgnored() + { + SqlMetaData[] metaData = new SqlMetaData[] + { + new SqlMetaData("col1", SqlDbType.NVarChar, 50), + new SqlMetaData("col2", SqlDbType.Int) + }; + SqlDataRecord record = new SqlDataRecord(metaData); + object[] values = new object[5] { "test", 2, null, null, null }; + int columnCount = record.SetValues(values); + + Assert.Equal((string)values[0], record.GetString(0)); + Assert.Equal((int)values[1], record.GetInt32(1)); + Assert.Equal(2, columnCount); + } + + [Fact] + public void SetValues_IfValuesShorterThanColumnCount_FillOnlyFirstColumn() + { + SqlMetaData[] metaData = new SqlMetaData[] + { + new SqlMetaData("col1", SqlDbType.NVarChar, 50), + new SqlMetaData("col2", SqlDbType.Int) + }; + SqlDataRecord record = new SqlDataRecord(metaData); + record.SetString(0, "test"); + record.SetSqlInt32(1, 2); + + object[] values = new object[1] { "test2" }; + int columnCount = record.SetValues(values); + + Assert.Equal("test2", record.GetString(0)); + Assert.Equal(2, record.GetInt32(1)); + + Assert.Equal(1, columnCount); + } + + [Fact] + public void SetSingleValue_ThrowsIfTypeMismatch() + { + SqlMetaData[] metaData = new SqlMetaData[] + { + new SqlMetaData("col1", SqlDbType.NVarChar, 50), + new SqlMetaData("col2", SqlDbType.Int), + new SqlMetaData("col3", SqlDbType.NVarChar, 50) + }; + + SqlDataRecord record = new SqlDataRecord(metaData); + + Assert.Throws(() => record.SetValue(1, "2")); + } + + [Fact] + public void GetName_ReturnsNameOfColumn() + { + SqlMetaData[] metaData = new SqlMetaData[] + { + new SqlMetaData("col1", SqlDbType.NVarChar, 50) + }; + + SqlDataRecord record = new SqlDataRecord(metaData); + + Assert.Equal("col1", record.GetName(0)); } [Fact] @@ -294,6 +333,7 @@ public void GetOrdinal_ReturnsIndexOfColumn() Assert.Equal(1, record.GetOrdinal("col2")); } + [Fact] public void GetOrdinal_ReturnsIndexOfColumn_CaseInsensitive() { @@ -319,6 +359,17 @@ public void GetChar_ThrowsNotSupported() Assert.Throws(() => record.GetChar(0)); } + [Fact] + public void SetChar_ThrowsNotSupported() + { + SqlMetaData[] metaData = new SqlMetaData[] + { + new SqlMetaData("col1", SqlDbType.Char, 100) + }; + SqlDataRecord record = new SqlDataRecord(metaData); + Assert.Throws(() => record.SetChar(0, 'c')); + } + [Theory] [ClassData(typeof(GetUdtTypeTestData))] public void GetUdt_ReturnsValue(Type udtType, object value, string serverTypeName) @@ -347,17 +398,164 @@ public void GetXXX_ThrowsIfBadType(Func getXXX) } [Theory] - [ClassData(typeof(GetXXXCheckValueTestData))] - public void GetXXX_ReturnValue(SqlDbType dbType, object value, Func getXXX) + [InlineData(-1)] + [InlineData(1)] + public void InvalidIndexAccess_Throws(int index) + { + SqlMetaData[] metaData = new SqlMetaData[] + { + new SqlMetaData("col1", SqlDbType.NVarChar, 1) + }; + SqlDataRecord record = new SqlDataRecord(metaData); + + Assert.Throws(() => record.GetSqlMetaData(index)); + Assert.Throws(() => record.SetDBNull(index)); + Assert.Throws(() => record.IsDBNull(index)); + Assert.Throws(() => record.GetValue(index)); + } + + [Theory] + [ClassData(typeof(GetFixedLengthCheckValueTestData))] + public void GetFixedLength_ReturnValue(SqlDbType dbType, object value, + Action setXXX, Func getXXX, + Action equalityAssertion) { SqlMetaData[] metaData = new SqlMetaData[] { new SqlMetaData("col1", dbType) }; SqlDataRecord record = new SqlDataRecord(metaData); + setXXX(record, value); + + Assert.False(record.IsDBNull(0)); + equalityAssertion(value, getXXX(record)); + + record.SetDBNull(0); + Assert.True(record.IsDBNull(0)); + } + + [Theory] + [ClassData(typeof(GetVariableLengthCheckValueTestData))] + public void GetVariableLength_ReturnValue(SqlDbType dbType, object value, + Action setXXX, Func getXXX, + Action equalityAssertion) + { + SqlMetaData[] metaData = new SqlMetaData[] + { + new SqlMetaData("col1", dbType, 50) + }; + SqlDataRecord record = new SqlDataRecord(metaData); + setXXX(record, value); + + Assert.False(record.IsDBNull(0)); + equalityAssertion(value, getXXX(record)); + + record.SetDBNull(0); + Assert.True(record.IsDBNull(0)); + } + + [Fact] + public void GetSqlXml_ReturnValue() + { + SqlMetaData[] metaData = new SqlMetaData[] + { + new SqlMetaData("col1", SqlDbType.Xml) + }; + SqlDataRecord record = new SqlDataRecord(metaData); + + string xmlString = ""; + using System.IO.MemoryStream xmlMS = new System.IO.MemoryStream(Encoding.Unicode.GetBytes(xmlString)); + SqlXml value = new SqlXml(xmlMS); record.SetValue(0, value); - Assert.Equal(value, record.GetValue(0)); - Assert.Equal(value, getXXX(record)); + + Assert.False(record.IsDBNull(0)); + Assert.Equal(xmlString, record.GetSqlXml(0).Value); + Assert.Equal(xmlString, record.GetString(0)); + } + + [Fact] + public void GetBytes_SetBytes_Succeed() + { + byte[] digits = new byte[16] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }; + byte[] zeroBytes = new byte[digits.Length]; + SqlMetaData[] metaData = new SqlMetaData[] + { + new SqlMetaData("col1", SqlDbType.VarBinary, digits.Length) + }; + SqlDataRecord record = new SqlDataRecord(metaData); + byte[] buffer = new byte[8]; + long byteCount; + + zeroBytes.AsSpan().Clear(); + buffer.AsSpan().Fill(0xFF); + + Assert.True(record.IsDBNull(0)); + record.SetBytes(0, 0, zeroBytes, 0, zeroBytes.Length); + Assert.False(record.IsDBNull(0)); + + // Read the first 8 bytes, confirming that 8 bytes were read and that all are zero + byteCount = record.GetBytes(0, 0, buffer, 0, buffer.Length); + Assert.Equal(buffer.Length, byteCount); + Assert.All(buffer, (b) => Assert.Equal(0, b)); + + // Write four bytes (index 8-11) from the sequence to the record. Read them back, confirming + // that they've been written and that the surrounding bytes remain zero. + record.SetBytes(0, 2, digits, 8, 4); + byteCount = record.GetBytes(0, 2, buffer, 2, 4); + Assert.Equal(0, buffer[0]); + Assert.Equal(0, buffer[1]); + Assert.Equal(9, buffer[2]); + Assert.Equal(10, buffer[3]); + Assert.Equal(11, buffer[4]); + Assert.Equal(12, buffer[5]); + Assert.Equal(0, buffer[6]); + Assert.Equal(0, buffer[7]); + + Assert.Equal(4, byteCount); + } + + [Fact] + public void GetChars_SetChars_Succeed() + { + char[] alpha = new char[16] { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p' }; + char[] zeroChars = new char[alpha.Length]; + SqlMetaData[] metaData = new SqlMetaData[] + { + new SqlMetaData("col1", SqlDbType.VarChar, alpha.Length) + }; + SqlDataRecord record = new SqlDataRecord(metaData); + char[] buffer = new char[8]; + long charCount; + + zeroChars.AsSpan().Fill('0'); + buffer.AsSpan().Fill('Z'); + + Assert.True(record.IsDBNull(0)); + record.SetChars(0, 0, zeroChars, 0, zeroChars.Length); + Assert.False(record.IsDBNull(0)); + + // Read the first 8 chars, confirming that 8 chars were read and that all are '0' + charCount = record.GetChars(0, 0, buffer, 0, buffer.Length); + Assert.Equal(buffer.Length, charCount); + Assert.All(buffer, (b) => Assert.Equal('0', b)); + + // Write four chars (index 8-11) from the sequence to the record. Read them back, confirming + // that they've been written and that the surrounding chars remain '0'. + record.SetChars(0, 2, alpha, 8, 4); + charCount = record.GetChars(0, 2, buffer, 2, 4); + Assert.Equal('0', buffer[0]); + Assert.Equal('0', buffer[1]); + Assert.Equal('i', buffer[2]); + Assert.Equal('j', buffer[3]); + Assert.Equal('k', buffer[4]); + Assert.Equal('l', buffer[5]); + Assert.Equal('0', buffer[6]); + Assert.Equal('0', buffer[7]); + + Assert.Equal(4, charCount); + + string resultantString = record.GetString(0); + Assert.Equal("00ijkl0000000000", resultantString); } } @@ -398,31 +596,373 @@ IEnumerator IEnumerable.GetEnumerator() } } - public class GetXXXCheckValueTestData : IEnumerable + public class DbTypeData : IEnumerable { public IEnumerator GetEnumerator() { - yield return new object[] { SqlDbType.UniqueIdentifier, Guid.NewGuid(), new Func(r => r.GetGuid(0)) }; - yield return new object[] { SqlDbType.SmallInt, (short)123, new Func(r => r.GetInt16(0)) }; - yield return new object[] { SqlDbType.Int, 123456, new Func(r => r.GetInt32(0)) }; - yield return new object[] { SqlDbType.BigInt, (long)123456789, new Func(r => r.GetInt64(0)) }; - yield return new object[] { SqlDbType.Float, (double)1.2, new Func(r => r.GetDouble(0)) }; - yield return new object[] { SqlDbType.Real, (float)1.2, new Func(r => r.GetFloat(0)) }; - yield return new object[] { SqlDbType.Decimal, 1.2m, new Func(r => r.GetDecimal(0)) }; - yield return new object[] { SqlDbType.DateTime, DateTime.Now, new Func(r => r.GetDateTime(0)) }; - yield return new object[] { SqlDbType.DateTimeOffset, new DateTimeOffset(DateTime.Now), new Func(r => r.GetDateTimeOffset(0)) }; - yield return new object[] { SqlDbType.Time, TimeSpan.FromHours(1), new Func(r => r.GetTimeSpan(0)) }; - yield return new object[] { SqlDbType.Date, DateTime.Now.Date, new Func(r => r.GetDateTime(0)) }; - yield return new object[] { SqlDbType.Bit, bool.Parse(bool.TrueString), new Func(r => r.GetBoolean(0)) }; - yield return new object[] { SqlDbType.SmallDateTime, DateTime.Now, new Func(r => r.GetDateTime(0)) }; - yield return new object[] { SqlDbType.TinyInt, (byte)1, new Func(r => r.GetByte(0)) }; + yield return new object[] { SqlDbType.BigInt, null, typeof(long), typeof(SqlInt64) }; + yield return new object[] { SqlDbType.Binary, 50, typeof(byte[]), typeof(SqlBinary) }; + yield return new object[] { SqlDbType.Bit, null, typeof(bool), typeof(SqlBoolean) }; + yield return new object[] { SqlDbType.Char, 50, typeof(string), typeof(SqlString) }; + yield return new object[] { SqlDbType.Date, null, typeof(DateTime), typeof(DateTime) }; + yield return new object[] { SqlDbType.DateTime, null, typeof(DateTime), typeof(SqlDateTime) }; + yield return new object[] { SqlDbType.DateTime2, null, typeof(DateTime), typeof(DateTime) }; + yield return new object[] { SqlDbType.DateTimeOffset, null, typeof(DateTimeOffset), typeof(DateTimeOffset) }; + yield return new object[] { SqlDbType.Decimal, null, typeof(decimal), typeof(SqlDecimal) }; + yield return new object[] { SqlDbType.Float, null, typeof(double), typeof(SqlDouble) }; + yield return new object[] { SqlDbType.Image, -1, typeof(byte[]), typeof(SqlBinary) }; + yield return new object[] { SqlDbType.Int, null, typeof(int), typeof(SqlInt32) }; + yield return new object[] { SqlDbType.Money, null, typeof(decimal), typeof(SqlMoney) }; + yield return new object[] { SqlDbType.NChar, 50, typeof(string), typeof(SqlString) }; + yield return new object[] { SqlDbType.NText, -1, typeof(string), typeof(SqlString) }; + yield return new object[] { SqlDbType.NVarChar, 50, typeof(string), typeof(SqlString) }; + yield return new object[] { SqlDbType.Real, null, typeof(float), typeof(SqlSingle) }; + yield return new object[] { SqlDbType.SmallDateTime, null, typeof(DateTime), typeof(SqlDateTime) }; + yield return new object[] { SqlDbType.SmallInt, null, typeof(short), typeof(SqlInt16) }; + yield return new object[] { SqlDbType.SmallMoney, null, typeof(decimal), typeof(SqlMoney) }; + yield return new object[] { SqlDbType.Text, -1, typeof(string), typeof(SqlString) }; + yield return new object[] { SqlDbType.Time, null, typeof(TimeSpan), typeof(TimeSpan) }; + yield return new object[] { SqlDbType.Timestamp, null, typeof(byte[]), typeof(SqlBinary) }; + yield return new object[] { SqlDbType.UniqueIdentifier, null, typeof(Guid), typeof(SqlGuid) }; + yield return new object[] { SqlDbType.VarBinary, 50, typeof(byte[]), typeof(SqlBinary) }; + yield return new object[] { SqlDbType.VarChar, 50, typeof(string), typeof(SqlString) }; + yield return new object[] { SqlDbType.Variant, null, typeof(object), typeof(object) }; + yield return new object[] { SqlDbType.Xml, null, typeof(string), typeof(SqlXml) }; } IEnumerator IEnumerable.GetEnumerator() + => GetEnumerator(); + } + + public abstract class GetXXXCheckValueTestData + : IEnumerable + { + public abstract IEnumerator GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() + => GetEnumerator(); + + private static IEnumerable GenerateCombination(SqlDbType dbType, T value, + Action setter, + params Func[] getters) { - return GetEnumerator(); + Action equalityAssertion = (o1, o2) => Assert.Equal((T)o1, (T)o2, EqualityComparer.Default); + + foreach (Func getter in getters) + { + yield return new object[] { dbType, value, + new Action((r, o) => r.SetValue(0, o)), getter, equalityAssertion }; + yield return new object[] { dbType, value, setter, getter, equalityAssertion }; + } } + + protected static IEnumerable GenerateClrCombination(SqlDbType dbType, T value, + Action setter, + params Func[] getters) + { + Action equalityAssertion = (o1, o2) => Assert.Equal((T)o1, (T)o2, EqualityComparer.Default); + + foreach (object[] data in GenerateCombination(dbType, value, setter, getters)) + { + yield return data; + } + foreach (object[] data in GenerateCombination(dbType, value, setter, r => (T)r[0], r => (T)r["col1"])) + { + yield return data; + } + + yield return new object[] { dbType, value, + new Action((r, o) => r.SetValue(0, o)), + new Func(r => r.GetValue(0)), + equalityAssertion }; + yield return new object[] { dbType, value, setter, new Func(r => r.GetValue(0)), equalityAssertion }; + } + + protected static IEnumerable GenerateSqlCombination(SqlDbType dbType, T value, + Action setter, + params Func[] getters) + => GenerateCombination(dbType, value, setter, getters); } + + public class GetFixedLengthCheckValueTestData : GetXXXCheckValueTestData + { + public override IEnumerator GetEnumerator() + { + foreach (object[] data in GenerateClrCombination(SqlDbType.UniqueIdentifier, Guid.NewGuid(), + (r, o) => r.SetGuid(0, (Guid)o), + r => r.GetGuid(0), r => r.GetSqlGuid(0).Value)) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.UniqueIdentifier, new SqlGuid(Guid.NewGuid()), + (r, o) => r.SetSqlGuid(0, (SqlGuid)o), + r => new SqlGuid(r.GetGuid(0)), r => r.GetSqlGuid(0), r => new SqlGuid((Guid)r[0]), r => new SqlGuid((Guid)r["col1"]))) + { + yield return data; + } + + foreach (object[] data in GenerateClrCombination(SqlDbType.TinyInt, (byte)1, + (r, o) => r.SetByte(0, (byte)o), + r => r.GetByte(0), r => r.GetSqlByte(0).Value)) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.TinyInt, new SqlByte(1), + (r, o) => r.SetSqlByte(0, (SqlByte)o), + r => new SqlByte(r.GetByte(0)), r => r.GetSqlByte(0), r => new SqlByte((byte)r[0]), r => new SqlByte((byte)r["col1"]))) + { + yield return data; + } + + foreach (object[] data in GenerateClrCombination(SqlDbType.SmallInt, (short)123, + (r, o) => r.SetInt16(0, (short)o), + r => r.GetInt16(0), r => r.GetSqlInt16(0).Value)) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.SmallInt, new SqlInt16(123), + (r, o) => r.SetSqlInt16(0, (SqlInt16)o), + r => new SqlInt16(r.GetInt16(0)), r => r.GetSqlInt16(0), r => new SqlInt16((short)r[0]), r => new SqlInt16((short)r["col1"]))) + { + yield return data; + } + + foreach (object[] data in GenerateClrCombination(SqlDbType.Int, 123456, + (r, o) => r.SetInt32(0, (int)o), + r => r.GetInt32(0), r => r.GetSqlInt32(0).Value)) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.Int, new SqlInt32(123456), + (r, o) => r.SetSqlInt32(0, (SqlInt32)o), + r => new SqlInt32(r.GetInt32(0)), r => r.GetSqlInt32(0), r => new SqlInt32((int)r[0]), r => new SqlInt32((int)r["col1"]))) + { + yield return data; + } + + foreach (object[] data in GenerateClrCombination(SqlDbType.BigInt, (long)123456789, + (r, o) => r.SetInt64(0, (long)o), + r => r.GetInt64(0), r => r.GetSqlInt64(0).Value)) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.BigInt, new SqlInt64(123456789), + (r, o) => r.SetSqlInt64(0, (SqlInt64)o), + r => new SqlInt64(r.GetInt64(0)), r => r.GetSqlInt64(0), r => new SqlInt64((long)r[0]), r => new SqlInt64((long)r["col1"]))) + { + yield return data; + } + + foreach (object[] data in GenerateClrCombination(SqlDbType.Float, (double)1.2, + (r, o) => r.SetDouble(0, (double)o), + r => r.GetDouble(0), r => r.GetSqlDouble(0).Value)) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.Float, new SqlDouble(1.2), + (r, o) => r.SetSqlDouble(0, (SqlDouble)o), + r => new SqlDouble(r.GetDouble(0)), r => r.GetSqlDouble(0), r => new SqlDouble((double)r[0]), r => new SqlDouble((double)r["col1"]))) + { + yield return data; + } + + foreach (object[] data in GenerateClrCombination(SqlDbType.Real, (float)1.2, + (r, o) => r.SetFloat(0, (float)o), + r => r.GetFloat(0), r => r.GetSqlSingle(0).Value)) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.Real, new SqlSingle(1.2), + (r, o) => r.SetSqlSingle(0, (SqlSingle)o), + r => new SqlSingle(r.GetFloat(0)), r => r.GetSqlSingle(0), r => new SqlSingle((float)r[0]), r => new SqlSingle((float)r["col1"]))) + { + yield return data; + } + + foreach (object[] data in GenerateClrCombination(SqlDbType.Decimal, 1.2m, + (r, o) => r.SetDecimal(0, (decimal)o), + r => r.GetDecimal(0), r => r.GetSqlDecimal(0).Value)) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.Decimal, new SqlDecimal(1.2), + (r, o) => r.SetSqlDecimal(0, (SqlDecimal)o), + r => new SqlDecimal(r.GetDecimal(0)), r => r.GetSqlDecimal(0), r => new SqlDecimal((decimal)r[0]), r => new SqlDecimal((decimal)r["col1"]))) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.Money, new SqlMoney(1.2), + (r, o) => r.SetSqlMoney(0, (SqlMoney)o), + r => new SqlMoney(r.GetDecimal(0)), r => r.GetSqlMoney(0), r => new SqlMoney((decimal)r[0]), r => new SqlMoney((decimal)r["col1"]))) + { + yield return data; + } + + // The precision of a datetime in SQL Server is limited. Hardcode a value which will always compare correctly + foreach (object[] data in GenerateClrCombination(SqlDbType.DateTime, new DateTime(2010, 1, 1, 12, 34, 56), + (r, o) => r.SetDateTime(0, (DateTime)o), + r => r.GetDateTime(0), r => r.GetSqlDateTime(0).Value)) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.DateTime, new SqlDateTime(2010, 1, 1, 12, 34, 56), + (r, o) => r.SetSqlDateTime(0, (SqlDateTime)o), + r => new SqlDateTime(r.GetDateTime(0)), r => r.GetSqlDateTime(0), r => new SqlDateTime((DateTime)r[0]), r => new SqlDateTime((DateTime)r["col1"]))) + { + yield return data; + } + + foreach (object[] data in GenerateClrCombination(SqlDbType.SmallDateTime, new DateTime(2010, 1, 1, 12, 34, 56), + (r, o) => r.SetDateTime(0, (DateTime)o), + r => r.GetDateTime(0), r => r.GetSqlDateTime(0).Value)) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.SmallDateTime, new SqlDateTime(2010, 1, 1, 12, 34, 56), + (r, o) => r.SetSqlDateTime(0, (SqlDateTime)o), + r => new SqlDateTime(r.GetDateTime(0)), r => r.GetSqlDateTime(0), r => new SqlDateTime((DateTime)r[0]), r => new SqlDateTime((DateTime)r["col1"]))) + { + yield return data; + } + + foreach (object[] data in GenerateClrCombination(SqlDbType.DateTime2, new DateTime(2010, 1, 1, 12, 34, 56), + (r, o) => r.SetDateTime(0, (DateTime)o), + r => r.GetDateTime(0), r => r.GetSqlDateTime(0).Value)) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.DateTime2, new SqlDateTime(2010, 1, 1, 12, 34, 56), + (r, o) => r.SetSqlDateTime(0, (SqlDateTime)o), + r => new SqlDateTime(r.GetDateTime(0)), r => r.GetSqlDateTime(0), r => new SqlDateTime((DateTime)r[0]), r => new SqlDateTime((DateTime)r["col1"]))) + { + yield return data; + } + + foreach (object[] data in GenerateClrCombination(SqlDbType.Date, new DateTime(2010, 1, 1), + (r, o) => r.SetDateTime(0, (DateTime)o), + r => r.GetDateTime(0), r => r.GetSqlDateTime(0).Value)) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.Date, new SqlDateTime(2010, 1, 1), + (r, o) => r.SetSqlDateTime(0, (SqlDateTime)o), + r => new SqlDateTime(r.GetDateTime(0)), r => r.GetSqlDateTime(0), r => new SqlDateTime((DateTime)r[0]), r => new SqlDateTime((DateTime)r["col1"]))) + { + yield return data; + } + + foreach (object[] data in GenerateClrCombination(SqlDbType.DateTimeOffset, DateTimeOffset.Now, + (r, o) => r.SetDateTimeOffset(0, (DateTimeOffset)o), + r => r.GetDateTimeOffset(0))) + { + yield return data; + } + + foreach (object[] data in GenerateClrCombination(SqlDbType.Time, TimeSpan.FromHours(1), + (r, o) => r.SetTimeSpan(0, (TimeSpan)o), + r => r.GetTimeSpan(0))) + { + yield return data; + } + + foreach (object[] data in GenerateClrCombination(SqlDbType.Bit, bool.Parse(bool.TrueString), + (r, o) => r.SetBoolean(0, (bool)o), + r => r.GetBoolean(0), r => r.GetSqlBoolean(0).Value)) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.Bit, new SqlBoolean(bool.Parse(bool.TrueString)), + (r, o) => r.SetSqlBoolean(0, (SqlBoolean)o), + r => new SqlBoolean(r.GetBoolean(0)), r => r.GetSqlBoolean(0), r => new SqlBoolean((bool)r[0]), r => new SqlBoolean((bool)r["col1"]))) + { + yield return data; + } + + foreach (object[] data in GenerateClrCombination(SqlDbType.Xml, "", + (r, o) => r.SetString(0, (string)o), + r => r.GetString(0), r => r.GetSqlString(0).Value, r => r.GetSqlXml(0).Value)) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.Xml, "", + (r, o) => r.SetSqlXml(0, CreateSqlXmlFromString((string)o)), + r => r.GetString(0), r => r.GetSqlXml(0).Value, r => (string)r[0], r => (string)r["col1"])) + { + yield return data; + } + } + + private static SqlXml CreateSqlXmlFromString(string xmlString) + { + byte[] xmlBytes = Encoding.Unicode.GetBytes(xmlString); + + return new SqlXml(new System.IO.MemoryStream(xmlBytes)); + } + } + + public class GetVariableLengthCheckValueTestData : GetXXXCheckValueTestData + { + public override IEnumerator GetEnumerator() + { + SqlDbType[] dbTypes = new SqlDbType[] { SqlDbType.VarChar, SqlDbType.NVarChar, SqlDbType.Char, SqlDbType.NChar }; + byte[] binaryValue = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }; + + foreach (SqlDbType characterType in dbTypes) + { + foreach (object[] data in GenerateClrCombination(characterType, "string", + (r, o) => r.SetString(0, (string)o), + r => r.GetString(0), r => r.GetSqlString(0).Value)) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(characterType, new SqlString("string"), + (r, o) => r.SetSqlString(0, (SqlString)o), + r => new SqlString(r.GetString(0)), r => r.GetSqlString(0), r => new SqlString((string)r[0]), r => new SqlString((string)r["col1"]))) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(characterType, new SqlString("string"), + (r, o) => r.SetSqlChars(0, new SqlChars((SqlString)o)), + r => new SqlString(r.GetString(0)), r => r.GetSqlString(0), r => r.GetSqlChars(0).ToSqlString(), + r => new SqlString((string)r[0]), r => new SqlString((string)r["col1"]))) + { + yield return data; + } + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.VarBinary, new SqlBinary(binaryValue), + (r, o) => r.SetSqlBinary(0, (SqlBinary)o), + r => r.GetSqlBinary(0), r => r.GetSqlBytes(0).ToSqlBinary(), r => new SqlBinary((byte[])r[0]), r => new SqlBinary((byte[])r["col1"]))) + { + yield return data; + } + + foreach (object[] data in GenerateSqlCombination(SqlDbType.VarBinary, new SqlBinary(binaryValue), + (r, o) => r.SetSqlBytes(0, new SqlBytes((SqlBinary)o)), + r => r.GetSqlBinary(0), r => r.GetSqlBytes(0).ToSqlBinary(), r => new SqlBinary((byte[])r[0]), r => new SqlBinary((byte[])r["col1"]))) + { + yield return data; + } + } + } + [SqlServer.Server.SqlUserDefinedType(SqlServer.Server.Format.UserDefined)] public class TestUdt {} } diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlMetaDataTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlMetaDataTest.cs index dd6c228fd2..a27acd86ae 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlMetaDataTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlMetaDataTest.cs @@ -37,7 +37,7 @@ public void Adjust(SqlDbType dbType, object expected) } [Theory] - [MemberData(nameof(SqlMetaDataMaxLengthTrimValues))] + [MemberData(nameof(SqlMetaDataMaxLengthTrimOrPadValues))] public void AdjustWithGreaterThanMaxLengthValues(SqlDbType dbType, object value) { int maxLength = 4; @@ -82,6 +82,28 @@ public void AdjustWithInvalidType_Throws(SqlDbType dbType, object expected) Assert.Contains("invalid", ex.Message, StringComparison.OrdinalIgnoreCase); } + [Fact] + public void AdjustWithUdtValue_Throws() + { + SqlMetaData metaData = new SqlMetaData( + "col1", + SqlDbType.Variant, + 4, + 2, + 2, + 0, + SqlCompareOptions.IgnoreCase, + null, + true, + true, + SortOrder.Ascending, + 0); + ArgumentException ex = Assert.ThrowsAny(() => + { + object actual = metaData.Adjust(new Address()); + }); + Assert.Contains("no mapping exists from object type", ex.Message, StringComparison.OrdinalIgnoreCase); + } [Fact] public void AdjustWithNullBytes() @@ -105,12 +127,13 @@ public void AdjustWithNullBytes() Assert.Null(actual); } - [Fact] - public void AdjustWithNullChars() + [Theory] + [MemberData(nameof(SqlMetaDataStringTypes))] + public void AdjustWithNullChars(SqlDbType dbType) { SqlMetaData metaData = new SqlMetaData( "col1", - SqlDbType.VarChar, + dbType, 4, 2, 2, @@ -127,12 +150,13 @@ public void AdjustWithNullChars() Assert.Null(actual); } - [Fact] - public void AdjustWithNullString() + [Theory] + [MemberData(nameof(SqlMetaDataStringTypes))] + public void AdjustWithNullString(SqlDbType dbType) { SqlMetaData metaData = new SqlMetaData( "col1", - SqlDbType.VarChar, + dbType, 4, 2, 2, @@ -585,10 +609,7 @@ public void InferFromValue(SqlDbType expectedDbType, object value) } [Theory] - [InlineData((SByte)1)] - [InlineData((UInt16)1)] - [InlineData((UInt32)1)] - [InlineData((UInt64)1)] + [MemberData(nameof(SqlMetaDataInvalidInferredValues))] public void InferFromValueWithInvalidValue_Throws(object value) { ArgumentException ex = Assert.Throws(() => @@ -801,6 +822,18 @@ public void XmlConstructorWithNullObjectName_Throws() SqlMetaData metaData = new SqlMetaData("col1", SqlDbType.Xml, "NorthWindDb", "schema", null); }); Assert.Contains("null", ex.Message, StringComparison.OrdinalIgnoreCase); + + ex = Assert.Throws(() => + { + SqlMetaData metaData = new SqlMetaData("col1", SqlDbType.Xml, "NorthWindDb", null, null); + }); + Assert.Contains("null", ex.Message, StringComparison.OrdinalIgnoreCase); + + ex = Assert.Throws(() => + { + SqlMetaData metaData = new SqlMetaData("col1", SqlDbType.Xml, null, "schema", null); + }); + Assert.Contains("null", ex.Message, StringComparison.OrdinalIgnoreCase); } #region Test values @@ -813,15 +846,28 @@ public void XmlConstructorWithNullObjectName_Throws() new object[] {SqlDbType.SmallDateTime, DateTime.Today}, }; - public static readonly object[][] SqlMetaDataMaxLengthTrimValues = + public static readonly object[][] SqlMetaDataStringTypes = + { + new object[] {SqlDbType.VarChar}, + new object[] {SqlDbType.NVarChar}, + new object[] {SqlDbType.Char}, + new object[] {SqlDbType.NChar}, + }; + + public static readonly object[][] SqlMetaDataMaxLengthTrimOrPadValues = { new object[] {SqlDbType.Binary, new SqlBinary(new byte[] { 1, 2, 3, 4, 5 })}, new object[] {SqlDbType.Binary, new byte[] { 1, 2, 3, 4, 5 }}, new object[] {SqlDbType.Char, "Tests"}, new object[] {SqlDbType.Char, "T"}, new object[] {SqlDbType.Char, new char[]{'T','e','s','t','s'}}, + new object[] {SqlDbType.Char, new char[]{'T'}}, + new object[] {SqlDbType.Char, new SqlChars(new char[]{'T'})}, new object[] {SqlDbType.NChar, "T"}, new object[] {SqlDbType.NChar, "Tests"}, + new object[] {SqlDbType.NChar, new char[]{'T','e','s','t','s'}}, + new object[] {SqlDbType.NChar, new char[]{'T'}}, + new object[] {SqlDbType.NChar, new SqlChars(new char[]{'T'})}, new object[] {SqlDbType.VarChar, "Tests" }, new object[] {SqlDbType.VarChar, new SqlString("Tests")}, new object[] {SqlDbType.VarChar, new char[]{'T','e','s','t','s'}}, @@ -880,11 +926,13 @@ public void XmlConstructorWithNullObjectName_Throws() new object[] {SqlDbType.Bit, (UInt64)1}, new object[] {SqlDbType.Bit, (sbyte)0}, new object[] {SqlDbType.Int, Guid.Empty}, + new object[] {SqlDbType.Int, 'T'}, new object[] {SqlDbType.NText, 'T'}, new object[] {SqlDbType.SmallMoney, (decimal)int.MaxValue}, new object[] {SqlDbType.SmallMoney, "Money" }, new object[] {SqlDbType.Bit, 1.0M }, new object[] {SqlDbType.Bit, DateTime.Today}, + new object[] {SqlDbType.Variant, new object()}, }; public static readonly object[][] SqlMetaDataAdjustValues = @@ -910,6 +958,8 @@ public void XmlConstructorWithNullObjectName_Throws() new object[] {SqlDbType.SmallMoney, 10.01M }, new object[] {SqlDbType.Decimal, 0M }, new object[] {SqlDbType.Decimal, SqlDecimal.Null}, + new object[] {SqlDbType.Decimal, new SqlDecimal(2, 1, true, 0, 0, 0, 0)}, + new object[] {SqlDbType.Decimal, new SqlDecimal(1, 0, true, 0, 0, 0, 0)}, new object[] {SqlDbType.Char, SqlString.Null}, new object[] {SqlDbType.Char, new char[] {'T','e','s', 't'}}, new object[] {SqlDbType.Char, "Test"}, @@ -917,6 +967,7 @@ public void XmlConstructorWithNullObjectName_Throws() new object[] {SqlDbType.Char, new SqlString("Test")}, new object[] {SqlDbType.Char, SqlChars.Null}, new object[] {SqlDbType.Char, new SqlChars(new char[] { 'T', 'e', 's', 't' })}, + new object[] {SqlDbType.Char, new SqlChars(new char[] { 'T', 'e', 's', 't', 's'})}, new object[] {SqlDbType.NChar, SqlString.Null}, new object[] {SqlDbType.NChar, new char[] {'T','e' ,'s', 't'}}, new object[] {SqlDbType.NChar, SqlChars.Null}, @@ -924,6 +975,7 @@ public void XmlConstructorWithNullObjectName_Throws() new object[] {SqlDbType.NChar, new SqlString("T")}, new object[] {SqlDbType.NChar, new SqlString("Test")}, new object[] {SqlDbType.NChar, new SqlChars(new char[] { 'T', 'e', 's', 't' })}, + new object[] {SqlDbType.NChar, new SqlChars(new char[] { 'T', 'e', 's', 't', 's'})}, new object[] {SqlDbType.VarChar, 'T'}, new object[] {SqlDbType.VarChar, "T"}, new object[] {SqlDbType.VarChar, "Test"}, @@ -966,6 +1018,7 @@ public void XmlConstructorWithNullObjectName_Throws() new object[] {SqlDbType.DateTimeOffset, new DateTimeOffset(new DateTime(0), TimeSpan.Zero)}, new object[] {SqlDbType.UniqueIdentifier, SqlGuid.Null}, new object[] {SqlDbType.UniqueIdentifier, Guid.Empty}, + new object[] {SqlDbType.Xml, new SqlXml(new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes("
")))}, }; public static readonly object[][] SqlMetaDataInferredValues = @@ -1028,6 +1081,15 @@ public void XmlConstructorWithNullObjectName_Throws() new object[] {SqlDbType.Xml, new SqlXml()}, new object[] {SqlDbType.Variant, new object()} }; + + public static readonly object[][] SqlMetaDataInvalidInferredValues = + { + new object[] { (SByte)1 }, + new object[] { (UInt16)1 }, + new object[] { (UInt32)1 }, + new object[] { (UInt64)1 }, + new object[] { DBNull.Value }, + }; #endregion } } diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlParameterTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlParameterTest.cs index ad6b7ec259..b911540834 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlParameterTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlParameterTest.cs @@ -331,6 +331,8 @@ public void CompareInfo() Assert.Equal(SqlCompareOptions.None, parameter.CompareInfo); parameter.CompareInfo = SqlCompareOptions.IgnoreNonSpace; Assert.Equal(SqlCompareOptions.IgnoreNonSpace, parameter.CompareInfo); + + Assert.Throws(() => parameter.CompareInfo = (SqlCompareOptions)int.MaxValue); } [Fact] @@ -947,6 +949,9 @@ public void ParameterName() p.ParameterName = string.Empty; Assert.Equal(string.Empty, p.ParameterName); Assert.Equal(string.Empty, p.SourceColumn); + + Assert.Throws(() => p.ParameterName = new string('a', 128)); + Assert.Throws(() => p.ParameterName = "@" + new string('a', 128)); } [Fact] @@ -1141,24 +1146,58 @@ public void SqlDbTypeTest() Assert.Equal(3510, p.Value); } + [Theory] + [InlineData((SqlDbType)666)] + [InlineData((SqlDbType)24)] + public void SqlDbTypeTest_Value_Invalid(SqlDbType type) + { + SqlParameter p = new SqlParameter("zipcode", 3510); + // The SqlDbType enumeration value, (int)type, is invalid + ArgumentOutOfRangeException ex = Assert.Throws(() => p.SqlDbType = type); + + Assert.Null(ex.InnerException); + Assert.NotNull(ex.Message); + Assert.True(ex.Message.IndexOf(((int)type).ToString(), StringComparison.Ordinal) != -1); + Assert.Equal(nameof(p.SqlDbType), ex.ParamName); + } + [Fact] - public void SqlDbTypeTest_Value_Invalid() + public void DirectionTest_Value_Invalid() { SqlParameter p = new SqlParameter("zipcode", 3510); - try - { - p.SqlDbType = (SqlDbType)666; - } - catch (ArgumentOutOfRangeException ex) - { - // The SqlDbType enumeration value, 666, is - // invalid - Assert.Equal(typeof(ArgumentOutOfRangeException), ex.GetType()); - Assert.Null(ex.InnerException); - Assert.NotNull(ex.Message); - Assert.True(ex.Message.IndexOf("666", StringComparison.Ordinal) != -1); - Assert.Equal("SqlDbType", ex.ParamName); - } + // The ParameterDirection enumeration value, int.MaxValue, is invalid + ArgumentOutOfRangeException ex = Assert.Throws(() => p.Direction = (ParameterDirection)int.MaxValue); + + Assert.Null(ex.InnerException); + Assert.NotNull(ex.Message); + Assert.True(ex.Message.IndexOf(int.MaxValue.ToString(), StringComparison.Ordinal) != -1); + Assert.Equal(nameof(ParameterDirection), ex.ParamName); + } + + [Fact] + public void SourceVersionTest_Value_Invalid() + { + SqlParameter p = new SqlParameter("zipcode", 3510); + // The DataRowVersion enumeration value, int.MaxValue, is invalid + ArgumentOutOfRangeException ex = Assert.Throws(() => p.SourceVersion = (DataRowVersion)int.MaxValue); + + Assert.Null(ex.InnerException); + Assert.NotNull(ex.Message); + Assert.True(ex.Message.IndexOf(int.MaxValue.ToString(), StringComparison.Ordinal) != -1); + Assert.Equal(nameof(DataRowVersion), ex.ParamName); + } + + [Fact] + public void OffsetTest_Value_Invalid() + { + SqlParameter p = new SqlParameter("zipcode", 3510); + // Invalid parameter Offset value -1. The value must be greater than or equal to 0. + ArgumentException ex = Assert.Throws(() => p.Offset = -1); + + Assert.Null(ex.InnerException); + Assert.NotNull(ex.Message); + Assert.True(ex.Message.IndexOf("-1", StringComparison.Ordinal) != -1); + Assert.Null(ex.ParamName); } [Fact]