Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
368 changes: 287 additions & 81 deletions simple_repository_browser/_search.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import dataclasses
from enum import Enum
import re
Expand Down Expand Up @@ -45,72 +47,289 @@ def normalise_name(name: str) -> str:
return re.sub(r"[-_.]+", "-", name).lower()


# A safe SQL statement must not use *any* user-defined input in the resulting first argument (the SQL query),
# rather any user input MUST be provided as part of the arguments (second part of the value), which will be passed
# to SQLITE to deal with.
SafeSQLStmt = typing.Tuple[str, typing.Tuple[typing.Any, ...]]


def prepare_name(term: Filter) -> SafeSQLStmt:
if term.value.startswith('"'):
# Match the phase precisely.
value = term.value[1:-1]
else:
value = normalise_name(term.value)
value = value.replace("*", "%")
return "canonical_name LIKE ?", (f"%{value}%",)


def prepare_summary(term: Filter) -> SafeSQLStmt:
if term.value.startswith('"'):
# Match the phase precisely.
value = term.value[1:-1]
else:
value = term.value
value = value.replace("*", "%")
return "summary LIKE ?", (f"%{value}%",)


def build_sql(term: typing.Union[Term, typing.Tuple[Term, ...]]) -> SafeSQLStmt:
# Return query and params to be used in SQL. query MUST not be produced using untrusted input, as is vulnerable to SQL injection.
# Instead, any user input must be in the parameters, which undergoes sqllite built-in cleaning.
if isinstance(term, tuple):
if len(term) == 0:
return "", ()

# No known query can produce a multi-value term
assert len(term) == 1
return build_sql(term[0])

if isinstance(term, Filter):
if term.filter_on == FilterOn.name_or_summary:
sql1, terms1 = prepare_name(term)
sql2, terms2 = prepare_summary(term)
return f"({sql1} OR {sql2})", terms1 + terms2
elif term.filter_on == FilterOn.name:
return prepare_name(term)
elif term.filter_on == FilterOn.summary:
return prepare_summary(term)
@dataclasses.dataclass(frozen=True)
class SQLBuilder:
"""Immutable SQL WHERE and ORDER BY clauses with parameters."""

where_clause: str
where_params: tuple[typing.Any, ...]
order_clause: str
order_params: tuple[typing.Any, ...]
search_context: SearchContext

def build_complete_query(
self,
base_select: str,
limit: int,
offset: int,
) -> tuple[str, tuple[typing.Any, ...]]:
"""Build complete query with LIMIT/OFFSET"""
where_part = f"WHERE {self.where_clause}" if self.where_clause else ""
query = f"{base_select} {where_part} {self.order_clause} LIMIT ? OFFSET ?"
return query, self.where_params + self.order_params + (limit, offset)

def with_where(self, clause: str, params: tuple[typing.Any, ...]) -> SQLBuilder:
"""Return new SQLBuilder with updated WHERE clause"""
return dataclasses.replace(self, where_clause=clause, where_params=params)

def with_order(self, clause: str, params: tuple[typing.Any, ...]) -> SQLBuilder:
"""Return new SQLBuilder with updated ORDER BY clause"""
return dataclasses.replace(self, order_clause=clause, order_params=params)


@dataclasses.dataclass(frozen=True)
class SearchContext:
"""Context collected during WHERE clause building."""

exact_names: tuple[str, ...] = ()
fuzzy_patterns: tuple[str, ...] = ()

def with_exact_name(self, name: str) -> SearchContext:
"""Add an exact name match."""
if name in self.exact_names:
return self
else:
return dataclasses.replace(self, exact_names=self.exact_names + (name,))

def with_fuzzy_pattern(self, pattern: str) -> SearchContext:
"""Add a fuzzy search pattern."""
if pattern in self.fuzzy_patterns:
return self
else:
return dataclasses.replace(
self, fuzzy_patterns=self.fuzzy_patterns + (pattern,)
)

def merge(self, other: SearchContext) -> SearchContext:
"""Merge contexts from multiple terms (for OR/AND)."""
names = self.exact_names + tuple(
name for name in other.exact_names if name not in self.exact_names
)
patterns = self.fuzzy_patterns + tuple(
pattern
for pattern in other.fuzzy_patterns
if pattern not in self.fuzzy_patterns
)

return dataclasses.replace(self, exact_names=names, fuzzy_patterns=patterns)


class SearchCompiler:
"""Extensible visitor-pattern compiler for search terms to SQL.

Uses AST-style method dispatch: visit_TermName maps to handle_term_TermName.
Subclasses can override specific handlers for customisation.
"""

@classmethod
def compile(cls, term: Term | None) -> SQLBuilder:
"""Compile search terms into SQL WHERE and ORDER BY clauses."""
if term is None:
return SQLBuilder(
where_clause="",
where_params=(),
order_clause="",
order_params=(),
search_context=SearchContext(),
)

# Build WHERE clause and collect context
context = SearchContext()
where_clause, where_params, final_context = cls._visit_term(term, context)

# Build ORDER BY clause based on collected context
order_clause, order_params = cls._build_ordering_from_context(final_context)

return SQLBuilder(
where_clause=where_clause,
where_params=where_params,
order_clause=order_clause,
order_params=order_params,
search_context=final_context,
)

@classmethod
def _visit_term(
cls, term: Term, context: SearchContext
) -> tuple[str, tuple[typing.Any, ...], SearchContext]:
"""Dispatch to appropriate handler using AST-style method naming."""
method_name = f"handle_term_{type(term).__name__}"
handler = getattr(cls, method_name, None)
if handler is None:
raise ValueError(f"No handler for term type {type(term).__name__}")
return handler(term, context)

@classmethod
def handle_term_Filter(
cls, term: Filter, context: SearchContext
) -> tuple[str, tuple[typing.Any, ...], SearchContext]:
"""Dispatch to field-specific filter handler."""
match term.filter_on:
case FilterOn.name_or_summary:
return cls.handle_filter_name_or_summary(term, context)
case FilterOn.name:
return cls.handle_filter_name(term, context)
case FilterOn.summary:
return cls.handle_filter_summary(term, context)
case _:
raise ValueError(f"Unhandled filter on {term.filter_on}")

@classmethod
def handle_term_And(
cls, term: And, context: SearchContext
) -> tuple[str, tuple[typing.Any, ...], SearchContext]:
lhs_sql, lhs_params, lhs_context = cls._visit_term(term.lhs, context)
rhs_sql, rhs_params, rhs_context = cls._visit_term(term.rhs, context)

merged_context = lhs_context.merge(rhs_context)
return f"({lhs_sql} AND {rhs_sql})", lhs_params + rhs_params, merged_context

@classmethod
def handle_term_Or(
cls, term: Or, context: SearchContext
) -> tuple[str, tuple[typing.Any, ...], SearchContext]:
lhs_sql, lhs_params, lhs_context = cls._visit_term(term.lhs, context)
rhs_sql, rhs_params, rhs_context = cls._visit_term(term.rhs, context)

merged_context = lhs_context.merge(rhs_context)
return f"({lhs_sql} OR {rhs_sql})", lhs_params + rhs_params, merged_context

@classmethod
def handle_term_Not(
cls, term: Not, context: SearchContext
) -> tuple[str, tuple[typing.Any, ...], SearchContext]:
inner_sql, inner_params, _ = cls._visit_term(term.term, context)
return f"(NOT {inner_sql})", inner_params, context

@classmethod
def handle_filter_name(
cls, term: Filter, context: SearchContext
) -> tuple[str, tuple[typing.Any, ...], SearchContext]:
if term.value.startswith('"'):
# Exact quoted match
value = term.value[1:-1]
normalised = normalise_name(value)
new_context = context.with_exact_name(normalised)
return "canonical_name = ?", (normalised,), new_context
else:
raise ValueError(f"Unhandled filter on {term.filter_on}")
elif isinstance(term, And):
sql1, terms1 = build_sql(term.lhs)
sql2, terms2 = build_sql(term.rhs)
return f"({sql1} AND {sql2})", terms1 + terms2
elif isinstance(term, Or):
sql1, terms1 = build_sql(term.lhs)
sql2, terms2 = build_sql(term.rhs)
return f"({sql1} OR {sql2})", terms1 + terms2
elif isinstance(term, Not):
sql1, terms1 = build_sql(term.term)
return f"(Not {sql1})", terms1
else:
raise ValueError(f"unknown term type {type(term)}")


def query_to_sql(query) -> SafeSQLStmt:
terms = parse(query)
return build_sql(terms)
normalised = normalise_name(term.value)
if "*" in term.value:
# Fuzzy wildcard search - respect wildcard position
# "numpy*" > "numpy%", "*numpy" > "%numpy", "*numpy*" > "%numpy%"
pattern = normalised.replace("*", "%")
new_context = context.with_fuzzy_pattern(pattern)
return "canonical_name LIKE ?", (pattern,), new_context
else:
# Simple name search
new_context = context.with_exact_name(normalised)
return "canonical_name LIKE ?", (f"%{normalised}%",), new_context

@classmethod
def handle_filter_summary(
cls, term: Filter, context: SearchContext
) -> tuple[str, tuple[typing.Any, ...], SearchContext]:
if term.value.startswith('"'):
value = term.value[1:-1]
else:
value = term.value
value = value.replace("*", "%")
return "summary LIKE ?", (f"%{value}%",), context

@classmethod
def handle_filter_name_or_summary(
cls, term: Filter, context: SearchContext
) -> tuple[str, tuple[typing.Any, ...], SearchContext]:
"""Handle filtering across both name and summary fields."""
name_sql, name_params, name_context = cls.handle_filter_name(term, context)
summary_sql, summary_params, _ = cls.handle_filter_summary(term, context)

combined_sql = f"({name_sql} OR {summary_sql})"
combined_params = name_params + summary_params
return combined_sql, combined_params, name_context

@classmethod
def _build_ordering_from_context(
cls, context: SearchContext
) -> tuple[str, tuple[typing.Any, ...]]:
"""Build mixed ordering for exact names and fuzzy patterns."""

exact_names, fuzzy_patterns = context.exact_names, context.fuzzy_patterns
order_parts = []
all_params = []

# Build single comprehensive CASE statement for priority
case_conditions = []

# Add exact match conditions (priority 0)
for name in exact_names:
case_conditions.append(f"WHEN canonical_name = ? THEN 0")
all_params.append(name)

# Add fuzzy pattern conditions (priority 1)
for pattern in fuzzy_patterns:
case_conditions.append(f"WHEN canonical_name LIKE ? THEN 1")
all_params.append(f"%{pattern}%")

# Add exact-related conditions (priority 2)
for name in exact_names:
case_conditions.append(f"WHEN canonical_name LIKE ? THEN 2") # prefix
case_conditions.append(f"WHEN canonical_name LIKE ? THEN 2") # suffix
all_params.extend([f"{name}%", f"%{name}"])

if case_conditions:
cond = "\n".join(case_conditions)
priority_expr = f"""
CASE
{cond}
ELSE 3
END
"""
order_parts.append(priority_expr)

# Length-based ordering for fuzzy matches (reuse same pattern logic)
if fuzzy_patterns:
length_conditions = []
for pattern in fuzzy_patterns:
length_conditions.append(f"canonical_name LIKE ?")
all_params.append(f"%{pattern}%")

length_expr = f"CASE WHEN ({' OR '.join(length_conditions)}) THEN LENGTH(canonical_name) ELSE 999999 END"
order_parts.append(length_expr)

# Prefix distance for exact names
if exact_names:
distance_conditions = []
for name in exact_names:
distance_conditions.append(
f"WHEN INSTR(canonical_name, ?) > 0 THEN (INSTR(canonical_name, ?) - 1)"
)
all_params.extend([name, name])

if distance_conditions:
cond = "\n".join(distance_conditions)
distance_expr = f"""
CASE
{cond}
ELSE 999999
END
"""
order_parts.append(distance_expr)

# Alphabetical fallback
order_parts.append("canonical_name")

order_clause = f"ORDER BY {', '.join(order_parts)}"
return order_clause, tuple(all_params)


def build_sql(term: Term | None) -> SQLBuilder:
"""Build SQL WHERE and ORDER BY clauses from search terms."""
return SearchCompiler.compile(term)


def query_to_sql(query) -> SQLBuilder:
term = parse(query)
return build_sql(term)


grammar = parsley.makeGrammar(
Expand Down Expand Up @@ -141,8 +360,8 @@ def query_to_sql(query) -> SafeSQLStmt:
|filter:filter -> filter
|'-' filters:filters -> Not(filters)
)
search_terms = (filters+:filters -> tuple(filters)
| -> ())
search_terms = (filters:filters -> filters
| -> None)
"""),
{
"And": And,
Expand All @@ -154,21 +373,8 @@ def query_to_sql(query) -> SafeSQLStmt:
)


def parse(query: str) -> typing.Tuple[Term, ...]:
def parse(query: str) -> Term | None:
return grammar(query.strip()).search_terms()


ParseError = parsley.ParseError


def simple_name_from_query(terms: typing.Tuple[Term, ...]) -> typing.Optional[str]:
"""If possible, give a simple (normalized) package name which represents the query terms provided"""
for term in terms:
if isinstance(term, Filter):
if term.filter_on in [FilterOn.name_or_summary, FilterOn.name]:
if "*" in term.value or '"' in term.value:
break
return normalise_name(term.value)
else:
break
return None
Loading