Skip to content

Commit dbea8c5

Browse files
authored
CPP-747 Fix case sensitive keyspaces (#467)
1 parent 6afcec7 commit dbea8c5

14 files changed

+296
-90
lines changed

src/connection.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ void Connection::maybe_set_keyspace(ResponseMessage* response) {
178178
if (response->opcode() == CQL_OPCODE_RESULT) {
179179
ResultResponse* result = static_cast<ResultResponse*>(response->response_body().get());
180180
if (result->kind() == CASS_RESULT_KIND_SET_KEYSPACE) {
181-
keyspace_ = result->keyspace().to_string();
181+
keyspace_ = result->quoted_keyspace();
182182
}
183183
}
184184
}

src/hash_table.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ class CaseInsensitiveHashTable : public Allocated {
7272

7373
private:
7474
size_t index_mask_;
75-
size_t count_;
7675
SmallVector<T*, 32> index_;
7776
EntryVec entries_;
7877

src/pooled_connection.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class ChainedSetKeyspaceCallback : public SimpleRequestCallback {
3737
class SetKeyspaceRequest : public QueryRequest {
3838
public:
3939
SetKeyspaceRequest(const String& keyspace, uint64_t request_timeout_ms)
40-
: QueryRequest("USE \"" + keyspace + "\"") {
40+
: QueryRequest("USE " + keyspace) {
4141
set_request_timeout_ms(request_timeout_ms);
4242
}
4343
};

src/request_handler.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ void RequestExecution::notify_result_metadata_changed(const Request* request,
461461
if (result_response->protocol_version().supports_set_keyspace() && !request->keyspace().empty()) {
462462
keyspace = request->keyspace();
463463
} else {
464-
keyspace = result_response->keyspace().to_string();
464+
keyspace = result_response->quoted_keyspace();
465465
}
466466

467467
if (request->opcode() == CQL_OPCODE_EXECUTE && result_response->kind() == CASS_RESULT_KIND_ROWS) {
@@ -531,7 +531,7 @@ void RequestExecution::on_result_response(Connection* connection, ResponseMessag
531531

532532
case CASS_RESULT_KIND_SET_KEYSPACE:
533533
// The response is set after the keyspace is propagated to all threads.
534-
request_handler_->notify_keyspace_changed(result->keyspace().to_string(), current_host_,
534+
request_handler_->notify_keyspace_changed(result->quoted_keyspace(), current_host_,
535535
response->response_body());
536536
break;
537537

src/result_response.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ class ResultResponse : public Response {
6666
StringRef keyspace() const { return keyspace_; }
6767
StringRef table() const { return table_; }
6868

69+
String quoted_keyspace() const {
70+
String temp(keyspace_.to_string());
71+
return escape_id(temp);
72+
}
73+
6974
bool metadata_changed() { return new_metadata_id_.size() > 0; }
7075
StringRef new_metadata_id() const { return new_metadata_id_; }
7176

src/statement.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ Statement::Statement(const Prepared* prepared)
281281
// If the keyspace wasn't explictly set then attempt to set it using the
282282
// prepared statement's result metadata.
283283
if (keyspace().empty()) {
284-
set_keyspace(prepared->result()->keyspace().to_string());
284+
set_keyspace(prepared->result()->quoted_keyspace());
285285
}
286286
}
287287

src/utils.cpp

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -111,38 +111,22 @@ String& trim(String& str) {
111111
return str;
112112
}
113113

114-
static bool is_word_char(int c) {
115-
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_';
116-
}
114+
static bool is_lowercase(const String& str) {
115+
if (str.empty()) return true;
117116

118-
static bool is_lower_word_char(int c) {
119-
return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '_';
120-
}
117+
char c = str[0];
118+
if (!(c >= 'a' && c <= 'z')) return false;
121119

122-
bool is_valid_cql_id(const String& str) {
123-
for (String::const_iterator i = str.begin(), end = str.end(); i != end; ++i) {
124-
if (!is_word_char(*i)) {
120+
for (String::const_iterator it = str.begin() + 1, end = str.end(); it != end; ++it) {
121+
char c = *it;
122+
if (!((c >= '0' && c <= '9') || (c == '_') || (c >= 'a' && c <= 'z'))) {
125123
return false;
126124
}
127125
}
128126
return true;
129127
}
130128

131-
bool is_valid_lower_cql_id(const String& str) {
132-
if (str.empty() || !is_lower_word_char(str[0])) {
133-
return false;
134-
}
135-
if (str.size() > 1) {
136-
for (String::const_iterator i = str.begin() + 1, end = str.end(); i != end; ++i) {
137-
if (!is_lower_word_char(*i)) {
138-
return false;
139-
}
140-
}
141-
}
142-
return true;
143-
}
144-
145-
String& quote_id(String& str) {
129+
static String& quote_id(String& str) {
146130
String temp(str);
147131
str.clear();
148132
str.push_back('"');
@@ -159,18 +143,7 @@ String& quote_id(String& str) {
159143
return str;
160144
}
161145

162-
String& escape_id(String& str) { return is_valid_lower_cql_id(str) ? str : quote_id(str); }
163-
164-
String& to_cql_id(String& str) {
165-
if (is_valid_cql_id(str)) {
166-
std::transform(str.begin(), str.end(), str.begin(), tolower);
167-
return str;
168-
}
169-
if (str.length() > 2 && str[0] == '"' && str[str.length() - 1] == '"') {
170-
return str.erase(str.length() - 1, 1).erase(0, 1);
171-
}
172-
return str;
173-
}
146+
String& escape_id(String& str) { return is_lowercase(str) ? str : quote_id(str); }
174147

175148
int32_t get_pid() {
176149
#if (defined(WIN32) || defined(_WIN32))

src/utils.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,6 @@ String implode(const Vector<String>& vec, const char delimiter = ',');
6969

7070
String& trim(String& str);
7171

72-
bool is_valid_cql_id(const String& str);
73-
74-
String& to_cql_id(String& str);
75-
7672
String& escape_id(String& str);
7773

7874
inline size_t num_leading_zeros(int64_t value) {
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
Copyright (c) DataStax, Inc.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
#include "integration.hpp"
18+
19+
#include <locale>
20+
21+
/**
22+
* "USE <keyspace>" case-sensitive tests
23+
*/
24+
class UseKeyspaceCaseSensitiveTests : public Integration {
25+
public:
26+
UseKeyspaceCaseSensitiveTests() {}
27+
28+
// Make a case-sensitive keyspace capitalizing the first char and wrapping in double quotes
29+
virtual std::string default_keyspace() {
30+
std::string temp(Integration::default_keyspace());
31+
temp[0] = std::toupper(temp[0]);
32+
return "\"" + temp + "\"";
33+
}
34+
35+
virtual void SetUp() {
36+
Integration::SetUp();
37+
session_.execute(
38+
format_string(CASSANDRA_KEY_VALUE_TABLE_FORMAT, table_name_.c_str(), "int", "int"));
39+
session_.execute(
40+
format_string(CASSANDRA_KEY_VALUE_INSERT_FORMAT, table_name_.c_str(), "1", "2"));
41+
}
42+
};
43+
44+
/**
45+
* Verify that case-sensitive keyspaces work when connecting a session with a keyspace.
46+
*/
47+
CASSANDRA_INTEGRATION_TEST_F(UseKeyspaceCaseSensitiveTests, ConnectWithKeyspace) {
48+
CHECK_FAILURE;
49+
Session session = default_cluster().connect(keyspace_name_);
50+
51+
Result result =
52+
session.execute(format_string(CASSANDRA_SELECT_VALUE_FORMAT, table_name_.c_str(), "1"));
53+
54+
Row row = result.first_row();
55+
EXPECT_EQ(row.column_by_name<Integer>("value"), Integer(2));
56+
}
57+
58+
/**
59+
* Verify that case-sensitive keyspaces work with "USE <keyspace>".
60+
*/
61+
CASSANDRA_INTEGRATION_TEST_F(UseKeyspaceCaseSensitiveTests, UseKeyspace) {
62+
CHECK_FAILURE;
63+
Session session = default_cluster().connect();
64+
65+
{ // Expect failure there's no keyspace set
66+
Result result =
67+
session.execute(format_string(CASSANDRA_SELECT_VALUE_FORMAT, table_name_.c_str(), "1"),
68+
CASS_CONSISTENCY_ONE, false, false);
69+
70+
EXPECT_EQ(result.error_code(), CASS_ERROR_SERVER_INVALID_QUERY);
71+
}
72+
73+
session.execute("USE " + keyspace_name_);
74+
75+
{ // Success
76+
Result result =
77+
session.execute(format_string(CASSANDRA_SELECT_VALUE_FORMAT, table_name_.c_str(), "1"));
78+
79+
Row row = result.first_row();
80+
EXPECT_EQ(row.column_by_name<Integer>("value"), Integer(2));
81+
}
82+
}

tests/src/unit/mockssandra.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "memory.hpp"
2424
#include "scoped_lock.hpp"
2525
#include "tracing_data_handler.hpp" // For tracing query
26+
#include "utils.hpp"
2627
#include "uuids.hpp"
2728

2829
#include <openssl/bio.h>
@@ -34,10 +35,12 @@
3435
#endif
3536

3637
using datastax::internal::bind_callback;
38+
using datastax::internal::escape_id;
3739
using datastax::internal::Map;
3840
using datastax::internal::Memory;
3941
using datastax::internal::OStringStream;
4042
using datastax::internal::ScopedMutex;
43+
using datastax::internal::trim;
4144
using datastax::internal::core::UuidGen;
4245

4346
#define SSL_BUF_SIZE 8192
@@ -1357,6 +1360,10 @@ Action::Builder& Action::Builder::use_keyspace(const String& keyspace) {
13571360
return execute((new UseKeyspace(keyspace)));
13581361
}
13591362

1363+
Action::Builder& Action::Builder::use_keyspace(const Vector<String>& keyspaces) {
1364+
return execute((new UseKeyspace(keyspaces)));
1365+
}
1366+
13601367
Action::Builder& Action::Builder::plaintext_auth(const String& username, const String& password) {
13611368
return execute((new PlaintextAuth(username, password)));
13621369
}
@@ -1807,18 +1814,22 @@ void UseKeyspace::on_run(Request* request) const {
18071814
String query;
18081815
QueryParameters params;
18091816
if (request->decode_query(&query, &params)) {
1810-
query.erase(0, query.find_first_not_of(" \t"));
1811-
if (query.substr(0, 3) == "USE" || query.substr(0, 3) == "use") {
1812-
query.erase(0, 3);
1813-
query.erase(0, query.find_first_not_of(" \t"));
1814-
if (query.substr(0, keyspace.size()) == keyspace) {
1815-
String body;
1816-
encode_int32(RESULT_SET_KEYSPACE, &body);
1817-
encode_string(keyspace, &body);
1818-
request->write(OPCODE_RESULT, body);
1819-
} else {
1820-
request->error(ERROR_INVALID_QUERY, "Keyspace '" + keyspace + "' does not exist");
1817+
trim(query);
1818+
if (query.compare(0, 3, "USE") == 0 || query.compare(0, 3, "use") == 0) {
1819+
String keyspace(query.substr(query.find_first_not_of(" \t", 3)));
1820+
for (Vector<String>::const_iterator it = keyspaces.begin(), end = keyspaces.end(); it != end;
1821+
++it) {
1822+
String temp(*it);
1823+
if (keyspace == escape_id(temp)) {
1824+
String body;
1825+
encode_int32(RESULT_SET_KEYSPACE, &body);
1826+
encode_string(*it, &body);
1827+
request->client()->set_keyspace(*it);
1828+
request->write(OPCODE_RESULT, body);
1829+
return;
1830+
}
18211831
}
1832+
request->error(ERROR_INVALID_QUERY, "Keyspace '" + keyspace + "' does not exist");
18221833
} else {
18231834
run_next(request);
18241835
}

0 commit comments

Comments
 (0)