diff --git a/src/anemoi/datasets/create/sources/grib_index.py b/src/anemoi/datasets/create/sources/grib_index.py index 46558b7d..1d109b60 100644 --- a/src/anemoi/datasets/create/sources/grib_index.py +++ b/src/anemoi/datasets/create/sources/grib_index.py @@ -18,7 +18,9 @@ import earthkit.data as ekd import tqdm +from anemoi.transform.fields import new_field_from_grid from anemoi.transform.flavour import RuleBasedFlavour +from anemoi.transform.grids import grid_registry from cachetools import LRUCache from earthkit.data.indexing.fieldlist import FieldArray @@ -102,6 +104,21 @@ def __init__( self.warnings = {} self.cache = {} + def _quote_column(self, column: str) -> str: + """Quote a column name for use in SQL queries. + + Parameters + ---------- + column : str + The column name to quote. + + Returns + ------- + str + The quoted column name. + """ + return f'"{column}"' + def _create_tables(self) -> None: """Create the necessary tables in the database.""" assert self.update @@ -123,7 +140,7 @@ def _create_tables(self) -> None: _path_id INTEGER not null, _offset INTEGER not null, _length INTEGER not null, - {', '.join(f"{key} TEXT not null default ''" for key in columns)}, + {', '.join(f"{self._quote_column(key)} TEXT not null default ''" for key in columns)}, FOREIGN KEY(_path_id) REFERENCES paths(id)) """) # , @@ -134,13 +151,13 @@ def _create_tables(self) -> None: self.cursor.execute(f""" CREATE UNIQUE INDEX IF NOT EXISTS idx_grib_index_all_keys - ON grib_index ({', '.join(columns)}) + ON grib_index ({', '.join(self._quote_column(col) for col in columns)}) """) for key in columns: self.cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_grib_index_{key} - ON grib_index ({key}) + CREATE INDEX IF NOT EXISTS idx_grib_index_{key.replace(':', '_')} + ON grib_index ({self._quote_column(key)}) """) self._commit() @@ -195,7 +212,7 @@ def _add_grib(self, **kwargs: Any) -> None: self.cursor.execute( f""" - INSERT INTO grib_index ({', '.join(kwargs.keys())}) + INSERT INTO grib_index ({', '.join(self._quote_column(k) for k in kwargs.keys())}) VALUES ({', '.join('?' for _ in kwargs)}) """, tuple(kwargs.values()), @@ -208,7 +225,8 @@ def _add_grib(self, **kwargs: Any) -> None: for n in ("_path_id", "_offset", "_length"): kwargs.pop(n) self.cursor.execute( - "SELECT * FROM grib_index WHERE " + " AND ".join(f"{key} = ?" for key in kwargs.keys()), + "SELECT * FROM grib_index WHERE " + + " AND ".join(f"{self._quote_column(key)} = ?" for key in kwargs.keys()), tuple(kwargs.values()), ) existing_record = self.cursor.fetchone() @@ -252,20 +270,22 @@ def _ensure_columns(self, columns: list[str]) -> None: self._columns = None for column in new_columns: - self.cursor.execute(f"ALTER TABLE grib_index ADD COLUMN {column} TEXT not null default ''") + self.cursor.execute( + f"ALTER TABLE grib_index ADD COLUMN {self._quote_column(column)} TEXT not null default ''" + ) self.cursor.execute("""DROP INDEX IF EXISTS idx_grib_index_all_keys""") all_columns = self._all_columns() self.cursor.execute(f""" CREATE UNIQUE INDEX IF NOT EXISTS idx_grib_index_all_keys - ON grib_index ({', '.join(all_columns)}) + ON grib_index ({', '.join(self._quote_column(col) for col in all_columns)}) """) for key in all_columns: self.cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_grib_index_{key} - ON grib_index ({key}) + CREATE INDEX IF NOT EXISTS idx_grib_index_{key.replace(':', '_')} + ON grib_index ({self._quote_column(key)}) """) def add_grib_file(self, path: str) -> None: @@ -301,6 +321,8 @@ def add_grib_file(self, path: str) -> None: self._unknown(path, field, i, param) self.warnings[param] = True + continue + self._ensure_columns(list(keys.keys())) self._add_grib( @@ -536,15 +558,14 @@ def retrieve(self, dates: list[Any], **kwargs: Any) -> Iterator[Any]: LOG.warning(f"Warning : {k} not in database columns, key discarded") continue if isinstance(v, list): - query += f" AND {k} IN ({', '.join('?' for _ in v)})" + query += f" AND {self._quote_column(k)} IN ({', '.join('?' for _ in v)})" params.extend([str(_) for _ in v]) else: - query += f" AND {k} = ?" + query += f" AND {self._quote_column(k)} = ?" params.append(str(v)) print("SELECT (query)", query) print("SELECT (params)", params) - self.cursor.execute(query, params) fetch = self.cursor.fetchall() @@ -593,6 +614,11 @@ def _execute( FieldArray An array of retrieved GRIB fields. """ + + grid_definition = kwargs.pop("grid_definition", None) + if grid_definition: + grid_definition = grid_registry.from_config(grid_definition) + index = GribIndex(indexdb) if flavour is not None: @@ -623,6 +649,9 @@ def _execute( field = flavour.apply(field) result.append(field) + if grid_definition is not None: + result = [new_field_from_grid(field, grid_definition) for field in result] + return FieldArray(result)