Skip to content

Commit

Permalink
♻️ Is922/select default wallet and pricing plan in the backend (#4851)
Browse files Browse the repository at this point in the history
Co-authored-by: Pedro Crespo <[email protected]>
  • Loading branch information
matusdrobuliak66 and pcrespov authored Oct 13, 2023
1 parent a422201 commit 6b9e988
Show file tree
Hide file tree
Showing 12 changed files with 164 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from models_library.clusters import ClusterID
from models_library.projects import ProjectID
from models_library.users import UserID
from models_library.wallets import WalletInfo
from models_library.wallets import WalletID, WalletInfo
from pydantic import BaseModel, Field, ValidationError, parse_obj_as
from pydantic.types import NonNegativeInt
from servicelib.aiohttp.rest_responses import create_error_response, get_http_error
Expand All @@ -26,8 +26,10 @@
from .._meta import API_VTAG as VTAG
from ..application_settings import get_settings
from ..login.decorators import login_required
from ..products import api as products_api
from ..projects import api as projects_api
from ..security.decorators import permission_required
from ..users import preferences_api as user_preferences_api
from ..utils_aiohttp import envelope_json_response
from ..version_control.models import CommitID
from ..wallets import api as wallets_api
Expand Down Expand Up @@ -98,18 +100,41 @@ async def start_computation(request: web.Request) -> web.Response:

# Get wallet information
wallet_info = None
project_wallet = await projects_api.get_project_wallet(
request.app, project_id=project_id
)
product = products_api.get_current_product(request)
app_settings = get_settings(request.app)
if project_wallet and app_settings.WEBSERVER_CREDIT_COMPUTATION_ENABLED:
# Check whether user has access to the wallet
await wallets_api.get_wallet_by_user(
request.app, req_ctx.user_id, project_wallet.wallet_id, req_ctx.product_name
if product.is_payment_enabled and app_settings.WEBSERVER_CREDIT_COMPUTATION_ENABLED:
project_wallet = await projects_api.get_project_wallet(
request.app, project_id=project_id
)
wallet_info = WalletInfo(
wallet_id=project_wallet.wallet_id, wallet_name=project_wallet.name
if project_wallet is None:
user_default_wallet_preference = await user_preferences_api.get_user_preference(
request.app,
user_id=req_ctx.user_id,
product_name=req_ctx.product_name,
preference_class=user_preferences_api.PreferredWalletIdFrontendUserPreference,
)
if user_default_wallet_preference is None:
raise ValueError(
"User does not have default wallet - this should not happen"
)
project_wallet_id = parse_obj_as(
WalletID, user_default_wallet_preference.value
)
await projects_api.connect_wallet_to_project(
request.app,
product_name=req_ctx.product_name,
project_id=project_id,
user_id=req_ctx.user_id,
wallet_id=project_wallet_id,
)
else:
project_wallet_id = project_wallet.wallet_id

# Check whether user has access to the wallet
wallet = await wallets_api.get_wallet_by_user(
request.app, req_ctx.user_id, project_wallet_id, req_ctx.product_name
)
wallet_info = WalletInfo(wallet_id=project_wallet_id, wallet_name=wallet.name)

options = {
"start_pipeline": True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ async def get_product(self, product_name: str) -> Product | None:
)
row: RowProxy | None = await result.first()
if row:
# NOTE: MD Observation: Currently we are not defensive, we assume automatically
# that the product is not billable when there is no product in the products_prices table
# or it's price is 0. We should change it and always assume that the product is billable, unless
# explicitely stated that it is free
enabled = await is_payment_enabled(conn, product_name=row.name)
return Product(**dict(row.items()), is_payment_enabled=enabled)
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from ..login.decorators import login_required
from ..security.decorators import permission_required
from ..users.api import get_user_role
from ..users.exceptions import UserDefaultWalletNotFoundError
from ..utils_aiohttp import envelope_json_response
from . import projects_api
from ._common_models import ProjectPathParams, RequestContext
Expand All @@ -74,7 +75,11 @@ async def wrapper(request: web.Request) -> web.StreamResponse:
try:
return await handler(request)

except (ProjectNotFoundError, NodeNotFoundError) as exc:
except (
ProjectNotFoundError,
NodeNotFoundError,
UserDefaultWalletNotFoundError,
) as exc:
raise web.HTTPNotFound(reason=f"{exc}") from exc

return wrapper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON
from simcore_postgres_database.models.users import UserRole
from simcore_postgres_database.webserver_models import ProjectType
from simcore_service_webserver.users.exceptions import UserDefaultWalletNotFoundError
from simcore_service_webserver.utils_aiohttp import envelope_json_response

from .._meta import API_VTAG as VTAG
Expand Down Expand Up @@ -54,7 +55,7 @@ async def _wrapper(request: web.Request) -> web.StreamResponse:
try:
return await handler(request)

except ProjectNotFoundError as exc:
except (ProjectNotFoundError, UserDefaultWalletNotFoundError) as exc:
raise web.HTTPNotFound(reason=f"{exc}") from exc

except ProjectInvalidRightsError as exc:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from models_library.api_schemas_webserver.wallets import WalletGet
from models_library.products import ProductName
from models_library.projects import ProjectID
from models_library.wallets import WalletDB
from models_library.users import UserID
from models_library.wallets import WalletDB, WalletID

from ..wallets import _api as wallet_api
from .db import ProjectDBAPI


Expand All @@ -11,3 +14,25 @@ async def get_project_wallet(app, project_id: ProjectID):
wallet_db: WalletDB | None = await db.get_project_wallet(project_uuid=project_id)
wallet: WalletGet | None = WalletGet(**wallet_db.dict()) if wallet_db else None
return wallet


async def connect_wallet_to_project(
app,
*,
product_name: ProductName,
project_id: ProjectID,
user_id: UserID,
wallet_id: WalletID,
) -> WalletGet:
db: ProjectDBAPI = ProjectDBAPI.get_from_app_context(app)

# ensure the wallet can be used by the user
wallet: WalletGet = await wallet_api.get_wallet_by_user(
app,
user_id=user_id,
wallet_id=wallet_id,
product_name=product_name,
)

await db.connect_wallet_to_project(project_uuid=project_id, wallet_id=wallet_id)
return wallet
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@
from .._meta import API_VTAG
from ..login.decorators import login_required
from ..security.decorators import permission_required
from ..wallets import _api as wallet_api
from ..wallets.errors import WalletAccessForbiddenError
from . import _wallets_api as wallets_api
from . import projects_api
from ._common_models import ProjectPathParams, RequestContext
from .db import ProjectDBAPI
from .exceptions import ProjectNotFoundError

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -84,7 +82,6 @@ class Config:
@permission_required("project.wallet.*")
@_handle_project_wallet_exceptions
async def connect_wallet_to_project(request: web.Request):
db: ProjectDBAPI = ProjectDBAPI.get_from_app_context(request.app)
req_ctx = RequestContext.parse_obj(request)
path_params = parse_request_path_parameters_as(_ProjectWalletPathParams, request)

Expand All @@ -95,16 +92,12 @@ async def connect_wallet_to_project(request: web.Request):
user_id=req_ctx.user_id,
include_state=False,
)
# ensure the wallet can be used by the user
wallet: WalletGet = await wallet_api.get_wallet_by_user(
wallet: WalletGet = await wallets_api.connect_wallet_to_project(
request.app,
product_name=req_ctx.product_name,
project_id=path_params.project_id,
user_id=req_ctx.user_id,
wallet_id=path_params.wallet_id,
product_name=req_ctx.product_name,
)

await db.connect_wallet_to_project(
project_uuid=path_params.project_id, wallet_id=path_params.wallet_id
)

return envelope_json_response(wallet)
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

from ._permalink_api import ProjectPermalink
from ._permalink_api import register_factory as register_permalink_factory
from ._wallets_api import get_project_wallet
from ._wallets_api import connect_wallet_to_project, get_project_wallet

__all__: tuple[str, ...] = (
"register_permalink_factory",
"ProjectPermalink",
"get_project_wallet",
"connect_wallet_to_project",
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from models_library.services_resources import ServiceResourcesDict
from models_library.users import UserID
from models_library.utils.fastapi_encoders import jsonable_encoder
from models_library.wallets import WalletInfo
from models_library.wallets import WalletID, WalletInfo
from pydantic import parse_obj_as
from servicelib.aiohttp.application_keys import APP_FIRE_AND_FORGET_TASKS_KEY
from servicelib.common_headers import (
Expand All @@ -55,6 +55,7 @@
from ..application_settings import get_settings
from ..catalog import client as catalog_client
from ..director_v2 import api as director_v2_api
from ..products import api as products_api
from ..products.api import get_product_name
from ..redis import get_redis_lock_manager_client_sdk
from ..resource_manager.user_sessions import (
Expand All @@ -72,10 +73,15 @@
from ..storage import api as storage_api
from ..users.api import UserNameDict, get_user_name, get_user_role
from ..users.exceptions import UserNotFoundError
from ..users.preferences_api import (
PreferredWalletIdFrontendUserPreference,
UserDefaultWalletNotFoundError,
get_user_preference,
)
from ..wallets import api as wallets_api
from . import _crud_api_delete, _nodes_api
from ._nodes_utils import set_reservation_same_as_limit, validate_new_service_resources
from ._wallets_api import get_project_wallet
from ._wallets_api import connect_wallet_to_project, get_project_wallet
from .db import APP_PROJECT_DBAPI, ProjectDBAPI
from .exceptions import (
NodeNotFoundError,
Expand Down Expand Up @@ -266,15 +272,43 @@ async def _start_dynamic_service(

# Get wallet information
wallet_info = None
project_wallet = await get_project_wallet(request.app, project_id=project_uuid)
product = products_api.get_current_product(request)
app_settings = get_settings(request.app)
if project_wallet and app_settings.WEBSERVER_CREDIT_COMPUTATION_ENABLED:
if (
product.is_payment_enabled
and app_settings.WEBSERVER_CREDIT_COMPUTATION_ENABLED
):
project_wallet = await get_project_wallet(
request.app, project_id=project_uuid
)
if project_wallet is None:
user_default_wallet_preference = await get_user_preference(
request.app,
user_id=user_id,
product_name=product_name,
preference_class=PreferredWalletIdFrontendUserPreference,
)
if user_default_wallet_preference is None:
raise UserDefaultWalletNotFoundError(uid=user_id)
project_wallet_id = parse_obj_as(
WalletID, user_default_wallet_preference.value
)
await connect_wallet_to_project(
request.app,
product_name=product_name,
project_id=project_uuid,
user_id=user_id,
wallet_id=project_wallet_id,
)
else:
project_wallet_id = project_wallet.wallet_id

# Check whether user has access to the wallet
await wallets_api.get_wallet_by_user(
request.app, user_id, project_wallet.wallet_id, product_name
wallet = await wallets_api.get_wallet_by_user(
request.app, user_id, project_wallet_id, product_name
)
wallet_info = WalletInfo(
wallet_id=project_wallet.wallet_id, wallet_name=project_wallet.name
wallet_id=project_wallet_id, wallet_name=wallet.name
)

await director_v2_api.run_dynamic_service(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pydantic import NonNegativeInt, parse_obj_as
from servicelib.utils import logged_gather

from ._preferences_db import get_user_preference, set_user_preference
from . import _preferences_db
from ._preferences_models import (
ALL_FRONTEND_PREFERENCES,
get_preference_identifier_to_preference_name_map,
Expand All @@ -38,7 +38,7 @@ async def _get_frontend_user_preferences(
) -> list[FrontendUserPreference]:
saved_user_preferences: list[FrontendUserPreference | None] = await logged_gather(
*(
get_user_preference(
_preferences_db.get_user_preference(
app,
user_id=user_id,
product_name=product_name,
Expand All @@ -57,6 +57,20 @@ async def _get_frontend_user_preferences(
]


async def get_user_preference(
app: web.Application,
user_id: UserID,
product_name: ProductName,
preference_class: AnyUserPreference,
) -> AnyUserPreference | None:
return await _preferences_db.get_user_preference(
app,
user_id=user_id,
product_name=product_name,
preference_class=preference_class,
)


async def get_frontend_user_preferences_aggregation(
app: web.Application, *, user_id: UserID, product_name: ProductName
) -> AggregatedPreferences:
Expand Down Expand Up @@ -93,7 +107,7 @@ async def set_frontend_user_preference(
FrontendUserPreference.get_preference_class_from_name(preference_name),
)

await set_user_preference(
await _preferences_db.set_user_preference(
app,
user_id=user_id,
preference=parse_obj_as(preference_class, {"value": value}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,9 @@ class TokenNotFoundError(UsersException):
def __init__(self, service_id: str):
super().__init__(f"Token for service {service_id} not found")
self.service_id = service_id


class UserDefaultWalletNotFoundError(UsersException):
def __init__(self, uid: int | None = None):
super().__init__(f"Default wallet for user {uid} not found")
self.uid = uid
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from ._preferences_api import get_user_preference, set_frontend_user_preference
from ._preferences_models import PreferredWalletIdFrontendUserPreference
from .exceptions import UserDefaultWalletNotFoundError

__all__ = (
"get_user_preference",
"PreferredWalletIdFrontendUserPreference",
"set_frontend_user_preference",
"UserDefaultWalletNotFoundError",
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from servicelib.aiohttp.observer import register_observer, setup_observer_registry

from ..resource_usage.api import add_credits_to_wallet
from ..users import preferences_api
from ..users.api import get_user_name_and_email
from ._api import any_wallet_owned_by_user, create_wallet

Expand Down Expand Up @@ -43,6 +44,17 @@ async def _auto_add_default_wallet(
created_at=wallet.created,
)

preference_id = (
preferences_api.PreferredWalletIdFrontendUserPreference().preference_identifier
)
await preferences_api.set_frontend_user_preference(
app,
user_id=user_id,
product_name=product_name,
frontend_preference_identifier=preference_id,
value=wallet.wallet_id,
)


async def _on_user_confirmation(
app: web.Application,
Expand Down

0 comments on commit 6b9e988

Please sign in to comment.