diff --git a/spextra/libraries.py b/spextra/libraries.py index 641f947..939c3bb 100644 --- a/spextra/libraries.py +++ b/spextra/libraries.py @@ -45,30 +45,36 @@ def __init__(self, name): self._items = yamldict.get(self.aliases["items"], {}) self.spectral_coverage = yamldict.get("spectral_coverage", []) - self.read_kwargs = self._get_read_kwargs(yamldict) - self._validate_units() - self.file_extension = yamldict.get("file_extension", None) + self.read_kwargs = self._get_read_kwargs(yamldict) def _get_read_kwargs(self, yamldict) -> dict: read_kwargs = { "wave_col": yamldict.get(self.aliases["wave_col"], "WAVELENGTH"), "flux_col": yamldict.get(self.aliases["flux_col"], "FLUX"), - "wave_unit": yamldict.get("wave_unit", "Angstrom"), - "flux_unit": yamldict.get(self.aliases["flux_unit"], "FLAM"), } + + # FITS stores units in TUNIT keywords + if self.data_type != "fits": + read_kwargs.update({ + "wave_unit": yamldict.get("wave_unit", "Angstrom"), + "flux_unit": yamldict.get(self.aliases["flux_unit"], "FLAM"), + }) + read_kwargs = self._validate_units(read_kwargs) return read_kwargs - def _validate_units(self) -> None: + @staticmethod + def _validate_units(read_kwargs: dict) -> dict: for key in ("wave_unit", "flux_unit"): - value = self.read_kwargs[key] + value = read_kwargs[key] # HACK to understand ext. units if value == "Av/E(B-V)": value = "mag" try: - self.read_kwargs[key] = validate_unit(value) + read_kwargs[key] = validate_unit(value) except (SynphotError, ValueError) as err: raise UnitError(f"{value} not understood.") from err + return read_kwargs @property def is_in_database(self) -> bool: