Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

[pre-commit.ci] pre-commit autoupdate #715

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@ repos:
- id: detect-private-key

# python code formatting/linting
- repo: https://github.com/charliermarsh/ruff-pre-commit
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: "v0.0.253"
rev: "v0.0.287"
hooks:
- id: ruff
args: [--fix]
- repo: https://github.com/psf/black
rev: 22.6.0
rev: 23.7.0
hooks:
- id: black
args: [--line-length, "100"]
# yaml formatting
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.0-alpha.4
rev: v3.0.3
hooks:
- id: prettier
types: [yaml]
1 change: 0 additions & 1 deletion notebooks/2021-08/2021-08-25/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
channel_indexes = [1, 8, 9]
satellite_data = []
for channel_index in channel_indexes:

# renormalize
satellite_data.append(
data["sat_data"][batch_index, :, :, :, channel_index] * SAT_STD.values[channel_index]
Expand Down
1 change: 0 additions & 1 deletion notebooks/2021-08/2021-08-26/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@
channel_indexes = [1, 9, 8]
satellite_data = []
for channel_index in channel_indexes:

# renormalize
satellite_data.append(
data["sat_data"][batch_index, :, :, :, channel_index] * SAT_STD.values[channel_index]
Expand Down
1 change: 0 additions & 1 deletion notebooks/2021-09/2021-09-13/remove_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

for filenames in [train_filenames, validation_filenames]:
for file in train_filenames:

print(file)

filename = file.split("/")[-1]
Expand Down
1 change: 0 additions & 1 deletion notebooks/2021-09/2021-09-14/gsp_centroid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

# for index in range(0, len(shape_data_raw)):
for index in range(140, 150):

# just select the first one
shape_data = shape_data_raw.iloc[index : index + 1]
shapes_dict = json.loads(shape_data["geometry"].to_json())
Expand Down
1 change: 0 additions & 1 deletion notebooks/2021-09/2021-09-14/gsp_duplicated.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
duplicated_raw["Amount"] = range(0, len(duplicated_raw))

for i in range(0, 8, 2):

# just select the first one
duplicated = duplicated_raw.iloc[i : i + 2]
shapes_dict = json.loads(duplicated["geometry"].to_json())
Expand Down
1 change: 0 additions & 1 deletion notebooks/2021-09/2021-09-29/gsp_duplicated.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
duplicated_raw["Amount"] = range(0, len(duplicated_raw))

for i in range(0, 8, 2):

# just select the first one
duplicated = duplicated_raw.iloc[i : i + 2]
shapes_dict = json.loads(duplicated["geometry"].to_json())
Expand Down
2 changes: 0 additions & 2 deletions notebooks/2021-09/2021-09-29/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@


def get_trace(dt):

# plot to check it looks right
return go.Choroplethmapbox(
geojson=shapes_dict,
Expand All @@ -54,7 +53,6 @@ def get_trace(dt):


def get_frame(dt):

# plot to check it looks right
return go.Choroplethmapbox(
z=gps_data[dt],
Expand Down
2 changes: 0 additions & 2 deletions notebooks/2021-10/2021-10-01/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


class Satellite(BaseModel):

# width: int = Field(..., g=0, description="The width of the satellite image")
# height: int = Field(..., g=0, description="The width of the satellite image")
# num_channels: int = Field(..., g=0, description="The width of the satellite image")
Expand Down Expand Up @@ -49,7 +48,6 @@ class Config:


class Batch(BaseModel):

batch_size: int = Field(
...,
g=0,
Expand Down
1 change: 0 additions & 1 deletion notebooks/2021-10/2021-10-08/xr_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
def get_satellite_xrarray_data_array(
batch_size, seq_length_5, satellite_image_size_pixels, number_sat_channels=10
):

r = np.random.randn(
# self.batch_size,
seq_length_5,
Expand Down
1 change: 0 additions & 1 deletion notebooks/2021-10/2021-10-08/xr_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def v_image_data(cls, v):


class Batch(BaseModel):

batch_size: int = 0
satellite: Satellite

Expand Down
1 change: 0 additions & 1 deletion nowcasting_dataset/data_sources/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def __post_init__(self):
def _get_start_dt(
self, t0_datetime_utc: Union[pd.Timestamp, pd.DatetimeIndex]
) -> Union[pd.Timestamp, pd.DatetimeIndex]:

return t0_datetime_utc - self.history_duration

def _get_end_dt(
Expand Down
1 change: 0 additions & 1 deletion nowcasting_dataset/data_sources/fake/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,6 @@ def topographic_fake(
# make batch of arrays
xr_arrays = []
for i in range(batch_size):

x, y = make_image_coords_osgb(
size_x=image_size_pixels_width,
size_y=image_size_pixels_height,
Expand Down
1 change: 0 additions & 1 deletion nowcasting_dataset/data_sources/gsp/eso.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def get_gsp_shape_from_eso(
shape_gpd["RegionID"] = range(1, len(shape_gpd) + 1)

if save_local_file:

# rename the columns to less than 10 characters
shape_gpd_to_save = shape_gpd.copy()
shape_gpd_to_save.rename(columns=rename_save_columns, inplace=True)
Expand Down
5 changes: 0 additions & 5 deletions nowcasting_dataset/data_sources/gsp/gsp_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def get_all_locations(self, t0_datetimes_utc: pd.DatetimeIndex) -> List[SpaceTim
if total_gsp_nan_count > 0:
assert Exception("There are nans in the GSP data. Can't get locations for all GSPs")
else:

t0_datetimes_utc.name = "t0_datetime_utc"

# get all locations
Expand Down Expand Up @@ -236,7 +235,6 @@ def get_locations(self, t0_datetimes_utc: pd.DatetimeIndex) -> List[SpaceTimeLoc

total_gsp_nan_count = self.gsp_power.isna().sum().sum()
if total_gsp_nan_count == 0:

# get random GSP metadata
indexes = sorted(
list(self.rng.integers(low=0, high=len(self.metadata), size=len(t0_datetimes_utc)))
Expand All @@ -249,7 +247,6 @@ def get_locations(self, t0_datetimes_utc: pd.DatetimeIndex) -> List[SpaceTimeLoc
ids = list(metadata.index)

else:

logger.warning(
"There are some nans in the gsp data, "
"so to get x,y locations we have to do a big loop"
Expand All @@ -262,7 +259,6 @@ def get_locations(self, t0_datetimes_utc: pd.DatetimeIndex) -> List[SpaceTimeLoc
ids = []

for t0_dt in t0_datetimes_utc:

# Choose start and end times
start_dt = self._get_start_dt(t0_dt)
end_dt = self._get_end_dt(t0_dt)
Expand Down Expand Up @@ -290,7 +286,6 @@ def get_locations(self, t0_datetimes_utc: pd.DatetimeIndex) -> List[SpaceTimeLoc

locations = []
for i in range(len(x_centers_osgb)):

locations.append(
SpaceTimeLocation(
t0_datetime_utc=t0_datetimes_utc[i],
Expand Down
1 change: 0 additions & 1 deletion nowcasting_dataset/data_sources/gsp/pvlive.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def load_pv_gsp_raw_data_from_pvlive(
future_tasks = []
with futures.ThreadPoolExecutor(max_workers=1) as executor:
for gsp_id in gsp_ids:

# set the first chunk start and end times
start_chunk = first_start_chunk
end_chunk = first_end_chunk
Expand Down
1 change: 0 additions & 1 deletion nowcasting_dataset/data_sources/metadata/metadata_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def save_to_csv(self, path):
metadata_df = pd.DataFrame(metadata_dict)

else:

metadata_df = pd.read_csv(filename)

metadata_df_extra = pd.DataFrame(metadata_dict)
Expand Down
2 changes: 0 additions & 2 deletions nowcasting_dataset/data_sources/pv/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def get_metadata_from_database(providers: List[str] = None) -> pd.DataFrame:

pv_system_all_df = []
for provider in providers:

logger.debug(f"Get PV systems from database for {provider}")

with db_connection.get_session() as session:
Expand Down Expand Up @@ -136,7 +135,6 @@ def get_pv_power_from_database(
logger.debug(f"Found {len(pv_yields_df)} pv yields")

if len(pv_yields_df) == 0:

data = create_empty_pv_data(end_utc=now, providers=providers, start_utc=start_utc)

return data
Expand Down
3 changes: 1 addition & 2 deletions nowcasting_dataset/data_sources/pv/pv_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def get_data_model_for_batch():
return PV

def _load_metadata(self):

logger.debug(f"Loading PV metadata from {self.files_groups}")

# collect all metadata together
Expand Down Expand Up @@ -155,7 +154,6 @@ def _load_metadata(self):
logger.debug(f"Found {len(pv_metadata)} pv systems")

def _load_pv_power(self):

logger.debug(f"Loading PV Power data from {self.files_groups}")

if not self.is_live:
Expand Down Expand Up @@ -453,6 +451,7 @@ def get_locations(self, t0_datetimes_utc: pd.DatetimeIndex) -> List[SpaceTimeLoc
Returns: x_locations, y_locations. Each has one entry per t0_datetime.
Locations are in OSGB coordinates.
"""

# Set this up as a separate function, so we can cache the result!
@functools.cache # functools.cache requires Python >= 3.9
def _get_pv_system_ids(t0_datetime: pd.Timestamp) -> pd.Int64Dtype:
Expand Down
3 changes: 0 additions & 3 deletions nowcasting_dataset/data_sources/sun/raw_data_load_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,13 @@ def get_azimuth_and_elevation(
names = []
# loop over locations and find azimuth and elevation angles,
with futures.ThreadPoolExecutor() as executor:

logger.debug("Setting up jobs")

# Submit tasks to the executor.
future_azimuth_and_elevation_per_location = []
for i in tqdm(range(len(x_centers))):

name = x_y_to_name(x_centers[i], y_centers[i])
if name not in names:

lat, lon = geospatial.osgb_to_lat_lon(x=x_centers[i], y=y_centers[i])

future_azimuth_and_elevation = executor.submit(
Expand Down
3 changes: 0 additions & 3 deletions nowcasting_dataset/data_sources/sun/sun_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def get_example(self, location: SpaceTimeLocation) -> xr.Dataset:
end_dt = self._get_end_dt(t0_datetime_utc)

if not self.load_live:

# The names of the columns get truncated when saving, therefore we need to look for the
# name of the columns near the location we are looking for
locations = np.array(
Expand All @@ -96,7 +95,6 @@ def get_example(self, location: SpaceTimeLocation) -> xr.Dataset:
elevation = self.elevation.loc[start_dt:end_dt][name]

else:

latitude, longitude = osgb_to_lat_lon(x=x_center_osgb, y=y_center_osgb)

datestamps = pd.date_range(start=start_dt, end=end_dt, freq="5T").tolist()
Expand All @@ -115,7 +113,6 @@ def get_example(self, location: SpaceTimeLocation) -> xr.Dataset:
return sun

def _load(self):

logger.info(f"Loading Sun data from {self.zarr_path}")

if not self.load_live:
Expand Down
2 changes: 0 additions & 2 deletions nowcasting_dataset/dataset/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def load_netcdf(

# loop over data sources
for data_source_name in data_sources_names:

local_netcdf_filename = os.path.join(
local_netcdf_path, data_source_name, get_netcdf_filename(batch_idx)
)
Expand Down Expand Up @@ -193,7 +192,6 @@ def load_netcdf(

# legacy NWP
if "nwp" in batch_dict.keys():

nwp_rename_dict = {
"x_index": "x_osgb_index",
"y_index": "y_osgb_index",
Expand Down
2 changes: 0 additions & 2 deletions nowcasting_dataset/dataset/split/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def split_method(
test_periods = unique_periods[unique_periods["modulo"].isin(test_indexes)]["period"]

elif method == "random":

# randomly sort indexes
rng = np.random.default_rng(seed)
unique_periods_in_dataset = rng.permutation(unique_periods_in_dataset)
Expand All @@ -108,7 +107,6 @@ def split_method(
test_periods = pd.to_datetime(unique_periods_in_dataset[validation_test_split:])

elif method == "specific":

train_periods = unique_periods_in_dataset[
unique_periods_in_dataset.isin(train_test_validation_specific.train)
]
Expand Down
1 change: 0 additions & 1 deletion nowcasting_dataset/filesystem/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def delete_all_files_in_temp_path(path: Union[Path, str], delete_dirs: bool = Fa
else:
# loop over folder structure, but only delete files
for root, dirs, files in filesystem.walk(path):

for f in files:
filesystem.rm(f"{root}/{f}")

Expand Down
3 changes: 0 additions & 3 deletions nowcasting_dataset/manager/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,13 @@ def sample_spatial_and_temporal_locations_for_examples(
shuffled_t0_datetimes = pd.DatetimeIndex(shuffled_t0_datetimes)

if get_all_locations:

# note that the returned 'shuffled_t0_datetimes'
# has duplicate datetimes for each location
locations = self.data_source_which_defines_geospatial_locations.get_all_locations(
t0_datetimes_utc=shuffled_t0_datetimes
)

else:

locations = self.data_source_which_defines_geospatial_locations.get_locations(
shuffled_t0_datetimes
)
Expand Down Expand Up @@ -404,7 +402,6 @@ def create_batches(self, overwrite_batches: bool) -> None:
for worker_id, (data_source_name, data_source) in enumerate(
self.data_sources.items()
):

# Get indexes of first batch and example. And subset locations_for_split.
idx_of_first_batch = first_batches_to_create[split_name][data_source_name]
idx_of_first_example = idx_of_first_batch * self.config.process.batch_size
Expand Down
2 changes: 0 additions & 2 deletions nowcasting_dataset/manager/manager_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def create_batches(self, use_async: Optional[bool] = True) -> None:
async_results_from_create_batches = []
an_error_has_occured = multiprocessing.Event()
for worker_id, (data_source_name, data_source) in enumerate(self.data_sources.items()):

# Get indexes of first batch and example. And subset locations_for_split.
idx_of_first_batch = 0
locations = locations_for_each_example
Expand Down Expand Up @@ -226,7 +225,6 @@ def create_batches(self, use_async: Optional[bool] = True) -> None:
# Sometimes when debuggin it is easy to use non async
data_source.create_batches(**kwargs_for_create_batches)
else:

async_result = pool.apply_async(
data_source.create_batches,
kwds=kwargs_for_create_batches,
Expand Down
2 changes: 2 additions & 0 deletions nowcasting_dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def shutdown(self, wait=True):

def arg_logger(func):
"""A function decorator to log all the args and kwargs passed into a function."""

# Adapted from https://stackoverflow.com/a/23983263/732596
@wraps(func)
def inner_func(*args, **kwargs):
Expand All @@ -191,6 +192,7 @@ def inner_func(*args, **kwargs):

def exception_logger(func):
"""A function decorator to log exceptions thrown by the inner function."""

# Adapted from
# www.blog.pythonlibrary.org/2016/06/09/python-how-to-create-an-exception-logging-decorator
@wraps(func)
Expand Down
1 change: 0 additions & 1 deletion scripts/generate_topographic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
upscale_factor = 0.12 # 30m to 250m-ish, just making it small enough files to actually merge
for f in files:
with rasterio.open(f) as dataset:

# resample data to target shape
data = dataset.read(
out_shape=(
Expand Down
1 change: 0 additions & 1 deletion tests/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def test_yaml_save():
"""

with tempfile.NamedTemporaryFile(suffix=".yaml") as fp:

filename = fp.name

# check that temp file cant be loaded
Expand Down
Loading