diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..5f0a7be --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,81 @@ +# Agent Instructions for pragmaticDB + +These rules are **mandatory** for any AI coding assistant working on this project. +Violations of these rules are unacceptable under any circumstance. + +--- + +## Rule 1: Zero Deletions + +**Do NOT delete a single line of existing code.** + +- No existing function, method, struct, class, enum, variable, comment, or include may be removed. +- No existing function signature may be changed in a way that breaks callers. +- If a function needs new parameters, use **default parameters** so all existing call sites compile unchanged. +- If a struct/class needs new fields, **append** them — never reorder, rename, or remove existing fields. +- New default values on appended fields must ensure the old behavior when the field is not explicitly set. + +## Rule 2: Existing Features Must Not Change Behavior + +**Every existing feature must continue to work exactly as it did before your changes.** + +This includes but is not limited to: +- `CREATE TABLE` +- `INSERT INTO` +- `SELECT * FROM ;` (single-table select) +- `DELETE FROM
;` and `DELETE FROM
WHERE col = val;` +- `COMMIT;` +- `exit` / `quit` +- Persistence (catalog.db, table_N.db files) +- TCP server behavior + +If you add a new code path (e.g., JOIN), it must be gated behind a condition check +so that existing queries never enter the new path. Example pattern: + +```cpp +if (new_feature_is_active) { + return NewFeaturePath(); +} +// ... entire existing code below, untouched ... +``` + +## Rule 3: Additive-Only Architecture + +All new features must be implemented as **additions**, not modifications: + +- **New files** are always preferred over modifying existing files. +- When modifying an existing file is unavoidable, only **append** new code (new methods, new includes, new fields). +- The only acceptable in-body change to an existing function is inserting a short dispatch/guard at the **top** that routes to a new function, leaving the rest of the function body untouched. + +## Rule 4: Test Preservation + +- **Never modify or delete existing test logic.** Avoid modifying existing test files. If the project uses a central test runner/registry (e.g., `tests/test_main.cpp`, `include/tests.h`), appending a new test invocation or declaration there is allowed when necessary to wire new tests. +- All existing tests must continue to compile and pass after your changes. +- New tests go in **new test files** (e.g., `tests/test_join.cpp`). +- After making changes, verify that `make test` passes all existing tests. + +## Rule 5: The Expression AST Is Join-Only + +The Expression AST (`Expression`, `ComparisonExpression`, `LogicalExpression`, etc.) +and the expression evaluator (`EvaluateExpression`) are **exclusively** for JOIN condition evaluation. + +- Do NOT refactor existing WHERE clause handling to use the AST. +- Do NOT refactor existing DELETE WHERE to use the AST. +- These existing features use their own simple string-comparison logic and must continue to do so. +- If a future feature needs expressions (e.g., SELECT WHERE with complex conditions), + that is a separate, deliberate decision — not something to do as a side effect. + +--- + +## Summary + +| Action | Allowed? | +|---|---| +| Adding new files | ✅ Always | +| Appending new methods/fields to existing files | ✅ Yes | +| Adding a 3-line dispatch guard at the top of an existing function | ✅ Yes | +| Deleting any existing line of code | ❌ Never | +| Renaming any existing function, variable, or file | ❌ Never | +| Changing an existing function's behavior | ❌ Never | +| Modifying existing test files | ❌ Never | +| Using the Expression AST outside of JOIN | ❌ Never | diff --git a/include/catalog/schema.h b/include/catalog/schema.h index ea06413..47a0e0a 100644 --- a/include/catalog/schema.h +++ b/include/catalog/schema.h @@ -19,6 +19,11 @@ class Schema { uint32_t GetLength() const; uint32_t GetColumnCount() const; + static Schema Merge( + const Schema& left, const std::string& left_table, + const Schema& right, const std::string& right_table + ); + private: uint32_t length_; std::vector columns_; diff --git a/include/ds/statement.h b/include/ds/statement.h index 3395197..4c9ac64 100644 --- a/include/ds/statement.h +++ b/include/ds/statement.h @@ -2,7 +2,9 @@ #include #include +#include #include "catalog/column.h" +#include "query/expression.h" // ── Statement type tag ──────────────────────────────────────────────────────── enum class StatementType { @@ -40,6 +42,9 @@ struct InsertStatement : public Statement { struct SelectStatement : public Statement { std::string table_name; + std::string join_table_name; + std::unique_ptr join_condition; + SelectStatement() : Statement(StatementType::SELECT) {} }; diff --git a/include/query/executor.h b/include/query/executor.h index 6e39c17..2ce6cb1 100644 --- a/include/query/executor.h +++ b/include/query/executor.h @@ -4,6 +4,8 @@ #include "../ds/statement.h" #include "../ds/query_result.h" #include "catalog/schema.h" +#include "query/optimizer.h" +#include "query/index_provider.h" /** * @brief Executes a parsed Statement against the database Catalog. * @@ -13,7 +15,8 @@ */ class Executor { public: - explicit Executor(Catalog& catalog) : catalog_(catalog) {} + explicit Executor(Catalog& catalog, const IndexProvider& idx = kDefaultIndexProvider) + : catalog_(catalog), index_provider_(idx) {} /** * @brief Execute the given statement and return a QueryResult. @@ -28,5 +31,14 @@ class Executor { QueryResult ExecuteCommit(); QueryResult ExecuteDelete(const DeleteStatement& stmt); + QueryResult ExecuteJoin(const SelectStatement& stmt); + std::vector> ExecuteBranch( + const BranchPlan& branch, + TableInfo* left_tbl, TableInfo* right_tbl, + const Schema& left_schema, const Schema& right_schema, + const Schema& merged_schema + ); + Catalog& catalog_; + const IndexProvider& index_provider_; }; diff --git a/include/query/expression.h b/include/query/expression.h new file mode 100644 index 0000000..266deef --- /dev/null +++ b/include/query/expression.h @@ -0,0 +1,60 @@ +#pragma once +#include +#include + +// Expression type tags +enum class ExpressionType { + COLUMN_REF, // e.g. "users.id" + CONSTANT, // e.g. "42", "true" + COMPARISON, // =, >, <, >=, <=, != + LOGICAL // AND, OR +}; + +enum class ComparisonType { EQ, NEQ, LT, GT, LTE, GTE }; +enum class LogicType { AND, OR }; + +// Base class — all expression nodes inherit from this +struct Expression { + ExpressionType expr_type; + virtual ~Expression() = default; + +protected: + Expression(ExpressionType type) : expr_type(type) {} +}; + +// Leaf: references a column like "users.id" +struct ColumnRefExpression : Expression { + std::string table_name; // "users" (from "users.id") + std::string col_name; // "id" + + ColumnRefExpression(std::string table, std::string col) + : Expression(ExpressionType::COLUMN_REF), table_name(std::move(table)), col_name(std::move(col)) {} +}; + +// Leaf: a literal value like 42 or true +struct ConstantExpression : Expression { + std::string raw_value; // stored as string, converted at eval time + + ConstantExpression(std::string val) + : Expression(ExpressionType::CONSTANT), raw_value(std::move(val)) {} +}; + +// Internal node: left right +struct ComparisonExpression : Expression { + ComparisonType comp_type; + std::unique_ptr left; + std::unique_ptr right; + + ComparisonExpression(ComparisonType comp, std::unique_ptr l, std::unique_ptr r) + : Expression(ExpressionType::COMPARISON), comp_type(comp), left(std::move(l)), right(std::move(r)) {} +}; + +// Internal node: left AND/OR right +struct LogicalExpression : Expression { + LogicType logic_type; + std::unique_ptr left; + std::unique_ptr right; + + LogicalExpression(LogicType logic, std::unique_ptr l, std::unique_ptr r) + : Expression(ExpressionType::LOGICAL), logic_type(logic), left(std::move(l)), right(std::move(r)) {} +}; diff --git a/include/query/expression_eval.h b/include/query/expression_eval.h new file mode 100644 index 0000000..1d0fec5 --- /dev/null +++ b/include/query/expression_eval.h @@ -0,0 +1,12 @@ +#pragma once +#include "query/expression.h" +#include "type/tuple.h" +#include "catalog/schema.h" + +// Evaluates any expression tree against a single (possibly merged) tuple. +// Returns true if the condition holds. +bool EvaluateExpression( + const Expression* expr, + const Tuple& tuple, + const Schema& schema +); diff --git a/include/query/index_provider.h b/include/query/index_provider.h new file mode 100644 index 0000000..5c5aadd --- /dev/null +++ b/include/query/index_provider.h @@ -0,0 +1,27 @@ +#pragma once + +#include +#include +#include "../ds/record_id.h" + +class IndexProvider { +public: + virtual ~IndexProvider() = default; + + // Given a table name, column name, and a value, return all RecordIds matching col=val + virtual std::vector LookupEquals( + const std::string& table_name, + const std::string& col_name, + const std::string& value) const = 0; +}; + +class NullIndexProvider : public IndexProvider { +public: + std::vector LookupEquals( + const std::string&, const std::string&, const std::string&) const override { + return {}; // always returns empty + } +}; + +// Static default instance with stable lifetime for use as a default reference +static NullIndexProvider kDefaultIndexProvider{}; diff --git a/include/query/optimizer.h b/include/query/optimizer.h new file mode 100644 index 0000000..6fa9a5c --- /dev/null +++ b/include/query/optimizer.h @@ -0,0 +1,31 @@ +#pragma once +#include +#include +#include "query/expression.h" +#include "query/index_provider.h" + +struct EquiCondition { + std::string left_col; + std::string right_col; +}; + +struct IndexCondition { + std::string table; + std::string col; + std::string val; +}; + +struct BranchPlan { + std::vector index_conditions; + std::vector equi_conditions; + std::unique_ptr theta_filter; // anything not captured above +}; + +struct JoinPlan { + std::vector branches; +}; + +class Optimizer { +public: + static JoinPlan PlanJoin(const std::unique_ptr& condition, const IndexProvider& idx_provider); +}; diff --git a/include/query/parser.h b/include/query/parser.h index 69754d7..6be491e 100644 --- a/include/query/parser.h +++ b/include/query/parser.h @@ -27,4 +27,9 @@ class Parser { std::unique_ptr ParseInsert(std::istringstream& ss); std::unique_ptr ParseSelect(std::istringstream& ss); std::unique_ptr ParseDelete(std::istringstream& ss); + + std::unique_ptr ParseExpression(std::istringstream& ss); + std::unique_ptr ParseAndExpression(std::istringstream& ss); + std::unique_ptr ParseComparison(std::istringstream& ss); + std::unique_ptr ParseAtom(std::istringstream& ss); }; diff --git a/include/tests.h b/include/tests.h index 3947710..6694a6d 100644 --- a/include/tests.h +++ b/include/tests.h @@ -11,4 +11,5 @@ class test { void TestCatalogClass(); void TestTableIteratorClass(); void TestQueryEngineClass(); + void TestJoin(); }; \ No newline at end of file diff --git a/include/type/tuple.h b/include/type/tuple.h index e526202..fceca64 100644 --- a/include/type/tuple.h +++ b/include/type/tuple.h @@ -31,6 +31,8 @@ class Tuple { const char* GetData() const; uint32_t GetLength() const; + static Tuple Merge(const Tuple& left, const Tuple& right); + private: std::vector data_; }; diff --git a/include/type/value.h b/include/type/value.h index 0553e65..1046110 100644 --- a/include/type/value.h +++ b/include/type/value.h @@ -17,6 +17,15 @@ class Value { data_ = std::make_any(val); } void test(); + + TypeId GetTypeId() const { return type_id_; } + + bool CompareEquals(const Value& other) const; + bool CompareLessThan(const Value& other) const; + bool CompareGreaterThan(const Value& other) const; + bool CompareLessThanOrEqual(const Value& other) const; + bool CompareGreaterThanOrEqual(const Value& other) const; + bool CompareNotEqual(const Value& other) const; private: TypeId type_id_; std::any data_; diff --git a/src/catalog/schema.cpp b/src/catalog/schema.cpp index cd6c206..ba1e4ec 100644 --- a/src/catalog/schema.cpp +++ b/src/catalog/schema.cpp @@ -23,3 +23,17 @@ uint32_t Schema::GetLength() const { return length_; } uint32_t Schema::GetColumnCount() const { return static_cast(columns_.size()); } + +Schema Schema::Merge( + const Schema& left, const std::string& left_table, + const Schema& right, const std::string& right_table +) { + std::vector merged_cols; + for (const auto& col : left.GetColumns()) { + merged_cols.emplace_back(left_table + "." + col.GetName(), col.GetType()); + } + for (const auto& col : right.GetColumns()) { + merged_cols.emplace_back(right_table + "." + col.GetName(), col.GetType()); + } + return Schema(merged_cols); +} diff --git a/src/query/executor.cpp b/src/query/executor.cpp index 733df84..29f6d80 100644 --- a/src/query/executor.cpp +++ b/src/query/executor.cpp @@ -3,6 +3,9 @@ #include "type/tuple.h" #include "type/value.h" #include "factory/value_factory.h" +#include "query/expression_eval.h" +#include +#include // Implement your Executor methods here! @@ -60,6 +63,10 @@ QueryResult Executor::ExecuteInsert(const InsertStatement& stmt) { } QueryResult Executor::ExecuteSelect(const SelectStatement& stmt) { + if (!stmt.join_table_name.empty()) { + return ExecuteJoin(stmt); + } + try { TableInfo* tbl = catalog_.GetTable(stmt.table_name); Schema& schema = tbl->schema_; @@ -130,3 +137,143 @@ QueryResult Executor::ExecuteDelete(const DeleteStatement& stmt) { return {false, e.what(), {}}; } } + +static int GetSchemaColIdx(const Schema& schema, const std::string& table_name, const std::string& full_col_name) { + for (uint32_t i = 0; i < schema.GetColumnCount(); i++) { + std::string col_name = schema.GetColumn(i).GetName(); + if (col_name == full_col_name) return i; + if (table_name + "." + col_name == full_col_name) return i; + size_t dot_pos = full_col_name.find('.'); + if (dot_pos != std::string::npos && full_col_name.substr(dot_pos + 1) == col_name) return i; + } + return -1; +} + +QueryResult Executor::ExecuteJoin(const SelectStatement& stmt) { + try { + TableInfo* left_tbl = catalog_.GetTable(stmt.table_name); + TableInfo* right_tbl = catalog_.GetTable(stmt.join_table_name); + if (!left_tbl || !right_tbl) throw std::runtime_error("Table not found for JOIN"); + + Schema& left_schema = left_tbl->schema_; + Schema& right_schema = right_tbl->schema_; + Schema merged_schema = Schema::Merge(left_schema, left_tbl->name_, right_schema, right_tbl->name_); + + JoinPlan plan = Optimizer::PlanJoin(stmt.join_condition, index_provider_); + + std::unordered_set seen; + std::vector> unique_pairs; + + for (const auto& branch : plan.branches) { + auto branch_pairs = ExecuteBranch(branch, left_tbl, right_tbl, left_schema, right_schema, merged_schema); + + for (const auto& pair : branch_pairs) { + std::string key; + key.append((char*)&pair.first.page_id, sizeof(page_id_t)); + key.append((char*)&pair.first.slot_id, sizeof(uint16_t)); + key.append((char*)&pair.second.page_id, sizeof(page_id_t)); + key.append((char*)&pair.second.slot_id, sizeof(uint16_t)); + + if (seen.insert(key).second) { + unique_pairs.push_back(pair); + } + } + } + + QueryResult result; + result.success = true; + for (const auto& pair : unique_pairs) { + Tuple left_tuple = left_tbl->table_->GetTuple(pair.first, left_schema); + Tuple right_tuple = right_tbl->table_->GetTuple(pair.second, right_schema); + Tuple merged = Tuple::Merge(left_tuple, right_tuple); + + std::vector row; + for (uint32_t i = 0; i < merged_schema.GetColumnCount(); i++) { + Value val = merged.GetValue(merged_schema, i); + TypeId type = merged_schema.GetColumn(i).GetType(); + row.push_back(ValueFactory::ToString(val, type)); + } + result.rows.push_back(row); + } + result.message = std::to_string(result.rows.size()) + " rows returned."; + return result; + + } catch (const std::exception& e) { + return {false, e.what(), {}}; + } +} + +std::vector> Executor::ExecuteBranch( + const BranchPlan& branch, + TableInfo* left_tbl, TableInfo* right_tbl, + const Schema& left_schema, const Schema& right_schema, + const Schema& merged_schema +) { + std::vector> surviving_pairs; + + // Stage 1: Index Probe (Skipped for now) + + // Stage 2: Hash Join or Cross Product + if (!branch.equi_conditions.empty()) { + std::unordered_map>> hash_map; + + // Build phase on right_tbl + for (auto it = right_tbl->table_->Begin(right_schema); it != right_tbl->table_->End(right_schema); ++it) { + std::string key; + for (const auto& equi : branch.equi_conditions) { + int col_idx = GetSchemaColIdx(right_schema, right_tbl->name_, equi.right_col); + if (col_idx == -1) col_idx = GetSchemaColIdx(right_schema, right_tbl->name_, equi.left_col); + if (col_idx != -1) { + Value val = (*it).GetValue(right_schema, col_idx); + key += ValueFactory::ToString(val, right_schema.GetColumn(col_idx).GetType()) + "|"; + } + } + if (!key.empty()) { + hash_map[key].push_back({*it, it.GetRid()}); + } + } + + // Probe phase on left_tbl + for (auto it = left_tbl->table_->Begin(left_schema); it != left_tbl->table_->End(left_schema); ++it) { + std::string key; + for (const auto& equi : branch.equi_conditions) { + int col_idx = GetSchemaColIdx(left_schema, left_tbl->name_, equi.left_col); + if (col_idx == -1) col_idx = GetSchemaColIdx(left_schema, left_tbl->name_, equi.right_col); + if (col_idx != -1) { + Value val = (*it).GetValue(left_schema, col_idx); + key += ValueFactory::ToString(val, left_schema.GetColumn(col_idx).GetType()) + "|"; + } + } + + if (!key.empty() && hash_map.find(key) != hash_map.end()) { + for (const auto& right_entry : hash_map[key]) { + Tuple merged = Tuple::Merge(*it, right_entry.first); + + // Stage 3: Theta Filter + if (branch.theta_filter) { + if (!EvaluateExpression(branch.theta_filter.get(), merged, merged_schema)) { + continue; + } + } + surviving_pairs.push_back({it.GetRid(), right_entry.second}); + } + } + } + } else { + // Cross product + for (auto left_it = left_tbl->table_->Begin(left_schema); left_it != left_tbl->table_->End(left_schema); ++left_it) { + for (auto right_it = right_tbl->table_->Begin(right_schema); right_it != right_tbl->table_->End(right_schema); ++right_it) { + Tuple merged = Tuple::Merge(*left_it, *right_it); + + if (branch.theta_filter) { + if (!EvaluateExpression(branch.theta_filter.get(), merged, merged_schema)) { + continue; + } + } + surviving_pairs.push_back({left_it.GetRid(), right_it.GetRid()}); + } + } + } + + return surviving_pairs; +} diff --git a/src/query/expression_eval.cpp b/src/query/expression_eval.cpp new file mode 100644 index 0000000..51be951 --- /dev/null +++ b/src/query/expression_eval.cpp @@ -0,0 +1,92 @@ +#include "query/expression_eval.h" +#include +#include + +static int32_t GetColIdx(const Schema& schema, const std::string& col_name, const std::string& table_name) { + const auto& cols = schema.GetColumns(); + for (size_t i = 0; i < cols.size(); ++i) { + if (!table_name.empty()) { + if (cols[i].GetName() == table_name + "." + col_name) { + return i; + } + } else { + if (cols[i].GetName() == col_name) { + return i; + } + size_t dot_pos = cols[i].GetName().find('.'); + if (dot_pos != std::string::npos && cols[i].GetName().substr(dot_pos + 1) == col_name) { + // Check for ambiguity: see if another column also matches + for (size_t j = i + 1; j < cols.size(); ++j) { + size_t dot_pos2 = cols[j].GetName().find('.'); + if (dot_pos2 != std::string::npos && cols[j].GetName().substr(dot_pos2 + 1) == col_name) { + throw std::runtime_error("Ambiguous column reference: '" + col_name + "'. Use table-qualified name."); + } + } + return i; + } + } + } + return -1; +} + +static Value EvaluateAtom(const Expression* expr, const Tuple& tuple, const Schema& schema) { + if (expr->expr_type == ExpressionType::COLUMN_REF) { + const auto* col_expr = static_cast(expr); + int32_t col_idx = GetColIdx(schema, col_expr->col_name, col_expr->table_name); + if (col_idx == -1) { + throw std::runtime_error("Column not found in schema"); + } + return tuple.GetValue(schema, col_idx); + } else if (expr->expr_type == ExpressionType::CONSTANT) { + const auto* const_expr = static_cast(expr); + if (const_expr->raw_value == "true" || const_expr->raw_value == "false") { + Value val(TypeId::BOOLEAN); + val.Set(const_expr->raw_value == "true" ? 1 : 0); + return val; + } else { + Value val(TypeId::INTEGER); + val.Set(std::stoi(const_expr->raw_value)); + return val; + } + } + throw std::runtime_error("Invalid atom expression type"); +} + +bool EvaluateExpression(const Expression* expr, const Tuple& tuple, const Schema& schema) { + if (!expr) return true; + + if (expr->expr_type == ExpressionType::LOGICAL) { + const auto* log_expr = static_cast(expr); + bool left_res = EvaluateExpression(log_expr->left.get(), tuple, schema); + + if (log_expr->logic_type == LogicType::AND) { + if (!left_res) return false; + return EvaluateExpression(log_expr->right.get(), tuple, schema); + } else { + if (left_res) return true; + return EvaluateExpression(log_expr->right.get(), tuple, schema); + } + } else if (expr->expr_type == ExpressionType::COMPARISON) { + const auto* comp_expr = static_cast(expr); + Value left_val = EvaluateAtom(comp_expr->left.get(), tuple, schema); + Value right_val = EvaluateAtom(comp_expr->right.get(), tuple, schema); + + switch (comp_expr->comp_type) { + case ComparisonType::EQ: return left_val.CompareEquals(right_val); + case ComparisonType::NEQ: return left_val.CompareNotEqual(right_val); + case ComparisonType::LT: return left_val.CompareLessThan(right_val); + case ComparisonType::GT: return left_val.CompareGreaterThan(right_val); + case ComparisonType::LTE: return left_val.CompareLessThanOrEqual(right_val); + case ComparisonType::GTE: return left_val.CompareGreaterThanOrEqual(right_val); + } + } + + if (expr->expr_type == ExpressionType::COLUMN_REF || expr->expr_type == ExpressionType::CONSTANT) { + Value val = EvaluateAtom(expr, tuple, schema); + if (val.GetTypeId() == TypeId::BOOLEAN) { + return val.Get() != 0; + } + } + + throw std::runtime_error("Invalid expression tree structure"); +} diff --git a/src/query/optimizer.cpp b/src/query/optimizer.cpp new file mode 100644 index 0000000..85e88b4 --- /dev/null +++ b/src/query/optimizer.cpp @@ -0,0 +1,98 @@ +#include "query/optimizer.h" + +static void CollectORBranches(const Expression* expr, std::vector& branches) { + if (!expr) return; + if (expr->expr_type == ExpressionType::LOGICAL) { + const auto* log = static_cast(expr); + if (log->logic_type == LogicType::OR) { + CollectORBranches(log->left.get(), branches); + CollectORBranches(log->right.get(), branches); + return; + } + } + branches.push_back(expr); +} + +static void CollectANDConditions(const Expression* expr, std::vector& conds) { + if (!expr) return; + if (expr->expr_type == ExpressionType::LOGICAL) { + const auto* log = static_cast(expr); + if (log->logic_type == LogicType::AND) { + CollectANDConditions(log->left.get(), conds); + CollectANDConditions(log->right.get(), conds); + return; + } + } + conds.push_back(expr); +} + +static std::unique_ptr CopyExpression(const Expression* expr) { + if (!expr) return nullptr; + if (expr->expr_type == ExpressionType::COLUMN_REF) { + const auto* col = static_cast(expr); + return std::make_unique(col->table_name, col->col_name); + } else if (expr->expr_type == ExpressionType::CONSTANT) { + const auto* con = static_cast(expr); + return std::make_unique(con->raw_value); + } else if (expr->expr_type == ExpressionType::COMPARISON) { + const auto* comp = static_cast(expr); + return std::make_unique( + comp->comp_type, CopyExpression(comp->left.get()), CopyExpression(comp->right.get())); + } else if (expr->expr_type == ExpressionType::LOGICAL) { + const auto* log = static_cast(expr); + return std::make_unique( + log->logic_type, CopyExpression(log->left.get()), CopyExpression(log->right.get())); + } + return nullptr; +} + +static std::unique_ptr CombineAND(std::unique_ptr left, std::unique_ptr right) { + if (!left) return right; + if (!right) return left; + return std::make_unique(LogicType::AND, std::move(left), std::move(right)); +} + +JoinPlan Optimizer::PlanJoin(const std::unique_ptr& condition, const IndexProvider& idx_provider) { + JoinPlan plan; + if (!condition) { + BranchPlan empty_branch; + plan.branches.push_back(std::move(empty_branch)); + return plan; + } + + std::vector or_branches; + CollectORBranches(condition.get(), or_branches); + + for (const Expression* branch_expr : or_branches) { + BranchPlan b_plan; + std::vector and_conds; + CollectANDConditions(branch_expr, and_conds); + + for (const Expression* cond : and_conds) { + bool handled = false; + if (cond->expr_type == ExpressionType::COMPARISON) { + const auto* comp = static_cast(cond); + if (comp->comp_type == ComparisonType::EQ) { + if (comp->left->expr_type == ExpressionType::COLUMN_REF && comp->right->expr_type == ExpressionType::COLUMN_REF) { + const auto* l_col = static_cast(comp->left.get()); + const auto* r_col = static_cast(comp->right.get()); + + EquiCondition equi; + equi.left_col = (l_col->table_name.empty() ? l_col->col_name : l_col->table_name + "." + l_col->col_name); + equi.right_col = (r_col->table_name.empty() ? r_col->col_name : r_col->table_name + "." + r_col->col_name); + + b_plan.equi_conditions.push_back(equi); + handled = true; + } + } + } + if (!handled) { + b_plan.theta_filter = CombineAND(std::move(b_plan.theta_filter), CopyExpression(cond)); + } + } + + plan.branches.push_back(std::move(b_plan)); + } + + return plan; +} diff --git a/src/query/parser.cpp b/src/query/parser.cpp index d1f2621..8887d98 100644 --- a/src/query/parser.cpp +++ b/src/query/parser.cpp @@ -101,6 +101,32 @@ std::unique_ptr Parser::ParseSelect(std::istringstream& ss) { stmt->table_name.pop_back(); } + // ── NEW: Check for optional JOIN clause ── + std::streampos pos = ss.tellg(); + std::string next; + if (ss >> next) { + std::string next_upper = next; + for (auto& c : next_upper) c = toupper(c); + if (next_upper == "JOIN") { + ss >> stmt->join_table_name; + std::string on_kw; + if (ss >> on_kw) { + for (auto& c : on_kw) c = toupper(c); + if (on_kw == "ON") { + stmt->join_condition = ParseExpression(ss); + } else { + return nullptr; // JOIN requires ON clause + } + } else { + return nullptr; // JOIN requires ON clause + } + } else { + // Not a JOIN, restore stream (might be just a semicolon) + ss.clear(); + ss.seekg(pos); + } + } + return stmt; } @@ -135,3 +161,114 @@ std::unique_ptr Parser::ParseDelete(std::istringstream& ss) { return stmt; } + +std::unique_ptr Parser::ParseExpression(std::istringstream& ss) { + auto left = ParseAndExpression(ss); + + while (true) { + std::streampos pos = ss.tellg(); + std::string next; + if (!(ss >> next)) break; + std::string next_upper = next; + for (auto& c : next_upper) c = toupper(c); + + if (next_upper == "OR") { + auto right = ParseAndExpression(ss); + left = std::make_unique(LogicType::OR, std::move(left), std::move(right)); + } else { + ss.clear(); + ss.seekg(pos); + break; + } + } + return left; +} + +std::unique_ptr Parser::ParseAndExpression(std::istringstream& ss) { + auto left = ParseComparison(ss); + + while (true) { + std::streampos pos = ss.tellg(); + std::string next; + if (!(ss >> next)) break; + std::string next_upper = next; + for (auto& c : next_upper) c = toupper(c); + + if (next_upper == "AND") { + auto right = ParseComparison(ss); + left = std::make_unique(LogicType::AND, std::move(left), std::move(right)); + } else { + ss.clear(); + ss.seekg(pos); + break; + } + } + return left; +} + +std::unique_ptr Parser::ParseComparison(std::istringstream& ss) { + auto left = ParseAtom(ss); + + std::streampos pos = ss.tellg(); + std::string op; + if (ss >> op) { + ComparisonType comp_type; + bool is_comp = true; + if (op == "=") comp_type = ComparisonType::EQ; + else if (op == "!=") comp_type = ComparisonType::NEQ; + else if (op == "<") comp_type = ComparisonType::LT; + else if (op == ">") comp_type = ComparisonType::GT; + else if (op == "<=") comp_type = ComparisonType::LTE; + else if (op == ">=") comp_type = ComparisonType::GTE; + else is_comp = false; + + if (is_comp) { + auto right = ParseAtom(ss); + return std::make_unique(comp_type, std::move(left), std::move(right)); + } else { + ss.clear(); + ss.seekg(pos); + } + } + return left; +} + +std::unique_ptr Parser::ParseAtom(std::istringstream& ss) { + std::string token; + + if (!(ss >> token)) return nullptr; + + if (token == "(") { + auto expr = ParseExpression(ss); + std::string close; + if (!(ss >> close) || close != ")") return nullptr; // missing closing paren + return expr; + } + + // Strip trailing ';' or ')' if attached + while (!token.empty() && (token.back() == ';' || token.back() == ')')) { + token.pop_back(); + } + + // Strip leading '(' if attached + if (!token.empty() && token.front() == '(') { + token.erase(0, 1); + } + + size_t dot_pos = token.find('.'); + if (dot_pos != std::string::npos) { + return std::make_unique(token.substr(0, dot_pos), token.substr(dot_pos + 1)); + } + + std::string token_upper = token; + for (auto& c : token_upper) c = toupper(c); + if (token_upper == "TRUE") { + return std::make_unique("true"); + } else if (token_upper == "FALSE") { + return std::make_unique("false"); + } else if ((token[0] >= '0' && token[0] <= '9') || (token[0] == '-' && token.size() > 1)) { + return std::make_unique(token); + } + + return std::make_unique("", token); +} diff --git a/src/type/tuple.cpp b/src/type/tuple.cpp index 937f78e..f1c88c8 100644 --- a/src/type/tuple.cpp +++ b/src/type/tuple.cpp @@ -28,3 +28,11 @@ uint32_t Tuple::GetLength() const { Tuple::Tuple(const char* raw_data, uint32_t size) { data_.assign(raw_data, raw_data + size); } + +Tuple Tuple::Merge(const Tuple& left, const Tuple& right) { + Tuple merged; + merged.data_.reserve(left.GetLength() + right.GetLength()); + merged.data_.insert(merged.data_.end(), left.data_.begin(), left.data_.end()); + merged.data_.insert(merged.data_.end(), right.data_.begin(), right.data_.end()); + return merged; +} diff --git a/src/type/value.cpp b/src/type/value.cpp index 38a4119..5ccf855 100644 --- a/src/type/value.cpp +++ b/src/type/value.cpp @@ -35,3 +35,42 @@ void Value::DeserializeFromChar(const char *data, TypeId type_id) { data_ = std::make_any(val); } } + +bool Value::CompareEquals(const Value& other) const { + if (type_id_ != other.type_id_) return false; + if (type_id_ == TypeId::INTEGER) return Get() == other.Get(); + if (type_id_ == TypeId::BOOLEAN) return Get() == other.Get(); + return false; +} + +bool Value::CompareNotEqual(const Value& other) const { + return !CompareEquals(other); +} + +bool Value::CompareLessThan(const Value& other) const { + if (type_id_ != other.type_id_) return false; + if (type_id_ == TypeId::INTEGER) return Get() < other.Get(); + if (type_id_ == TypeId::BOOLEAN) return Get() < other.Get(); + return false; +} + +bool Value::CompareGreaterThan(const Value& other) const { + if (type_id_ != other.type_id_) return false; + if (type_id_ == TypeId::INTEGER) return Get() > other.Get(); + if (type_id_ == TypeId::BOOLEAN) return Get() > other.Get(); + return false; +} + +bool Value::CompareLessThanOrEqual(const Value& other) const { + if (type_id_ != other.type_id_) return false; + if (type_id_ == TypeId::INTEGER) return Get() <= other.Get(); + if (type_id_ == TypeId::BOOLEAN) return Get() <= other.Get(); + return false; +} + +bool Value::CompareGreaterThanOrEqual(const Value& other) const { + if (type_id_ != other.type_id_) return false; + if (type_id_ == TypeId::INTEGER) return Get() >= other.Get(); + if (type_id_ == TypeId::BOOLEAN) return Get() >= other.Get(); + return false; +} diff --git a/tests/test_join.cpp b/tests/test_join.cpp new file mode 100644 index 0000000..1aac678 --- /dev/null +++ b/tests/test_join.cpp @@ -0,0 +1,79 @@ +#include +#include "tests.h" +#include "catalog/catalog.h" +#include "query/executor.h" +#include "query/parser.h" + +#define ASSERT_JOIN(expr) \ + if (!(expr)) { \ + std::cout << "[FAIL] " << #expr << " on line " << __LINE__ << "\n"; \ + failed_count++; \ + } else { \ + passed_count++; \ + } + +void test::TestJoin() { + int passed_count = 0; + int failed_count = 0; + + std::cout << "\n--- Testing JOIN ---\n"; + + Catalog catalog; + Executor executor(catalog); + Parser parser; + + auto exec = [&](const std::string& sql) -> QueryResult { + auto stmt = parser.Parse(sql); + ASSERT_JOIN(stmt != nullptr); + if (!stmt) return {false, "Parse failed", {}}; + return executor.Execute(*stmt); + }; + + // 1. Setup tables + exec("CREATE TABLE A (id INTEGER, val INTEGER);"); + exec("CREATE TABLE B (a_id INTEGER, type INTEGER);"); + + // Insert data into A + exec("INSERT INTO A VALUES (1, 100);"); + exec("INSERT INTO A VALUES (2, 200);"); + exec("INSERT INTO A VALUES (3, 300);"); + + // Insert data into B + exec("INSERT INTO B VALUES (1, 10);"); + exec("INSERT INTO B VALUES (2, 20);"); + exec("INSERT INTO B VALUES (2, 20);"); // Duplicate for dedup checking + exec("INSERT INTO B VALUES (4, 40);"); + + std::cout << "Testing Basic Equi-Join..." << std::endl; + auto res1 = exec("SELECT * FROM A JOIN B ON A.id = B.a_id;"); + ASSERT_JOIN(res1.success); + ASSERT_JOIN(res1.rows.size() == 3); // (1,1), (2,2), (2,2) + + std::cout << "Testing Theta-Only Join..." << std::endl; + auto res2 = exec("SELECT * FROM A JOIN B ON A.val > B.type;"); + ASSERT_JOIN(res2.success); + ASSERT_JOIN(res2.rows.size() == 12); // Cross product (3 * 4) since all A.val (100+) > B.type (10-40) + + std::cout << "Testing Mixed AND (order independent)..." << std::endl; + auto res3 = exec("SELECT * FROM A JOIN B ON A.id = B.a_id AND A.val = 200;"); + ASSERT_JOIN(res3.success); + ASSERT_JOIN(res3.rows.size() == 2); // only id=2 matches, B has two rows for a_id=2 + + auto res4 = exec("SELECT * FROM A JOIN B ON A.val = 200 AND A.id = B.a_id;"); + ASSERT_JOIN(res4.success); + ASSERT_JOIN(res4.rows.size() == 2); + + std::cout << "Testing OR with Mixed Conditions (Dedup)..." << std::endl; + // id=1 OR val=200 + auto res5 = exec("SELECT * FROM A JOIN B ON A.id = B.a_id AND A.val = 100 OR A.val = 200 AND A.id = B.a_id;"); + ASSERT_JOIN(res5.success); + ASSERT_JOIN(res5.rows.size() == 3); + + std::cout << "Testing Empty Table Join..." << std::endl; + exec("CREATE TABLE C (id INTEGER);"); + auto res6 = exec("SELECT * FROM A JOIN C ON A.id = C.id;"); + ASSERT_JOIN(res6.success); + ASSERT_JOIN(res6.rows.size() == 0); + + std::cout << "JOIN test summary: " << passed_count << " passed, " << failed_count << " failed.\n"; +} diff --git a/tests/test_main.cpp b/tests/test_main.cpp index 5d97118..9749036 100644 --- a/tests/test_main.cpp +++ b/tests/test_main.cpp @@ -12,5 +12,6 @@ int main() { tests.TestCatalogClass(); tests.TestTableIteratorClass(); tests.TestQueryEngineClass(); + tests.TestJoin(); return 0; }