2020from http .cookies import Morsel
2121
2222from tornado import escape , httputil , web
23- from traitlets import Bool , Dict , Type , Unicode , default
23+ from traitlets import Bool , Dict , Enum , List , TraitError , Type , Unicode , default , validate
2424from traitlets .config import LoggingConfigurable
2525
2626from jupyter_server .transutils import _i18n
3131_non_alphanum = re .compile (r"[^A-Za-z0-9]" )
3232
3333
34+ # Define the User properties that can be updated
35+ UpdatableField = t .Literal ["name" , "display_name" , "initials" , "avatar_url" , "color" ]
36+
37+
3438@dataclass
3539class User :
3640 """Object representing a User
@@ -188,6 +192,14 @@ class IdentityProvider(LoggingConfigurable):
188192 help = _i18n ("The logout handler class to use." ),
189193 )
190194
195+ # Define the fields that can be updated
196+ updatable_fields = List (
197+ trait = Enum (list (t .get_args (UpdatableField ))),
198+ default_value = ["color" ], # Default updatable field
199+ config = True ,
200+ help = _i18n ("List of fields in the User model that can be updated." ),
201+ )
202+
191203 token_generated = False
192204
193205 @default ("token" )
@@ -207,6 +219,18 @@ def _token_default(self):
207219 self .token_generated = True
208220 return binascii .hexlify (os .urandom (24 )).decode ("ascii" )
209221
222+ @validate ("updatable_fields" )
223+ def _validate_updatable_fields (self , proposal ):
224+ """Validate that all fields in updatable_fields are valid."""
225+ valid_updatable_fields = list (t .get_args (UpdatableField ))
226+ invalid_fields = [
227+ field for field in proposal ["value" ] if field not in valid_updatable_fields
228+ ]
229+ if invalid_fields :
230+ msg = f"Invalid fields in updatable_fields: { invalid_fields } "
231+ raise TraitError (msg )
232+ return proposal ["value" ]
233+
210234 need_token : bool | Bool [bool , t .Union [bool , int ]] = Bool (True )
211235
212236 def get_user (self , handler : web .RequestHandler ) -> User | None | t .Awaitable [User | None ]:
@@ -269,6 +293,31 @@ async def _get_user(self, handler: web.RequestHandler) -> User | None:
269293
270294 return user
271295
296+ def update_user (
297+ self , handler : web .RequestHandler , user_data : dict [UpdatableField , str ]
298+ ) -> User :
299+ """Update user information and persist the user model."""
300+ self .check_update (user_data )
301+ current_user = t .cast (User , handler .current_user )
302+ updated_user = self .update_user_model (current_user , user_data )
303+ self .persist_user_model (handler )
304+ return updated_user
305+
306+ def check_update (self , user_data : dict [UpdatableField , str ]) -> None :
307+ """Raises if some fields to update are not updatable."""
308+ for field in user_data :
309+ if field not in self .updatable_fields :
310+ msg = f"Field { field } is not updatable"
311+ raise ValueError (msg )
312+
313+ def update_user_model (self , current_user : User , user_data : dict [UpdatableField , str ]) -> User :
314+ """Update user information."""
315+ raise NotImplementedError
316+
317+ def persist_user_model (self , handler : web .RequestHandler ) -> None :
318+ """Persist the user model (i.e. a cookie)."""
319+ raise NotImplementedError
320+
272321 def identity_model (self , user : User ) -> dict [str , t .Any ]:
273322 """Return a User as an Identity model"""
274323 # TODO: validate?
@@ -617,6 +666,16 @@ class PasswordIdentityProvider(IdentityProvider):
617666 def _need_token_default (self ):
618667 return not bool (self .hashed_password )
619668
669+ @default ("updatable_fields" )
670+ def _default_updatable_fields (self ):
671+ return [
672+ "name" ,
673+ "display_name" ,
674+ "initials" ,
675+ "avatar_url" ,
676+ "color" ,
677+ ]
678+
620679 @property
621680 def login_available (self ) -> bool :
622681 """Whether a LoginHandler is needed - and therefore whether the login page should be displayed."""
@@ -627,6 +686,17 @@ def auth_enabled(self) -> bool:
627686 """Return whether any auth is enabled"""
628687 return bool (self .hashed_password or self .token )
629688
689+ def update_user_model (self , current_user : User , user_data : dict [UpdatableField , str ]) -> User :
690+ """Update user information."""
691+ for field in self .updatable_fields :
692+ if field in user_data :
693+ setattr (current_user , field , user_data [field ])
694+ return current_user
695+
696+ def persist_user_model (self , handler : web .RequestHandler ) -> None :
697+ """Persist the user model to a cookie."""
698+ self .set_login_cookie (handler , handler .current_user )
699+
630700 def passwd_check (self , password ):
631701 """Check password against our stored hashed password"""
632702 return passwd_check (self .hashed_password , password )
0 commit comments