Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from fastapi import APIRouter, Depends

from app.classify import classify, detect
from app.forecast import daily_flare_forecast
from app.forecast import pt_daily_flare_forecast, ts_flare_forecast
from app.schemas import (
ARCutoutClassificationInput,
ARCutoutClassificationResult,
ARDetection,
ARDetectionInput,
FlareForecast,
TSFlareForecast, PTFlareForecast,
)

classification_router = APIRouter()
Expand Down Expand Up @@ -80,21 +80,33 @@ async def full_disk_detection_post(
return _perform_detection(request)


@forecast_router.get("/flare_forecast", tags=["Flare Forecast"])
async def flare_forecast_get(
@forecast_router.get("/pt_flare_forecast", tags=["Flare Forecast"])
async def pt_flare_forecast_get(
request: ARDetectionInput = Depends(),
) -> FlareForecast:
) -> PTFlareForecast:
r"""
Flare forecast for next 24 hours
"""
forecast_result = daily_flare_forecast(request.time)
forecast_result = pt_daily_flare_forecast(request.time)
return forecast_result


@forecast_router.post("/flare_forecast", tags=["Flare Forecast"])
async def flare_forecast_post(request: ARDetectionInput) -> FlareForecast:
@forecast_router.post("/pt_flare_forecast", tags=["Flare Forecast"])
async def pt_flare_forecast_post(request: ARDetectionInput) -> PTFlareForecast:
r"""
Flare forecast for next 24 hours
"""
forecast_result = daily_flare_forecast(request.time)
forecast_result = pt_daily_flare_forecast(request.time)
return forecast_result


@forecast_router.get("/ts_flare_forecast", tags=["Flare Forecast"])
async def ts_flare_forecast_get(request: ARDetectionInput = Depends()) -> TSFlareForecast:
forecast_result = ts_flare_forecast(request.time)
return forecast_result


@forecast_router.post("/ts_flare_forecast", tags=["Flare Forecast"])
async def ts_flare_forecast_post(request: ARDetectionInput) -> TSFlareForecast:
forecast_result = ts_flare_forecast(request.time)
return forecast_result
35 changes: 29 additions & 6 deletions app/forecast.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from datetime import datetime
from datetime import datetime, timezone
from typing import List

from app.schemas import FlareForecast, ARFlareForecast
from app.schemas import PTFlareForecast, ARPTFlareForecast, ARFlareProbability, TSFlareForecast, TSARFlareForecast


def daily_flare_forecast(time: datetime) -> FlareForecast:
mock_forecast = FlareForecast(
def pt_daily_flare_forecast(time: datetime) -> PTFlareForecast:
mock_forecast = PTFlareForecast(
timestamp=time,
forecasts=[
ARFlareForecast(
ARPTFlareForecast(
noaa=13664,
c=0.45,
m=0.25,
x=0.10,
),
ARFlareForecast(
ARPTFlareForecast(
noaa=13666,
c=0.50,
m=0.30,
Expand All @@ -24,3 +24,26 @@ def daily_flare_forecast(time: datetime) -> FlareForecast:
)

return mock_forecast

def ts_flare_forecast(time: datetime) -> TSFlareForecast:
# Single AR forecast — high activity, 1-hour steps over 24 hours
mock_forecast = TSFlareForecast(
timestamp=datetime(2024, 11, 1, 6, 0, 0, tzinfo=timezone.utc),
step_minutes=60,
forecasts=[
TSARFlareForecast(
noaa=13490,
probabilities=[
ARFlareProbability(
offset_minutes=offset,
c=round(0.80 - i * 0.01, 3),
m=round(0.40 - i * 0.005, 3),
x=round(0.08 - i * 0.001, 3),
)
for i, offset in enumerate(range(0, 25 * 60, 60)) # 0, 60, 120, ... 1440
],
)
],
)

return mock_forecast
77 changes: 70 additions & 7 deletions app/schemas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import datetime, timedelta
from typing import List, Optional

from fastapi import Query
Expand All @@ -13,9 +13,11 @@
"HeliographicStonyhurstCoordinate",
"BoundingBox",
"ARDetection",
"FlareForecast",
"DailyFlareForecast",
"ActiveRegionForecast",
"ARPTFlareForecast",
"PTFlareForecast",
"ARFlareProbability",
"TSARFlareForecast",
"TSFlareForecast",
]


Expand Down Expand Up @@ -114,7 +116,7 @@ class ARDetection(BaseModel):
confidence: float = Field(title="Confidence", example="0.90")


class ARFlareForecast(BaseModel):
class ARPTFlareForecast(BaseModel):
noaa: int = Field(..., gt=0, description="Positive NOAA active region number")
c: float = Field(..., ge=0.0, le=1.0, description="C-class flare probability")
m: float = Field(..., ge=0.0, le=1.0, description="M-class flare probability")
Expand All @@ -127,6 +129,67 @@ def check_flare_hierarchy(self):
return self


class FlareForecast(BaseModel):
class PTFlareForecast(BaseModel):
timestamp: datetime = Field(..., description="Forecast timestamp (UTC)")
forecasts: List[ARFlareForecast]
forecasts: List[ARPTFlareForecast]


class ARFlareProbability(BaseModel):
"""Flare probabilities at a single point in the forecast horizon."""
offset_minutes: int = Field(..., ge=0, description="Minutes from forecast timestamp")
c: float = Field(..., ge=0.0, le=1.0, description="C-class flare probability")
m: float = Field(..., ge=0.0, le=1.0, description="M-class flare probability")
x: float = Field(..., ge=0.0, le=1.0, description="X-class flare probability")

@model_validator(mode="after")
def check_flare_hierarchy(self):
if not (self.x <= self.m <= self.c):
raise ValueError("Flare probabilities must satisfy: x <= m <= c")
return self


class TSARFlareForecast(BaseModel):
noaa: int = Field(..., gt=0, description="Positive NOAA active region number")
probabilities: List[ARFlareProbability] = Field(
..., min_length=1, description="Time-ordered probability series"
)

@model_validator(mode="after")
def check_offsets_ordered_and_unique(self):
offsets = [p.offset_minutes for p in self.probabilities]
if offsets != sorted(set(offsets)):
raise ValueError("offset_minutes must be strictly increasing and unique")
return self

def at_offset(self, offset_minutes: int) -> ARFlareProbability | None:
return next((p for p in self.probabilities if p.offset_minutes == offset_minutes), None)

@property
def horizon_minutes(self) -> int:
return self.probabilities[-1].offset_minutes


class TSFlareForecast(BaseModel):
timestamp: datetime = Field(..., description="Forecast timestamp (UTC)")
step_minutes: int = Field(..., gt=0, description="Expected interval between steps (informational)")
forecasts: List[TSARFlareForecast]

@model_validator(mode="after")
def check_consistent_offsets(self):
if not self.forecasts:
return self
reference = [p.offset_minutes for p in self.forecasts[0].probabilities]
for ar in self.forecasts[1:]:
if [p.offset_minutes for p in ar.probabilities] != reference:
raise ValueError(
f"AR {ar.noaa} has different time offsets than AR {self.forecasts[0].noaa}"
)
return self

def absolute_times(self) -> List[datetime]:
if not self.forecasts:
return []
return [
self.timestamp + timedelta(minutes=p.offset_minutes)
for p in self.forecasts[0].probabilities
]