diff --git a/AGENTS.md b/AGENTS.md
new file mode 100644
index 0000000..5f0a7be
--- /dev/null
+++ b/AGENTS.md
@@ -0,0 +1,81 @@
+# Agent Instructions for pragmaticDB
+
+These rules are **mandatory** for any AI coding assistant working on this project.
+Violations of these rules are unacceptable under any circumstance.
+
+---
+
+## Rule 1: Zero Deletions
+
+**Do NOT delete a single line of existing code.**
+
+- No existing function, method, struct, class, enum, variable, comment, or include may be removed.
+- No existing function signature may be changed in a way that breaks callers.
+- If a function needs new parameters, use **default parameters** so all existing call sites compile unchanged.
+- If a struct/class needs new fields, **append** them — never reorder, rename, or remove existing fields.
+- New default values on appended fields must ensure the old behavior when the field is not explicitly set.
+
+## Rule 2: Existing Features Must Not Change Behavior
+
+**Every existing feature must continue to work exactly as it did before your changes.**
+
+This includes but is not limited to:
+- `CREATE TABLE`
+- `INSERT INTO`
+- `SELECT * FROM
;` (single-table select)
+- `DELETE FROM ;` and `DELETE FROM WHERE col = val;`
+- `COMMIT;`
+- `exit` / `quit`
+- Persistence (catalog.db, table_N.db files)
+- TCP server behavior
+
+If you add a new code path (e.g., JOIN), it must be gated behind a condition check
+so that existing queries never enter the new path. Example pattern:
+
+```cpp
+if (new_feature_is_active) {
+ return NewFeaturePath();
+}
+// ... entire existing code below, untouched ...
+```
+
+## Rule 3: Additive-Only Architecture
+
+All new features must be implemented as **additions**, not modifications:
+
+- **New files** are always preferred over modifying existing files.
+- When modifying an existing file is unavoidable, only **append** new code (new methods, new includes, new fields).
+- The only acceptable in-body change to an existing function is inserting a short dispatch/guard at the **top** that routes to a new function, leaving the rest of the function body untouched.
+
+## Rule 4: Test Preservation
+
+- **Never modify or delete existing test logic.** Avoid modifying existing test files. If the project uses a central test runner/registry (e.g., `tests/test_main.cpp`, `include/tests.h`), appending a new test invocation or declaration there is allowed when necessary to wire new tests.
+- All existing tests must continue to compile and pass after your changes.
+- New tests go in **new test files** (e.g., `tests/test_join.cpp`).
+- After making changes, verify that `make test` passes all existing tests.
+
+## Rule 5: The Expression AST Is Join-Only
+
+The Expression AST (`Expression`, `ComparisonExpression`, `LogicalExpression`, etc.)
+and the expression evaluator (`EvaluateExpression`) are **exclusively** for JOIN condition evaluation.
+
+- Do NOT refactor existing WHERE clause handling to use the AST.
+- Do NOT refactor existing DELETE WHERE to use the AST.
+- These existing features use their own simple string-comparison logic and must continue to do so.
+- If a future feature needs expressions (e.g., SELECT WHERE with complex conditions),
+ that is a separate, deliberate decision — not something to do as a side effect.
+
+---
+
+## Summary
+
+| Action | Allowed? |
+|---|---|
+| Adding new files | ✅ Always |
+| Appending new methods/fields to existing files | ✅ Yes |
+| Adding a 3-line dispatch guard at the top of an existing function | ✅ Yes |
+| Deleting any existing line of code | ❌ Never |
+| Renaming any existing function, variable, or file | ❌ Never |
+| Changing an existing function's behavior | ❌ Never |
+| Modifying existing test files | ❌ Never |
+| Using the Expression AST outside of JOIN | ❌ Never |
diff --git a/include/catalog/schema.h b/include/catalog/schema.h
index ea06413..47a0e0a 100644
--- a/include/catalog/schema.h
+++ b/include/catalog/schema.h
@@ -19,6 +19,11 @@ class Schema {
uint32_t GetLength() const;
uint32_t GetColumnCount() const;
+ static Schema Merge(
+ const Schema& left, const std::string& left_table,
+ const Schema& right, const std::string& right_table
+ );
+
private:
uint32_t length_;
std::vector columns_;
diff --git a/include/ds/statement.h b/include/ds/statement.h
index 3395197..4c9ac64 100644
--- a/include/ds/statement.h
+++ b/include/ds/statement.h
@@ -2,7 +2,9 @@
#include
#include
+#include
#include "catalog/column.h"
+#include "query/expression.h"
// ── Statement type tag ────────────────────────────────────────────────────────
enum class StatementType {
@@ -40,6 +42,9 @@ struct InsertStatement : public Statement {
struct SelectStatement : public Statement {
std::string table_name;
+ std::string join_table_name;
+ std::unique_ptr join_condition;
+
SelectStatement() : Statement(StatementType::SELECT) {}
};
diff --git a/include/query/executor.h b/include/query/executor.h
index 6e39c17..2ce6cb1 100644
--- a/include/query/executor.h
+++ b/include/query/executor.h
@@ -4,6 +4,8 @@
#include "../ds/statement.h"
#include "../ds/query_result.h"
#include "catalog/schema.h"
+#include "query/optimizer.h"
+#include "query/index_provider.h"
/**
* @brief Executes a parsed Statement against the database Catalog.
*
@@ -13,7 +15,8 @@
*/
class Executor {
public:
- explicit Executor(Catalog& catalog) : catalog_(catalog) {}
+ explicit Executor(Catalog& catalog, const IndexProvider& idx = kDefaultIndexProvider)
+ : catalog_(catalog), index_provider_(idx) {}
/**
* @brief Execute the given statement and return a QueryResult.
@@ -28,5 +31,14 @@ class Executor {
QueryResult ExecuteCommit();
QueryResult ExecuteDelete(const DeleteStatement& stmt);
+ QueryResult ExecuteJoin(const SelectStatement& stmt);
+ std::vector> ExecuteBranch(
+ const BranchPlan& branch,
+ TableInfo* left_tbl, TableInfo* right_tbl,
+ const Schema& left_schema, const Schema& right_schema,
+ const Schema& merged_schema
+ );
+
Catalog& catalog_;
+ const IndexProvider& index_provider_;
};
diff --git a/include/query/expression.h b/include/query/expression.h
new file mode 100644
index 0000000..266deef
--- /dev/null
+++ b/include/query/expression.h
@@ -0,0 +1,60 @@
+#pragma once
+#include
+#include
+
+// Expression type tags
+enum class ExpressionType {
+ COLUMN_REF, // e.g. "users.id"
+ CONSTANT, // e.g. "42", "true"
+ COMPARISON, // =, >, <, >=, <=, !=
+ LOGICAL // AND, OR
+};
+
+enum class ComparisonType { EQ, NEQ, LT, GT, LTE, GTE };
+enum class LogicType { AND, OR };
+
+// Base class — all expression nodes inherit from this
+struct Expression {
+ ExpressionType expr_type;
+ virtual ~Expression() = default;
+
+protected:
+ Expression(ExpressionType type) : expr_type(type) {}
+};
+
+// Leaf: references a column like "users.id"
+struct ColumnRefExpression : Expression {
+ std::string table_name; // "users" (from "users.id")
+ std::string col_name; // "id"
+
+ ColumnRefExpression(std::string table, std::string col)
+ : Expression(ExpressionType::COLUMN_REF), table_name(std::move(table)), col_name(std::move(col)) {}
+};
+
+// Leaf: a literal value like 42 or true
+struct ConstantExpression : Expression {
+ std::string raw_value; // stored as string, converted at eval time
+
+ ConstantExpression(std::string val)
+ : Expression(ExpressionType::CONSTANT), raw_value(std::move(val)) {}
+};
+
+// Internal node: left right
+struct ComparisonExpression : Expression {
+ ComparisonType comp_type;
+ std::unique_ptr left;
+ std::unique_ptr right;
+
+ ComparisonExpression(ComparisonType comp, std::unique_ptr l, std::unique_ptr r)
+ : Expression(ExpressionType::COMPARISON), comp_type(comp), left(std::move(l)), right(std::move(r)) {}
+};
+
+// Internal node: left AND/OR right
+struct LogicalExpression : Expression {
+ LogicType logic_type;
+ std::unique_ptr left;
+ std::unique_ptr right;
+
+ LogicalExpression(LogicType logic, std::unique_ptr l, std::unique_ptr r)
+ : Expression(ExpressionType::LOGICAL), logic_type(logic), left(std::move(l)), right(std::move(r)) {}
+};
diff --git a/include/query/expression_eval.h b/include/query/expression_eval.h
new file mode 100644
index 0000000..1d0fec5
--- /dev/null
+++ b/include/query/expression_eval.h
@@ -0,0 +1,12 @@
+#pragma once
+#include "query/expression.h"
+#include "type/tuple.h"
+#include "catalog/schema.h"
+
+// Evaluates any expression tree against a single (possibly merged) tuple.
+// Returns true if the condition holds.
+bool EvaluateExpression(
+ const Expression* expr,
+ const Tuple& tuple,
+ const Schema& schema
+);
diff --git a/include/query/index_provider.h b/include/query/index_provider.h
new file mode 100644
index 0000000..5c5aadd
--- /dev/null
+++ b/include/query/index_provider.h
@@ -0,0 +1,27 @@
+#pragma once
+
+#include
+#include
+#include "../ds/record_id.h"
+
+class IndexProvider {
+public:
+ virtual ~IndexProvider() = default;
+
+ // Given a table name, column name, and a value, return all RecordIds matching col=val
+ virtual std::vector LookupEquals(
+ const std::string& table_name,
+ const std::string& col_name,
+ const std::string& value) const = 0;
+};
+
+class NullIndexProvider : public IndexProvider {
+public:
+ std::vector LookupEquals(
+ const std::string&, const std::string&, const std::string&) const override {
+ return {}; // always returns empty
+ }
+};
+
+// Static default instance with stable lifetime for use as a default reference
+static NullIndexProvider kDefaultIndexProvider{};
diff --git a/include/query/optimizer.h b/include/query/optimizer.h
new file mode 100644
index 0000000..6fa9a5c
--- /dev/null
+++ b/include/query/optimizer.h
@@ -0,0 +1,31 @@
+#pragma once
+#include
+#include
+#include "query/expression.h"
+#include "query/index_provider.h"
+
+struct EquiCondition {
+ std::string left_col;
+ std::string right_col;
+};
+
+struct IndexCondition {
+ std::string table;
+ std::string col;
+ std::string val;
+};
+
+struct BranchPlan {
+ std::vector index_conditions;
+ std::vector equi_conditions;
+ std::unique_ptr theta_filter; // anything not captured above
+};
+
+struct JoinPlan {
+ std::vector branches;
+};
+
+class Optimizer {
+public:
+ static JoinPlan PlanJoin(const std::unique_ptr& condition, const IndexProvider& idx_provider);
+};
diff --git a/include/query/parser.h b/include/query/parser.h
index 69754d7..6be491e 100644
--- a/include/query/parser.h
+++ b/include/query/parser.h
@@ -27,4 +27,9 @@ class Parser {
std::unique_ptr ParseInsert(std::istringstream& ss);
std::unique_ptr ParseSelect(std::istringstream& ss);
std::unique_ptr ParseDelete(std::istringstream& ss);
+
+ std::unique_ptr ParseExpression(std::istringstream& ss);
+ std::unique_ptr ParseAndExpression(std::istringstream& ss);
+ std::unique_ptr ParseComparison(std::istringstream& ss);
+ std::unique_ptr ParseAtom(std::istringstream& ss);
};
diff --git a/include/tests.h b/include/tests.h
index 3947710..6694a6d 100644
--- a/include/tests.h
+++ b/include/tests.h
@@ -11,4 +11,5 @@ class test {
void TestCatalogClass();
void TestTableIteratorClass();
void TestQueryEngineClass();
+ void TestJoin();
};
\ No newline at end of file
diff --git a/include/type/tuple.h b/include/type/tuple.h
index e526202..fceca64 100644
--- a/include/type/tuple.h
+++ b/include/type/tuple.h
@@ -31,6 +31,8 @@ class Tuple {
const char* GetData() const;
uint32_t GetLength() const;
+ static Tuple Merge(const Tuple& left, const Tuple& right);
+
private:
std::vector data_;
};
diff --git a/include/type/value.h b/include/type/value.h
index 0553e65..1046110 100644
--- a/include/type/value.h
+++ b/include/type/value.h
@@ -17,6 +17,15 @@ class Value {
data_ = std::make_any(val);
}
void test();
+
+ TypeId GetTypeId() const { return type_id_; }
+
+ bool CompareEquals(const Value& other) const;
+ bool CompareLessThan(const Value& other) const;
+ bool CompareGreaterThan(const Value& other) const;
+ bool CompareLessThanOrEqual(const Value& other) const;
+ bool CompareGreaterThanOrEqual(const Value& other) const;
+ bool CompareNotEqual(const Value& other) const;
private:
TypeId type_id_;
std::any data_;
diff --git a/src/catalog/schema.cpp b/src/catalog/schema.cpp
index cd6c206..ba1e4ec 100644
--- a/src/catalog/schema.cpp
+++ b/src/catalog/schema.cpp
@@ -23,3 +23,17 @@ uint32_t Schema::GetLength() const { return length_; }
uint32_t Schema::GetColumnCount() const {
return static_cast(columns_.size());
}
+
+Schema Schema::Merge(
+ const Schema& left, const std::string& left_table,
+ const Schema& right, const std::string& right_table
+) {
+ std::vector merged_cols;
+ for (const auto& col : left.GetColumns()) {
+ merged_cols.emplace_back(left_table + "." + col.GetName(), col.GetType());
+ }
+ for (const auto& col : right.GetColumns()) {
+ merged_cols.emplace_back(right_table + "." + col.GetName(), col.GetType());
+ }
+ return Schema(merged_cols);
+}
diff --git a/src/query/executor.cpp b/src/query/executor.cpp
index 733df84..29f6d80 100644
--- a/src/query/executor.cpp
+++ b/src/query/executor.cpp
@@ -3,6 +3,9 @@
#include "type/tuple.h"
#include "type/value.h"
#include "factory/value_factory.h"
+#include "query/expression_eval.h"
+#include
+#include
// Implement your Executor methods here!
@@ -60,6 +63,10 @@ QueryResult Executor::ExecuteInsert(const InsertStatement& stmt) {
}
QueryResult Executor::ExecuteSelect(const SelectStatement& stmt) {
+ if (!stmt.join_table_name.empty()) {
+ return ExecuteJoin(stmt);
+ }
+
try {
TableInfo* tbl = catalog_.GetTable(stmt.table_name);
Schema& schema = tbl->schema_;
@@ -130,3 +137,143 @@ QueryResult Executor::ExecuteDelete(const DeleteStatement& stmt) {
return {false, e.what(), {}};
}
}
+
+static int GetSchemaColIdx(const Schema& schema, const std::string& table_name, const std::string& full_col_name) {
+ for (uint32_t i = 0; i < schema.GetColumnCount(); i++) {
+ std::string col_name = schema.GetColumn(i).GetName();
+ if (col_name == full_col_name) return i;
+ if (table_name + "." + col_name == full_col_name) return i;
+ size_t dot_pos = full_col_name.find('.');
+ if (dot_pos != std::string::npos && full_col_name.substr(dot_pos + 1) == col_name) return i;
+ }
+ return -1;
+}
+
+QueryResult Executor::ExecuteJoin(const SelectStatement& stmt) {
+ try {
+ TableInfo* left_tbl = catalog_.GetTable(stmt.table_name);
+ TableInfo* right_tbl = catalog_.GetTable(stmt.join_table_name);
+ if (!left_tbl || !right_tbl) throw std::runtime_error("Table not found for JOIN");
+
+ Schema& left_schema = left_tbl->schema_;
+ Schema& right_schema = right_tbl->schema_;
+ Schema merged_schema = Schema::Merge(left_schema, left_tbl->name_, right_schema, right_tbl->name_);
+
+ JoinPlan plan = Optimizer::PlanJoin(stmt.join_condition, index_provider_);
+
+ std::unordered_set seen;
+ std::vector> unique_pairs;
+
+ for (const auto& branch : plan.branches) {
+ auto branch_pairs = ExecuteBranch(branch, left_tbl, right_tbl, left_schema, right_schema, merged_schema);
+
+ for (const auto& pair : branch_pairs) {
+ std::string key;
+ key.append((char*)&pair.first.page_id, sizeof(page_id_t));
+ key.append((char*)&pair.first.slot_id, sizeof(uint16_t));
+ key.append((char*)&pair.second.page_id, sizeof(page_id_t));
+ key.append((char*)&pair.second.slot_id, sizeof(uint16_t));
+
+ if (seen.insert(key).second) {
+ unique_pairs.push_back(pair);
+ }
+ }
+ }
+
+ QueryResult result;
+ result.success = true;
+ for (const auto& pair : unique_pairs) {
+ Tuple left_tuple = left_tbl->table_->GetTuple(pair.first, left_schema);
+ Tuple right_tuple = right_tbl->table_->GetTuple(pair.second, right_schema);
+ Tuple merged = Tuple::Merge(left_tuple, right_tuple);
+
+ std::vector row;
+ for (uint32_t i = 0; i < merged_schema.GetColumnCount(); i++) {
+ Value val = merged.GetValue(merged_schema, i);
+ TypeId type = merged_schema.GetColumn(i).GetType();
+ row.push_back(ValueFactory::ToString(val, type));
+ }
+ result.rows.push_back(row);
+ }
+ result.message = std::to_string(result.rows.size()) + " rows returned.";
+ return result;
+
+ } catch (const std::exception& e) {
+ return {false, e.what(), {}};
+ }
+}
+
+std::vector> Executor::ExecuteBranch(
+ const BranchPlan& branch,
+ TableInfo* left_tbl, TableInfo* right_tbl,
+ const Schema& left_schema, const Schema& right_schema,
+ const Schema& merged_schema
+) {
+ std::vector> surviving_pairs;
+
+ // Stage 1: Index Probe (Skipped for now)
+
+ // Stage 2: Hash Join or Cross Product
+ if (!branch.equi_conditions.empty()) {
+ std::unordered_map>> hash_map;
+
+ // Build phase on right_tbl
+ for (auto it = right_tbl->table_->Begin(right_schema); it != right_tbl->table_->End(right_schema); ++it) {
+ std::string key;
+ for (const auto& equi : branch.equi_conditions) {
+ int col_idx = GetSchemaColIdx(right_schema, right_tbl->name_, equi.right_col);
+ if (col_idx == -1) col_idx = GetSchemaColIdx(right_schema, right_tbl->name_, equi.left_col);
+ if (col_idx != -1) {
+ Value val = (*it).GetValue(right_schema, col_idx);
+ key += ValueFactory::ToString(val, right_schema.GetColumn(col_idx).GetType()) + "|";
+ }
+ }
+ if (!key.empty()) {
+ hash_map[key].push_back({*it, it.GetRid()});
+ }
+ }
+
+ // Probe phase on left_tbl
+ for (auto it = left_tbl->table_->Begin(left_schema); it != left_tbl->table_->End(left_schema); ++it) {
+ std::string key;
+ for (const auto& equi : branch.equi_conditions) {
+ int col_idx = GetSchemaColIdx(left_schema, left_tbl->name_, equi.left_col);
+ if (col_idx == -1) col_idx = GetSchemaColIdx(left_schema, left_tbl->name_, equi.right_col);
+ if (col_idx != -1) {
+ Value val = (*it).GetValue(left_schema, col_idx);
+ key += ValueFactory::ToString(val, left_schema.GetColumn(col_idx).GetType()) + "|";
+ }
+ }
+
+ if (!key.empty() && hash_map.find(key) != hash_map.end()) {
+ for (const auto& right_entry : hash_map[key]) {
+ Tuple merged = Tuple::Merge(*it, right_entry.first);
+
+ // Stage 3: Theta Filter
+ if (branch.theta_filter) {
+ if (!EvaluateExpression(branch.theta_filter.get(), merged, merged_schema)) {
+ continue;
+ }
+ }
+ surviving_pairs.push_back({it.GetRid(), right_entry.second});
+ }
+ }
+ }
+ } else {
+ // Cross product
+ for (auto left_it = left_tbl->table_->Begin(left_schema); left_it != left_tbl->table_->End(left_schema); ++left_it) {
+ for (auto right_it = right_tbl->table_->Begin(right_schema); right_it != right_tbl->table_->End(right_schema); ++right_it) {
+ Tuple merged = Tuple::Merge(*left_it, *right_it);
+
+ if (branch.theta_filter) {
+ if (!EvaluateExpression(branch.theta_filter.get(), merged, merged_schema)) {
+ continue;
+ }
+ }
+ surviving_pairs.push_back({left_it.GetRid(), right_it.GetRid()});
+ }
+ }
+ }
+
+ return surviving_pairs;
+}
diff --git a/src/query/expression_eval.cpp b/src/query/expression_eval.cpp
new file mode 100644
index 0000000..51be951
--- /dev/null
+++ b/src/query/expression_eval.cpp
@@ -0,0 +1,92 @@
+#include "query/expression_eval.h"
+#include
+#include
+
+static int32_t GetColIdx(const Schema& schema, const std::string& col_name, const std::string& table_name) {
+ const auto& cols = schema.GetColumns();
+ for (size_t i = 0; i < cols.size(); ++i) {
+ if (!table_name.empty()) {
+ if (cols[i].GetName() == table_name + "." + col_name) {
+ return i;
+ }
+ } else {
+ if (cols[i].GetName() == col_name) {
+ return i;
+ }
+ size_t dot_pos = cols[i].GetName().find('.');
+ if (dot_pos != std::string::npos && cols[i].GetName().substr(dot_pos + 1) == col_name) {
+ // Check for ambiguity: see if another column also matches
+ for (size_t j = i + 1; j < cols.size(); ++j) {
+ size_t dot_pos2 = cols[j].GetName().find('.');
+ if (dot_pos2 != std::string::npos && cols[j].GetName().substr(dot_pos2 + 1) == col_name) {
+ throw std::runtime_error("Ambiguous column reference: '" + col_name + "'. Use table-qualified name.");
+ }
+ }
+ return i;
+ }
+ }
+ }
+ return -1;
+}
+
+static Value EvaluateAtom(const Expression* expr, const Tuple& tuple, const Schema& schema) {
+ if (expr->expr_type == ExpressionType::COLUMN_REF) {
+ const auto* col_expr = static_cast(expr);
+ int32_t col_idx = GetColIdx(schema, col_expr->col_name, col_expr->table_name);
+ if (col_idx == -1) {
+ throw std::runtime_error("Column not found in schema");
+ }
+ return tuple.GetValue(schema, col_idx);
+ } else if (expr->expr_type == ExpressionType::CONSTANT) {
+ const auto* const_expr = static_cast(expr);
+ if (const_expr->raw_value == "true" || const_expr->raw_value == "false") {
+ Value val(TypeId::BOOLEAN);
+ val.Set(const_expr->raw_value == "true" ? 1 : 0);
+ return val;
+ } else {
+ Value val(TypeId::INTEGER);
+ val.Set(std::stoi(const_expr->raw_value));
+ return val;
+ }
+ }
+ throw std::runtime_error("Invalid atom expression type");
+}
+
+bool EvaluateExpression(const Expression* expr, const Tuple& tuple, const Schema& schema) {
+ if (!expr) return true;
+
+ if (expr->expr_type == ExpressionType::LOGICAL) {
+ const auto* log_expr = static_cast(expr);
+ bool left_res = EvaluateExpression(log_expr->left.get(), tuple, schema);
+
+ if (log_expr->logic_type == LogicType::AND) {
+ if (!left_res) return false;
+ return EvaluateExpression(log_expr->right.get(), tuple, schema);
+ } else {
+ if (left_res) return true;
+ return EvaluateExpression(log_expr->right.get(), tuple, schema);
+ }
+ } else if (expr->expr_type == ExpressionType::COMPARISON) {
+ const auto* comp_expr = static_cast(expr);
+ Value left_val = EvaluateAtom(comp_expr->left.get(), tuple, schema);
+ Value right_val = EvaluateAtom(comp_expr->right.get(), tuple, schema);
+
+ switch (comp_expr->comp_type) {
+ case ComparisonType::EQ: return left_val.CompareEquals(right_val);
+ case ComparisonType::NEQ: return left_val.CompareNotEqual(right_val);
+ case ComparisonType::LT: return left_val.CompareLessThan(right_val);
+ case ComparisonType::GT: return left_val.CompareGreaterThan(right_val);
+ case ComparisonType::LTE: return left_val.CompareLessThanOrEqual(right_val);
+ case ComparisonType::GTE: return left_val.CompareGreaterThanOrEqual(right_val);
+ }
+ }
+
+ if (expr->expr_type == ExpressionType::COLUMN_REF || expr->expr_type == ExpressionType::CONSTANT) {
+ Value val = EvaluateAtom(expr, tuple, schema);
+ if (val.GetTypeId() == TypeId::BOOLEAN) {
+ return val.Get() != 0;
+ }
+ }
+
+ throw std::runtime_error("Invalid expression tree structure");
+}
diff --git a/src/query/optimizer.cpp b/src/query/optimizer.cpp
new file mode 100644
index 0000000..85e88b4
--- /dev/null
+++ b/src/query/optimizer.cpp
@@ -0,0 +1,98 @@
+#include "query/optimizer.h"
+
+static void CollectORBranches(const Expression* expr, std::vector& branches) {
+ if (!expr) return;
+ if (expr->expr_type == ExpressionType::LOGICAL) {
+ const auto* log = static_cast(expr);
+ if (log->logic_type == LogicType::OR) {
+ CollectORBranches(log->left.get(), branches);
+ CollectORBranches(log->right.get(), branches);
+ return;
+ }
+ }
+ branches.push_back(expr);
+}
+
+static void CollectANDConditions(const Expression* expr, std::vector& conds) {
+ if (!expr) return;
+ if (expr->expr_type == ExpressionType::LOGICAL) {
+ const auto* log = static_cast(expr);
+ if (log->logic_type == LogicType::AND) {
+ CollectANDConditions(log->left.get(), conds);
+ CollectANDConditions(log->right.get(), conds);
+ return;
+ }
+ }
+ conds.push_back(expr);
+}
+
+static std::unique_ptr CopyExpression(const Expression* expr) {
+ if (!expr) return nullptr;
+ if (expr->expr_type == ExpressionType::COLUMN_REF) {
+ const auto* col = static_cast(expr);
+ return std::make_unique(col->table_name, col->col_name);
+ } else if (expr->expr_type == ExpressionType::CONSTANT) {
+ const auto* con = static_cast(expr);
+ return std::make_unique(con->raw_value);
+ } else if (expr->expr_type == ExpressionType::COMPARISON) {
+ const auto* comp = static_cast(expr);
+ return std::make_unique(
+ comp->comp_type, CopyExpression(comp->left.get()), CopyExpression(comp->right.get()));
+ } else if (expr->expr_type == ExpressionType::LOGICAL) {
+ const auto* log = static_cast(expr);
+ return std::make_unique(
+ log->logic_type, CopyExpression(log->left.get()), CopyExpression(log->right.get()));
+ }
+ return nullptr;
+}
+
+static std::unique_ptr CombineAND(std::unique_ptr left, std::unique_ptr right) {
+ if (!left) return right;
+ if (!right) return left;
+ return std::make_unique(LogicType::AND, std::move(left), std::move(right));
+}
+
+JoinPlan Optimizer::PlanJoin(const std::unique_ptr& condition, const IndexProvider& idx_provider) {
+ JoinPlan plan;
+ if (!condition) {
+ BranchPlan empty_branch;
+ plan.branches.push_back(std::move(empty_branch));
+ return plan;
+ }
+
+ std::vector or_branches;
+ CollectORBranches(condition.get(), or_branches);
+
+ for (const Expression* branch_expr : or_branches) {
+ BranchPlan b_plan;
+ std::vector and_conds;
+ CollectANDConditions(branch_expr, and_conds);
+
+ for (const Expression* cond : and_conds) {
+ bool handled = false;
+ if (cond->expr_type == ExpressionType::COMPARISON) {
+ const auto* comp = static_cast(cond);
+ if (comp->comp_type == ComparisonType::EQ) {
+ if (comp->left->expr_type == ExpressionType::COLUMN_REF && comp->right->expr_type == ExpressionType::COLUMN_REF) {
+ const auto* l_col = static_cast(comp->left.get());
+ const auto* r_col = static_cast(comp->right.get());
+
+ EquiCondition equi;
+ equi.left_col = (l_col->table_name.empty() ? l_col->col_name : l_col->table_name + "." + l_col->col_name);
+ equi.right_col = (r_col->table_name.empty() ? r_col->col_name : r_col->table_name + "." + r_col->col_name);
+
+ b_plan.equi_conditions.push_back(equi);
+ handled = true;
+ }
+ }
+ }
+ if (!handled) {
+ b_plan.theta_filter = CombineAND(std::move(b_plan.theta_filter), CopyExpression(cond));
+ }
+ }
+
+ plan.branches.push_back(std::move(b_plan));
+ }
+
+ return plan;
+}
diff --git a/src/query/parser.cpp b/src/query/parser.cpp
index d1f2621..8887d98 100644
--- a/src/query/parser.cpp
+++ b/src/query/parser.cpp
@@ -101,6 +101,32 @@ std::unique_ptr Parser::ParseSelect(std::istringstream& ss) {
stmt->table_name.pop_back();
}
+ // ── NEW: Check for optional JOIN clause ──
+ std::streampos pos = ss.tellg();
+ std::string next;
+ if (ss >> next) {
+ std::string next_upper = next;
+ for (auto& c : next_upper) c = toupper(c);
+ if (next_upper == "JOIN") {
+ ss >> stmt->join_table_name;
+ std::string on_kw;
+ if (ss >> on_kw) {
+ for (auto& c : on_kw) c = toupper(c);
+ if (on_kw == "ON") {
+ stmt->join_condition = ParseExpression(ss);
+ } else {
+ return nullptr; // JOIN requires ON clause
+ }
+ } else {
+ return nullptr; // JOIN requires ON clause
+ }
+ } else {
+ // Not a JOIN, restore stream (might be just a semicolon)
+ ss.clear();
+ ss.seekg(pos);
+ }
+ }
+
return stmt;
}
@@ -135,3 +161,114 @@ std::unique_ptr Parser::ParseDelete(std::istringstream& ss) {
return stmt;
}
+
+std::unique_ptr Parser::ParseExpression(std::istringstream& ss) {
+ auto left = ParseAndExpression(ss);
+
+ while (true) {
+ std::streampos pos = ss.tellg();
+ std::string next;
+ if (!(ss >> next)) break;
+ std::string next_upper = next;
+ for (auto& c : next_upper) c = toupper(c);
+
+ if (next_upper == "OR") {
+ auto right = ParseAndExpression(ss);
+ left = std::make_unique(LogicType::OR, std::move(left), std::move(right));
+ } else {
+ ss.clear();
+ ss.seekg(pos);
+ break;
+ }
+ }
+ return left;
+}
+
+std::unique_ptr Parser::ParseAndExpression(std::istringstream& ss) {
+ auto left = ParseComparison(ss);
+
+ while (true) {
+ std::streampos pos = ss.tellg();
+ std::string next;
+ if (!(ss >> next)) break;
+ std::string next_upper = next;
+ for (auto& c : next_upper) c = toupper(c);
+
+ if (next_upper == "AND") {
+ auto right = ParseComparison(ss);
+ left = std::make_unique(LogicType::AND, std::move(left), std::move(right));
+ } else {
+ ss.clear();
+ ss.seekg(pos);
+ break;
+ }
+ }
+ return left;
+}
+
+std::unique_ptr Parser::ParseComparison(std::istringstream& ss) {
+ auto left = ParseAtom(ss);
+
+ std::streampos pos = ss.tellg();
+ std::string op;
+ if (ss >> op) {
+ ComparisonType comp_type;
+ bool is_comp = true;
+ if (op == "=") comp_type = ComparisonType::EQ;
+ else if (op == "!=") comp_type = ComparisonType::NEQ;
+ else if (op == "<") comp_type = ComparisonType::LT;
+ else if (op == ">") comp_type = ComparisonType::GT;
+ else if (op == "<=") comp_type = ComparisonType::LTE;
+ else if (op == ">=") comp_type = ComparisonType::GTE;
+ else is_comp = false;
+
+ if (is_comp) {
+ auto right = ParseAtom(ss);
+ return std::make_unique(comp_type, std::move(left), std::move(right));
+ } else {
+ ss.clear();
+ ss.seekg(pos);
+ }
+ }
+ return left;
+}
+
+std::unique_ptr Parser::ParseAtom(std::istringstream& ss) {
+ std::string token;
+
+ if (!(ss >> token)) return nullptr;
+
+ if (token == "(") {
+ auto expr = ParseExpression(ss);
+ std::string close;
+ if (!(ss >> close) || close != ")") return nullptr; // missing closing paren
+ return expr;
+ }
+
+ // Strip trailing ';' or ')' if attached
+ while (!token.empty() && (token.back() == ';' || token.back() == ')')) {
+ token.pop_back();
+ }
+
+ // Strip leading '(' if attached
+ if (!token.empty() && token.front() == '(') {
+ token.erase(0, 1);
+ }
+
+ size_t dot_pos = token.find('.');
+ if (dot_pos != std::string::npos) {
+ return std::make_unique(token.substr(0, dot_pos), token.substr(dot_pos + 1));
+ }
+
+ std::string token_upper = token;
+ for (auto& c : token_upper) c = toupper(c);
+ if (token_upper == "TRUE") {
+ return std::make_unique("true");
+ } else if (token_upper == "FALSE") {
+ return std::make_unique("false");
+ } else if ((token[0] >= '0' && token[0] <= '9') || (token[0] == '-' && token.size() > 1)) {
+ return std::make_unique(token);
+ }
+
+ return std::make_unique("", token);
+}
diff --git a/src/type/tuple.cpp b/src/type/tuple.cpp
index 937f78e..f1c88c8 100644
--- a/src/type/tuple.cpp
+++ b/src/type/tuple.cpp
@@ -28,3 +28,11 @@ uint32_t Tuple::GetLength() const {
Tuple::Tuple(const char* raw_data, uint32_t size) {
data_.assign(raw_data, raw_data + size);
}
+
+Tuple Tuple::Merge(const Tuple& left, const Tuple& right) {
+ Tuple merged;
+ merged.data_.reserve(left.GetLength() + right.GetLength());
+ merged.data_.insert(merged.data_.end(), left.data_.begin(), left.data_.end());
+ merged.data_.insert(merged.data_.end(), right.data_.begin(), right.data_.end());
+ return merged;
+}
diff --git a/src/type/value.cpp b/src/type/value.cpp
index 38a4119..5ccf855 100644
--- a/src/type/value.cpp
+++ b/src/type/value.cpp
@@ -35,3 +35,42 @@ void Value::DeserializeFromChar(const char *data, TypeId type_id) {
data_ = std::make_any(val);
}
}
+
+bool Value::CompareEquals(const Value& other) const {
+ if (type_id_ != other.type_id_) return false;
+ if (type_id_ == TypeId::INTEGER) return Get() == other.Get();
+ if (type_id_ == TypeId::BOOLEAN) return Get() == other.Get();
+ return false;
+}
+
+bool Value::CompareNotEqual(const Value& other) const {
+ return !CompareEquals(other);
+}
+
+bool Value::CompareLessThan(const Value& other) const {
+ if (type_id_ != other.type_id_) return false;
+ if (type_id_ == TypeId::INTEGER) return Get() < other.Get();
+ if (type_id_ == TypeId::BOOLEAN) return Get() < other.Get();
+ return false;
+}
+
+bool Value::CompareGreaterThan(const Value& other) const {
+ if (type_id_ != other.type_id_) return false;
+ if (type_id_ == TypeId::INTEGER) return Get() > other.Get();
+ if (type_id_ == TypeId::BOOLEAN) return Get() > other.Get();
+ return false;
+}
+
+bool Value::CompareLessThanOrEqual(const Value& other) const {
+ if (type_id_ != other.type_id_) return false;
+ if (type_id_ == TypeId::INTEGER) return Get() <= other.Get();
+ if (type_id_ == TypeId::BOOLEAN) return Get() <= other.Get();
+ return false;
+}
+
+bool Value::CompareGreaterThanOrEqual(const Value& other) const {
+ if (type_id_ != other.type_id_) return false;
+ if (type_id_ == TypeId::INTEGER) return Get() >= other.Get();
+ if (type_id_ == TypeId::BOOLEAN) return Get() >= other.Get();
+ return false;
+}
diff --git a/tests/test_join.cpp b/tests/test_join.cpp
new file mode 100644
index 0000000..1aac678
--- /dev/null
+++ b/tests/test_join.cpp
@@ -0,0 +1,79 @@
+#include
+#include "tests.h"
+#include "catalog/catalog.h"
+#include "query/executor.h"
+#include "query/parser.h"
+
+#define ASSERT_JOIN(expr) \
+ if (!(expr)) { \
+ std::cout << "[FAIL] " << #expr << " on line " << __LINE__ << "\n"; \
+ failed_count++; \
+ } else { \
+ passed_count++; \
+ }
+
+void test::TestJoin() {
+ int passed_count = 0;
+ int failed_count = 0;
+
+ std::cout << "\n--- Testing JOIN ---\n";
+
+ Catalog catalog;
+ Executor executor(catalog);
+ Parser parser;
+
+ auto exec = [&](const std::string& sql) -> QueryResult {
+ auto stmt = parser.Parse(sql);
+ ASSERT_JOIN(stmt != nullptr);
+ if (!stmt) return {false, "Parse failed", {}};
+ return executor.Execute(*stmt);
+ };
+
+ // 1. Setup tables
+ exec("CREATE TABLE A (id INTEGER, val INTEGER);");
+ exec("CREATE TABLE B (a_id INTEGER, type INTEGER);");
+
+ // Insert data into A
+ exec("INSERT INTO A VALUES (1, 100);");
+ exec("INSERT INTO A VALUES (2, 200);");
+ exec("INSERT INTO A VALUES (3, 300);");
+
+ // Insert data into B
+ exec("INSERT INTO B VALUES (1, 10);");
+ exec("INSERT INTO B VALUES (2, 20);");
+ exec("INSERT INTO B VALUES (2, 20);"); // Duplicate for dedup checking
+ exec("INSERT INTO B VALUES (4, 40);");
+
+ std::cout << "Testing Basic Equi-Join..." << std::endl;
+ auto res1 = exec("SELECT * FROM A JOIN B ON A.id = B.a_id;");
+ ASSERT_JOIN(res1.success);
+ ASSERT_JOIN(res1.rows.size() == 3); // (1,1), (2,2), (2,2)
+
+ std::cout << "Testing Theta-Only Join..." << std::endl;
+ auto res2 = exec("SELECT * FROM A JOIN B ON A.val > B.type;");
+ ASSERT_JOIN(res2.success);
+ ASSERT_JOIN(res2.rows.size() == 12); // Cross product (3 * 4) since all A.val (100+) > B.type (10-40)
+
+ std::cout << "Testing Mixed AND (order independent)..." << std::endl;
+ auto res3 = exec("SELECT * FROM A JOIN B ON A.id = B.a_id AND A.val = 200;");
+ ASSERT_JOIN(res3.success);
+ ASSERT_JOIN(res3.rows.size() == 2); // only id=2 matches, B has two rows for a_id=2
+
+ auto res4 = exec("SELECT * FROM A JOIN B ON A.val = 200 AND A.id = B.a_id;");
+ ASSERT_JOIN(res4.success);
+ ASSERT_JOIN(res4.rows.size() == 2);
+
+ std::cout << "Testing OR with Mixed Conditions (Dedup)..." << std::endl;
+ // id=1 OR val=200
+ auto res5 = exec("SELECT * FROM A JOIN B ON A.id = B.a_id AND A.val = 100 OR A.val = 200 AND A.id = B.a_id;");
+ ASSERT_JOIN(res5.success);
+ ASSERT_JOIN(res5.rows.size() == 3);
+
+ std::cout << "Testing Empty Table Join..." << std::endl;
+ exec("CREATE TABLE C (id INTEGER);");
+ auto res6 = exec("SELECT * FROM A JOIN C ON A.id = C.id;");
+ ASSERT_JOIN(res6.success);
+ ASSERT_JOIN(res6.rows.size() == 0);
+
+ std::cout << "JOIN test summary: " << passed_count << " passed, " << failed_count << " failed.\n";
+}
diff --git a/tests/test_main.cpp b/tests/test_main.cpp
index 5d97118..9749036 100644
--- a/tests/test_main.cpp
+++ b/tests/test_main.cpp
@@ -12,5 +12,6 @@ int main() {
tests.TestCatalogClass();
tests.TestTableIteratorClass();
tests.TestQueryEngineClass();
+ tests.TestJoin();
return 0;
}