Skip to content

Commit

Permalink
Merge pull request #120 from openclimatefix/start-end-times
Browse files Browse the repository at this point in the history
add start and end times on forecast and generation
  • Loading branch information
peterdudfield authored Jul 19, 2024
2 parents 83a2e6f + 7437276 commit ed26213
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 7 deletions.
15 changes: 12 additions & 3 deletions pv_site_api/_db_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def _get_latest_forecast_by_sites(
session: Session,
site_uuids: list[str],
start_utc: Optional[dt.datetime] = None,
end_utc: Optional[dt.datetime] = None,
sum_by: Optional[str] = None,
) -> list[Row]:
"""Get the latest forecast for given site uuids."""
Expand All @@ -123,6 +124,9 @@ def _get_latest_forecast_by_sites(
if start_utc is not None:
query = query.filter(ForecastValueSQL.start_utc >= start_utc)

if end_utc is not None:
query = query.filter(ForecastValueSQL.end_utc <= end_utc)

query.order_by(forecast_subq.timestamp_utc, ForecastValueSQL.start_utc)

if sum_by is None:
Expand Down Expand Up @@ -155,6 +159,7 @@ def get_forecasts_by_sites(
horizon_minutes: int,
compact: bool = False,
sum_by: Optional[str] = None,
end_utc: Optional[dt.datetime] = None,
) -> Union[list[Forecast], ManyForecastCompact]:
"""Combination of the latest forecast and the past forecasts, for given sites.
Expand All @@ -163,20 +168,22 @@ def get_forecasts_by_sites(

logger.info(f"Getting forecast for {len(site_uuids)} sites")

end_utc = dt.datetime.utcnow()
end_utc_past = dt.datetime.utcnow()
if (end_utc is not None) and (end_utc < end_utc_past):
end_utc_past = end_utc

rows_past = _get_forecasts_for_horizon(
session,
site_uuids=site_uuids,
start_utc=start_utc,
end_utc=end_utc,
end_utc=end_utc_past,
horizon_minutes=horizon_minutes,
sum_by=sum_by,
)
logger.debug("Found %s past forecasts", len(rows_past))

rows_future = _get_latest_forecast_by_sites(
session=session, site_uuids=site_uuids, start_utc=start_utc, sum_by=sum_by
session=session, site_uuids=site_uuids, start_utc=start_utc, sum_by=sum_by, end_utc=end_utc
)
logger.debug("Found %s future forecasts", len(rows_future))

Expand All @@ -199,12 +206,14 @@ def get_generation_by_sites(
start_utc: dt.datetime,
compact: bool = False,
sum_by: Optional[str] = None,
end_utc: Optional[dt.datetime] = None,
) -> Union[list[MultiplePVActual], MultipleSitePVActualCompact]:
"""Get the generation since yesterday (midnight) for a list of sites."""
logger.info(f"Getting generation for {len(site_uuids)} sites")
rows = get_pv_generation_by_sites(
session=session,
start_utc=start_utc,
end_utc=end_utc,
site_uuids=[uuid.UUID(su) for su in site_uuids],
sum_by=sum_by,
)
Expand Down
30 changes: 27 additions & 3 deletions pv_site_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import time
import uuid
from datetime import datetime
from typing import Optional, Union

import pandas as pd
Expand Down Expand Up @@ -444,6 +445,8 @@ def get_pv_actual_many_sites(
sum_by: Optional[str] = None,
auth: dict = Depends(auth),
compact: bool = False,
start_utc: Optional[str] = None,
end_utc: Optional[str] = None,
):
"""
### Get the actual power generation for a list of sites.
Expand All @@ -468,6 +471,11 @@ def get_pv_actual_many_sites(
site_uuids = site_uuids.replace(" ", "")
site_uuids_list = site_uuids.split(",")

if start_utc is not None:
start_utc = datetime.fromisoformat(start_utc)
if end_utc is not None:
end_utc = datetime.fromisoformat(end_utc)

if is_fake():
return [make_fake_pv_generation(site_uuid) for site_uuid in site_uuids_list]

Expand All @@ -479,10 +487,16 @@ def get_pv_actual_many_sites(

check_user_has_access_to_sites(session=session, auth=auth, site_uuids=site_uuids_list)

start_utc = get_yesterday_midnight()
if start_utc is None:
start_utc = get_yesterday_midnight()

return get_generation_by_sites(
session, site_uuids=site_uuids_list, start_utc=start_utc, compact=compact, sum_by=sum_by
session,
site_uuids=site_uuids_list,
start_utc=start_utc,
compact=compact,
sum_by=sum_by,
end_utc=end_utc,
)


Expand Down Expand Up @@ -540,6 +554,8 @@ def get_pv_forecast_many_sites(
session: Session = Depends(get_session),
auth: dict = Depends(auth),
sum_by: Optional[str] = None,
start_utc: Optional[str] = None,
end_utc: Optional[str] = None,
compact: bool = False,
):
"""
Expand All @@ -565,12 +581,19 @@ def get_pv_forecast_many_sites(
if is_fake():
return [make_fake_forecast(fake_site_uuid)]

start_utc = get_yesterday_midnight()
if start_utc is not None:
start_utc = datetime.fromisoformat(start_utc)
if end_utc is not None:
end_utc = datetime.fromisoformat(end_utc)

if start_utc is None:
start_utc = get_yesterday_midnight()

if (site_uuids == "[]") or (site_uuids == ""):
return []

site_uuids = site_uuids.replace(" ", "")

site_uuids_list = site_uuids.split(",")

# check that uuids are given
Expand All @@ -587,6 +610,7 @@ def get_pv_forecast_many_sites(
session,
site_uuids=site_uuids_list,
start_utc=start_utc,
end_utc=end_utc,
horizon_minutes=0,
compact=compact,
sum_by=sum_by,
Expand Down
36 changes: 36 additions & 0 deletions tests/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,42 @@ def test_get_forecast_many_sites_late_forecast_one_day(db_session, client, forec
assert forecast_value.target_datetime_utc < one_day_from_now


def test_get_forecast_many_sites_late_forecast_start(db_session, client, forecast_values, sites):
"""Test the case where the forecast stop working 1 day ago"""
site_uuids = [str(s.site_uuid) for s in sites]
site_uuids_str = ",".join(site_uuids)
one_day_from_now = datetime.utcnow() + timedelta(days=1)
start_utc = (datetime.utcnow() - timedelta(minutes=5)).isoformat()

with freeze_time(one_day_from_now):
resp = client.get(f"/sites/pv_forecast?site_uuids={site_uuids_str}&start_utc={start_utc}")
assert resp.status_code == 200

forecasts = [Forecast(**x) for x in resp.json()]

assert len(forecasts) == len(sites)
# We have 10 forecasts
assert len(forecasts[0].forecast_values) == 11


def test_get_forecast_many_sites_late_forecast_end(db_session, client, forecast_values, sites):
"""Test the case where the forecast stop working 1 day ago"""
site_uuids = [str(s.site_uuid) for s in sites]
site_uuids_str = ",".join(site_uuids)
one_day_from_now = datetime.utcnow() + timedelta(days=1)
end_utc = (datetime.utcnow() - timedelta(minutes=5)).isoformat()

with freeze_time(one_day_from_now):
resp = client.get(f"/sites/pv_forecast?site_uuids={site_uuids_str}&end_utc={end_utc}")
assert resp.status_code == 200

forecasts = [Forecast(**x) for x in resp.json()]

assert len(forecasts) == len(sites)
# We have 10 forecasts
assert len(forecasts[0].forecast_values) == 9


def test_get_forecast_many_sites_late_forecast_one_day_compact(
db_session, client, forecast_values, sites
):
Expand Down
31 changes: 30 additions & 1 deletion tests/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import uuid
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone

from pvsite_datamodel.pydantic_models import GenerationSum
from pvsite_datamodel.sqlmodels import GenerationSQL
Expand Down Expand Up @@ -84,6 +84,35 @@ def test_pv_actual_many_sites_dno(client, sites, generations):
assert len(pv_actuals) == 30


def test_pv_actual_many_sites_start(client, sites, generations):
site_uuids = [str(s.site_uuid) for s in sites]
site_uuid_str = ",".join(site_uuids)
start_utc = (datetime.today() - timedelta(minutes=5)).isoformat()

resp = client.get(f"/sites/pv_actual?site_uuids={site_uuid_str}&start_utc={start_utc}")

assert resp.status_code == 200

pv_actuals = [MultiplePVActual(**x) for x in resp.json()]
assert len(pv_actuals) == len(sites)
assert len(pv_actuals[0].pv_actual_values) == 5


def test_pv_actual_many_sites_end(client, sites, generations):
site_uuids = [str(s.site_uuid) for s in sites]
site_uuid_str = ",".join(site_uuids)
end_utc = (datetime.today()).isoformat()

resp = client.get(f"/sites/pv_actual?site_uuids={site_uuid_str}&end_utc={end_utc}")

assert resp.status_code == 200

pv_actuals = [MultiplePVActual(**x) for x in resp.json()]
assert len(pv_actuals) == len(sites)
# only 5 generations are later than now, the other 5 all stop before now
assert len(pv_actuals[0].pv_actual_values) == 5


def test_pv_actual_many_sites_gsp(client, sites, generations):
site_uuids = [str(s.site_uuid) for s in sites]
site_uuid_str = ",".join(site_uuids)
Expand Down

0 comments on commit ed26213

Please sign in to comment.