Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions src/anemoi/datasets/create/sources/grib_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import earthkit.data as ekd
import tqdm
from anemoi.transform.fields import new_field_from_grid
from anemoi.transform.flavour import RuleBasedFlavour
from anemoi.transform.grids import grid_registry
from cachetools import LRUCache
from earthkit.data.indexing.fieldlist import FieldArray

Expand Down Expand Up @@ -102,6 +104,21 @@ def __init__(
self.warnings = {}
self.cache = {}

def _quote_column(self, column: str) -> str:
"""Quote a column name for use in SQL queries.

Parameters
----------
column : str
The column name to quote.

Returns
-------
str
The quoted column name.
"""
return f'"{column}"'

def _create_tables(self) -> None:
"""Create the necessary tables in the database."""
assert self.update
Expand All @@ -123,7 +140,7 @@ def _create_tables(self) -> None:
_path_id INTEGER not null,
_offset INTEGER not null,
_length INTEGER not null,
{', '.join(f"{key} TEXT not null default ''" for key in columns)},
{', '.join(f"{self._quote_column(key)} TEXT not null default ''" for key in columns)},
FOREIGN KEY(_path_id) REFERENCES paths(id))
""") # ,

Expand All @@ -134,13 +151,13 @@ def _create_tables(self) -> None:

self.cursor.execute(f"""
CREATE UNIQUE INDEX IF NOT EXISTS idx_grib_index_all_keys
ON grib_index ({', '.join(columns)})
ON grib_index ({', '.join(self._quote_column(col) for col in columns)})
""")

for key in columns:
self.cursor.execute(f"""
CREATE INDEX IF NOT EXISTS idx_grib_index_{key}
ON grib_index ({key})
CREATE INDEX IF NOT EXISTS idx_grib_index_{key.replace(':', '_')}
ON grib_index ({self._quote_column(key)})
""")

self._commit()
Expand Down Expand Up @@ -195,7 +212,7 @@ def _add_grib(self, **kwargs: Any) -> None:

self.cursor.execute(
f"""
INSERT INTO grib_index ({', '.join(kwargs.keys())})
INSERT INTO grib_index ({', '.join(self._quote_column(k) for k in kwargs.keys())})
VALUES ({', '.join('?' for _ in kwargs)})
""",
tuple(kwargs.values()),
Expand All @@ -208,7 +225,8 @@ def _add_grib(self, **kwargs: Any) -> None:
for n in ("_path_id", "_offset", "_length"):
kwargs.pop(n)
self.cursor.execute(
"SELECT * FROM grib_index WHERE " + " AND ".join(f"{key} = ?" for key in kwargs.keys()),
"SELECT * FROM grib_index WHERE "
+ " AND ".join(f"{self._quote_column(key)} = ?" for key in kwargs.keys()),
tuple(kwargs.values()),
)
existing_record = self.cursor.fetchone()
Expand Down Expand Up @@ -252,20 +270,22 @@ def _ensure_columns(self, columns: list[str]) -> None:
self._columns = None

for column in new_columns:
self.cursor.execute(f"ALTER TABLE grib_index ADD COLUMN {column} TEXT not null default ''")
self.cursor.execute(
f"ALTER TABLE grib_index ADD COLUMN {self._quote_column(column)} TEXT not null default ''"
)

self.cursor.execute("""DROP INDEX IF EXISTS idx_grib_index_all_keys""")
all_columns = self._all_columns()

self.cursor.execute(f"""
CREATE UNIQUE INDEX IF NOT EXISTS idx_grib_index_all_keys
ON grib_index ({', '.join(all_columns)})
ON grib_index ({', '.join(self._quote_column(col) for col in all_columns)})
""")

for key in all_columns:
self.cursor.execute(f"""
CREATE INDEX IF NOT EXISTS idx_grib_index_{key}
ON grib_index ({key})
CREATE INDEX IF NOT EXISTS idx_grib_index_{key.replace(':', '_')}
ON grib_index ({self._quote_column(key)})
""")

def add_grib_file(self, path: str) -> None:
Expand Down Expand Up @@ -301,6 +321,8 @@ def add_grib_file(self, path: str) -> None:
self._unknown(path, field, i, param)
self.warnings[param] = True

continue

self._ensure_columns(list(keys.keys()))

self._add_grib(
Expand Down Expand Up @@ -536,15 +558,14 @@ def retrieve(self, dates: list[Any], **kwargs: Any) -> Iterator[Any]:
LOG.warning(f"Warning : {k} not in database columns, key discarded")
continue
if isinstance(v, list):
query += f" AND {k} IN ({', '.join('?' for _ in v)})"
query += f" AND {self._quote_column(k)} IN ({', '.join('?' for _ in v)})"
params.extend([str(_) for _ in v])
else:
query += f" AND {k} = ?"
query += f" AND {self._quote_column(k)} = ?"
params.append(str(v))

print("SELECT (query)", query)
print("SELECT (params)", params)

self.cursor.execute(query, params)

fetch = self.cursor.fetchall()
Expand Down Expand Up @@ -593,6 +614,11 @@ def _execute(
FieldArray
An array of retrieved GRIB fields.
"""

grid_definition = kwargs.pop("grid_definition", None)
if grid_definition:
grid_definition = grid_registry.from_config(grid_definition)

index = GribIndex(indexdb)

if flavour is not None:
Expand Down Expand Up @@ -623,6 +649,9 @@ def _execute(
field = flavour.apply(field)
result.append(field)

if grid_definition is not None:
result = [new_field_from_grid(field, grid_definition) for field in result]

return FieldArray(result)


Expand Down
Loading