Skip to content

Commit

Permalink
Merge pull request #220 from voynow/219-speed-up-new-plan-generation
Browse files Browse the repository at this point in the history
219 speed up new plan generation
  • Loading branch information
voynow authored Jan 31, 2025
2 parents 17767d5 + e5c6b8e commit b5a5fcb
Show file tree
Hide file tree
Showing 12 changed files with 758 additions and 224 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,5 @@ tests/artifacts
.terraform.lock.hcl
*.tfvars
*.tfvars.json

observe.jsonl
2 changes: 2 additions & 0 deletions api/src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
DEFAULT_ATHLETE_ID = -1
DEFAULT_USER_ID = "default"
DEFAULT_JWT_TOKEN = "default"

OBSERVE_FILE = "observe.jsonl"
63 changes: 53 additions & 10 deletions api/src/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,66 @@
from typing import Dict, List, Optional, Type

from dotenv import load_dotenv
from openai import OpenAI
from openai import AsyncOpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from pydantic import BaseModel, ValidationError
from src.constants import OBSERVE_FILE

load_dotenv()
client = OpenAI()
client = AsyncOpenAI()


def _get_completion(
def observe(
generation_name: str,
messages: List[ChatCompletionMessage],
response: ChatCompletion,
duration: float,
):
with open(OBSERVE_FILE, "a") as f:
f.write(
json.dumps(
{
"generation_name": generation_name,
"messages": [message["content"] for message in messages],
"response_id": response.id,
"content": response.choices[0].message.content,
"model": response.model,
"completion_tokens": response.usage.completion_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"total_tokens": response.usage.total_tokens,
"duration": duration,
}
)
+ "\n"
)


async def _get_completion(
messages: List[ChatCompletionMessage],
model: str = "gpt-4o",
response_format: Optional[Dict] = None,
generation_name: Optional[str] = None,
):
response = client.chat.completions.create(
start_time = time.time()
response = await client.chat.completions.create(
model=model, messages=messages, response_format=response_format
)
duration = time.time() - start_time
observe(
generation_name=generation_name,
messages=messages,
response=response,
duration=duration,
)

return response.choices[0].message.content


def get_completion(
async def get_completion(
message: str,
model: str = "gpt-4o",
model: Optional[str] = "gpt-4o",
generation_name: Optional[str] = None,
):
"""
LLM completion with raw string response
Expand All @@ -34,15 +72,18 @@ def get_completion(
:return: The raw string response from the LLM.
"""
messages = [{"role": "user", "content": message}]
return _get_completion(messages=messages, model=model)
return await _get_completion(
messages=messages, model=model, generation_name=generation_name
)


def get_completion_json(
async def get_completion_json(
message: str,
response_model: Type[BaseModel],
model: str = "gpt-4o",
max_retries: int = 3,
retry_delay: float = 1.0,
generation_name: Optional[str] = None,
) -> BaseModel:
"""
Get a JSON completion from the LLM and parse it into a Pydantic model.
Expand All @@ -61,17 +102,19 @@ def get_completion_json(
messages = [
{
"role": "system",
"content": f"You are a helpful assistant designed to output JSON. {response_model_content}",
"content": f"You are a helpful assistant designed to output JSON. Do not use newline characters or spaces for json formatting. {response_model_content}",
},
{"role": "user", "content": message},
]

response_str = "Completion failed."
for attempt in range(max_retries):
try:
response_str = _get_completion(
response_str = await _get_completion(
model=model,
messages=messages,
response_format={"type": "json_object"},
generation_name=generation_name,
)
response = json.loads(response_str)
return response_model(**response)
Expand Down
53 changes: 22 additions & 31 deletions api/src/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import os
import time
from typing import Callable, Optional

from fastapi import (
Expand All @@ -18,10 +17,9 @@
from src.types.feedback import FeedbackRow
from src.types.training_plan import TrainingPlan
from src.types.training_week import FullTrainingWeek
from src.types.update_pipeline import ExeType
from src.types.user import User
from src.types.webhook import StravaEvent
from src.update_pipeline import update_all_users, update_training_week
from src.update_pipeline import refresh_user_data, update_all_users

app = FastAPI()

Expand Down Expand Up @@ -89,25 +87,25 @@ async def update_preferences(
return {"success": True}


@app.get("/profile/")
async def get_profile(user: User = Depends(auth_manager.validate_user)) -> dict:
"""
Retrieve user profile information including Strava details
# @app.get("/profile/")
# async def get_profile(user: User = Depends(auth_manager.validate_user)) -> dict:
# """
# Retrieve user profile information including Strava details

:param user: The authenticated user
:return: Dictionary containing profile information
"""
athlete = auth_manager.get_strava_client(user.athlete_id).get_athlete()
return {
"success": True,
"profile": {
"firstname": athlete.firstname,
"lastname": athlete.lastname,
"profile": athlete.profile,
"email": user.email,
"preferences": user.preferences.json(),
},
}
# :param user: The authenticated user
# :return: Dictionary containing profile information
# """
# athlete = auth_manager.get_strava_client(user.athlete_id).get_athlete()
# return {
# "success": True,
# "profile": {
# "firstname": athlete.firstname,
# "lastname": athlete.lastname,
# "profile": athlete.profile,
# "email": user.email,
# "preferences": user.preferences.json(),
# },
# }


@app.get("/v2/profile/")
Expand Down Expand Up @@ -194,7 +192,7 @@ async def strava_webhook(request: Request, background_tasks: BackgroundTasks) ->


@app.post("/refresh/")
async def refresh_user_data(
async def refresh(
user: User = Depends(auth_manager.validate_user),
) -> dict:
"""
Expand All @@ -203,14 +201,7 @@ async def refresh_user_data(
:param user: The authenticated user
:return: Success status
"""
start_time = time.time()
update_training_week(user, ExeType.NEW_WEEK, dt=utils.get_last_sunday())
print(f"New week update time: {time.time() - start_time:.2f} seconds")

start_time = time.time()
update_training_week(user, ExeType.MID_WEEK, dt=utils.datetime_now_est())
print(f"Mid-week update time: {time.time() - start_time:.2f} seconds")

await refresh_user_data(user)
return {"success": True}


Expand All @@ -224,7 +215,7 @@ async def update_all_users_trigger(request: Request) -> dict:
if api_key != os.environ["API_KEY"]:
raise HTTPException(status_code=403, detail="Invalid API key")

update_all_users()
await update_all_users()
return {"success": True}


Expand Down
32 changes: 7 additions & 25 deletions api/src/mileage_recommendation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import logging
from typing import List, Tuple
from typing import List

from src import activities, supabase_client
from src.training_plan import gen_training_plan_pipeline
Expand All @@ -16,7 +16,7 @@
logger.setLevel(logging.INFO)


def gen_mileage_rec_wrapper(
async def gen_mileage_rec_wrapper(
user: User, daily_activity: List[DailyActivity], dt: datetime.datetime
) -> MileageRecommendation:
"""
Expand All @@ -34,7 +34,7 @@ def gen_mileage_rec_wrapper(
)

weekly_summaries = activities.get_weekly_summaries(daily_activity=daily_activity)
training_plan = gen_training_plan_pipeline(
training_plan = await gen_training_plan_pipeline(
user=user, weekly_summaries=weekly_summaries, dt=dt
)
next_week_plan = training_plan.training_plan_weeks[0]
Expand All @@ -45,7 +45,7 @@ def gen_mileage_rec_wrapper(
)


def create_new_mileage_recommendation(
async def create_new_mileage_recommendation(
user: User,
daily_activity: List[DailyActivity],
dt: datetime.datetime,
Expand All @@ -58,7 +58,7 @@ def create_new_mileage_recommendation(
:param dt: datetime injection, helpful for testing
:return: mileage recommendation entity
"""
mileage_recommendation = gen_mileage_rec_wrapper(
mileage_recommendation = await gen_mileage_rec_wrapper(
user=user, daily_activity=daily_activity, dt=dt
)
tomorrow = dt + datetime.timedelta(days=1)
Expand All @@ -77,25 +77,7 @@ def create_new_mileage_recommendation(
return mileage_recommendation


def get_week_of_year_and_year(
exe_type: ExeType, dt: datetime.datetime
) -> Tuple[int, int]:
"""
Get the week of year and year for the given datetime. We run NEW_WEEK on
Sunday night, so in order to upsert
:param exe_type: ExeType object
:param dt: datetime object
:return: Tuple[int, int]
"""
if exe_type == ExeType.NEW_WEEK:
tomorrow = dt + datetime.timedelta(days=1)
return tomorrow.isocalendar().week, tomorrow.isocalendar().year
else:
return


def get_or_gen_mileage_recommendation(
async def get_or_gen_mileage_recommendation(
user: User,
daily_activity: List[DailyActivity],
exe_type: ExeType,
Expand All @@ -111,7 +93,7 @@ def get_or_gen_mileage_recommendation(
:return: mileage recommendation entity
"""
if exe_type == ExeType.NEW_WEEK:
return create_new_mileage_recommendation(
return await create_new_mileage_recommendation(
user=user, daily_activity=daily_activity, dt=dt
)
else:
Expand Down
Loading

0 comments on commit b5a5fcb

Please sign in to comment.