diff --git a/service/db/base.py b/service/db/base.py index 652be61..81118ec 100644 --- a/service/db/base.py +++ b/service/db/base.py @@ -83,6 +83,19 @@ async def add_store(self, store: Store) -> int: """ pass + @abstractmethod + async def add_many_stores(self, stores: list[Store]) -> dict[str, int]: + """ + Add multiple stores in a batch operation. + + Args: + stores: List of Store objects to add or update. + + Returns: + Dictionary mapping store codes to their database IDs. + """ + pass + @abstractmethod async def update_store( self, @@ -191,6 +204,19 @@ async def add_ean(self, ean: str) -> int: """ pass + @abstractmethod + async def add_many_eans(self, eans: list[str]) -> dict[str, int]: + """ + Add multiple empty products with only EAN codes in a batch operation. + + Args: + eans: List of EAN codes to add. + + Returns: + Dictionary mapping EAN codes to their database IDs. + """ + pass + @abstractmethod async def get_products_by_ean(self, ean: list[str]) -> list[ProductWithId]: """ diff --git a/service/db/import.py b/service/db/import.py index a67c341..cd3fd9e 100644 --- a/service/db/import.py +++ b/service/db/import.py @@ -2,360 +2,14 @@ import argparse import asyncio import logging -import zipfile -from csv import DictReader -from datetime import date, datetime -from decimal import Decimal from pathlib import Path -from tempfile import TemporaryDirectory -from time import time -from typing import Any, Dict, List from service.config import settings -from service.db.models import Chain, ChainProduct, Price, Store -from service.db.stats import compute_stats - -logger = logging.getLogger("importer") +from service.db.importer.archive_handler import import_archive, import_directory db = settings.get_db() -async def read_csv(file_path: Path) -> List[Dict[str, str]]: - """ - Read a CSV file and return a list of dictionaries. - - Args: - file_path: Path to the CSV file. - - Returns: - List of dictionaries where each dictionary represents a row in the CSV. - """ - try: - with open(file_path, "r", encoding="utf-8") as f: - reader = DictReader(f) # type: ignore - return [row for row in reader] - except Exception as e: - logger.error(f"Error reading {file_path}: {e}") - return [] - - -async def process_stores(stores_path: Path, chain_id: int) -> dict[str, int]: - """ - Process stores CSV and import to database. - - Args: - stores_path: Path to the stores CSV file. - chain_id: ID of the chain to which these stores belong. - - Returns: - A dictionary mapping store codes to their database IDs. - """ - logger.debug(f"Importing stores from {stores_path}") - - stores_data = await read_csv(stores_path) - store_map = {} - - for store_row in stores_data: - store = Store( - chain_id=chain_id, - code=store_row["store_id"], - type=store_row.get("type"), - address=store_row.get("address"), - city=store_row.get("city"), - zipcode=store_row.get("zipcode"), - ) - - store_id = await db.add_store(store) - store_map[store.code] = store_id - - logger.debug(f"Processed {len(stores_data)} stores") - return store_map - - -async def process_products( - products_path: Path, - chain_id: int, - chain_code: str, - barcodes: dict[str, int], -) -> Dict[str, int]: - """ - Process products CSV and import to database. - - As a side effect, this function will also add any newly - created EAN codes to the provided `barcodes` dictionary. - - Args: - products_path: Path to the products CSV file. - chain_id: ID of the chain to which these products belong. - chain_code: Code of the retail chain. - barcodes: Dictionary mapping EAN codes to global product IDs. - - Returns: - A dictionary mapping product codes to their database IDs for the chain. - """ - logger.debug(f"Processing products from {products_path}") - - products_data = await read_csv(products_path) - chain_product_map = await db.get_chain_product_map(chain_id) - - # Ideally the CSV would already have valid barcodes, but some older - # archives contain invalid ones so we need to clean them up. - def clean_barcode(data: dict[str, Any]) -> dict: - barcode = data.get("barcode", "").strip() - - if ":" in barcode: - return data - - if len(barcode) >= 8 and barcode.isdigit(): - return data - - product_id = data.get("product_id", "") - if not product_id: - logger.warning(f"Product has no barcode: {data}") - return data - - # Construct a chain-specific barcode - data["barcode"] = f"{chain_code}:{product_id}" - return data - - new_products = [ - clean_barcode(p) - for p in products_data - if p["product_id"] not in chain_product_map - ] - - if not new_products: - return chain_product_map - - logger.debug( - f"Found {len(new_products)} new products out of {len(products_data)} total" - ) - - n_new_barcodes = 0 - for product in new_products: - barcode = product["barcode"] - if barcode in barcodes: - continue - - global_product_id = await db.add_ean(barcode) - barcodes[barcode] = global_product_id - n_new_barcodes += 1 - - if n_new_barcodes: - logger.debug(f"Added {n_new_barcodes} new barcodes to global products") - - products_to_create = [] - for product in new_products: - barcode = product["barcode"] - code = product["product_id"] - global_product_id = barcodes[barcode] - - products_to_create.append( - ChainProduct( - chain_id=chain_id, - product_id=global_product_id, - code=code, - name=product["name"], - brand=(product["brand"] or "").strip() or None, - category=(product["category"] or "").strip() or None, - unit=(product["unit"] or "").strip() or None, - quantity=(product["quantity"] or "").strip() or None, - ) - ) - - n_inserts = await db.add_many_chain_products(products_to_create) - if n_inserts != len(new_products): - logger.warning( - f"Expected to insert {len(new_products)} products, but inserted {n_inserts}." - ) - logger.debug(f"Imported {len(new_products)} new chain products") - - chain_product_map = await db.get_chain_product_map(chain_id) - return chain_product_map - - -async def process_prices( - price_date: date, - prices_path: Path, - chain_id: int, - store_map: dict[str, int], - chain_product_map: dict[str, int], -) -> int: - """ - Process prices CSV and import to database. - - Args: - price_date: The date for which the prices are valid. - prices_path: Path to the prices CSV file. - chain_id: ID of the chain to which these prices belong. - store_map: Dictionary mapping store codes to their database IDs. - chain_product_map: Dictionary mapping product codes to their database IDs. - - Returns: - The number of prices successfully inserted into the database. - """ - logger.debug(f"Reading prices from {prices_path}") - - prices_data = await read_csv(prices_path) - - # Create price objects - prices_to_create = [] - - logger.debug(f"Found {len(prices_data)} price entries, preparing to import") - - def clean_price(value: str) -> Decimal | None: - if value is None: - return None - value = value.strip() - if value == "": - return None - dval = Decimal(value) - if dval == 0: - return None - return dval - - for price_row in prices_data: - store_id = store_map[price_row["store_id"]] - product_id = chain_product_map.get(price_row["product_id"]) - if product_id is None: - # Price for a product that wasn't added, perhaps because the - # barcode is invalid - logger.warning( - f"Skipping price for unknown product {price_row['product_id']}" - ) - continue - - prices_to_create.append( - Price( - chain_product_id=product_id, - store_id=store_id, - price_date=price_date, - regular_price=Decimal(price_row["price"]), - special_price=clean_price(price_row.get("special_price") or ""), - unit_price=clean_price(price_row["unit_price"]), - best_price_30=clean_price(price_row["best_price_30"]), - anchor_price=clean_price(price_row["anchor_price"]), - ) - ) - - logger.debug(f"Importing {len(prices_to_create)} prices") - n_inserted = await db.add_many_prices(prices_to_create) - return n_inserted - - -async def process_chain( - price_date: date, - chain_dir: Path, - barcodes: dict[str, int], -) -> None: - """ - Process a single retail chain and import its data. - - The expected directory structure and CSV columns are documented in - `crawler/store/archive_info.txt`. - - Note: updates the `barcodes` dictionary with any new EAN codes found - (see the `process_products` function). - - Args: - price_date: The date for which the prices are valid. - chain_dir: Path to the directory containing the chain's CSV files. - barcodes: Dictionary mapping EAN codes to global product IDs. - - """ - code = chain_dir.name - - stores_path = chain_dir / "stores.csv" - if not stores_path.exists(): - logger.warning(f"No stores.csv found for chain {code}") - return - - products_path = chain_dir / "products.csv" - if not products_path.exists(): - logger.warning(f"No products.csv found for chain {code}") - return - - prices_path = chain_dir / "prices.csv" - if not prices_path.exists(): - logger.warning(f"No prices.csv found for chain {code}") - return - - logger.debug(f"Processing chain: {code}") - - chain = Chain(code=code) - chain_id = await db.add_chain(chain) - - store_map = await process_stores(stores_path, chain_id) - chain_product_map = await process_products(products_path, chain_id, code, barcodes) - - n_new_prices = await process_prices( - price_date, - prices_path, - chain_id, - store_map, - chain_product_map, - ) - - logger.info(f"Imported {n_new_prices} new prices for {code}") - - -async def import_archive(path: Path, compute_stats_flag: bool = True): - """Import data from all chain directories in the given zip archive.""" - try: - price_date = datetime.strptime(path.stem, "%Y-%m-%d") - except ValueError: - logger.error(f"`{path.stem}` is not a valid date in YYYY-MM-DD format") - return - - with TemporaryDirectory() as temp_dir: # type: ignore - logger.debug(f"Extracting archive {path} to {temp_dir}") - with zipfile.ZipFile(path, "r") as zip_ref: - zip_ref.extractall(temp_dir) - await _import(Path(temp_dir), price_date, compute_stats_flag) - - -async def import_directory(path: Path, compute_stats_flag: bool = True) -> None: - """Import data from all chain directories in the given directory.""" - if not path.is_dir(): - logger.error(f"`{path}` does not exist or is not a directory") - return - - try: - price_date = datetime.strptime(path.name, "%Y-%m-%d") - except ValueError: - logger.error( - f"Directory `{path.name}` is not a valid date in YYYY-MM-DD format" - ) - return - - await _import(path, price_date, compute_stats_flag) - - -async def _import( - path: Path, price_date: datetime, compute_stats_flag: bool = True -) -> None: - chain_dirs = [d.resolve() for d in path.iterdir() if d.is_dir()] - if not chain_dirs: - logger.warning(f"No chain directories found in {path}") - return - - logger.debug(f"Importing {len(chain_dirs)} chains from {path}") - - t0 = time() - - barcodes = await db.get_product_barcodes() - for chain_dir in chain_dirs: - await process_chain(price_date, chain_dir, barcodes) - - dt = int(time() - t0) - logger.info(f"Imported {len(chain_dirs)} chains in {dt} seconds") - - if compute_stats_flag: - await compute_stats(price_date) - else: - logger.debug(f"Skipping statistics computation for {price_date:%Y-%m-%d}") - - async def main(): """ Import price data from directories or zip archives. @@ -414,7 +68,9 @@ async def main(): elif path.suffix.lower() == ".zip": await import_archive(path, compute_stats_flag) else: - logger.error(f"Path `{path}` is neither a directory nor a zip archive.") + logging.error( + f"Path `{path}` is neither a directory nor a zip archive." + ) finally: await db.close() diff --git a/service/db/importer/__init__.py b/service/db/importer/__init__.py new file mode 100644 index 0000000..b7d3707 --- /dev/null +++ b/service/db/importer/__init__.py @@ -0,0 +1,32 @@ +""" +Import package for retail chain price data. + +This package provides functionality to import price data from CSV files +organized in directories or zip archives. The data is processed and stored +in the database according to the schema defined in service.db.models. + +Main entry points: +- cli.main(): Command-line interface for importing data +- archive_handler.import_archive(): Import from zip archive +- archive_handler.import_directory(): Import from directory + +The package is organized into the following modules: +- cli: Command-line interface +- archive_handler: Archive and directory processing +- chain_importer: Chain-level import logic +- processors: Data processing functions (stores, products, prices) +- csv_reader: CSV file reading utilities +""" + +from .archive_handler import import_archive, import_directory +from .processors import process_stores, process_products, process_prices +from .csv_reader import read_csv + +__all__ = [ + "import_archive", + "import_directory", + "process_stores", + "process_products", + "process_prices", + "read_csv", +] diff --git a/service/db/importer/archive_handler.py b/service/db/importer/archive_handler.py new file mode 100644 index 0000000..0333ef4 --- /dev/null +++ b/service/db/importer/archive_handler.py @@ -0,0 +1,150 @@ +import asyncio +import logging +import zipfile +from datetime import datetime +from pathlib import Path +from tempfile import TemporaryDirectory +from time import time + +from service.config import settings +from service.db.stats import compute_stats +from .chain_importer import process_chain_products_only, process_chain_stores_and_prices + +logger = logging.getLogger("importer.archive_handler") + +db = settings.get_db() + + +def parse_date_from_path(path: Path) -> datetime: + """ + Parse date from path name in YYYY-MM-DD format. + + Args: + path: Path with date in name. + + Returns: + Parsed datetime object. + + Raises: + ValueError: If date format is invalid. + """ + date_str = path.stem if path.is_file() else path.name + return datetime.strptime(date_str, "%Y-%m-%d") + + +async def import_archive(path: Path, compute_stats_flag: bool = True) -> None: + """Import data from all chain directories in the given zip archive.""" + try: + price_date = parse_date_from_path(path) + except ValueError: + logger.error(f"`{path.stem}` is not a valid date in YYYY-MM-DD format") + return + + try: + with TemporaryDirectory() as temp_dir: # type: ignore + logger.debug(f"Extracting archive {path} to {temp_dir}") + with zipfile.ZipFile(path, "r") as zip_ref: + zip_ref.extractall(temp_dir) + await _import(Path(temp_dir), price_date, compute_stats_flag) + except zipfile.BadZipFile: + logger.error(f"Invalid or corrupted zip file: {path}") + return + except FileNotFoundError: + logger.error(f"Archive file not found: {path}") + return + except PermissionError: + logger.error(f"Permission denied accessing archive: {path}") + return + except Exception as e: + logger.error(f"Unexpected error processing archive {path}: {e}") + return + + +async def import_directory(path: Path, compute_stats_flag: bool = True) -> None: + """Import data from all chain directories in the given directory.""" + if not path.is_dir(): + logger.error(f"`{path}` does not exist or is not a directory") + return + + try: + price_date = parse_date_from_path(path) + except ValueError: + logger.error( + f"Directory `{path.name}` is not a valid date in YYYY-MM-DD format" + ) + return + + await _import(path, price_date, compute_stats_flag) + + +async def _import( + path: Path, price_date: datetime, compute_stats_flag: bool = True +) -> None: + """ + Import data from chain directories in the given path. + + Args: + path: Path containing chain directories. + price_date: Date for which the prices are valid. + compute_stats_flag: Whether to compute statistics after import. + """ + chain_dirs = [d for d in path.iterdir() if d.is_dir()] + if not chain_dirs: + logger.warning(f"No chain directories found in {path}") + return + + logger.debug(f"Importing {len(chain_dirs)} chains from {path}") + + t0 = time() + + barcodes = await db.get_product_barcodes() + + # Phase 1: Sequential EAN processing to avoid deadlocks + logger.debug("Phase 1: Processing EAN codes sequentially") + await _process_eans_sequentially(chain_dirs, price_date, barcodes) + + # Phase 2: Parallel processing of stores and prices + logger.debug("Phase 2: Processing stores and prices in parallel") + await _process_stores_and_prices_parallel(chain_dirs, price_date, barcodes) + + dt = int(time() - t0) + logger.info(f"Imported {len(chain_dirs)} chains in {dt} seconds") + + if compute_stats_flag: + await compute_stats(price_date) + else: + logger.debug(f"Skipping statistics computation for {price_date:%Y-%m-%d}") + + +async def _process_eans_sequentially( + chain_dirs: list[Path], price_date: datetime, barcodes: dict[str, int] +) -> None: + """ + Process EAN codes sequentially to avoid database deadlocks. + + Args: + chain_dirs: List of chain directories to process. + price_date: Date for which the prices are valid. + barcodes: Dictionary of existing EAN codes and their product IDs. + """ + for chain_dir in chain_dirs: + await process_chain_products_only(price_date, chain_dir, barcodes) + + +async def _process_stores_and_prices_parallel( + chain_dirs: list[Path], price_date: datetime, barcodes: dict[str, int] +) -> None: + """ + Process stores and prices in parallel since they don't share resources. + + Args: + chain_dirs: List of chain directories to process. + price_date: Date for which the prices are valid. + barcodes: Dictionary of existing EAN codes and their product IDs. + """ + tasks = [] + for chain_dir in chain_dirs: + task = process_chain_stores_and_prices(price_date, chain_dir, barcodes) + tasks.append(task) + + await asyncio.gather(*tasks) diff --git a/service/db/importer/chain_importer.py b/service/db/importer/chain_importer.py new file mode 100644 index 0000000..c0229c1 --- /dev/null +++ b/service/db/importer/chain_importer.py @@ -0,0 +1,136 @@ +import logging +from datetime import date +from pathlib import Path +from typing import Dict, NamedTuple + +from service.config import settings +from service.db.models import Chain +from .processors import process_stores, process_products, process_prices + +logger = logging.getLogger("importer.chain_importer") + +db = settings.get_db() + + +class ChainFiles(NamedTuple): + """Container for chain CSV file paths.""" + + stores: Path + products: Path + prices: Path + + +def build_chain_csv_file_paths(chain_dir: Path) -> ChainFiles: + """ + Build the CSV file paths for a chain directory. + + Args: + chain_dir: Path to the chain directory. + + Returns: + ChainFiles containing the constructed paths to the CSV files. + """ + return ChainFiles( + stores=chain_dir / "stores.csv", + products=chain_dir / "products.csv", + prices=chain_dir / "prices.csv", + ) + + +def get_chain_files_if_all_exist(chain_dir: Path) -> ChainFiles | None: + """ + Get chain CSV file paths if all required files exist in the directory. + + Args: + chain_dir: Path to the chain directory. + + Returns: + ChainFiles if all required files exist, None otherwise. + """ + code = chain_dir.name + files = build_chain_csv_file_paths(chain_dir) + + for file_path in files: + if not file_path.exists(): + logger.warning(f"No {file_path.name} found for chain {code}") + return None + + return files + + +async def register_chain_in_database(chain_dir: Path) -> int: + """ + Register a chain in the database and return its ID. + + Args: + chain_dir: Path to the chain directory. + + Returns: + The database ID of the registered chain. + """ + code = chain_dir.name + chain = Chain(code=code) + return await db.add_chain(chain) + + +async def process_chain_products_only( + price_date: date, + chain_dir: Path, + barcodes: Dict[str, int], +) -> None: + """ + Process only the products/EAN codes for a chain to avoid deadlocks. + + Args: + price_date: The date for which the prices are valid. + chain_dir: Path to the directory containing the chain's CSV files. + barcodes: Dictionary mapping EAN codes to global product IDs. + """ + files = get_chain_files_if_all_exist(chain_dir) + if files is None: + return + + code = chain_dir.name + logger.debug(f"Processing products for chain: {code}") + + chain_id = await register_chain_in_database(chain_dir) + + # Only process products to add EAN codes sequentially + await process_products(files.products, chain_id, code, barcodes) + + +async def process_chain_stores_and_prices( + price_date: date, + chain_dir: Path, + barcodes: Dict[str, int], +) -> None: + """ + Process stores and prices for a chain (EAN codes should already be processed). + + Args: + price_date: The date for which the prices are valid. + chain_dir: Path to the directory containing the chain's CSV files. + barcodes: Dictionary mapping EAN codes to global product IDs. + """ + files = get_chain_files_if_all_exist(chain_dir) + if files is None: + return + + code = chain_dir.name + logger.debug(f"Processing stores and prices for chain: {code}") + + chain_id = await register_chain_in_database(chain_dir) + + # Process stores and prices (products should already be processed) + store_map = await process_stores(files.stores, chain_id) + chain_product_map = await db.get_chain_product_map(chain_id) + + n_new_prices = await process_prices( + price_date, + files.prices, + chain_id, + store_map, + chain_product_map, + ) + + logger.info(f"Imported {n_new_prices} new prices for {code}") diff --git a/service/db/importer/csv_reader.py b/service/db/importer/csv_reader.py new file mode 100644 index 0000000..a190f67 --- /dev/null +++ b/service/db/importer/csv_reader.py @@ -0,0 +1,35 @@ +import logging +from csv import DictReader +from pathlib import Path +from typing import Dict, List + +logger = logging.getLogger("importer.csv_reader") + + +async def read_csv(file_path: Path) -> List[Dict[str, str]]: + """ + Read a CSV file and return a list of dictionaries. + + Args: + file_path: Path to the CSV file. + + Returns: + List of dictionaries where each dictionary represents a row in the CSV. + Returns empty list if file cannot be read. + """ + try: + with open(file_path, "r", encoding="utf-8") as f: + reader = DictReader(f) # type: ignore + return [row for row in reader] + except FileNotFoundError: + logger.error(f"CSV file not found: {file_path}") + return [] + except PermissionError: + logger.error(f"Permission denied reading CSV file: {file_path}") + return [] + except UnicodeDecodeError as e: + logger.error(f"Encoding error reading CSV file {file_path}: {e}") + return [] + except Exception as e: + logger.error(f"Unexpected error reading CSV file {file_path}: {e}") + return [] diff --git a/service/db/importer/processors.py b/service/db/importer/processors.py new file mode 100644 index 0000000..0499370 --- /dev/null +++ b/service/db/importer/processors.py @@ -0,0 +1,302 @@ +import logging +from datetime import date +from pathlib import Path +from typing import Any, Dict + +from service.config import settings +from service.db.models import ChainProduct, Store +from .csv_reader import read_csv + +logger = logging.getLogger("importer.processors") + +db = settings.get_db() + + +async def process_stores(stores_path: Path, chain_id: int) -> Dict[str, int]: + """ + Process stores CSV and import to database. + + Args: + stores_path: Path to the stores CSV file. + chain_id: ID of the chain to which these stores belong. + + Returns: + A dictionary mapping store codes to their database IDs. + """ + logger.debug(f"Importing stores from {stores_path}") + + stores_data = await read_csv(stores_path) + + # Prepare all stores for bulk insertion + stores_to_create = [] + for store_row in stores_data: + store = Store( + chain_id=chain_id, + code=store_row["store_id"], + type=store_row.get("type"), + address=store_row.get("address"), + city=store_row.get("city"), + zipcode=store_row.get("zipcode"), + ) + stores_to_create.append(store) + + # Insert all stores in bulk + store_map = await db.add_many_stores(stores_to_create) + + logger.debug(f"Processed {len(stores_data)} stores") + return store_map + + +def validate_and_fix_barcode(data: Dict[str, Any], chain_code: str) -> Dict[str, Any]: + """ + Validate barcode data and generate a chain-specific barcode if needed. + + Args: + data: Product data dictionary. + chain_code: Code of the retail chain. + + Returns: + Updated product data dictionary with valid barcode. + """ + barcode = data.get("barcode", "").strip() + + if ":" in barcode: + return data + + if len(barcode) >= 8 and barcode.isdigit(): + return data + + product_id = data.get("product_id", "") + if not product_id: + logger.warning(f"Product has no barcode: {data}") + return data + + # Construct a chain-specific barcode + data["barcode"] = f"{chain_code}:{product_id}" + return data + + +def _filter_new_products( + products_data: list[Dict[str, Any]], + chain_product_map: Dict[str, int], + chain_code: str, +) -> list[Dict[str, Any]]: + """ + Filter products that don't exist in the chain and clean their barcodes. + + Args: + products_data: Raw product data from CSV. + chain_product_map: Existing product codes mapped to their IDs. + chain_code: Code of the retail chain. + + Returns: + List of new products with cleaned barcodes. + """ + return [ + validate_and_fix_barcode(p, chain_code) + for p in products_data + if p["product_id"] not in chain_product_map + ] + + +def _extract_missing_barcodes( + products: list[Dict[str, Any]], existing_barcodes: Dict[str, int] +) -> list[str]: + """ + Extract barcodes from products that don't exist in the existing barcodes dictionary. + + Args: + products: List of product dictionaries. + existing_barcodes: Dictionary mapping barcodes to global product IDs. + + Returns: + List of missing barcodes that need to be registered. + """ + new_barcodes = [] + for product in products: + barcode = product["barcode"] + if barcode not in existing_barcodes: + new_barcodes.append(barcode) + return new_barcodes + + +async def _register_missing_barcodes_to_database( + new_products: list[Dict[str, Any]], barcodes_dict: Dict[str, int] +) -> None: + """ + Register missing barcodes to database and update the global barcodes dictionary. + + Args: + new_products: List of new product dictionaries. + barcodes_dict: Dictionary mapping barcodes to global product IDs (modified in place). + """ + new_barcodes = _extract_missing_barcodes(new_products, barcodes_dict) + + if new_barcodes: + new_barcode_ids = await db.add_many_eans(new_barcodes) + barcodes_dict.update(new_barcode_ids) + logger.debug(f"Added {len(new_barcodes)} new barcodes to global products") + + +def _sanitize_product_optional_fields(product: Dict[str, Any]) -> Dict[str, str | None]: + """ + Sanitize optional product fields by converting empty strings to None. + + Args: + product: Raw product dictionary from CSV. + + Returns: + Dictionary with sanitized optional field values. + """ + return { + "brand": (product["brand"] or "").strip() or None, + "category": (product["category"] or "").strip() or None, + "unit": (product["unit"] or "").strip() or None, + "quantity": (product["quantity"] or "").strip() or None, + } + + +def _create_chain_product_objects( + products: list[Dict[str, Any]], chain_id: int, barcodes_dict: Dict[str, int] +) -> list[ChainProduct]: + """ + Create ChainProduct objects from product dictionaries. + + Args: + products: List of product dictionaries. + chain_id: ID of the chain. + barcodes_dict: Dictionary mapping barcodes to global product IDs. + + Returns: + List of ChainProduct objects ready for database insertion. + """ + products_to_create = [] + for product in products: + barcode = product["barcode"] + code = product["product_id"] + global_product_id = barcodes_dict[barcode] + + validated_data = _sanitize_product_optional_fields(product) + + products_to_create.append( + ChainProduct( + chain_id=chain_id, + product_id=global_product_id, + code=code, + name=product["name"], + brand=validated_data["brand"], + category=validated_data["category"], + unit=validated_data["unit"], + quantity=validated_data["quantity"], + ) + ) + + return products_to_create + + +async def _insert_chain_products(products_to_create: list[ChainProduct]) -> int: + """ + Insert ChainProduct objects into the database with validation. + + Args: + products_to_create: List of ChainProduct objects to insert. + + Returns: + Number of products successfully inserted. + """ + if not products_to_create: + return 0 + + n_inserts = await db.add_many_chain_products(products_to_create) + if n_inserts != len(products_to_create): + logger.warning( + f"Expected to insert {len(products_to_create)} products, but inserted {n_inserts}." + ) + logger.debug(f"Imported {len(products_to_create)} new chain products") + return n_inserts + + +async def _fetch_updated_chain_product_map(chain_id: int) -> Dict[str, int]: + """ + Fetch the updated chain product map from the database. + + Args: + chain_id: ID of the chain. + + Returns: + Dictionary mapping product codes to their database IDs for the chain. + """ + return await db.get_chain_product_map(chain_id) + + +async def process_products( + products_path: Path, + chain_id: int, + chain_code: str, + barcodes: Dict[str, int], +) -> Dict[str, int]: + """ + Process products CSV and import to database. + + As a side effect, this function will also add any newly + created EAN codes to the provided `barcodes` dictionary. + + Args: + products_path: Path to the products CSV file. + chain_id: ID of the chain to which these products belong. + chain_code: Code of the retail chain. + barcodes: Dictionary mapping EAN codes to global product IDs. + + Returns: + A dictionary mapping product codes to their database IDs for the chain. + """ + logger.debug(f"Processing products from {products_path}") + + products_data = await read_csv(products_path) + chain_product_map = await db.get_chain_product_map(chain_id) + + new_products = _filter_new_products(products_data, chain_product_map, chain_code) + if not new_products: + return chain_product_map + + logger.debug( + f"Found {len(new_products)} new products out of {len(products_data)} total" + ) + + await _register_missing_barcodes_to_database(new_products, barcodes) + + products_to_create = _create_chain_product_objects(new_products, chain_id, barcodes) + await _insert_chain_products(products_to_create) + + return await _fetch_updated_chain_product_map(chain_id) + + +async def process_prices( + price_date: date, + prices_path: Path, + chain_id: int, + store_map: Dict[str, int], + chain_product_map: Dict[str, int], +) -> int: + """ + Process prices CSV and import to database using direct CSV streaming. + + Args: + price_date: The date for which the prices are valid. + prices_path: Path to the prices CSV file. + chain_id: ID of the chain to which these prices belong. + store_map: Dictionary mapping store codes to their database IDs. + chain_product_map: Dictionary mapping product codes to their database IDs. + + Returns: + The number of prices successfully inserted into the database. + """ + logger.debug(f"Processing prices directly from CSV: {prices_path}") + + # Use direct CSV streaming for optimal performance + n_inserted = await db.add_many_prices_direct_csv( # type: ignore[possibly-unbound-attribute] + prices_path, price_date, store_map, chain_product_map + ) + + logger.debug(f"Imported {n_inserted} prices using direct CSV streaming") + return n_inserted diff --git a/service/db/psql.py b/service/db/psql.py index d1dcb46..6e6dad6 100644 --- a/service/db/psql.py +++ b/service/db/psql.py @@ -8,6 +8,11 @@ ) import logging import os +import io +import re +from csv import DictReader +from pathlib import Path +from decimal import Decimal from datetime import date from .base import Database from .models import ( @@ -29,7 +34,7 @@ class PostgresDatabase(Database): """PostgreSQL implementation of the database interface using asyncpg.""" - def __init__(self, dsn: str, min_size: int = 10, max_size: int = 30): + def __init__(self, dsn: str, min_size: int = 20, max_size: int = 50): """Initialize the PostgreSQL database connection pool. Args: @@ -164,6 +169,83 @@ async def add_store(self, store: Store) -> int: store.zipcode or None, ) + async def add_many_stores(self, stores: list[Store]) -> dict[str, int]: + """ + Add multiple stores in a batch operation. + + Args: + stores: List of Store objects to add or update. + + Returns: + Dictionary mapping store codes to their database IDs. + """ + if not stores: + return {} + + async with self._atomic() as conn: + # Create temporary table for bulk insert + await conn.execute( + """ + CREATE TEMP TABLE temp_stores ( + chain_id INTEGER, + code VARCHAR(100), + type VARCHAR(100), + address VARCHAR(255), + city VARCHAR(100), + zipcode VARCHAR(20) + ) + """ + ) + + # Insert all stores into temporary table + await conn.copy_records_to_table( + "temp_stores", + records=[ + ( + store.chain_id, + store.code, + store.type, + store.address or None, + store.city or None, + store.zipcode or None, + ) + for store in stores + ], + ) + + # Perform bulk upsert and get all store IDs + await conn.execute( + """ + INSERT INTO stores (chain_id, code, type, address, city, zipcode) + SELECT chain_id, code, type, address, city, zipcode + FROM temp_stores + ON CONFLICT (chain_id, code) DO UPDATE SET + type = COALESCE(EXCLUDED.type, stores.type), + address = COALESCE(EXCLUDED.address, stores.address), + city = COALESCE(EXCLUDED.city, stores.city), + zipcode = COALESCE(EXCLUDED.zipcode, stores.zipcode) + """ + ) + + # Fetch all store IDs for the provided stores + rows = await conn.fetch( + """ + SELECT s.id, s.code + FROM stores s + JOIN temp_stores t ON s.chain_id = t.chain_id AND s.code = t.code + """ + ) + + # Clean up temporary table + await conn.execute("DROP TABLE temp_stores") + + # Build the result dictionary + result = {} + for row in rows: + result[row["code"]] = row["id"] + + return result + async def update_store( self, chain_id: int, @@ -302,6 +384,62 @@ async def add_ean(self, ean: str) -> int: ean, ) + async def add_many_eans(self, eans: list[str]) -> dict[str, int]: + """ + Add multiple empty products with only EAN codes in a batch operation. + + Args: + eans: List of EAN codes to add. + + Returns: + Dictionary mapping EAN codes to their database IDs. + """ + if not eans: + return {} + + async with self._atomic() as conn: + # Create temporary table for bulk insert + await conn.execute( + """ + CREATE TEMP TABLE temp_eans ( + ean VARCHAR(50) + ) + """ + ) + + # Insert all EAN codes into temporary table + await conn.copy_records_to_table( + "temp_eans", + records=[(ean,) for ean in eans], + ) + + # Insert new EAN codes (ignoring conflicts for existing ones) + await conn.execute( + """ + INSERT INTO products (ean) + SELECT ean FROM temp_eans + ON CONFLICT (ean) DO NOTHING + """ + ) + + # Fetch all product IDs for the requested EAN codes + rows = await conn.fetch( + """ + SELECT id, ean FROM products + WHERE ean IN (SELECT ean FROM temp_eans) + """ + ) + + # Clean up temporary table + await conn.execute("DROP TABLE temp_eans") + + # Build the result dictionary + result = {} + for row in rows: + result[row["ean"]] = row["id"] + + return result + async def get_products_by_ean(self, ean: list[str]) -> list[ProductWithId]: async with self._get_conn() as conn: rows = await conn.fetch( @@ -530,22 +668,152 @@ async def add_many_prices(self, prices: list[Price]) -> int: ) """ ) - await conn.copy_records_to_table( - "temp_prices", - records=( - ( - p.chain_product_id, - p.store_id, - p.price_date, - p.regular_price, - p.special_price, - p.unit_price, - p.best_price_30, - p.anchor_price, + # Generate CSV data for optimized bulk insert + csv_data = io.BytesIO() + for p in prices: + csv_line = ( + f"{p.chain_product_id},{p.store_id},{p.price_date}," + f"{p.regular_price or '\\N'},{p.special_price or '\\N'}," + f"{p.unit_price or '\\N'},{p.best_price_30 or '\\N'}," + f"{p.anchor_price or '\\N'}\n" + ) + csv_data.write(csv_line.encode("utf-8")) + + csv_data.seek(0) + await conn.copy_to_table( + "temp_prices", source=csv_data, format="csv", delimiter=",", null="\\N" + ) + result = await conn.execute( + """ + INSERT INTO prices( + chain_product_id, + store_id, + price_date, + regular_price, + special_price, + unit_price, + best_price_30, + anchor_price + ) + SELECT * from temp_prices + ON CONFLICT DO NOTHING + """ + ) + await conn.execute("DROP TABLE temp_prices") + + _, _, rowcount = result.split(" ") + rowcount = int(rowcount) + return rowcount + + def _clean_price(self, value: str) -> str: + """ + Clean and validate price value for CSV format using fast string validation. + + Args: + value: Price value as string. + + Returns: + Cleaned price as string or '\\N' for null. + """ + if not value or not value.strip(): + return "\\N" + + cleaned = value.strip() + + # Fast string-based validation without Decimal object creation + # Match valid price patterns: digits with optional decimal (1-2 places) + if re.match(r"^\d+(\.\d{1,2})?$", cleaned): + # Check for zero values without creating Decimal object + if cleaned in ("0", "0.0", "0.00"): + return "\\N" + return cleaned + + return "\\N" + + async def add_many_prices_direct_csv( + self, + csv_path: Path, + price_date: date, + store_map: dict[str, int], + chain_product_map: dict[str, int], + ) -> int: + """ + Add multiple prices directly from CSV file for optimal performance. + + Args: + csv_path: Path to the CSV file containing price data. + price_date: The date for which the prices are valid. + store_map: Dictionary mapping store codes to their database IDs. + chain_product_map: Dictionary mapping product codes to their database IDs. + + Returns: + The number of prices successfully inserted into the database. + """ + async with self._atomic() as conn: + await conn.execute( + """ + CREATE TEMP TABLE temp_prices ( + chain_product_id INTEGER, + store_id INTEGER, + price_date DATE, + regular_price DECIMAL(10, 2), + special_price DECIMAL(10, 2), + unit_price DECIMAL(10, 2), + best_price_30 DECIMAL(10, 2), + anchor_price DECIMAL(10, 2) + ) + """ + ) + + # Stream CSV data directly with transformations + csv_data = io.BytesIO() + skipped_count = 0 + + with open(csv_path, "r", encoding="utf-8") as f: + reader = DictReader(f) # type: ignore[no-matching-overload] + for row in reader: + store_id = store_map.get(row["store_id"]) + product_id = chain_product_map.get(row["product_id"]) + + if store_id is None or product_id is None: + skipped_count += 1 + self.logger.warning( + f"Skipped price row due to missing store/product ({store_id}/{product_id}) mappings" + ) + continue + + # Convert price directly like the old import (no validation) + try: + regular_price = str(Decimal(row["price"].strip())) + except (ValueError, TypeError, AttributeError): + skipped_count += 1 + self.logger.warning( + f"Skipped price row due to invalid price: {row['price']} - {row}" + ) + continue + + # Transform row data + csv_line = ( + f"{product_id},{store_id},{price_date}," + f"{regular_price}," + f"{self._clean_price(row.get('special_price', ''))}," + f"{self._clean_price(row.get('unit_price', ''))}," + f"{self._clean_price(row.get('best_price_30', ''))}," + f"{self._clean_price(row.get('anchor_price', ''))}\n" ) - for p in prices - ), + + csv_data.write(csv_line.encode("utf-8")) + + if skipped_count > 0: + self.logger.warning( + f"Skipped {skipped_count} price rows due to missing store/product ({store_id}/{product_id}) mappings" + ) + + csv_data.seek(0) + await conn.copy_to_table( + "temp_prices", source=csv_data, format="csv", delimiter=",", null="\\N" ) + result = await conn.execute( """ INSERT INTO prices( @@ -563,6 +831,7 @@ async def add_many_prices(self, prices: list[Price]) -> int: """ ) await conn.execute("DROP TABLE temp_prices") + _, _, rowcount = result.split(" ") rowcount = int(rowcount) return rowcount diff --git a/service/db/psql.sql b/service/db/psql.sql index 8be166d..c98789a 100644 --- a/service/db/psql.sql +++ b/service/db/psql.sql @@ -114,3 +114,11 @@ CREATE TABLE IF NOT EXISTS chain_prices ( avg_price DECIMAL(10, 2) NOT NULL, UNIQUE (chain_product_id, price_date) ); + + -- Drop the existing unique constraint and primary key + ALTER TABLE prices DROP CONSTRAINT IF EXISTS prices_chain_product_id_store_id_price_date_key; + ALTER TABLE prices DROP CONSTRAINT IF EXISTS prices_pkey; + ALTER TABLE prices DROP COLUMN IF EXISTS id; + + -- Add the composite primary key + ALTER TABLE prices ADD PRIMARY KEY (chain_product_id, store_id, price_date); \ No newline at end of file