|
16 | 16 | // under the License. |
17 | 17 |
|
18 | 18 | #include <arrow/flight/sql/odbc/flight_sql/include/flight_sql/flight_sql_driver.h> |
| 19 | +#include <arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/diagnostics.h> |
| 20 | +#include <arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/attribute_utils.h> |
19 | 21 | #include <arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_connection.h> |
20 | 22 | #include <arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_environment.h> |
21 | 23 | #include <arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/connection.h> |
@@ -121,6 +123,108 @@ SQLRETURN SQLFreeHandle(SQLSMALLINT type, SQLHANDLE handle) { |
121 | 123 | return SQL_ERROR; |
122 | 124 | } |
123 | 125 |
|
| 126 | +SQLRETURN SQLGetDiagFieldW(SQLSMALLINT handleType, SQLHANDLE handle, |
| 127 | + SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier, |
| 128 | + SQLPOINTER diagInfoPtr, SQLSMALLINT bufferLength, |
| 129 | + SQLSMALLINT* stringLengthPtr) { |
| 130 | + using driver::odbcabstraction::Diagnostics; |
| 131 | + using ODBC::GetStringAttribute; |
| 132 | + using ODBC::ODBCConnection; |
| 133 | + using ODBC::ODBCEnvironment; |
| 134 | + |
| 135 | + if (!handle) { |
| 136 | + return SQL_INVALID_HANDLE; |
| 137 | + } |
| 138 | + |
| 139 | + if (!diagInfoPtr) { |
| 140 | + return SQL_ERROR; |
| 141 | + } |
| 142 | + |
| 143 | + // Set character type to be Unicode by defualt (not Ansi) |
| 144 | + bool isUnicode = true; |
| 145 | + Diagnostics* diagnostics = nullptr; |
| 146 | + |
| 147 | + switch (handleType) { |
| 148 | + case SQL_HANDLE_ENV: { |
| 149 | + ODBCEnvironment* environment = reinterpret_cast<ODBCEnvironment*>(handle); |
| 150 | + diagnostics = &environment->GetDiagnostics(); |
| 151 | + break; |
| 152 | + } |
| 153 | + |
| 154 | + case SQL_HANDLE_DBC: { |
| 155 | + ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(handle); |
| 156 | + diagnostics = &connection->GetDiagnostics(); |
| 157 | + break; |
| 158 | + } |
| 159 | + |
| 160 | + default: |
| 161 | + return SQL_ERROR; |
| 162 | + } |
| 163 | + |
| 164 | + if (!diagnostics) { |
| 165 | + return SQL_ERROR; |
| 166 | + } |
| 167 | + |
| 168 | + // Retrieve header level diagnostics if Record 0 specified |
| 169 | + if (recNumber == 0) { |
| 170 | + switch (diagIdentifier) { |
| 171 | + case SQL_DIAG_NUMBER: { |
| 172 | + SQLINTEGER count = static_cast<SQLINTEGER>(diagnostics->GetRecordCount()); |
| 173 | + *static_cast<SQLINTEGER*>(diagInfoPtr) = count; |
| 174 | + if (stringLengthPtr) { |
| 175 | + *stringLengthPtr = sizeof(SQLINTEGER); |
| 176 | + } |
| 177 | + |
| 178 | + return SQL_SUCCESS; |
| 179 | + } |
| 180 | + |
| 181 | + case SQL_DIAG_SERVER_NAME: { |
| 182 | + const std::string source = diagnostics->GetDataSourceComponent(); |
| 183 | + return GetStringAttribute(isUnicode, source, false, diagInfoPtr, bufferLength, |
| 184 | + stringLengthPtr, *diagnostics); |
| 185 | + } |
| 186 | + |
| 187 | + default: |
| 188 | + return SQL_ERROR; |
| 189 | + } |
| 190 | + } |
| 191 | + |
| 192 | + // Retrieve record level diagnostics from specified 1 based record |
| 193 | + uint32_t recordIndex = static_cast<uint32_t>(recNumber - 1); |
| 194 | + if (!diagnostics->HasRecord(recordIndex)) { |
| 195 | + return SQL_NO_DATA; |
| 196 | + } |
| 197 | + |
| 198 | + // Retrieve record field data |
| 199 | + switch (diagIdentifier) { |
| 200 | + case SQL_DIAG_MESSAGE_TEXT: { |
| 201 | + const std::string message = diagnostics->GetMessageText(recordIndex); |
| 202 | + return GetStringAttribute(isUnicode, message, false, diagInfoPtr, bufferLength, |
| 203 | + stringLengthPtr, *diagnostics); |
| 204 | + } |
| 205 | + |
| 206 | + case SQL_DIAG_NATIVE: { |
| 207 | + *static_cast<SQLINTEGER*>(diagInfoPtr) = diagnostics->GetNativeError(recordIndex); |
| 208 | + if (stringLengthPtr) { |
| 209 | + *stringLengthPtr = sizeof(SQLINTEGER); |
| 210 | + } |
| 211 | + |
| 212 | + return SQL_SUCCESS; |
| 213 | + } |
| 214 | + |
| 215 | + case SQL_DIAG_SQLSTATE: { |
| 216 | + const std::string state = diagnostics->GetSQLState(recordIndex); |
| 217 | + return GetStringAttribute(isUnicode, state, false, diagInfoPtr, bufferLength, |
| 218 | + stringLengthPtr, *diagnostics); |
| 219 | + } |
| 220 | + |
| 221 | + default: |
| 222 | + return SQL_ERROR; |
| 223 | + } |
| 224 | + |
| 225 | + return SQL_ERROR; |
| 226 | +} |
| 227 | + |
124 | 228 | SQLRETURN SQLGetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER valuePtr, |
125 | 229 | SQLINTEGER bufferLen, SQLINTEGER* strLenPtr) { |
126 | 230 | using driver::odbcabstraction::DriverException; |
|
0 commit comments