diff --git a/src/binder/bind_create.cpp b/src/binder/bind_create.cpp index 6b3371206..17f605652 100644 --- a/src/binder/bind_create.cpp +++ b/src/binder/bind_create.cpp @@ -46,6 +46,7 @@ #include "fmt/ranges.h" #include "nodes/nodes.hpp" #include "nodes/primnodes.hpp" +#include "nodes/value.hpp" #include "pg_definitions.hpp" #include "postgres_parser.hpp" #include "type/type_id.h" @@ -86,6 +87,16 @@ auto Binder::BindColumnDefinition(duckdb_libpgquery::PGColumnDef *cdef) -> Colum return {colname, TypeId::VARCHAR, varchar_max_length}; } + if (name == "vector") { + auto exprs = BindExpressionList(cdef->typeName->typmods); + if (exprs.size() != 1) { + throw bustub::Exception("should specify vector length"); + } + const auto &vector_length_val = dynamic_cast(*exprs[0]); + uint32_t vector_length = std::stoi(vector_length_val.ToString()); + return {colname, TypeId::VECTOR, vector_length}; + } + throw NotImplementedException(fmt::format("unsupported type: {}", name)); } @@ -156,6 +167,7 @@ auto Binder::BindCreate(duckdb_libpgquery::PGCreateStmt *pg_stmt) -> std::unique auto Binder::BindIndex(duckdb_libpgquery::PGIndexStmt *stmt) -> std::unique_ptr { std::vector> cols; + std::vector col_options; auto table = BindBaseTableRef(stmt->relation->relname, std::nullopt); for (auto cell = stmt->indexParams->head; cell != nullptr; cell = cell->next) { @@ -163,12 +175,43 @@ auto Binder::BindIndex(duckdb_libpgquery::PGIndexStmt *stmt) -> std::unique_ptr< if (index_element->name != nullptr) { auto column_ref = ResolveColumn(*table, std::vector{std::string(index_element->name)}); cols.emplace_back(std::make_unique(dynamic_cast(*column_ref))); + std::string opt; + if (index_element->opclass != nullptr) { + for (auto c = index_element->opclass->head; c != nullptr; c = lnext(c)) { + opt = reinterpret_cast(c->data.ptr_value)->val.str; + break; + } + } + col_options.emplace_back(opt); } else { throw NotImplementedException("create index by expr is not supported yet"); } } - return std::make_unique(stmt->idxname, std::move(table), std::move(cols)); + std::string index_type; + + if (stmt->accessMethod != nullptr) { + index_type = stmt->accessMethod; + if (index_type == "art") { + index_type = ""; + } + } + + std::vector> options; + + if (stmt->options != nullptr) { + for (auto c = stmt->options->head; c != nullptr; c = lnext(c)) { + auto def_elem = reinterpret_cast(c->data.ptr_value); + int val; + if (def_elem->arg != nullptr) { + val = reinterpret_cast(def_elem->arg)->val.ival; + } + options.emplace_back(def_elem->defname, val); + } + } + + return std::make_unique(stmt->idxname, std::move(table), std::move(cols), std::move(index_type), + std::move(col_options), std::move(options)); } } // namespace bustub diff --git a/src/binder/bind_select.cpp b/src/binder/bind_select.cpp index e08867c3f..d9eac42e1 100644 --- a/src/binder/bind_select.cpp +++ b/src/binder/bind_select.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include "binder/binder.h" #include "binder/bound_expression.h" @@ -467,6 +468,10 @@ auto Binder::BindConstant(duckdb_libpgquery::PGAConst *node) -> std::unique_ptr< BUSTUB_ENSURE(val.val.ival <= BUSTUB_INT32_MAX, "value out of range"); return std::make_unique(ValueFactory::GetIntegerValue(static_cast(val.val.ival))); } + case duckdb_libpgquery::T_PGFloat: { + double parsed_val = std::stod(std::string(val.val.str)); + return std::make_unique(ValueFactory::GetDecimalValue(parsed_val)); + } case duckdb_libpgquery::T_PGString: { return std::make_unique(ValueFactory::GetVarcharValue(val.val.str)); } diff --git a/src/binder/statement/index_statement.cpp b/src/binder/statement/index_statement.cpp index 90471d47f..5595e0f43 100644 --- a/src/binder/statement/index_statement.cpp +++ b/src/binder/statement/index_statement.cpp @@ -7,14 +7,19 @@ namespace bustub { IndexStatement::IndexStatement(std::string index_name, std::unique_ptr table, - std::vector> cols) + std::vector> cols, std::string index_type, + std::vector col_options, std::vector> options) : BoundStatement(StatementType::INDEX_STATEMENT), index_name_(std::move(index_name)), table_(std::move(table)), - cols_(std::move(cols)) {} + cols_(std::move(cols)), + index_type_(std::move(index_type)), + col_options_(std::move(col_options)), + options_(std::move(options)) {} auto IndexStatement::ToString() const -> std::string { - return fmt::format("BoundIndex {{ index_name={}, table={}, cols={} }}", index_name_, *table_, cols_); + return fmt::format("BoundIndex {{ index_name={}, table={}, cols={}, using={}, col_options=[{}], options=[{}] }}", + index_name_, *table_, cols_, index_type_, fmt::join(col_options_, ","), fmt::join(options_, ",")); } } // namespace bustub diff --git a/src/catalog/column.cpp b/src/catalog/column.cpp index db9aa78a0..4877481e7 100644 --- a/src/catalog/column.cpp +++ b/src/catalog/column.cpp @@ -14,6 +14,7 @@ #include #include +#include "type/type_id.h" namespace bustub { @@ -21,6 +22,12 @@ auto Column::ToString(bool simplified) const -> std::string { if (simplified) { std::ostringstream os; os << column_name_ << ":" << Type::TypeIdToString(column_type_); + if (column_type_ == VARCHAR) { + os << "(" << length_ << ")"; + } + if (column_type_ == VECTOR) { + os << "(" << length_ / sizeof(double) << ")"; + } return (os.str()); } @@ -28,12 +35,7 @@ auto Column::ToString(bool simplified) const -> std::string { os << "Column[" << column_name_ << ", " << Type::TypeIdToString(column_type_) << ", " << "Offset:" << column_offset_ << ", "; - - if (IsInlined()) { - os << "FixedLength:" << fixed_length_; - } else { - os << "VarLength:" << variable_length_; - } + os << "Length:" << length_; os << "]"; return (os.str()); } diff --git a/src/catalog/schema.cpp b/src/catalog/schema.cpp index 5a4dd7140..dbc818654 100644 --- a/src/catalog/schema.cpp +++ b/src/catalog/schema.cpp @@ -29,7 +29,11 @@ Schema::Schema(const std::vector &columns) { } // set column offset column.column_offset_ = curr_offset; - curr_offset += column.GetFixedLength(); + if (column.IsInlined()) { + curr_offset += column.GetStorageSize(); + } else { + curr_offset += sizeof(uint32_t); + } // add column this->columns_.push_back(column); diff --git a/src/common/bustub_ddl.cpp b/src/common/bustub_ddl.cpp index 02fc01ad4..f3a8adef2 100644 --- a/src/common/bustub_ddl.cpp +++ b/src/common/bustub_ddl.cpp @@ -106,15 +106,37 @@ void BustubInstance::HandleIndexStatement(Transaction *txn, const IndexStatement } std::unique_lock l(catalog_lock_); - auto info = catalog_->CreateIndex( - txn, stmt.index_name_, stmt.table_->table_, stmt.table_->schema_, key_schema, col_ids, TWO_INTEGER_SIZE, - IntegerHashFunctionType{}, false, IndexType::HashTableIndex); + IndexInfo *info = nullptr; + + if (stmt.index_type_.empty()) { + info = catalog_->CreateIndex( + txn, stmt.index_name_, stmt.table_->table_, stmt.table_->schema_, key_schema, col_ids, TWO_INTEGER_SIZE, + IntegerHashFunctionType{}, false); // create default index + } else if (stmt.index_type_ == "hash") { + info = catalog_->CreateIndex( + txn, stmt.index_name_, stmt.table_->table_, stmt.table_->schema_, key_schema, col_ids, TWO_INTEGER_SIZE, + IntegerHashFunctionType{}, false, IndexType::HashTableIndex); + } else if (stmt.index_type_ == "bplustree") { + info = catalog_->CreateIndex( + txn, stmt.index_name_, stmt.table_->table_, stmt.table_->schema_, key_schema, col_ids, TWO_INTEGER_SIZE, + IntegerHashFunctionType{}, false, IndexType::BPlusTreeIndex); + } else if (stmt.index_type_ == "stl_ordered") { + info = catalog_->CreateIndex( + txn, stmt.index_name_, stmt.table_->table_, stmt.table_->schema_, key_schema, col_ids, TWO_INTEGER_SIZE, + IntegerHashFunctionType{}, false, IndexType::STLOrderedIndex); + } else if (stmt.index_type_ == "stl_unordered") { + info = catalog_->CreateIndex( + txn, stmt.index_name_, stmt.table_->table_, stmt.table_->schema_, key_schema, col_ids, TWO_INTEGER_SIZE, + IntegerHashFunctionType{}, false, IndexType::STLUnorderedIndex); + } else { + UNIMPLEMENTED("unsupported index type " + stmt.index_type_); + } l.unlock(); if (info == nullptr) { throw bustub::Exception("Failed to create index"); } - WriteOneCell(fmt::format("Index created with id = {}", info->index_oid_), writer); + WriteOneCell(fmt::format("Index created with id = {} with type = {}", info->index_oid_, info->index_type_), writer); } void BustubInstance::HandleExplainStatement(Transaction *txn, const ExplainStatement &stmt, ResultWriter &writer) { diff --git a/src/execution/plan_node.cpp b/src/execution/plan_node.cpp index 47a0dfdff..ed55226a0 100644 --- a/src/execution/plan_node.cpp +++ b/src/execution/plan_node.cpp @@ -18,7 +18,7 @@ auto SeqScanPlanNode::InferScanSchema(const BoundBaseTableRef &table) -> Schema std::vector schema; for (const auto &column : table.schema_.GetColumns()) { auto col_name = fmt::format("{}.{}", table.GetBoundTableName(), column.GetName()); - schema.emplace_back(Column(col_name, column)); + schema.emplace_back(col_name, column); } return Schema(schema); } @@ -38,12 +38,7 @@ auto ProjectionPlanNode::InferProjectionSchema(const std::vector schema; for (const auto &expr : expressions) { auto type_id = expr->GetReturnType(); - if (type_id != TypeId::VARCHAR) { - schema.emplace_back("", type_id); - } else { - // TODO(chi): infer the correct VARCHAR length. Maybe it doesn't matter for executors? - schema.emplace_back("", type_id, VARCHAR_DEFAULT_LENGTH); - } + schema.emplace_back(expr->GetReturnType().WithColumnName("")); } return Schema(schema); } @@ -55,7 +50,7 @@ auto ProjectionPlanNode::RenameSchema(const Schema &schema, const std::vector output; output.reserve(group_bys.size() + aggregates.size()); for (const auto &column : group_bys) { - // TODO(chi): correctly process VARCHAR column - if (column->GetReturnType() == TypeId::VARCHAR) { - output.emplace_back(Column("", column->GetReturnType(), 128)); - } else { - output.emplace_back(Column("", column->GetReturnType())); - } + output.emplace_back(column->GetReturnType().WithColumnName("")); } for (size_t idx = 0; idx < aggregates.size(); idx++) { // TODO(chi): correctly infer agg call return type - output.emplace_back(Column("", TypeId::INTEGER)); + output.emplace_back("", TypeId::INTEGER); } return Schema(output); } auto WindowFunctionPlanNode::InferWindowSchema(const std::vector &columns) -> Schema { std::vector output; + output.reserve(columns.size()); // TODO(avery): correctly infer window call return type for (const auto &column : columns) { - // TODO(chi): correctly process VARCHAR column - if (column->GetReturnType() == TypeId::VARCHAR) { - output.emplace_back(Column("", column->GetReturnType(), 128)); - } else { - output.emplace_back(Column("", column->GetReturnType())); - } + output.emplace_back(column->GetReturnType().WithColumnName("")); } return Schema(output); } diff --git a/src/include/binder/statement/index_statement.h b/src/include/binder/statement/index_statement.h index 97438877b..d27f539d5 100644 --- a/src/include/binder/statement/index_statement.h +++ b/src/include/binder/statement/index_statement.h @@ -9,6 +9,7 @@ #include #include +#include #include #include "binder/bound_statement.h" @@ -21,7 +22,8 @@ namespace bustub { class IndexStatement : public BoundStatement { public: explicit IndexStatement(std::string index_name, std::unique_ptr table, - std::vector> cols); + std::vector> cols, std::string index_type, + std::vector col_options, std::vector> options); /** Name of the index */ std::string index_name_; @@ -32,6 +34,12 @@ class IndexStatement : public BoundStatement { /** Name of the columns */ std::vector> cols_; + /** Using */ + std::string index_type_; + + std::vector col_options_; + std::vector> options_; + auto ToString() const -> std::string override; }; diff --git a/src/include/catalog/catalog.h b/src/include/catalog/catalog.h index c34ef5543..3591356d0 100644 --- a/src/include/catalog/catalog.h +++ b/src/include/catalog/catalog.h @@ -25,7 +25,10 @@ #include "storage/index/b_plus_tree_index.h" #include "storage/index/extendible_hash_table_index.h" #include "storage/index/index.h" +#include "storage/index/stl_ordered.h" +#include "storage/index/stl_unordered.h" #include "storage/table/table_heap.h" +#include "storage/table/tuple.h" namespace bustub { @@ -36,7 +39,7 @@ using table_oid_t = uint32_t; using column_oid_t = uint32_t; using index_oid_t = uint32_t; -enum class IndexType { BPlusTreeIndex, HashTableIndex }; +enum class IndexType { BPlusTreeIndex, HashTableIndex, STLOrderedIndex, STLUnorderedIndex }; /** * The TableInfo class maintains metadata about a table. @@ -75,14 +78,15 @@ struct IndexInfo { * @param key_size The size of the index key, in bytes */ IndexInfo(Schema key_schema, std::string name, std::unique_ptr &&index, index_oid_t index_oid, - std::string table_name, size_t key_size, bool is_primary_key) + std::string table_name, size_t key_size, bool is_primary_key, IndexType index_type) : key_schema_{std::move(key_schema)}, name_{std::move(name)}, index_{std::move(index)}, index_oid_{index_oid}, table_name_{std::move(table_name)}, key_size_{key_size}, - is_primary_key_{is_primary_key} {} + is_primary_key_{is_primary_key}, + index_type_(index_type) {} /** The schema for the index key */ Schema key_schema_; /** The name of the index */ @@ -98,7 +102,7 @@ struct IndexInfo { /** Is primary key index? */ bool is_primary_key_; /** The index type */ - [[maybe_unused]] IndexType index_type_{IndexType::BPlusTreeIndex}; + IndexType index_type_; }; /** @@ -241,9 +245,15 @@ class Catalog { if (index_type == IndexType::HashTableIndex) { index = std::make_unique>(std::move(meta), bpm_, hash_function); - } else { - BUSTUB_ASSERT(index_type == IndexType::BPlusTreeIndex, "Unsupported Index Type"); + } else if (index_type == IndexType::BPlusTreeIndex) { index = std::make_unique>(std::move(meta), bpm_); + } else if (index_type == IndexType::STLOrderedIndex) { + index = std::make_unique>(std::move(meta), bpm_); + } else if (index_type == IndexType::STLUnorderedIndex) { + index = + std::make_unique>(std::move(meta), bpm_, hash_function); + } else { + UNIMPLEMENTED("Unsupported Index Type"); } // Populate the index with all tuples in table heap @@ -259,7 +269,7 @@ class Catalog { // Construct index information; IndexInfo takes ownership of the Index itself auto index_info = std::make_unique(key_schema, index_name, std::move(index), index_oid, table_name, - keysize, is_primary_key); + keysize, is_primary_key, index_type); auto *tmp = index_info.get(); // Update internal tracking @@ -393,3 +403,29 @@ class Catalog { }; } // namespace bustub + +template <> +struct fmt::formatter : formatter { + template + auto format(bustub::IndexType c, FormatContext &ctx) const { + string_view name; + switch (c) { + case bustub::IndexType::BPlusTreeIndex: + name = "BPlusTree"; + break; + case bustub::IndexType::HashTableIndex: + name = "Hash"; + break; + case bustub::IndexType::STLOrderedIndex: + name = "STLOrdered"; + break; + case bustub::IndexType::STLUnorderedIndex: + name = "STLUnordered"; + break; + default: + name = "Unknown"; + break; + } + return formatter::format(name, ctx); + } +}; diff --git a/src/include/catalog/column.h b/src/include/catalog/column.h index b645ca2a3..a38ae2c41 100644 --- a/src/include/catalog/column.h +++ b/src/include/catalog/column.h @@ -22,6 +22,7 @@ #include "common/exception.h" #include "common/macros.h" #include "type/type.h" +#include "type/type_id.h" namespace bustub { class AbstractExpression; @@ -36,8 +37,9 @@ class Column { * @param type type of the column */ Column(std::string column_name, TypeId type) - : column_name_(std::move(column_name)), column_type_(type), fixed_length_(TypeSize(type)) { + : column_name_(std::move(column_name)), column_type_(type), length_(TypeSize(type)) { BUSTUB_ASSERT(type != TypeId::VARCHAR, "Wrong constructor for VARCHAR type."); + BUSTUB_ASSERT(type != TypeId::VECTOR, "Wrong constructor for VECTOR type."); } /** @@ -48,11 +50,8 @@ class Column { * @param expr expression used to create this column */ Column(std::string column_name, TypeId type, uint32_t length) - : column_name_(std::move(column_name)), - column_type_(type), - fixed_length_(TypeSize(type)), - variable_length_(length) { - BUSTUB_ASSERT(type == TypeId::VARCHAR, "Wrong constructor for non-VARCHAR type."); + : column_name_(std::move(column_name)), column_type_(type), length_(TypeSize(type, length)) { + BUSTUB_ASSERT(type == TypeId::VARCHAR || type == TypeId::VECTOR, "Wrong constructor for fixed-size type."); } /** @@ -63,26 +62,20 @@ class Column { Column(std::string column_name, const Column &column) : column_name_(std::move(column_name)), column_type_(column.column_type_), - fixed_length_(column.fixed_length_), - variable_length_(column.variable_length_), + length_(column.length_), column_offset_(column.column_offset_) {} + auto WithColumnName(std::string column_name) -> Column { + Column c = *this; + c.column_name_ = std::move(column_name); + return c; + } + /** @return column name */ auto GetName() const -> std::string { return column_name_; } /** @return column length */ - auto GetLength() const -> uint32_t { - if (IsInlined()) { - return fixed_length_; - } - return variable_length_; - } - - /** @return column fixed length */ - auto GetFixedLength() const -> uint32_t { return fixed_length_; } - - /** @return column variable length */ - auto GetVariableLength() const -> uint32_t { return variable_length_; } + auto GetStorageSize() const -> uint32_t { return length_; } /** @return column's offset in the tuple */ auto GetOffset() const -> uint32_t { return column_offset_; } @@ -91,7 +84,7 @@ class Column { auto GetType() const -> TypeId { return column_type_; } /** @return true if column is inlined, false otherwise */ - auto IsInlined() const -> bool { return column_type_ != TypeId::VARCHAR; } + auto IsInlined() const -> bool { return column_type_ != TypeId::VARCHAR && column_type_ != TypeId::VECTOR; } /** @return a string representation of this column */ auto ToString(bool simplified = true) const -> std::string; @@ -102,7 +95,7 @@ class Column { * @param type type whose size is to be determined * @return size in bytes */ - static auto TypeSize(TypeId type) -> uint8_t { + static auto TypeSize(TypeId type, uint32_t length = 0) -> uint8_t { switch (type) { case TypeId::BOOLEAN: case TypeId::TINYINT: @@ -116,8 +109,9 @@ class Column { case TypeId::TIMESTAMP: return 8; case TypeId::VARCHAR: - // TODO(Amadou): Confirm this. - return 12; + return length; + case TypeId::VECTOR: + return length * sizeof(double); default: { UNREACHABLE("Cannot get size of invalid type"); } @@ -130,11 +124,8 @@ class Column { /** Column value's type. */ TypeId column_type_; - /** For a non-inlined column, this is the size of a pointer. Otherwise, the size of the fixed length column. */ - uint32_t fixed_length_; - - /** For an inlined column, 0. Otherwise, the length of the variable length column. */ - uint32_t variable_length_{0}; + /** The size of the column. */ + uint32_t length_; /** Column offset in the tuple. */ uint32_t column_offset_{0}; diff --git a/src/include/catalog/schema.h b/src/include/catalog/schema.h index 6681acda5..c85d3f5c3 100644 --- a/src/include/catalog/schema.h +++ b/src/include/catalog/schema.h @@ -91,7 +91,7 @@ class Schema { auto GetUnlinedColumnCount() const -> uint32_t { return static_cast(uninlined_columns_.size()); } /** @return the number of bytes used by one tuple */ - inline auto GetLength() const -> uint32_t { return length_; } + inline auto GetInlinedStorageSize() const -> uint32_t { return length_; } /** @return true if all columns are inlined, false otherwise */ inline auto IsInlined() const -> bool { return tuple_is_inlined_; } diff --git a/src/include/common/util/hash_util.h b/src/include/common/util/hash_util.h index 469c9ce35..e584a9463 100644 --- a/src/include/common/util/hash_util.h +++ b/src/include/common/util/hash_util.h @@ -88,7 +88,7 @@ class HashUtil { } case TypeId::VARCHAR: { auto raw = val->GetData(); - auto len = val->GetLength(); + auto len = val->GetStorageSize(); return HashBytes(raw, len); } case TypeId::TIMESTAMP: { diff --git a/src/include/execution/expressions/abstract_expression.h b/src/include/execution/expressions/abstract_expression.h index 97f6dfbe3..96d84fbf7 100644 --- a/src/include/execution/expressions/abstract_expression.h +++ b/src/include/execution/expressions/abstract_expression.h @@ -17,9 +17,11 @@ #include #include +#include "catalog/column.h" #include "catalog/schema.h" #include "fmt/format.h" #include "storage/table/tuple.h" +#include "type/type.h" #define BUSTUB_EXPR_CLONE_WITH_CHILDREN(cname) \ auto CloneWithChildren(std::vector children) const->std::unique_ptr \ @@ -45,8 +47,8 @@ class AbstractExpression { * @param children the children of this abstract expression * @param ret_type the return type of this abstract expression when it is evaluated */ - AbstractExpression(std::vector children, TypeId ret_type) - : children_{std::move(children)}, ret_type_{ret_type} {} + AbstractExpression(std::vector children, Column ret_type) + : children_{std::move(children)}, ret_type_{std::move(ret_type)} {} /** Virtual destructor. */ virtual ~AbstractExpression() = default; @@ -72,7 +74,7 @@ class AbstractExpression { auto GetChildren() const -> const std::vector & { return children_; } /** @return the type of this expression if it were to be evaluated */ - virtual auto GetReturnType() const -> TypeId { return ret_type_; } + virtual auto GetReturnType() const -> Column { return ret_type_; } /** @return the string representation of the plan node and its children */ virtual auto ToString() const -> std::string { return ""; } @@ -86,7 +88,7 @@ class AbstractExpression { private: /** The return type of this expression. */ - TypeId ret_type_; + Column ret_type_; }; } // namespace bustub diff --git a/src/include/execution/expressions/arithmetic_expression.h b/src/include/execution/expressions/arithmetic_expression.h index 9d99a8cec..cae3aab97 100644 --- a/src/include/execution/expressions/arithmetic_expression.h +++ b/src/include/execution/expressions/arithmetic_expression.h @@ -23,6 +23,7 @@ #include "execution/expressions/abstract_expression.h" #include "fmt/format.h" #include "storage/table/tuple.h" +#include "type/type.h" #include "type/type_id.h" #include "type/value_factory.h" @@ -38,8 +39,10 @@ class ArithmeticExpression : public AbstractExpression { public: /** Creates a new comparison expression representing (left comp_type right). */ ArithmeticExpression(AbstractExpressionRef left, AbstractExpressionRef right, ArithmeticType compute_type) - : AbstractExpression({std::move(left), std::move(right)}, TypeId::INTEGER), compute_type_{compute_type} { - if (GetChildAt(0)->GetReturnType() != TypeId::INTEGER || GetChildAt(1)->GetReturnType() != TypeId::INTEGER) { + : AbstractExpression({std::move(left), std::move(right)}, Column{"", TypeId::INTEGER}), + compute_type_{compute_type} { + if (GetChildAt(0)->GetReturnType().GetType() != TypeId::INTEGER || + GetChildAt(1)->GetReturnType().GetType() != TypeId::INTEGER) { throw bustub::NotImplementedException("only support integer for now"); } } diff --git a/src/include/execution/expressions/array_expression.h b/src/include/execution/expressions/array_expression.h new file mode 100644 index 000000000..b3965cb27 --- /dev/null +++ b/src/include/execution/expressions/array_expression.h @@ -0,0 +1,65 @@ +//===----------------------------------------------------------------------===// +// +// BusTub +// +// constant_value_expression.h +// +// Identification: src/include/expression/constant_value_expression.h +// +// Copyright (c) 2015-19, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include + +#include "common/exception.h" +#include "execution/expressions/abstract_expression.h" +#include "type/value_factory.h" + +namespace bustub { +/** + * ArrayExpression represents arrays. + */ +class ArrayExpression : public AbstractExpression { + public: + /** Creates a new constant value expression wrapping the given value. */ + explicit ArrayExpression(const std::vector &children) + : AbstractExpression(children, Column{"", TypeId::VECTOR, static_cast(children.size())}) {} + + auto Evaluate(const Tuple *tuple, const Schema &schema) const -> Value override { + std::vector values; + values.reserve(children_.size()); + for (const auto &expr : children_) { + auto val = expr->Evaluate(tuple, schema); + if (val.GetTypeId() != TypeId::DECIMAL) { + throw Exception("vector value can only be constructed from decimal type"); + } + values.emplace_back(val.GetAs()); + } + return ValueFactory::GetVectorValue(values); + } + + auto EvaluateJoin(const Tuple *left_tuple, const Schema &left_schema, const Tuple *right_tuple, + const Schema &right_schema) const -> Value override { + std::vector values; + values.resize(children_.size()); + for (const auto &expr : children_) { + auto val = expr->EvaluateJoin(left_tuple, left_schema, right_tuple, right_schema); + if (val.GetTypeId() != TypeId::DECIMAL) { + throw Exception("vector value can only be constructed from decimal type"); + } + values.emplace_back(val.GetAs()); + } + return ValueFactory::GetVectorValue(values); + } + + /** @return the string representation of the plan node and its children */ + auto ToString() const -> std::string override { return fmt::format("[{}]", fmt::join(children_, ",")); } + + BUSTUB_EXPR_CLONE_WITH_CHILDREN(ArrayExpression); +}; +} // namespace bustub diff --git a/src/include/execution/expressions/column_value_expression.h b/src/include/execution/expressions/column_value_expression.h index 3969dcb91..710e7f5d3 100644 --- a/src/include/execution/expressions/column_value_expression.h +++ b/src/include/execution/expressions/column_value_expression.h @@ -14,6 +14,7 @@ #include #include +#include #include #include "catalog/schema.h" @@ -32,8 +33,8 @@ class ColumnValueExpression : public AbstractExpression { * @param col_idx the index of the column in the schema * @param ret_type the return type of the expression */ - ColumnValueExpression(uint32_t tuple_idx, uint32_t col_idx, TypeId ret_type) - : AbstractExpression({}, ret_type), tuple_idx_{tuple_idx}, col_idx_{col_idx} {} + ColumnValueExpression(uint32_t tuple_idx, uint32_t col_idx, Column ret_type) + : AbstractExpression({}, std::move(ret_type)), tuple_idx_{tuple_idx}, col_idx_{col_idx} {} auto Evaluate(const Tuple *tuple, const Schema &schema) const -> Value override { return tuple->GetValue(&schema, col_idx_); diff --git a/src/include/execution/expressions/comparison_expression.h b/src/include/execution/expressions/comparison_expression.h index 84f5f091d..1a3898864 100644 --- a/src/include/execution/expressions/comparison_expression.h +++ b/src/include/execution/expressions/comparison_expression.h @@ -34,7 +34,8 @@ class ComparisonExpression : public AbstractExpression { public: /** Creates a new comparison expression representing (left comp_type right). */ ComparisonExpression(AbstractExpressionRef left, AbstractExpressionRef right, ComparisonType comp_type) - : AbstractExpression({std::move(left), std::move(right)}, TypeId::BOOLEAN), comp_type_{comp_type} {} + : AbstractExpression({std::move(left), std::move(right)}, Column{"", TypeId::BOOLEAN}), + comp_type_{comp_type} {} auto Evaluate(const Tuple *tuple, const Schema &schema) const -> Value override { Value lhs = GetChildAt(0)->Evaluate(tuple, schema); diff --git a/src/include/execution/expressions/constant_value_expression.h b/src/include/execution/expressions/constant_value_expression.h index ebea3c6c7..08f4a03ce 100644 --- a/src/include/execution/expressions/constant_value_expression.h +++ b/src/include/execution/expressions/constant_value_expression.h @@ -25,7 +25,7 @@ namespace bustub { class ConstantValueExpression : public AbstractExpression { public: /** Creates a new constant value expression wrapping the given value. */ - explicit ConstantValueExpression(const Value &val) : AbstractExpression({}, val.GetTypeId()), val_(val) {} + explicit ConstantValueExpression(const Value &val) : AbstractExpression({}, val.GetColumn()), val_(val) {} auto Evaluate(const Tuple *tuple, const Schema &schema) const -> Value override { return val_; } diff --git a/src/include/execution/expressions/logic_expression.h b/src/include/execution/expressions/logic_expression.h index c4636d605..bc82ae8ba 100644 --- a/src/include/execution/expressions/logic_expression.h +++ b/src/include/execution/expressions/logic_expression.h @@ -16,6 +16,7 @@ #include #include +#include "catalog/column.h" #include "catalog/schema.h" #include "common/exception.h" #include "common/macros.h" @@ -38,8 +39,10 @@ class LogicExpression : public AbstractExpression { public: /** Creates a new comparison expression representing (left comp_type right). */ LogicExpression(AbstractExpressionRef left, AbstractExpressionRef right, LogicType logic_type) - : AbstractExpression({std::move(left), std::move(right)}, TypeId::BOOLEAN), logic_type_{logic_type} { - if (GetChildAt(0)->GetReturnType() != TypeId::BOOLEAN || GetChildAt(1)->GetReturnType() != TypeId::BOOLEAN) { + : AbstractExpression({std::move(left), std::move(right)}, Column{"", TypeId::BOOLEAN}), + logic_type_{logic_type} { + if (GetChildAt(0)->GetReturnType().GetType() != TypeId::BOOLEAN || + GetChildAt(1)->GetReturnType().GetType() != TypeId::BOOLEAN) { throw bustub::NotImplementedException("expect boolean from either side"); } } diff --git a/src/include/execution/expressions/string_expression.h b/src/include/execution/expressions/string_expression.h index a2a1bc86c..ce3d874ca 100644 --- a/src/include/execution/expressions/string_expression.h +++ b/src/include/execution/expressions/string_expression.h @@ -38,9 +38,10 @@ enum class StringExpressionType { Lower, Upper }; class StringExpression : public AbstractExpression { public: StringExpression(AbstractExpressionRef arg, StringExpressionType expr_type) - : AbstractExpression({std::move(arg)}, TypeId::VARCHAR), expr_type_{expr_type} { - if (GetChildAt(0)->GetReturnType() != TypeId::VARCHAR) { - throw bustub::NotImplementedException("expect the first arg to be varchar"); + : AbstractExpression({std::move(arg)}, Column{"", TypeId::VARCHAR, 256 /* hardcode max length */}), + expr_type_{expr_type} { + if (GetChildAt(0)->GetReturnType().GetType() != TypeId::VARCHAR) { + BUSTUB_ENSURE(GetChildAt(0)->GetReturnType().GetType() == TypeId::VARCHAR, "unexpected arg"); } } diff --git a/src/include/storage/index/index.h b/src/include/storage/index/index.h index 2676b336f..60277d037 100644 --- a/src/include/storage/index/index.h +++ b/src/include/storage/index/index.h @@ -184,7 +184,7 @@ class Index { */ virtual void ScanKey(const Tuple &key, std::vector *result, Transaction *transaction) = 0; - private: + protected: /** The Index structure owns its metadata */ std::unique_ptr metadata_; }; diff --git a/src/include/storage/index/stl_ordered.h b/src/include/storage/index/stl_ordered.h new file mode 100644 index 000000000..526036897 --- /dev/null +++ b/src/include/storage/index/stl_ordered.h @@ -0,0 +1,102 @@ +#pragma once + +#include +#include +#include // NOLINT +#include +#include +#include + +#include "common/rid.h" +#include "container/hash/hash_function.h" +#include "storage/index/b_plus_tree.h" +#include "storage/index/index.h" +#include "storage/index/stl_comparator_wrapper.h" + +namespace bustub { + +template +class STLOrderedIndexIterator { + public: + STLOrderedIndexIterator(const std::map> *map, + typename std::map>::const_iterator iter) + : map_(map), iter_(std::move(iter)) {} + + ~STLOrderedIndexIterator() = default; + + auto IsEnd() -> bool { return iter_ == map_->cend(); } + + auto operator*() -> const std::pair & { + ret_val_ = *iter_; + return ret_val_; + } + + auto operator++() -> STLOrderedIndexIterator & { + iter_++; + return *this; + } + + inline auto operator==(const STLOrderedIndexIterator &itr) const -> bool { return itr.iter_ == iter_; } + + inline auto operator!=(const STLOrderedIndexIterator &itr) const -> bool { return !(*this == itr); } + + private: + const std::map> *map_; + typename std::map>::const_iterator iter_; + std::pair ret_val_; +}; + +template +class STLOrderedIndex : public Index { + public: + STLOrderedIndex(std::unique_ptr &&metadata, BufferPoolManager *buffer_pool_manager) + : Index(std::move(metadata)), + comparator_(StlComparatorWrapper(Cmp(metadata_->GetKeySchema()))), + data_(comparator_) {} + + auto InsertEntry(const Tuple &key, VT rid, Transaction *transaction) -> bool override { + KT index_key; + index_key.SetFromKey(key); + std::scoped_lock lck(lock_); + if (data_.count(index_key) == 1) { + return false; + } + data_.emplace(index_key, rid); + return true; + } + + void DeleteEntry(const Tuple &key, VT rid, Transaction *transaction) override { + KT index_key; + index_key.SetFromKey(key); + std::scoped_lock lck(lock_); + data_.erase(index_key); + } + + void ScanKey(const Tuple &key, std::vector *result, Transaction *transaction) override { + KT index_key; + index_key.SetFromKey(key); + std::scoped_lock lck(lock_); + if (data_.count(index_key) == 1) { + *result = std::vector{data_[index_key]}; + return; + } + *result = {}; + } + + auto GetBeginIterator() -> STLOrderedIndexIterator { return {&data_, data_.cbegin()}; } + + auto GetBeginIterator(const KT &key) -> STLOrderedIndexIterator { + return {&data_, data_.lower_bound(key)}; + } + + auto GetEndIterator() -> STLOrderedIndexIterator { return {&data_, data_.cend()}; } + + protected: + std::mutex lock_; + StlComparatorWrapper comparator_; + std::map> data_; +}; + +using STLOrderedIndexForTwoIntegerColumn = STLOrderedIndex, RID, GenericComparator<8>>; + +} // namespace bustub diff --git a/src/include/storage/index/stl_unordered.h b/src/include/storage/index/stl_unordered.h new file mode 100644 index 000000000..fee68f312 --- /dev/null +++ b/src/include/storage/index/stl_unordered.h @@ -0,0 +1,74 @@ +#pragma once + +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#include "common/rid.h" +#include "container/hash/hash_function.h" +#include "storage/index/b_plus_tree.h" +#include "storage/index/index.h" +#include "storage/index/stl_comparator_wrapper.h" +#include "storage/index/stl_equal_wrapper.h" +#include "storage/index/stl_hasher_wrapper.h" + +namespace bustub { + +template +class STLUnorderedIndex : public Index { + public: + STLUnorderedIndex(std::unique_ptr &&metadata, BufferPoolManager *buffer_pool_manager, + const HashFunction &hash_fn) + : Index(std::move(metadata)), + comparator_(StlComparatorWrapper(Cmp(metadata_->GetKeySchema()))), + hash_fn_(StlHasherWrapper(hash_fn)), + eq_(StlEqualWrapper(Cmp(metadata_->GetKeySchema()))), + data_(0, hash_fn_, eq_) {} + + auto InsertEntry(const Tuple &key, VT rid, Transaction *transaction) -> bool override { + KT index_key; + index_key.SetFromKey(key); + std::scoped_lock lck(lock_); + if (data_.find(index_key) != data_.end()) { + return false; + } + data_.emplace(index_key, rid); + return true; + } + + void DeleteEntry(const Tuple &key, VT rid, Transaction *transaction) override { + KT index_key; + index_key.SetFromKey(key); + std::scoped_lock lck(lock_); + if (auto it = data_.find(index_key); it != data_.end()) { + data_.erase(it); + return; + } + } + + void ScanKey(const Tuple &key, std::vector *result, Transaction *transaction) override { + KT index_key; + index_key.SetFromKey(key); + std::scoped_lock lck(lock_); + if (auto it = data_.find(index_key); it != data_.end()) { + *result = std::vector{it->second}; + return; + } + *result = {}; + } + + protected: + std::mutex lock_; + StlComparatorWrapper comparator_; + StlHasherWrapper hash_fn_; + StlEqualWrapper eq_; + std::unordered_map, StlEqualWrapper> data_; +}; + +using STLUnorderedIndexForTwoIntegerColumn = STLUnorderedIndex, RID, GenericComparator<8>>; + +} // namespace bustub diff --git a/src/include/type/type.h b/src/include/type/type.h index ef8e7bb37..26f6188c7 100644 --- a/src/include/type/type.h +++ b/src/include/type/type.h @@ -28,6 +28,7 @@ class Type { explicit Type(TypeId type_id) : type_id_(type_id) {} virtual ~Type() = default; + // Get the size of this data type in bytes static auto GetTypeSize(TypeId type_id) -> uint64_t; @@ -99,19 +100,16 @@ class Type { virtual auto CastAs(const Value &val, TypeId type_id) const -> Value; - // Access the raw variable length data + // Access the raw varlen data stored from the tuple storage virtual auto GetData(const Value &val) const -> const char *; - // Get the length of the variable length data - virtual auto GetLength(const Value &val) const -> uint32_t; - - // Access the raw varlen data stored from the tuple storage - virtual auto GetData(char *storage) -> char *; + // Get the storage size of the value. + virtual auto GetStorageSize(const Value &val) const -> uint32_t; protected: // The actual type ID TypeId type_id_; // Singleton instances. - static Type *k_types[14]; + static Type *k_types[10]; }; } // namespace bustub diff --git a/src/include/type/type_id.h b/src/include/type/type_id.h index 54e779421..4239143b6 100644 --- a/src/include/type/type_id.h +++ b/src/include/type/type_id.h @@ -14,5 +14,5 @@ namespace bustub { // Every possible SQL type ID -enum TypeId { INVALID = 0, BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, DECIMAL, VARCHAR, TIMESTAMP }; +enum TypeId { INVALID = 0, BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, DECIMAL, VARCHAR, TIMESTAMP, VECTOR }; } // namespace bustub diff --git a/src/include/type/value.h b/src/include/type/value.h index d4a991e4b..fab783b49 100644 --- a/src/include/type/value.h +++ b/src/include/type/value.h @@ -16,6 +16,7 @@ #include #include #include +#include #include "fmt/format.h" @@ -24,6 +25,8 @@ namespace bustub { +class Column; + inline auto GetCmpBool(bool boolean) -> CmpBool { return boolean ? CmpBool::CmpTrue : CmpBool::CmpFalse; } // A value is an abstract class that represents a view over SQL data stored in @@ -42,6 +45,7 @@ class Value { friend class TimestampType; friend class BooleanType; friend class VarlenType; + friend class VectorType; public: explicit Value(const TypeId type) : manage_data_(false), type_id_(type) { size_.len_ = BUSTUB_VALUE_NULL; } @@ -61,6 +65,7 @@ class Value { // VARCHAR Value(TypeId type, const char *data, uint32_t len, bool manage_data); Value(TypeId type, const std::string &data); + Value(TypeId type, const std::vector &data); Value() : Value(TypeId::INVALID) {} Value(const Value &other); @@ -80,8 +85,11 @@ class Value { // Get the type of this value inline auto GetTypeId() const -> TypeId { return type_id_; } + // Get the type of this value + auto GetColumn() const -> Column; + // Get the length of the variable length data - inline auto GetLength() const -> uint32_t { return Type::GetInstance(type_id_)->GetLength(*this); } + inline auto GetStorageSize() const -> uint32_t { return Type::GetInstance(type_id_)->GetStorageSize(*this); } // Access the raw variable length data inline auto GetData() const -> const char * { return Type::GetInstance(type_id_)->GetData(*this); } @@ -90,6 +98,8 @@ class Value { return *reinterpret_cast(&value_); } + auto GetVector() const -> std::vector; + inline auto CastAs(const TypeId type_id) const -> Value { return Type::GetInstance(type_id_)->CastAs(*this, type_id); } diff --git a/src/include/type/value_factory.h b/src/include/type/value_factory.h index 6d53d562f..22a0d1196 100644 --- a/src/include/type/value_factory.h +++ b/src/include/type/value_factory.h @@ -15,6 +15,7 @@ #include #include #include +#include #include "common/macros.h" #include "common/util/string_util.h" @@ -74,6 +75,11 @@ class ValueFactory { return {TypeId::VARCHAR, value}; } + static inline auto GetVectorValue(const std::vector &value, + __attribute__((__unused__)) AbstractPool *pool = nullptr) -> Value { + return {TypeId::VECTOR, value}; + } + static inline auto GetNullValueByType(TypeId type_id) -> Value { Value ret_value; switch (type_id) { diff --git a/src/include/type/varlen_type.h b/src/include/type/varlen_type.h index 3c791b160..bee5dffdd 100644 --- a/src/include/type/varlen_type.h +++ b/src/include/type/varlen_type.h @@ -29,7 +29,7 @@ class VarlenType : public Type { auto GetData(const Value &val) const -> const char * override; // Get the length of the variable length data - auto GetLength(const Value &val) const -> uint32_t override; + auto GetStorageSize(const Value &val) const -> uint32_t override; // Comparison functions auto CompareEquals(const Value &left, const Value &right) const -> CmpBool override; diff --git a/src/include/type/vector_type.h b/src/include/type/vector_type.h new file mode 100644 index 000000000..b6c0d2580 --- /dev/null +++ b/src/include/type/vector_type.h @@ -0,0 +1,67 @@ +//===----------------------------------------------------------------------===// +// +// BusTub +// +// vector_type.h +// +// Identification: src/include/type/vector_type.h +// +// Copyright (c) 2015-2019, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include "type/value.h" + +namespace bustub { +/* A varlen value is an abstract class representing all objects that have + * variable length. + * For simplicity, for valen_type we always set flag "inline" as true, which + * means we store actual data along with its size rather than a pointer + */ +class VectorType : public Type { + public: + VectorType(); + ~VectorType() override; + + // Access the raw variable length data + auto GetData(const Value &val) const -> const char * override; + + auto GetVector(const Value &val) const -> std::vector; + + // Get the length of the variable length data + auto GetStorageSize(const Value &val) const -> uint32_t override; + + // Comparison functions + auto CompareEquals(const Value &left, const Value &right) const -> CmpBool override; + auto CompareNotEquals(const Value &left, const Value &right) const -> CmpBool override; + auto CompareLessThan(const Value &left, const Value &right) const -> CmpBool override; + auto CompareLessThanEquals(const Value &left, const Value &right) const -> CmpBool override; + auto CompareGreaterThan(const Value &left, const Value &right) const -> CmpBool override; + auto CompareGreaterThanEquals(const Value &left, const Value &right) const -> CmpBool override; + + // Other mathematical functions + auto Min(const Value &left, const Value &right) const -> Value override; + auto Max(const Value &left, const Value &right) const -> Value override; + + auto CastAs(const Value &value, TypeId type_id) const -> Value override; + + // Decimal types are always inlined + auto IsInlined(const Value & /*val*/) const -> bool override { return false; } + + // Debug + auto ToString(const Value &val) const -> std::string override; + + // Serialize this value into the given storage space + void SerializeTo(const Value &val, char *storage) const override; + + // Deserialize a value of the given type from the given storage space. + auto DeserializeFrom(const char *storage) const -> Value override; + + // Create a copy of this value + auto Copy(const Value &val) const -> Value override; +}; +} // namespace bustub diff --git a/src/planner/expression_factory.cpp b/src/planner/expression_factory.cpp index 6e93b2da0..cf6657f8a 100644 --- a/src/planner/expression_factory.cpp +++ b/src/planner/expression_factory.cpp @@ -1,7 +1,9 @@ #include "binder/bound_expression.h" +#include "binder/expressions/bound_func_call.h" #include "binder/statement/select_statement.h" #include "execution/expressions/abstract_expression.h" #include "execution/expressions/arithmetic_expression.h" +#include "execution/expressions/array_expression.h" #include "execution/expressions/column_value_expression.h" #include "execution/expressions/comparison_expression.h" #include "execution/expressions/constant_value_expression.h" @@ -102,4 +104,17 @@ auto Planner::GetBinaryExpressionFromFactory(const std::string &op_name, Abstrac throw Exception(fmt::format("binary op {} not supported in planner yet", op_name)); } +auto Planner::PlanFuncCall(const BoundFuncCall &expr, const std::vector &children) + -> AbstractExpressionRef { + std::vector args; + for (const auto &arg : expr.args_) { + auto [_1, arg_expr] = PlanExpression(*arg, children); + args.push_back(std::move(arg_expr)); + } + if (expr.func_name_ == "construct_array") { + return std::make_shared(args); + } + return GetFuncCallFromFactory(expr.func_name_, std::move(args)); +} + } // namespace bustub diff --git a/src/planner/plan_aggregation.cpp b/src/planner/plan_aggregation.cpp index 494f4b2d7..9686217d4 100644 --- a/src/planner/plan_aggregation.cpp +++ b/src/planner/plan_aggregation.cpp @@ -128,7 +128,7 @@ auto Planner::PlanSelectAgg(const SelectStatement &statement, AbstractPlanNodeRe agg_types.push_back(agg_type); output_col_names.emplace_back(fmt::format("agg#{}", term_idx)); ctx_.expr_in_agg_.emplace_back( - std::make_unique(0, agg_begin_idx + term_idx, TypeId::INTEGER)); + std::make_unique(0, agg_begin_idx + term_idx, Column("", TypeId::INTEGER))); term_idx += 1; } diff --git a/src/planner/plan_expression.cpp b/src/planner/plan_expression.cpp index 792bdd59e..a408ce47c 100644 --- a/src/planner/plan_expression.cpp +++ b/src/planner/plan_expression.cpp @@ -58,7 +58,7 @@ auto Planner::PlanColumnRef(const BoundColumnRef &expr, const std::vector(0, col_idx, col_type)); } if (children.size() == 2) { @@ -88,11 +88,11 @@ auto Planner::PlanColumnRef(const BoundColumnRef &expr, const std::vector(0, *col_idx_left, col_type)); } if (col_idx_right) { - auto col_type = right_schema.GetColumn(*col_idx_right).GetType(); + auto col_type = right_schema.GetColumn(*col_idx_right); return std::make_tuple(col_name, std::make_shared(1, *col_idx_right, col_type)); } throw bustub::Exception(fmt::format("column name {} not found", col_name)); diff --git a/src/planner/plan_func_call.cpp b/src/planner/plan_func_call.cpp index e1f5ea817..1ee50530b 100644 --- a/src/planner/plan_func_call.cpp +++ b/src/planner/plan_func_call.cpp @@ -23,16 +23,6 @@ namespace bustub { -auto Planner::PlanFuncCall(const BoundFuncCall &expr, const std::vector &children) - -> AbstractExpressionRef { - std::vector args; - for (const auto &arg : expr.args_) { - auto [_1, arg_expr] = PlanExpression(*arg, children); - args.push_back(std::move(arg_expr)); - } - return GetFuncCallFromFactory(expr.func_name_, std::move(args)); -} - // NOLINTNEXTLINE auto Planner::GetFuncCallFromFactory(const std::string &func_name, std::vector args) -> AbstractExpressionRef { diff --git a/src/planner/plan_insert.cpp b/src/planner/plan_insert.cpp index 83364d4ea..fee4338d9 100644 --- a/src/planner/plan_insert.cpp +++ b/src/planner/plan_insert.cpp @@ -67,8 +67,7 @@ auto Planner::PlanUpdate(const UpdateStatement &statement) -> AbstractPlanNodeRe for (size_t idx = 0; idx < target_exprs.size(); idx++) { if (target_exprs[idx] == nullptr) { - target_exprs[idx] = - std::make_shared(0, idx, filter->output_schema_->GetColumn(idx).GetType()); + target_exprs[idx] = std::make_shared(0, idx, filter->output_schema_->GetColumn(idx)); } } diff --git a/src/planner/plan_select.cpp b/src/planner/plan_select.cpp index 211ea3148..c8829ae94 100644 --- a/src/planner/plan_select.cpp +++ b/src/planner/plan_select.cpp @@ -107,7 +107,7 @@ auto Planner::PlanSelect(const SelectStatement &statement) -> AbstractPlanNodeRe std::vector distinct_exprs; size_t col_idx = 0; for (const auto &col : child->OutputSchema().GetColumns()) { - distinct_exprs.emplace_back(std::make_shared(0, col_idx++, col.GetType())); + distinct_exprs.emplace_back(std::make_shared(0, col_idx++, col)); } plan = std::make_shared(std::make_shared(child->OutputSchema()), child, diff --git a/src/planner/plan_table_ref.cpp b/src/planner/plan_table_ref.cpp index e4342de9a..22f32a576 100644 --- a/src/planner/plan_table_ref.cpp +++ b/src/planner/plan_table_ref.cpp @@ -71,7 +71,7 @@ auto Planner::PlanSubquery(const BoundSubqueryRef &table_ref, const std::string // This projection will be removed by eliminate projection rule. It's solely used for renaming columns. for (const auto &col : select_node->OutputSchema().GetColumns()) { - auto expr = std::make_shared(0, idx, col.GetType()); + auto expr = std::make_shared(0, idx, col); output_column_names.emplace_back(fmt::format("{}.{}", alias, fmt::join(table_ref.select_list_name_[idx], "."))); exprs.push_back(std::move(expr)); idx++; @@ -155,11 +155,7 @@ auto Planner::PlanExpressionListRef(const BoundExpressionListRef &table_ref) -> size_t idx = 0; for (const auto &col : first_row) { auto col_name = fmt::format("{}.{}", table_ref.identifier_, idx); - if (col->GetReturnType() != TypeId::VARCHAR) { - cols.emplace_back(Column(col_name, col->GetReturnType())); - } else { - cols.emplace_back(Column(col_name, col->GetReturnType(), VARCHAR_DEFAULT_LENGTH)); - } + cols.emplace_back(col->GetReturnType().WithColumnName(col_name)); idx += 1; } auto schema = std::make_shared(cols); diff --git a/src/planner/plan_window_function.cpp b/src/planner/plan_window_function.cpp index 951fd7fa2..b949e4876 100644 --- a/src/planner/plan_window_function.cpp +++ b/src/planner/plan_window_function.cpp @@ -82,7 +82,7 @@ auto Planner::PlanSelectWindow(const SelectStatement &statement, AbstractPlanNod // parse window function window_func_indexes.push_back(i); // we assign a -1 here as a placeholder - columns.emplace_back(std::make_shared(0, -1, TypeId::INTEGER)); + columns.emplace_back(std::make_shared(0, -1, Column{"", TypeId::INTEGER})); const BoundExpression *window_item = nullptr; if (item->type_ == ExpressionType::ALIAS) { diff --git a/src/storage/table/tuple.cpp b/src/storage/table/tuple.cpp index 25fb30536..de36e1c01 100644 --- a/src/storage/table/tuple.cpp +++ b/src/storage/table/tuple.cpp @@ -25,13 +25,13 @@ Tuple::Tuple(std::vector values, const Schema *schema) { assert(values.size() == schema->GetColumnCount()); // 1. Calculate the size of the tuple. - uint32_t tuple_size = schema->GetLength(); + uint32_t tuple_size = schema->GetInlinedStorageSize(); for (auto &i : schema->GetUnlinedColumns()) { - auto len = values[i].GetLength(); + auto len = values[i].GetStorageSize(); if (len == BUSTUB_VALUE_NULL) { len = 0; } - tuple_size += (len + sizeof(uint32_t)); + tuple_size += sizeof(uint32_t) + len; } // 2. Allocate memory. @@ -40,7 +40,7 @@ Tuple::Tuple(std::vector values, const Schema *schema) { // 3. Serialize each attribute based on the input value. uint32_t column_count = schema->GetColumnCount(); - uint32_t offset = schema->GetLength(); + uint32_t offset = schema->GetInlinedStorageSize(); for (uint32_t i = 0; i < column_count; i++) { const auto &col = schema->GetColumn(i); @@ -49,11 +49,11 @@ Tuple::Tuple(std::vector values, const Schema *schema) { *reinterpret_cast(data_.data() + col.GetOffset()) = offset; // Serialize varchar value, in place (size+data). values[i].SerializeTo(data_.data() + offset); - auto len = values[i].GetLength(); + auto len = values[i].GetStorageSize(); if (len == BUSTUB_VALUE_NULL) { len = 0; } - offset += (len + sizeof(uint32_t)); + offset += sizeof(uint32_t) + len; } else { values[i].SerializeTo(data_.data() + col.GetOffset()); } diff --git a/src/type/CMakeLists.txt b/src/type/CMakeLists.txt index 2cb28d580..23baca0dd 100644 --- a/src/type/CMakeLists.txt +++ b/src/type/CMakeLists.txt @@ -11,7 +11,8 @@ add_library( tinyint_type.cpp type.cpp value.cpp - varlen_type.cpp) + varlen_type.cpp + vector_type.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/type/type.cpp b/src/type/type.cpp index 788efaae3..6d75c32ab 100644 --- a/src/type/type.cpp +++ b/src/type/type.cpp @@ -17,16 +17,25 @@ #include "type/decimal_type.h" #include "type/integer_type.h" #include "type/smallint_type.h" +#include "type/timestamp_type.h" #include "type/tinyint_type.h" +#include "type/type_id.h" #include "type/value.h" #include "type/varlen_type.h" +#include "type/vector_type.h" namespace bustub { -Type *Type::k_types[] = { - new Type(TypeId::INVALID), new BooleanType(), new TinyintType(), new SmallintType(), - new IntegerType(TypeId::INTEGER), new BigintType(), new DecimalType(), new VarlenType(TypeId::VARCHAR), -}; +Type *Type::k_types[] = {new Type(TypeId::INVALID), + new BooleanType(), + new TinyintType(), + new SmallintType(), + new IntegerType(TypeId::INTEGER), + new BigintType(), + new DecimalType(), + new VarlenType(TypeId::VARCHAR), + new TimestampType(), + new VectorType()}; // Get the size of this data type in bytes auto Type::GetTypeSize(const TypeId type_id) -> uint64_t { @@ -115,6 +124,8 @@ auto Type::TypeIdToString(const TypeId type_id) -> std::string { return "TIMESTAMP"; case VARCHAR: return "VARCHAR"; + case VECTOR: + return "VECTOR"; default: return "INVALID"; } @@ -281,13 +292,8 @@ auto Type::GetData(const Value &val __attribute__((unused))) const -> const char } // Get the length of the variable length data -auto Type::GetLength(const Value &val __attribute__((unused))) const -> uint32_t { - throw NotImplementedException("GetLength not implemented"); -} - -// Access the raw varlen data stored from the tuple storage -auto Type::GetData(char *storage __attribute__((unused))) -> char * { - throw NotImplementedException("GetData not implemented"); +auto Type::GetStorageSize(const Value &val __attribute__((unused))) const -> uint32_t { + throw NotImplementedException("GetStorageSize not implemented"); } } // namespace bustub diff --git a/src/type/value.cpp b/src/type/value.cpp index 4ff5411d3..eeb6796c1 100644 --- a/src/type/value.cpp +++ b/src/type/value.cpp @@ -14,8 +14,11 @@ #include #include +#include "catalog/column.h" #include "common/exception.h" +#include "type/type.h" #include "type/value.h" +#include "type/vector_type.h" namespace bustub { Value::Value(const Value &other) { @@ -214,6 +217,7 @@ Value::Value(TypeId type, float f) : Value(type) { Value::Value(TypeId type, const char *data, uint32_t len, bool manage_data) : Value(type) { switch (type) { case TypeId::VARCHAR: + case TypeId::VECTOR: if (data == nullptr) { value_.varlen_ = nullptr; size_.len_ = BUSTUB_VALUE_NULL; @@ -254,6 +258,22 @@ Value::Value(TypeId type, const std::string &data) : Value(type) { } } +Value::Value(TypeId type, const std::vector &data) : Value(type) { + switch (type) { + case TypeId::VECTOR: { + manage_data_ = true; + auto len = data.size() * sizeof(double); + value_.varlen_ = new char[len]; + assert(value_.varlen_ != nullptr); + size_.len_ = len; + memcpy(value_.varlen_, data.data(), len); + break; + } + default: + throw Exception(ExceptionType::INCOMPATIBLE_TYPE, "Invalid Type for variable-length Value constructor"); + } +} + // delete allocated char array space Value::~Value() { switch (type_id_) { @@ -310,4 +330,20 @@ auto Value::CheckInteger() const -> bool { } return false; } + +auto Value::GetColumn() const -> Column { + switch (GetTypeId()) { + case TypeId::VARCHAR: + case TypeId::VECTOR: { + return Column{"", GetTypeId(), GetStorageSize()}; + } + default: + return Column{"", GetTypeId()}; + } +} + +auto Value::GetVector() const -> std::vector { + return reinterpret_cast(Type::GetInstance(type_id_))->GetVector(*this); +} + } // namespace bustub diff --git a/src/type/varlen_type.cpp b/src/type/varlen_type.cpp index 871deecce..9206c6997 100644 --- a/src/type/varlen_type.cpp +++ b/src/type/varlen_type.cpp @@ -20,18 +20,18 @@ namespace bustub { #define VARLEN_COMPARE_FUNC(OP) \ const char *str1 = left.GetData(); \ - uint32_t len1 = GetLength(left) - 1; \ + uint32_t len1 = GetStorageSize(left) - 1; \ const char *str2; \ uint32_t len2; \ if (right.GetTypeId() == TypeId::VARCHAR) { \ str2 = right.GetData(); \ - len2 = GetLength(right) - 1; \ + len2 = GetStorageSize(right) - 1; \ /* NOLINTNEXTLINE */ \ return GetCmpBool(TypeUtil::CompareStrings(str1, len1, str2, len2) OP 0); \ } else { \ auto r_value = right.CastAs(TypeId::VARCHAR); \ str2 = r_value.GetData(); \ - len2 = GetLength(r_value) - 1; \ + len2 = GetStorageSize(r_value) - 1; \ /* NOLINTNEXTLINE */ \ return GetCmpBool(TypeUtil::CompareStrings(str1, len1, str2, len2) OP 0); \ } @@ -44,16 +44,13 @@ VarlenType::~VarlenType() = default; auto VarlenType::GetData(const Value &val) const -> const char * { return val.value_.varlen_; } // Get the length of the variable length data (including the length field) -auto VarlenType::GetLength(const Value &val) const -> uint32_t { return val.size_.len_; } +auto VarlenType::GetStorageSize(const Value &val) const -> uint32_t { return val.size_.len_; } auto VarlenType::CompareEquals(const Value &left, const Value &right) const -> CmpBool { assert(left.CheckComparable(right)); if (left.IsNull() || right.IsNull()) { return CmpBool::CmpNull; } - if (GetLength(left) == BUSTUB_VARCHAR_MAX_LEN || GetLength(right) == BUSTUB_VARCHAR_MAX_LEN) { - return GetCmpBool(GetLength(left) == GetLength(right)); - } VARLEN_COMPARE_FUNC(==); // NOLINT } @@ -63,9 +60,6 @@ auto VarlenType::CompareNotEquals(const Value &left, const Value &right) const - if (left.IsNull() || right.IsNull()) { return CmpBool::CmpNull; } - if (GetLength(left) == BUSTUB_VARCHAR_MAX_LEN || GetLength(right) == BUSTUB_VARCHAR_MAX_LEN) { - return GetCmpBool(GetLength(left) != GetLength(right)); - } VARLEN_COMPARE_FUNC(!=); // NOLINT } @@ -75,9 +69,6 @@ auto VarlenType::CompareLessThan(const Value &left, const Value &right) const -> if (left.IsNull() || right.IsNull()) { return CmpBool::CmpNull; } - if (GetLength(left) == BUSTUB_VARCHAR_MAX_LEN || GetLength(right) == BUSTUB_VARCHAR_MAX_LEN) { - return GetCmpBool(GetLength(left) < GetLength(right)); - } VARLEN_COMPARE_FUNC(<); // NOLINT } @@ -87,9 +78,6 @@ auto VarlenType::CompareLessThanEquals(const Value &left, const Value &right) co if (left.IsNull() || right.IsNull()) { return CmpBool::CmpNull; } - if (GetLength(left) == BUSTUB_VARCHAR_MAX_LEN || GetLength(right) == BUSTUB_VARCHAR_MAX_LEN) { - return GetCmpBool(GetLength(left) <= GetLength(right)); - } VARLEN_COMPARE_FUNC(<=); // NOLINT } @@ -99,9 +87,6 @@ auto VarlenType::CompareGreaterThan(const Value &left, const Value &right) const if (left.IsNull() || right.IsNull()) { return CmpBool::CmpNull; } - if (GetLength(left) == BUSTUB_VARCHAR_MAX_LEN || GetLength(right) == BUSTUB_VARCHAR_MAX_LEN) { - return GetCmpBool(GetLength(left) > GetLength(right)); - } VARLEN_COMPARE_FUNC(>); // NOLINT } @@ -111,9 +96,6 @@ auto VarlenType::CompareGreaterThanEquals(const Value &left, const Value &right) if (left.IsNull() || right.IsNull()) { return CmpBool::CmpNull; } - if (GetLength(left) == BUSTUB_VARCHAR_MAX_LEN || GetLength(right) == BUSTUB_VARCHAR_MAX_LEN) { - return GetCmpBool(GetLength(left) >= GetLength(right)); - } VARLEN_COMPARE_FUNC(>=); // NOLINT } @@ -141,7 +123,7 @@ auto VarlenType::Max(const Value &left, const Value &right) const -> Value { } auto VarlenType::ToString(const Value &val) const -> std::string { - uint32_t len = GetLength(val); + uint32_t len = GetStorageSize(val); if (val.IsNull()) { return "varlen_null"; @@ -156,7 +138,7 @@ auto VarlenType::ToString(const Value &val) const -> std::string { } void VarlenType::SerializeTo(const Value &val, char *storage) const { - uint32_t len = GetLength(val); + uint32_t len = GetStorageSize(val); if (len == BUSTUB_VALUE_NULL) { memcpy(storage, &len, sizeof(uint32_t)); return; diff --git a/src/type/vector_type.cpp b/src/type/vector_type.cpp new file mode 100644 index 000000000..25a755e39 --- /dev/null +++ b/src/type/vector_type.cpp @@ -0,0 +1,117 @@ +//===----------------------------------------------------------------------===// +// +// BusTub +// +// varlen_type.cpp +// +// Identification: src/type/varlen_type.cpp +// +// Copyright (c) 2015-2019, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "common/exception.h" +#include "common/macros.h" +#include "type/type_id.h" +#include "type/type_util.h" +#include "type/vector_type.h" + +namespace bustub { + +VectorType::VectorType() : Type(TypeId::VECTOR) {} + +VectorType::~VectorType() = default; + +// Access the raw variable length data +auto VectorType::GetData(const Value &val) const -> const char * { return val.value_.varlen_; } + +auto VectorType::GetVector(const Value &val) const -> std::vector { + auto *base_ptr = reinterpret_cast(val.value_.varlen_); + auto size = val.size_.len_ / sizeof(double); + std::vector data; + data.reserve(size); + for (unsigned i = 0; i < size; i++) { + data.push_back(base_ptr[i]); + } + return data; +} + +// Get the length of the variable length data (including the length field) +auto VectorType::GetStorageSize(const Value &val) const -> uint32_t { return val.size_.len_; } + +auto VectorType::CompareEquals(const Value &left, const Value &right) const -> CmpBool { + UNIMPLEMENTED("vector type comparison not supported"); +} + +auto VectorType::CompareNotEquals(const Value &left, const Value &right) const -> CmpBool { + UNIMPLEMENTED("vector type comparison not supported"); +} + +auto VectorType::CompareLessThan(const Value &left, const Value &right) const -> CmpBool { + UNIMPLEMENTED("vector type comparison not supported"); +} + +auto VectorType::CompareLessThanEquals(const Value &left, const Value &right) const -> CmpBool { + UNIMPLEMENTED("vector type comparison not supported"); +} + +auto VectorType::CompareGreaterThan(const Value &left, const Value &right) const -> CmpBool { + UNIMPLEMENTED("vector type comparison not supported"); +} + +auto VectorType::CompareGreaterThanEquals(const Value &left, const Value &right) const -> CmpBool { + UNIMPLEMENTED("vector type comparison not supported"); +} + +auto VectorType::Min(const Value &left, const Value &right) const -> Value { + UNIMPLEMENTED("vector type comparison not supported"); +} + +auto VectorType::Max(const Value &left, const Value &right) const -> Value { + UNIMPLEMENTED("vector type comparison not supported"); +} + +auto VectorType::ToString(const Value &val) const -> std::string { + uint32_t len = GetStorageSize(val); + + if (val.IsNull()) { + return "vector_null"; + } + if (len == BUSTUB_VARCHAR_MAX_LEN) { + return "vector_max"; + } + if (len == 0) { + return ""; + } + return fmt::format("[{}]", fmt::join(GetVector(val), ",")); +} + +void VectorType::SerializeTo(const Value &val, char *storage) const { + uint32_t len = GetStorageSize(val); + if (len == BUSTUB_VALUE_NULL) { + memcpy(storage, &len, sizeof(uint32_t)); + return; + } + memcpy(storage, &len, sizeof(uint32_t)); + memcpy(storage + sizeof(uint32_t), val.value_.varlen_, len); +} + +// Deserialize a value of the given type from the given storage space. +auto VectorType::DeserializeFrom(const char *storage) const -> Value { + uint32_t len = *reinterpret_cast(storage); + if (len == BUSTUB_VALUE_NULL) { + return {type_id_, nullptr, len, false}; + } + // set manage_data as true + return {type_id_, storage + sizeof(uint32_t), len, true}; +} + +auto VectorType::Copy(const Value &val) const -> Value { return {val}; } + +auto VectorType::CastAs(const Value &value, const TypeId type_id) const -> Value { + UNIMPLEMENTED("vector type cast not supported"); +} +} // namespace bustub diff --git a/test/sql/index.slt b/test/sql/index.slt new file mode 100644 index 000000000..dad9b9ed2 --- /dev/null +++ b/test/sql/index.slt @@ -0,0 +1,20 @@ +statement ok +CREATE TABLE t1(v1 integer, v2 integer); + +statement ok +INSERT INTO t1 VALUES (0, 0), (1, 1), (2, 2); + +statement ok +CREATE INDEX t1v2a ON t1 (v2); + +statement ok +CREATE INDEX t1v2b ON t1 USING bplustree (v2); + +statement ok +CREATE INDEX t1v2c ON t1 USING hash (v2); + +statement ok +CREATE INDEX t1v2d ON t1 USING stl_ordered (v2); + +statement ok +CREATE INDEX t1v2e ON t1 USING stl_unordered (v2); diff --git a/test/sql/vector.slt b/test/sql/vector.slt new file mode 100644 index 000000000..cf467a9f0 --- /dev/null +++ b/test/sql/vector.slt @@ -0,0 +1,5 @@ +statement ok +CREATE TABLE t1(v1 VECTOR(3), v2 integer); + +statement ok +INSERT INTO t1 VALUES (ARRAY [1.0, 1.0, 1.0], 232), (ARRAY [2.0, 1.0, 1.0], 233), (ARRAY [3.0, 1.0, 1.0], 234), (ARRAY [4.0, 1.0, 1.0], 235);