Skip to content
55 changes: 42 additions & 13 deletions src/anemoi/datasets/create/sources/grib_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import earthkit.data as ekd
import tqdm
from anemoi.transform.fields import new_field_from_grid
from anemoi.transform.flavour import RuleBasedFlavour
from anemoi.transform.grids import grid_registry
from cachetools import LRUCache
from earthkit.data.indexing.fieldlist import FieldArray

Expand Down Expand Up @@ -102,6 +104,21 @@ def __init__(
self.warnings = {}
self.cache = {}

def _quote_column(self, column: str) -> str:
"""Quote a column name for use in SQL queries.

Parameters
----------
column : str
The column name to quote.

Returns
-------
str
The quoted column name.
"""
return f'"{column}"'

def _create_tables(self) -> None:
"""Create the necessary tables in the database."""
assert self.update
Expand All @@ -123,7 +140,7 @@ def _create_tables(self) -> None:
_path_id INTEGER not null,
_offset INTEGER not null,
_length INTEGER not null,
{', '.join(f"{key} TEXT not null default ''" for key in columns)},
{', '.join(f"{self._quote_column(key)} TEXT not null default ''" for key in columns)},
FOREIGN KEY(_path_id) REFERENCES paths(id))
""") # ,

Expand All @@ -134,13 +151,13 @@ def _create_tables(self) -> None:

self.cursor.execute(f"""
CREATE UNIQUE INDEX IF NOT EXISTS idx_grib_index_all_keys
ON grib_index ({', '.join(columns)})
ON grib_index ({', '.join(self._quote_column(col) for col in columns)})
""")

for key in columns:
self.cursor.execute(f"""
CREATE INDEX IF NOT EXISTS idx_grib_index_{key}
ON grib_index ({key})
CREATE INDEX IF NOT EXISTS idx_grib_index_{key.replace(':', '_')}
ON grib_index ({self._quote_column(key)})
""")

self._commit()
Expand Down Expand Up @@ -195,7 +212,7 @@ def _add_grib(self, **kwargs: Any) -> None:

self.cursor.execute(
f"""
INSERT INTO grib_index ({', '.join(kwargs.keys())})
INSERT INTO grib_index ({', '.join(self._quote_column(k) for k in kwargs.keys())})
VALUES ({', '.join('?' for _ in kwargs)})
""",
tuple(kwargs.values()),
Expand All @@ -208,7 +225,8 @@ def _add_grib(self, **kwargs: Any) -> None:
for n in ("_path_id", "_offset", "_length"):
kwargs.pop(n)
self.cursor.execute(
"SELECT * FROM grib_index WHERE " + " AND ".join(f"{key} = ?" for key in kwargs.keys()),
"SELECT * FROM grib_index WHERE "
+ " AND ".join(f"{self._quote_column(key)} = ?" for key in kwargs.keys()),
tuple(kwargs.values()),
)
existing_record = self.cursor.fetchone()
Expand Down Expand Up @@ -252,20 +270,22 @@ def _ensure_columns(self, columns: list[str]) -> None:
self._columns = None

for column in new_columns:
self.cursor.execute(f"ALTER TABLE grib_index ADD COLUMN {column} TEXT not null default ''")
self.cursor.execute(
f"ALTER TABLE grib_index ADD COLUMN {self._quote_column(column)} TEXT not null default ''"
)

self.cursor.execute("""DROP INDEX IF EXISTS idx_grib_index_all_keys""")
all_columns = self._all_columns()

self.cursor.execute(f"""
CREATE UNIQUE INDEX IF NOT EXISTS idx_grib_index_all_keys
ON grib_index ({', '.join(all_columns)})
ON grib_index ({', '.join(self._quote_column(col) for col in all_columns)})
""")

for key in all_columns:
self.cursor.execute(f"""
CREATE INDEX IF NOT EXISTS idx_grib_index_{key}
ON grib_index ({key})
CREATE INDEX IF NOT EXISTS idx_grib_index_{key.replace(':', '_')}
ON grib_index ({self._quote_column(key)})
""")

def add_grib_file(self, path: str) -> None:
Expand Down Expand Up @@ -301,6 +321,8 @@ def add_grib_file(self, path: str) -> None:
self._unknown(path, field, i, param)
self.warnings[param] = True

continue

self._ensure_columns(list(keys.keys()))

self._add_grib(
Expand Down Expand Up @@ -536,15 +558,14 @@ def retrieve(self, dates: list[Any], **kwargs: Any) -> Iterator[Any]:
LOG.warning(f"Warning : {k} not in database columns, key discarded")
continue
if isinstance(v, list):
query += f" AND {k} IN ({', '.join('?' for _ in v)})"
query += f" AND {self._quote_column(k)} IN ({', '.join('?' for _ in v)})"
params.extend([str(_) for _ in v])
else:
query += f" AND {k} = ?"
query += f" AND {self._quote_column(k)} = ?"
params.append(str(v))

print("SELECT (query)", query)
print("SELECT (params)", params)

self.cursor.execute(query, params)

fetch = self.cursor.fetchall()
Expand Down Expand Up @@ -593,6 +614,11 @@ def _execute(
FieldArray
An array of retrieved GRIB fields.
"""

grid_definition = kwargs.pop("grid_definition", None)
if grid_definition:
grid_definition = grid_registry.from_config(grid_definition)

index = GribIndex(indexdb)

if flavour is not None:
Expand Down Expand Up @@ -623,6 +649,9 @@ def _execute(
field = flavour.apply(field)
result.append(field)

if grid_definition is not None:
result = [new_field_from_grid(field, grid_definition) for field in result]

return FieldArray(result)


Expand Down
Loading