diff --git a/simple_repository_browser/_search.py b/simple_repository_browser/_search.py index 4fcc96e..0bdbb7a 100644 --- a/simple_repository_browser/_search.py +++ b/simple_repository_browser/_search.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses from enum import Enum import re @@ -45,72 +47,289 @@ def normalise_name(name: str) -> str: return re.sub(r"[-_.]+", "-", name).lower() -# A safe SQL statement must not use *any* user-defined input in the resulting first argument (the SQL query), -# rather any user input MUST be provided as part of the arguments (second part of the value), which will be passed -# to SQLITE to deal with. -SafeSQLStmt = typing.Tuple[str, typing.Tuple[typing.Any, ...]] - - -def prepare_name(term: Filter) -> SafeSQLStmt: - if term.value.startswith('"'): - # Match the phase precisely. - value = term.value[1:-1] - else: - value = normalise_name(term.value) - value = value.replace("*", "%") - return "canonical_name LIKE ?", (f"%{value}%",) - - -def prepare_summary(term: Filter) -> SafeSQLStmt: - if term.value.startswith('"'): - # Match the phase precisely. - value = term.value[1:-1] - else: - value = term.value - value = value.replace("*", "%") - return "summary LIKE ?", (f"%{value}%",) - - -def build_sql(term: typing.Union[Term, typing.Tuple[Term, ...]]) -> SafeSQLStmt: - # Return query and params to be used in SQL. query MUST not be produced using untrusted input, as is vulnerable to SQL injection. - # Instead, any user input must be in the parameters, which undergoes sqllite built-in cleaning. - if isinstance(term, tuple): - if len(term) == 0: - return "", () - - # No known query can produce a multi-value term - assert len(term) == 1 - return build_sql(term[0]) - - if isinstance(term, Filter): - if term.filter_on == FilterOn.name_or_summary: - sql1, terms1 = prepare_name(term) - sql2, terms2 = prepare_summary(term) - return f"({sql1} OR {sql2})", terms1 + terms2 - elif term.filter_on == FilterOn.name: - return prepare_name(term) - elif term.filter_on == FilterOn.summary: - return prepare_summary(term) +@dataclasses.dataclass(frozen=True) +class SQLBuilder: + """Immutable SQL WHERE and ORDER BY clauses with parameters.""" + + where_clause: str + where_params: tuple[typing.Any, ...] + order_clause: str + order_params: tuple[typing.Any, ...] + search_context: SearchContext + + def build_complete_query( + self, + base_select: str, + limit: int, + offset: int, + ) -> tuple[str, tuple[typing.Any, ...]]: + """Build complete query with LIMIT/OFFSET""" + where_part = f"WHERE {self.where_clause}" if self.where_clause else "" + query = f"{base_select} {where_part} {self.order_clause} LIMIT ? OFFSET ?" + return query, self.where_params + self.order_params + (limit, offset) + + def with_where(self, clause: str, params: tuple[typing.Any, ...]) -> SQLBuilder: + """Return new SQLBuilder with updated WHERE clause""" + return dataclasses.replace(self, where_clause=clause, where_params=params) + + def with_order(self, clause: str, params: tuple[typing.Any, ...]) -> SQLBuilder: + """Return new SQLBuilder with updated ORDER BY clause""" + return dataclasses.replace(self, order_clause=clause, order_params=params) + + +@dataclasses.dataclass(frozen=True) +class SearchContext: + """Context collected during WHERE clause building.""" + + exact_names: tuple[str, ...] = () + fuzzy_patterns: tuple[str, ...] = () + + def with_exact_name(self, name: str) -> SearchContext: + """Add an exact name match.""" + if name in self.exact_names: + return self + else: + return dataclasses.replace(self, exact_names=self.exact_names + (name,)) + + def with_fuzzy_pattern(self, pattern: str) -> SearchContext: + """Add a fuzzy search pattern.""" + if pattern in self.fuzzy_patterns: + return self + else: + return dataclasses.replace( + self, fuzzy_patterns=self.fuzzy_patterns + (pattern,) + ) + + def merge(self, other: SearchContext) -> SearchContext: + """Merge contexts from multiple terms (for OR/AND).""" + names = self.exact_names + tuple( + name for name in other.exact_names if name not in self.exact_names + ) + patterns = self.fuzzy_patterns + tuple( + pattern + for pattern in other.fuzzy_patterns + if pattern not in self.fuzzy_patterns + ) + + return dataclasses.replace(self, exact_names=names, fuzzy_patterns=patterns) + + +class SearchCompiler: + """Extensible visitor-pattern compiler for search terms to SQL. + + Uses AST-style method dispatch: visit_TermName maps to handle_term_TermName. + Subclasses can override specific handlers for customisation. + """ + + @classmethod + def compile(cls, term: Term | None) -> SQLBuilder: + """Compile search terms into SQL WHERE and ORDER BY clauses.""" + if term is None: + return SQLBuilder( + where_clause="", + where_params=(), + order_clause="", + order_params=(), + search_context=SearchContext(), + ) + + # Build WHERE clause and collect context + context = SearchContext() + where_clause, where_params, final_context = cls._visit_term(term, context) + + # Build ORDER BY clause based on collected context + order_clause, order_params = cls._build_ordering_from_context(final_context) + + return SQLBuilder( + where_clause=where_clause, + where_params=where_params, + order_clause=order_clause, + order_params=order_params, + search_context=final_context, + ) + + @classmethod + def _visit_term( + cls, term: Term, context: SearchContext + ) -> tuple[str, tuple[typing.Any, ...], SearchContext]: + """Dispatch to appropriate handler using AST-style method naming.""" + method_name = f"handle_term_{type(term).__name__}" + handler = getattr(cls, method_name, None) + if handler is None: + raise ValueError(f"No handler for term type {type(term).__name__}") + return handler(term, context) + + @classmethod + def handle_term_Filter( + cls, term: Filter, context: SearchContext + ) -> tuple[str, tuple[typing.Any, ...], SearchContext]: + """Dispatch to field-specific filter handler.""" + match term.filter_on: + case FilterOn.name_or_summary: + return cls.handle_filter_name_or_summary(term, context) + case FilterOn.name: + return cls.handle_filter_name(term, context) + case FilterOn.summary: + return cls.handle_filter_summary(term, context) + case _: + raise ValueError(f"Unhandled filter on {term.filter_on}") + + @classmethod + def handle_term_And( + cls, term: And, context: SearchContext + ) -> tuple[str, tuple[typing.Any, ...], SearchContext]: + lhs_sql, lhs_params, lhs_context = cls._visit_term(term.lhs, context) + rhs_sql, rhs_params, rhs_context = cls._visit_term(term.rhs, context) + + merged_context = lhs_context.merge(rhs_context) + return f"({lhs_sql} AND {rhs_sql})", lhs_params + rhs_params, merged_context + + @classmethod + def handle_term_Or( + cls, term: Or, context: SearchContext + ) -> tuple[str, tuple[typing.Any, ...], SearchContext]: + lhs_sql, lhs_params, lhs_context = cls._visit_term(term.lhs, context) + rhs_sql, rhs_params, rhs_context = cls._visit_term(term.rhs, context) + + merged_context = lhs_context.merge(rhs_context) + return f"({lhs_sql} OR {rhs_sql})", lhs_params + rhs_params, merged_context + + @classmethod + def handle_term_Not( + cls, term: Not, context: SearchContext + ) -> tuple[str, tuple[typing.Any, ...], SearchContext]: + inner_sql, inner_params, _ = cls._visit_term(term.term, context) + return f"(NOT {inner_sql})", inner_params, context + + @classmethod + def handle_filter_name( + cls, term: Filter, context: SearchContext + ) -> tuple[str, tuple[typing.Any, ...], SearchContext]: + if term.value.startswith('"'): + # Exact quoted match + value = term.value[1:-1] + normalised = normalise_name(value) + new_context = context.with_exact_name(normalised) + return "canonical_name = ?", (normalised,), new_context else: - raise ValueError(f"Unhandled filter on {term.filter_on}") - elif isinstance(term, And): - sql1, terms1 = build_sql(term.lhs) - sql2, terms2 = build_sql(term.rhs) - return f"({sql1} AND {sql2})", terms1 + terms2 - elif isinstance(term, Or): - sql1, terms1 = build_sql(term.lhs) - sql2, terms2 = build_sql(term.rhs) - return f"({sql1} OR {sql2})", terms1 + terms2 - elif isinstance(term, Not): - sql1, terms1 = build_sql(term.term) - return f"(Not {sql1})", terms1 - else: - raise ValueError(f"unknown term type {type(term)}") - - -def query_to_sql(query) -> SafeSQLStmt: - terms = parse(query) - return build_sql(terms) + normalised = normalise_name(term.value) + if "*" in term.value: + # Fuzzy wildcard search - respect wildcard position + # "numpy*" > "numpy%", "*numpy" > "%numpy", "*numpy*" > "%numpy%" + pattern = normalised.replace("*", "%") + new_context = context.with_fuzzy_pattern(pattern) + return "canonical_name LIKE ?", (pattern,), new_context + else: + # Simple name search + new_context = context.with_exact_name(normalised) + return "canonical_name LIKE ?", (f"%{normalised}%",), new_context + + @classmethod + def handle_filter_summary( + cls, term: Filter, context: SearchContext + ) -> tuple[str, tuple[typing.Any, ...], SearchContext]: + if term.value.startswith('"'): + value = term.value[1:-1] + else: + value = term.value + value = value.replace("*", "%") + return "summary LIKE ?", (f"%{value}%",), context + + @classmethod + def handle_filter_name_or_summary( + cls, term: Filter, context: SearchContext + ) -> tuple[str, tuple[typing.Any, ...], SearchContext]: + """Handle filtering across both name and summary fields.""" + name_sql, name_params, name_context = cls.handle_filter_name(term, context) + summary_sql, summary_params, _ = cls.handle_filter_summary(term, context) + + combined_sql = f"({name_sql} OR {summary_sql})" + combined_params = name_params + summary_params + return combined_sql, combined_params, name_context + + @classmethod + def _build_ordering_from_context( + cls, context: SearchContext + ) -> tuple[str, tuple[typing.Any, ...]]: + """Build mixed ordering for exact names and fuzzy patterns.""" + + exact_names, fuzzy_patterns = context.exact_names, context.fuzzy_patterns + order_parts = [] + all_params = [] + + # Build single comprehensive CASE statement for priority + case_conditions = [] + + # Add exact match conditions (priority 0) + for name in exact_names: + case_conditions.append(f"WHEN canonical_name = ? THEN 0") + all_params.append(name) + + # Add fuzzy pattern conditions (priority 1) + for pattern in fuzzy_patterns: + case_conditions.append(f"WHEN canonical_name LIKE ? THEN 1") + all_params.append(f"%{pattern}%") + + # Add exact-related conditions (priority 2) + for name in exact_names: + case_conditions.append(f"WHEN canonical_name LIKE ? THEN 2") # prefix + case_conditions.append(f"WHEN canonical_name LIKE ? THEN 2") # suffix + all_params.extend([f"{name}%", f"%{name}"]) + + if case_conditions: + cond = "\n".join(case_conditions) + priority_expr = f""" + CASE + {cond} + ELSE 3 + END + """ + order_parts.append(priority_expr) + + # Length-based ordering for fuzzy matches (reuse same pattern logic) + if fuzzy_patterns: + length_conditions = [] + for pattern in fuzzy_patterns: + length_conditions.append(f"canonical_name LIKE ?") + all_params.append(f"%{pattern}%") + + length_expr = f"CASE WHEN ({' OR '.join(length_conditions)}) THEN LENGTH(canonical_name) ELSE 999999 END" + order_parts.append(length_expr) + + # Prefix distance for exact names + if exact_names: + distance_conditions = [] + for name in exact_names: + distance_conditions.append( + f"WHEN INSTR(canonical_name, ?) > 0 THEN (INSTR(canonical_name, ?) - 1)" + ) + all_params.extend([name, name]) + + if distance_conditions: + cond = "\n".join(distance_conditions) + distance_expr = f""" + CASE + {cond} + ELSE 999999 + END + """ + order_parts.append(distance_expr) + + # Alphabetical fallback + order_parts.append("canonical_name") + + order_clause = f"ORDER BY {', '.join(order_parts)}" + return order_clause, tuple(all_params) + + +def build_sql(term: Term | None) -> SQLBuilder: + """Build SQL WHERE and ORDER BY clauses from search terms.""" + return SearchCompiler.compile(term) + + +def query_to_sql(query) -> SQLBuilder: + term = parse(query) + return build_sql(term) grammar = parsley.makeGrammar( @@ -141,8 +360,8 @@ def query_to_sql(query) -> SafeSQLStmt: |filter:filter -> filter |'-' filters:filters -> Not(filters) ) - search_terms = (filters+:filters -> tuple(filters) - | -> ()) + search_terms = (filters:filters -> filters + | -> None) """), { "And": And, @@ -154,21 +373,8 @@ def query_to_sql(query) -> SafeSQLStmt: ) -def parse(query: str) -> typing.Tuple[Term, ...]: +def parse(query: str) -> Term | None: return grammar(query.strip()).search_terms() ParseError = parsley.ParseError - - -def simple_name_from_query(terms: typing.Tuple[Term, ...]) -> typing.Optional[str]: - """If possible, give a simple (normalized) package name which represents the query terms provided""" - for term in terms: - if isinstance(term, Filter): - if term.filter_on in [FilterOn.name_or_summary, FilterOn.name]: - if "*" in term.value or '"' in term.value: - break - return normalise_name(term.value) - else: - break - return None diff --git a/simple_repository_browser/model.py b/simple_repository_browser/model.py index 760d7d8..fc01c4b 100644 --- a/simple_repository_browser/model.py +++ b/simple_repository_browser/model.py @@ -114,26 +114,28 @@ async def project_query( self, query: str, page_size: int, page: int ) -> QueryResultModel: try: - search_terms = _search.parse(query) + search_term = _search.parse(query) except _search.ParseError: raise errors.InvalidSearchQuery("Invalid search pattern") - if not search_terms: + if search_term is None: raise errors.InvalidSearchQuery("Please specify a search query") try: - condition_query, condition_terms = _search.build_sql(search_terms) + sql_builder = _search.build_sql(search_term) except ValueError as err: raise errors.InvalidSearchQuery(f"Search query invalid ({str(err)})") - single_name_proposal = _search.simple_name_from_query(search_terms) - exact = None - offset = (page - 1) * page_size # page is 1 based. with self.projects_db as cursor: + # Count query uses only WHERE parameters + count_query = f"SELECT COUNT(*) as count FROM projects" + if sql_builder.where_clause: + count_query += f" WHERE {sql_builder.where_clause}" + result_count = cursor.execute( - f"SELECT COUNT(*) as count FROM projects WHERE {condition_query}", - condition_terms, + count_query, + sql_builder.where_params, ).fetchone() n_results = result_count["count"] @@ -143,29 +145,38 @@ async def project_query( f"Requested page (page: {page}) is beyond the number of pages ({n_pages})", ) - results = cursor.execute( - "SELECT canonical_name, summary, release_version, release_date FROM projects WHERE " - f"{condition_query} LIMIT ? OFFSET ?", - condition_terms + (page_size, offset), - ).fetchall() + # Main query uses the builder's complete query method + query, params = sql_builder.build_complete_query( + "SELECT canonical_name, summary, release_version, release_date FROM projects", + page_size, + offset, + ) + results = cursor.execute(query, params).fetchall() # Convert results to SearchResultItem objects results = [SearchResultItem(*result) for result in results] - # Check if single_name_proposal is already in the results - if single_name_proposal and page == 1: - exact_found = any(r.canonical_name == single_name_proposal for r in results) - if not exact_found: - # Not in results, check if it exists in repository - try: - await self.source.get_project_page(single_name_proposal) - # Package exists in repository! Add it to the beginning - results.insert( - 0, SearchResultItem(canonical_name=single_name_proposal) - ) - n_results += 1 - except PackageNotFoundError: - pass + # If the search was for a specific name, then make sure we return it if + # it is in the package repository. + if page == 1: + result_names = {r.canonical_name for r in results} + missing = tuple( + name + for name in sql_builder.search_context.exact_names + if name not in result_names + ) + if missing: + for name_proposal in reversed(missing): + # Not in results, check if it exists in repository + try: + await self.source.get_project_page(name_proposal) + # Package exists in repository! Add it to the beginning + results.insert( + 0, SearchResultItem(canonical_name=name_proposal) + ) + n_results += 1 + except PackageNotFoundError: + pass return QueryResultModel( search_query=query, diff --git a/simple_repository_browser/tests/test_search.py b/simple_repository_browser/tests/test_search.py index 7ddc323..81954ac 100644 --- a/simple_repository_browser/tests/test_search.py +++ b/simple_repository_browser/tests/test_search.py @@ -1,101 +1,92 @@ +from pathlib import Path +import sqlite3 +import tempfile + +import diskcache import parsley import pytest -from simple_repository_browser import _search +from simple_repository_browser import _search, model from simple_repository_browser._search import Filter, FilterOn @pytest.mark.parametrize( ["query", "expected_expression_graph"], [ - ("", ()), - pytest.param("some-name", (Filter(FilterOn.name_or_summary, "some-name"),)), + ("", None), + pytest.param("some-name", Filter(FilterOn.name_or_summary, "some-name")), pytest.param( "some name", - ( - _search.And( - Filter(FilterOn.name_or_summary, "some"), - Filter(FilterOn.name_or_summary, "name"), - ), + _search.And( + Filter(FilterOn.name_or_summary, "some"), + Filter(FilterOn.name_or_summary, "name"), ), ), - pytest.param("som*name", (Filter(FilterOn.name_or_summary, "som*name"),)), - pytest.param('"some name"', (Filter(FilterOn.name_or_summary, '"some name"'),)), - pytest.param('"some-name"', (Filter(FilterOn.name_or_summary, '"some-name"'),)), - pytest.param('"CASE"', (Filter(FilterOn.name_or_summary, '"CASE"'),)), - pytest.param("-foo", (_search.Not(Filter(FilterOn.name_or_summary, "foo")),)), + pytest.param("som*name", Filter(FilterOn.name_or_summary, "som*name")), + pytest.param('"some name"', Filter(FilterOn.name_or_summary, '"some name"')), + pytest.param('"some-name"', Filter(FilterOn.name_or_summary, '"some-name"')), + pytest.param('"CASE"', Filter(FilterOn.name_or_summary, '"CASE"')), + pytest.param("-foo", _search.Not(Filter(FilterOn.name_or_summary, "foo"))), pytest.param( - '-"foo bar"', (_search.Not(Filter(FilterOn.name_or_summary, '"foo bar"')),) + '-"foo bar"', _search.Not(Filter(FilterOn.name_or_summary, '"foo bar"')) ), pytest.param( - '-name:"foo bar"', (_search.Not(Filter(FilterOn.name, '"foo bar"')),) + '-name:"foo bar"', _search.Not(Filter(FilterOn.name, '"foo bar"')) ), - pytest.param("name:foo", (Filter(FilterOn.name, "foo"),)), + pytest.param("name:foo", Filter(FilterOn.name, "foo")), pytest.param( "name:foo OR name:bar", - ( - _search.Or( - Filter(FilterOn.name, "foo"), - Filter(FilterOn.name, "bar"), - ), + _search.Or( + Filter(FilterOn.name, "foo"), + Filter(FilterOn.name, "bar"), ), ), pytest.param( 'name:foo AND "fiddle AND sticks"', - ( - _search.And( - Filter(FilterOn.name, "foo"), - Filter(FilterOn.name_or_summary, '"fiddle AND sticks"'), - ), + _search.And( + Filter(FilterOn.name, "foo"), + Filter(FilterOn.name_or_summary, '"fiddle AND sticks"'), ), ), - pytest.param("summary:foo", (Filter(FilterOn.summary, "foo"),)), + pytest.param("summary:foo", Filter(FilterOn.summary, "foo")), pytest.param( 'name:"NAME OR" AND "fiddle AND sticks"', - ( - _search.And( - Filter(FilterOn.name, '"NAME OR"'), - Filter(FilterOn.name_or_summary, '"fiddle AND sticks"'), - ), + _search.And( + Filter(FilterOn.name, '"NAME OR"'), + Filter(FilterOn.name_or_summary, '"fiddle AND sticks"'), ), ), - pytest.param("(((a)))", (Filter(FilterOn.name_or_summary, "a"),)), + pytest.param("(((a)))", Filter(FilterOn.name_or_summary, "a")), pytest.param( "(((a) OR (b)))", - ( - _search.Or( - Filter(FilterOn.name_or_summary, "a"), - Filter(FilterOn.name_or_summary, "b"), - ), + _search.Or( + Filter(FilterOn.name_or_summary, "a"), + Filter(FilterOn.name_or_summary, "b"), ), ), pytest.param( "(a AND b) OR (c AND d)", - ( - _search.Or( - _search.And( - Filter(FilterOn.name_or_summary, "a"), - Filter(FilterOn.name_or_summary, "b"), - ), - _search.And( - Filter(FilterOn.name_or_summary, "c"), - Filter(FilterOn.name_or_summary, "d"), - ), + _search.Or( + _search.And( + Filter(FilterOn.name_or_summary, "a"), + Filter(FilterOn.name_or_summary, "b"), + ), + _search.And( + Filter(FilterOn.name_or_summary, "c"), + Filter(FilterOn.name_or_summary, "d"), ), ), ), pytest.param( "((a AND b)) OR (c AND -d)", - ( - _search.Or( - _search.And( - Filter(FilterOn.name_or_summary, "a"), - Filter(FilterOn.name_or_summary, "b"), - ), - _search.And( - Filter(FilterOn.name_or_summary, "c"), - _search.Not(Filter(FilterOn.name_or_summary, "d")), - ), + _search.Or( + _search.And( + Filter(FilterOn.name_or_summary, "a"), + Filter(FilterOn.name_or_summary, "b"), + ), + _search.And( + Filter(FilterOn.name_or_summary, "c"), + _search.Not(Filter(FilterOn.name_or_summary, "d")), ), ), ), @@ -106,26 +97,6 @@ def test_parse_query(query, expected_expression_graph): assert result == expected_expression_graph -@pytest.mark.parametrize( - ["query", "expected_result"], - [ - ("", None), - ("name:foo", "foo"), - ("name:foo__unnormed", "foo-unnormed"), - ("foo", "foo"), - ("some*.Name", None), - ('summary:"Some Description"', None), - ("foo bar", None), - ("foo OR bar", None), - ("-name:foo OR -bar", None), - ], -) -def test_simple_name_proposal(query, expected_result): - terms = _search.parse(query) - result = _search.simple_name_from_query(terms) - assert result == expected_result - - @pytest.mark.parametrize( ["query", "expected_predicate"], [ @@ -138,7 +109,14 @@ def test_simple_name_proposal(query, expected_result): "some*.Name", ( "(canonical_name LIKE ? OR summary LIKE ?)", - ("%some%-name%", "%some%.Name%"), + ("some%-name", "%some%.Name%"), + ), + ), + ( + "some*.Name*", + ( + "(canonical_name LIKE ? OR summary LIKE ?)", + ("some%-name%", "%some%.Name%%"), ), ), ('summary:"Some Description"', ("summary LIKE ?", ("%Some Description%",))), @@ -159,7 +137,7 @@ def test_simple_name_proposal(query, expected_result): ( "-name:foo OR -bar", ( - "(Not (canonical_name LIKE ? OR (Not (canonical_name LIKE ? OR summary LIKE ?))))", + "(NOT (canonical_name LIKE ? OR (NOT (canonical_name LIKE ? OR summary LIKE ?))))", ("%foo%", "%bar%", "%bar%"), ), ), @@ -170,7 +148,9 @@ def test_simple_name_proposal(query, expected_result): ], ) def test_build_sql_predicate(query, expected_predicate): - sql_stmt, params = _search.query_to_sql(query) + sql_builder = _search.query_to_sql(query) + sql_stmt = sql_builder.where_clause + params = sql_builder.where_params assert (sql_stmt, params) == expected_predicate assert sql_stmt == expected_predicate[0] assert params == expected_predicate[1] @@ -193,4 +173,282 @@ def test_build_sql_predicate(query, expected_predicate): def test_invalid_query(query, expected_exception): with expected_exception: result = _search.parse(query) - print("Result:", result) + + +class MockSimpleRepository: + """Mock repository for testing search functionality.""" + + def __init__(self, available_packages=None): + # Packages that exist in the repository but not in the database + self.available_packages = available_packages or set() + + async def get_project_page(self, name: str): + """Mock project page retrieval.""" + from simple_repository.errors import PackageNotFoundError + + if name in self.available_packages: + # Return a mock project detail for testing + from simple_repository.model import Meta, ProjectDetail + + return ProjectDetail(meta=Meta("1.0"), name=name, files=()) + raise PackageNotFoundError(f"Project {name} not found") + + +@pytest.fixture +def test_database(tmp_path: Path): + """Create a temporary SQLite database with test data for search ordering.""" + db_path = tmp_path / "test.db" + con = sqlite3.connect(db_path) + con.row_factory = sqlite3.Row + + # Create projects table matching the real schema + con.execute(""" + CREATE TABLE projects ( + canonical_name TEXT PRIMARY KEY, + summary TEXT, + release_version TEXT, + release_date TEXT + ) + """) + + # Insert test data designed for ordering tests + test_projects = [ + # numpy family - for testing exact name closeness + ("numpy", "Fundamental package for array computing", "1.24.0", "2023-01-01"), + ("numpy-image", "Image processing with numpy", "0.1.0", "2023-02-01"), + ("xnumpy", "Extended numpy functionality", "0.2.0", "2023-03-01"), + ("amazeballs-numpy", "Extended numpy functionality", "0.2.0", "2023-03-01"), + ("anumpyb", "Extended numpy functionality", "0.2.0", "2023-03-01"), + ("numpyish", "Numpy-like functionality", "0.1.0", "2023-04-01"), + ("abc", "Not at all like numpy", "0.1.0", "2023-04-01"), + # scipy family - for testing exact name closeness + ("scipy", "Scientific computing library", "1.10.0", "2023-01-15"), + ("scipy2", "Alternative scipy implementation", "0.5.0", "2023-02-15"), + ("scipylab", "Scipy laboratory", "0.3.0", "2023-03-15"), + # scikit family - for testing fuzzy pattern matching + ("scikit-amazeballs", "The bee's knees of scikits", "1.2.0", "2023-01-20"), + ("scikit-learn", "Machine learning library", "1.2.0", "2023-01-20"), + ("scikit-image", "Image processing library", "0.20.0", "2023-02-20"), + ("scikit-optimize", "Optimisation library", "0.9.0", "2023-03-20"), + # Other packages + ("pandas", "Data manipulation library", "2.0.0", "2023-03-01"), + ("matplotlib", "Plotting library", "3.7.0", "2023-01-10"), + ("requests", "HTTP library", "2.28.0", "2022-12-01"), + ] + + for name, summary, version, date in test_projects: + con.execute( + "INSERT INTO projects (canonical_name, summary, release_version, release_date) VALUES (?, ?, ?, ?)", + (name, summary, version, date), + ) + + con.commit() + yield con + # Cleanup + con.close() + + +@pytest.fixture +def test_model(test_database): + """Create a model instance with test database.""" + # Create temporary cache directory + cache_dir = tempfile.mkdtemp() + cache = diskcache.Cache(cache_dir) + + # Create model with mock repository + test_model = model.Model( + source=MockSimpleRepository(), + projects_db=test_database, + cache=cache, + crawler=None, # Not needed for search tests + ) + yield test_model + cache.close() + + +def assert_order(expected_names, actual_results): + """Helper to assert that results appear in expected order.""" + actual_names = [item.canonical_name for item in actual_results] + + # Check that all expected names are present + for name in expected_names: + assert name in actual_names, ( + f"Expected '{name}' not found in results: {actual_names}" + ) + + # Check relative ordering + indices = {name: actual_names.index(name) for name in expected_names} + for i in range(len(expected_names) - 1): + current_name = expected_names[i] + next_name = expected_names[i + 1] + assert indices[current_name] < indices[next_name], ( + f"Expected '{current_name}' to come before '{next_name}' in {actual_names}" + ) + + +@pytest.mark.asyncio +async def test_exact_name_search_ordering(test_model): + """Test that exact name searches return results in closeness order.""" + result = await test_model.project_query("numpy", page_size=10, page=1) + + # numpy should come first (exact match), then prefix matches, then suffix matches + assert_order(["numpy", "numpy-image", "xnumpy", "abc"], result["results"]) + + +@pytest.mark.asyncio +async def test_exact_name_search_scipy_ordering(test_model): + """Test exact name search ordering with scipy family.""" + result = await test_model.project_query("scipy", page_size=10, page=1) + + # scipy should come first, then prefix matches (scipy2, scipylab) + assert_order(["scipy", "scipy2"], result["results"]) + + +@pytest.mark.asyncio +async def test_fuzzy_pattern_search_ordering(test_model): + """Test that fuzzy pattern searches work correctly.""" + result = await test_model.project_query("scikit-*", page_size=10, page=1) + # Should include all scikit-* packages, ordered by shortest, then alphabetically + assert_order( + ["scikit-image", "scikit-learn", "scikit-optimize", "scikit-amazeballs"], + result["results"], + ) + + +@pytest.mark.asyncio +async def test_fuzzy_pattern_search_ordering_not_matching_prefix(test_model): + result = await test_model.project_query("name:numpy*", page_size=10, page=1) + names = [item.canonical_name for item in result["results"]] + assert "xnumpy" not in names + assert "anumpyb" not in names + assert "abc" not in names + assert_order(["numpy", "numpy-image"], result["results"]) + + +@pytest.mark.asyncio +async def test_mixed_search_ordering_scipy_or_scikit(test_model): + """Test mixed search: 'scipy OR scikit-*' - scipy first, then scikit patterns.""" + result = await test_model.project_query("scipy OR scikit-*", page_size=10, page=1) + # Should get the exact match, then all scikits, then similar to exact match (scipy2) + assert_order(["scipy", "scikit-amazeballs", "scipy2"], result["results"]) + + +@pytest.mark.asyncio +async def test_quoted_exact_search(test_model): + """Test quoted exact searches.""" + result = await test_model.project_query('"numpy"', page_size=10, page=1) + names = [item.canonical_name for item in result["results"]] + # Should return numpy first (exact match) + assert names[0] == "numpy" + assert_order(["numpy", "numpy-image", "xnumpy", "anumpyb"], result["results"]) + + +@pytest.mark.asyncio +async def test_fuzzy_search(test_model): + """Test name field-specific searches.""" + result = await test_model.project_query("num*-*", page_size=10, page=1) + names = [item.canonical_name for item in result["results"]] + assert "numpy" not in names + # We should also find numpyish because of its summary containing "Numpy-like" + assert_order(["numpy-image", "numpyish"], result["results"]) + + +@pytest.mark.asyncio +async def test_name_field_specific_search(test_model): + """Test name field-specific searches.""" + result = await test_model.project_query("name:numpy", page_size=10, page=1) + assert_order(["numpy", "numpy-image", "xnumpy"], result["results"]) + + +@pytest.mark.asyncio +async def test_summary_field_specific_search(test_model): + """Test summary field-specific searches.""" + result = await test_model.project_query("summary:computing", page_size=10, page=1) + assert_order(["numpy", "scipy"], result["results"]) + + +@pytest.mark.asyncio +async def test_not_operator(test_model): + """Test NOT operator functionality.""" + result = await test_model.project_query("-scipy", page_size=10, page=1) + + names = [item.canonical_name for item in result["results"]] + # Should not include scipy + assert "scipy" not in names + assert_order(["numpy", "pandas"], result["results"]) + + +@pytest.mark.asyncio +async def test_complex_mixed_query(test_model): + """Test complex mixed query with multiple exact names.""" + result = await test_model.project_query("numpy OR scipy", page_size=10, page=1) + + assert_order(["numpy", "scipy", "numpy-image", "scipy2"], result["results"]) + + +@pytest.mark.asyncio +async def test_mixed_query_suffix_ordering(test_model): + """Test that suffix matches (xnumpy) come after fuzzy patterns in mixed queries.""" + result = await test_model.project_query("numpy OR scikit-*", page_size=10, page=1) + + assert_order( + [ + "numpy", + "scikit-learn", + "numpy-image", + "xnumpy", + "amazeballs-numpy", + "anumpyb", + ], + result["results"], + ) + + +@pytest.mark.asyncio +async def test_empty_results(test_model): + """Test queries that return no results.""" + result = await test_model.project_query("nonexistentpackage", page_size=10, page=1) + + assert result["results"] == [] + assert result["results_count"] == 0 + + +@pytest.mark.asyncio +async def test_summary_search(test_model): + """Test queries that return no results.""" + result = await test_model.project_query("summary:numpy", page_size=1000, page=1) + names = [item.canonical_name for item in result["results"]] + assert "scipy" not in names + assert "numpy" not in names + assert "abc" in names + + +@pytest.mark.asyncio +async def test_injected_results_when_not_in_db(test_model): + assert isinstance(test_model.source, MockSimpleRepository) + test_model.source.available_packages = [ + "jingo", + "wibble", + "wibbleof", + "bongo", + "bongo-bong", + ] + result = await test_model.project_query( + "jingo OR wibble* OR boNgo OR bongo_Bong OR numpy OR totallyMissing", + page_size=10, + page=1, + ) + # result = await test_model.project_query("jingo OR numpy", page_size=10, page=1) + + names = [item.canonical_name for item in result["results"]] + + assert "wibble" not in names + assert_order( + ["jingo", "bongo", "bongo-bong", "numpy"], + result["results"], + ) + # Check that the one result that came from the db actually has the summary. + [numpy_result] = [ + item for item in result["results"] if item.canonical_name == "numpy" + ] + assert "Fundamental" in numpy_result.summary