Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Teo Koon Peng <[email protected]>
  • Loading branch information
koonpeng committed Jul 4, 2024
1 parent 966ec5c commit 3a033e9
Show file tree
Hide file tree
Showing 33 changed files with 236 additions and 208 deletions.
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ disable=raw-checker-failed,
line-too-long,
global-statement,
unnecessary-dunder-call,
too-many-lines,

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
3 changes: 0 additions & 3 deletions packages/api-server/api_server/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import logging
import os

import uvicorn
import uvicorn.logging
from uvicorn.config import LOGGING_CONFIG
Expand Down
19 changes: 12 additions & 7 deletions packages/api-server/api_server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Callable, Coroutine, Union

import schedule
from fastapi import Depends, FastAPI
from fastapi import Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import (
get_redoc_html,
Expand All @@ -23,7 +23,7 @@
from . import gateway, ros, routes
from .app_config import app_config
from .authenticator import AuthenticationError, authenticator, user_dep
from .fast_io import FastIO, SioSession
from .fast_io import FastIO
from .logging import default_logger
from .models import DispenserState, DoorState, IngestorState, LiftState, User
from .models import tortoise_models as ttm
Expand All @@ -32,7 +32,7 @@
from .types import is_coroutine


async def on_sio_connect(sid: str, environ: dict, auth: dict | None = None):
async def on_sio_connect(_sid: str, _environ: dict, auth: dict | None = None):
token = None
if auth:
token = auth["token"]
Expand Down Expand Up @@ -90,9 +90,8 @@ async def lifespan(_app: FastIO):
cached_files, ros_node.node, rmf_events, RmfRepository(User.get_system_user())
)
await stack.enter_async_context(gateway.RmfGateway.set_instance(rmf_gateway))
await stack.enter_async_context(
TasksService.set_instance(TasksService(ros_node.node))
)
tasks_service = TasksService(ros_node.node)
await stack.enter_async_context(TasksService.set_instance(tasks_service))

# shutdown event is not called when the app crashes, this can cause the app to be
# "locked up" as some dependencies like tortoise does not allow python to exit until
Expand Down Expand Up @@ -126,7 +125,13 @@ def on_signal(sig, frame):
default_logger.warning(f"user [{t.created_by}] does not exist")
continue
task_repo = TaskRepository(user, default_logger)
await routes.scheduled_tasks.schedule_task(t, task_repo, default_logger)
await routes.scheduled_tasks.schedule_task(
t,
RmfRepository(User.get_system_user()),
task_repo,
tasks_service,
default_logger,
)
scheduled += 1
default_logger.info(f"loaded {scheduled} tasks")
default_logger.info("successfully started scheduler")
Expand Down
10 changes: 5 additions & 5 deletions packages/api-server/api_server/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import urllib.parse
from dataclasses import dataclass
from importlib.abc import Loader
from typing import Any, List, Optional, cast
from typing import Any, cast


@dataclass
Expand All @@ -16,11 +16,11 @@ class AppConfig:
cache_directory: str
log_level: str
builtin_admin: str
jwt_public_key: Optional[str]
oidc_url: Optional[str]
jwt_public_key: str | None
oidc_url: str | None
aud: str
iss: Optional[str]
ros_args: List[str]
iss: str | None
ros_args: list[str]

def __post_init__(self):
self.public_url = urllib.parse.urlparse(cast(str, self.public_url))
Expand Down
12 changes: 6 additions & 6 deletions packages/api-server/api_server/authenticator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import json
import logging
from typing import Any, Callable, Coroutine, Optional, Protocol, Union
from typing import Any, Callable, Coroutine, Protocol

import jwt
from fastapi import Depends, Header, HTTPException
Expand All @@ -16,9 +16,9 @@ class AuthenticationError(Exception):


class Authenticator(Protocol):
async def verify_token(self, token: Optional[str]) -> User: ...
async def verify_token(self, token: str | None) -> User: ...

def fastapi_dep(self) -> Callable[..., Union[Coroutine[Any, Any, User], User]]: ...
def fastapi_dep(self) -> Callable[..., Coroutine[Any, Any, User] | User]: ...


class JwtAuthenticator:
Expand Down Expand Up @@ -62,7 +62,7 @@ async def _get_user(self, claims: dict) -> User:

return user

async def verify_token(self, token: Optional[str]) -> User:
async def verify_token(self, token: str | None) -> User:
if not token:
raise AuthenticationError("authentication required")
try:
Expand All @@ -79,7 +79,7 @@ async def verify_token(self, token: Optional[str]) -> User:
except jwt.InvalidTokenError as e:
raise AuthenticationError(str(e)) from e

def fastapi_dep(self) -> Callable[..., Union[Coroutine[Any, Any, User], User]]:
def fastapi_dep(self) -> Callable[..., Coroutine[Any, Any, User] | User]:
async def dep(
auth_header: str = Depends(OpenIdConnect(openIdConnectUrl=self.oidc_url)),
):
Expand All @@ -101,7 +101,7 @@ class StubAuthenticator(Authenticator):
WITHOUT verifying the signature and authenticated as the user given.
"""

async def verify_token(self, token: Optional[str]):
async def verify_token(self, token: str | None):
if not token:
return User(username="stub", is_admin=True)
# decode the jwt without verifying signature
Expand Down
3 changes: 1 addition & 2 deletions packages/api-server/api_server/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

from api_server import ros_time

from .fast_io import SubscriptionRequest
from .models import Pagination, User
from .models import Pagination


def pagination_query(
Expand Down
15 changes: 5 additions & 10 deletions packages/api-server/api_server/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import hashlib
import logging
from datetime import datetime
from typing import Any, List, Optional, cast
from typing import Any, cast

import rclpy
import rclpy.client
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(
rmf_events: RmfEvents,
rmf_repo: RmfRepository,
*,
logger: Optional[logging.Logger] = None,
logger: logging.Logger | None = None,
):
self._cached_files = cached_files
self._ros_node = ros_node
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(
),
)

self._subscriptions: List[Subscription] = []
self._subscriptions: list[Subscription] = []

self._subscribe_all()

Expand Down Expand Up @@ -300,11 +300,6 @@ def handle_delivery_alert(msg):
self._subscriptions.append(delivery_alert_request_sub)

def handle_fire_alarm_trigger(msg):
async def save(delivery_alert: DeliveryAlert):
# await self._rmf_repo.(delivery_alert)
self._rmf_events.delivery_alerts.on_next(delivery_alert)
logging.debug("%s", delivery_alert)

msg = cast(BoolMsg, msg)
if msg.data:
logging.info("Fire alarm triggered")
Expand Down Expand Up @@ -350,7 +345,7 @@ def request_lift(
destination: str,
request_type: int,
door_mode: int,
additional_session_ids: List[str],
additional_session_ids: list[str],
):
msg = RmfLiftRequest(
lift_name=lift_name,
Expand Down Expand Up @@ -386,7 +381,7 @@ def respond_to_delivery_alert(

def manual_release_mutex_groups(
self,
mutex_groups: List[str],
mutex_groups: list[str],
fleet: str,
robot: str,
):
Expand Down
1 change: 0 additions & 1 deletion packages/api-server/api_server/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import typing
from logging import LoggerAdapter

from fastapi import Depends
from fastapi.requests import HTTPConnection
from termcolor import colored
from termcolor._types import Color
Expand Down
2 changes: 1 addition & 1 deletion packages/api-server/api_server/models/lifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel, Field

from . import tortoise_models as ttm
from .ros_pydantic import builtin_interfaces, rmf_building_map_msgs, rmf_lift_msgs
from .ros_pydantic import builtin_interfaces, rmf_building_map_msgs

Lift = rmf_building_map_msgs.Lift

Expand Down
10 changes: 5 additions & 5 deletions packages/api-server/api_server/models/task_favorite.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Dict

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict


class TaskFavorite(BaseModel):
model_config = ConfigDict(from_attributes=True)

id: str
name: str
unix_millis_earliest_start_time: int
priority: Dict | None
priority: dict | None
category: str
description: Dict | None
description: dict | None
user: str
task_definition_id: str
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import os

from api_server.app_config import app_config
from api_server.fast_io.singleton_dep import SingletonDep


Expand Down
12 changes: 6 additions & 6 deletions packages/api-server/api_server/repositories/fleets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, List, Optional, Sequence, Tuple, cast
from typing import Annotated, Sequence, cast

from fastapi import Depends
from tortoise.exceptions import IntegrityError
Expand All @@ -20,20 +20,20 @@ def __init__(
self.user = user
self.logger = logger

async def get_all_fleets(self) -> List[FleetState]:
async def get_all_fleets(self) -> list[FleetState]:
db_states = await ttm.FleetState.all().values_list("data")
return [FleetState(**s[0]) for s in db_states]

async def get_fleet_state(self, name: str) -> Optional[FleetState]:
async def get_fleet_state(self, name: str) -> FleetState | None:
# TODO: enforce with authz
result = await ttm.FleetState.get_or_none(name=name)
if result is None:
return None
return FleetState(**cast(dict, result.data))

async def get_fleet_log(
self, name: str, between: Tuple[int, int]
) -> Optional[FleetLog]:
self, name: str, between: tuple[int, int]
) -> FleetLog | None:
"""
:param between: The period in unix millis to fetch.
"""
Expand All @@ -42,7 +42,7 @@ async def get_fleet_log(
"unix_millis_time__lte": between[1],
}
result = cast(
Optional[ttm.FleetLog],
ttm.FleetLog | None,
await ttm.FleetLog.get_or_none(name=name).prefetch_related(
Prefetch(
"log",
Expand Down
18 changes: 9 additions & 9 deletions packages/api-server/api_server/repositories/rmf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, List, Literal, Optional, cast
from typing import Annotated, Literal, cast

from fastapi import Depends
from tortoise.queryset import ValuesListQuery
Expand Down Expand Up @@ -49,13 +49,13 @@ async def save_building_map(self, building_map: BuildingMap) -> None:
{"data": building_map.model_dump()}, id_=building_map.name
)

async def get_doors(self) -> List[Door]:
async def get_doors(self) -> list[Door]:
building_map = await self.get_bulding_map()
if building_map is None:
return []
return [door for level in building_map.levels for door in level.doors]

async def get_door_state(self, door_name: str) -> Optional[DoorState]:
async def get_door_state(self, door_name: str) -> DoorState | None:
door_state = await ttm.DoorState.get_or_none(id_=door_name)
if door_state is None:
return None
Expand All @@ -66,13 +66,13 @@ async def save_door_state(self, door_state: DoorState) -> None:
{"data": door_state.model_dump()}, id_=door_state.door_name
)

async def get_lifts(self) -> List[Lift]:
async def get_lifts(self) -> list[Lift]:
building_map = await self.get_bulding_map()
if building_map is None:
return []
return building_map.lifts

async def get_lift_state(self, lift_name: str) -> Optional[LiftState]:
async def get_lift_state(self, lift_name: str) -> LiftState | None:
lift_state = await ttm.LiftState.get_or_none(id_=lift_name)
if lift_state is None:
return None
Expand All @@ -83,7 +83,7 @@ async def save_lift_state(self, lift_state: LiftState) -> None:
{"data": lift_state.model_dump()}, id_=lift_state.lift_name
)

async def get_dispensers(self) -> List[Dispenser]:
async def get_dispensers(self) -> list[Dispenser]:
states = await ttm.DispenserState.all()
return [Dispenser.model_validate(state.data) for state in states]

Expand All @@ -92,13 +92,13 @@ async def save_dispenser_state(self, dispenser_state: DispenserState) -> None:
{"data": dispenser_state.model_dump()}, id_=dispenser_state.guid
)

async def get_dispenser_state(self, guid: str) -> Optional[DispenserState]:
async def get_dispenser_state(self, guid: str) -> DispenserState | None:
dispenser_state = await ttm.DispenserState.get_or_none(id_=guid)
if dispenser_state is None:
return None
return DispenserState.model_validate(dispenser_state.data)

async def get_ingestors(self) -> List[Ingestor]:
async def get_ingestors(self) -> list[Ingestor]:
states = await ttm.IngestorState.all()
return [Ingestor.model_validate(state.data) for state in states]

Expand All @@ -107,7 +107,7 @@ async def save_ingestor_state(self, ingestor_state: IngestorState) -> None:
{"data": ingestor_state.model_dump()}, id_=ingestor_state.guid
)

async def get_ingestor_state(self, guid: str) -> Optional[IngestorState]:
async def get_ingestor_state(self, guid: str) -> IngestorState | None:
ingestor_state = await ttm.IngestorState.get_or_none(id_=guid)
if ingestor_state is None:
return None
Expand Down
Loading

0 comments on commit 3a033e9

Please sign in to comment.