Skip to content

Allow user to update identity values #1518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion jupyter_server/auth/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from http.cookies import Morsel

from tornado import escape, httputil, web
from traitlets import Bool, Dict, Type, Unicode, default
from traitlets import Bool, Dict, Enum, List, TraitError, Type, Unicode, default, validate
from traitlets.config import LoggingConfigurable

from jupyter_server.transutils import _i18n
Expand All @@ -31,6 +31,10 @@
_non_alphanum = re.compile(r"[^A-Za-z0-9]")


# Define the User properties that can be updated
UpdatableField = t.Literal["name", "display_name", "initials", "avatar_url", "color"]


@dataclass
class User:
"""Object representing a User
Expand Down Expand Up @@ -188,6 +192,14 @@ class IdentityProvider(LoggingConfigurable):
help=_i18n("The logout handler class to use."),
)

# Define the fields that can be updated
updatable_fields = List(
trait=Enum(list(t.get_args(UpdatableField))),
default_value=["color"], # Default updatable field
config=True,
help=_i18n("List of fields in the User model that can be updated."),
)

token_generated = False

@default("token")
Expand All @@ -207,6 +219,18 @@ def _token_default(self):
self.token_generated = True
return binascii.hexlify(os.urandom(24)).decode("ascii")

@validate("updatable_fields")
def _validate_updatable_fields(self, proposal):
"""Validate that all fields in updatable_fields are valid."""
valid_updatable_fields = list(t.get_args(UpdatableField))
invalid_fields = [
field for field in proposal["value"] if field not in valid_updatable_fields
]
if invalid_fields:
msg = f"Invalid fields in updatable_fields: {invalid_fields}"
raise TraitError(msg)
return proposal["value"]

need_token: bool | Bool[bool, t.Union[bool, int]] = Bool(True)

def get_user(self, handler: web.RequestHandler) -> User | None | t.Awaitable[User | None]:
Expand Down Expand Up @@ -269,6 +293,26 @@ async def _get_user(self, handler: web.RequestHandler) -> User | None:

return user

def update_user(
self, handler: web.RequestHandler, user_data: dict[UpdatableField, str]
) -> User:
"""Update user information."""
current_user = t.cast(User, handler.current_user)

for field in user_data:
if field not in self.updatable_fields:
msg = f"Field {field} is not updatable"
raise ValueError(msg)

# Update fields
for field in self.updatable_fields:
if field in user_data:
setattr(current_user, field, user_data[field])

# Persist changes (if applicable)
self.set_login_cookie(handler, current_user) # Save updated user to cookie/session
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation assumes that get_user is not overridden, which is the main thing IdentityProvider subclasses (e.g. JupyterHub) will do.

The result of this change as-is for JupyterHub will be that:

  • the PATCH request succeeds
  • an unused cookie is set
  • the user model returned by /api/me is not actually changed

Probably the right thing to happen for JupyterHub which overrides get_user but not this method (yet) is to hit a NotImplementedError. I see two ways to go about that:

  1. move this default implementation into the PasswordIdentityProvider, leaving the base class with NotImplementedError
  2. split the last set_login_cookie step to a persist_user_model method, so the field validation and user model persistence don't have to be overridden at the same time

This method really does 3 things:

  1. validate keys (override via config works as you have it, base class can define this and it should work for all subclasses)
  2. validate values (not currently possible without full override of update_user)
  3. persist changes (not currently possible with full override of update_user)

I think keeping it as a single method is okay, but that means the base class shouldn't have an implementation of it by default (It can still have it as a method, so subclasses can opt-in to default behavior). But if you want to slice it up so e.g. the key validation happens outside the typically overridden field, e.g.

def _update_user(self, ...):
    # check updatable_fields
    # only call update_user after validating
    return self.update_user(...) # responsible for persistence, possibly _value_ validation

that would reduce the duplication required by subclasses. If you want, validate_update_user could also be an overridable method.

return current_user

def identity_model(self, user: User) -> dict[str, t.Any]:
"""Return a User as an Identity model"""
# TODO: validate?
Expand Down Expand Up @@ -617,6 +661,16 @@ class PasswordIdentityProvider(IdentityProvider):
def _need_token_default(self):
return not bool(self.hashed_password)

@default("updatable_fields")
def _default_updatable_fields(self):
return [
"name",
"display_name",
"initials",
"avatar_url",
"color",
]

@property
def login_available(self) -> bool:
"""Whether a LoginHandler is needed - and therefore whether the login page should be displayed."""
Expand Down
26 changes: 24 additions & 2 deletions jupyter_server/services/api/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# Distributed under the terms of the Modified BSD License.
import json
import os
from typing import Any
from typing import Any, cast

from jupyter_core.utils import ensure_async
from tornado import web

from jupyter_server._tz import isoformat, utcfromtimestamp
from jupyter_server.auth.decorator import authorized
from jupyter_server.auth.identity import IdentityProvider, UpdatableField

from ...base.handlers import APIHandler, JupyterHandler

Expand Down Expand Up @@ -70,7 +71,7 @@ async def get(self):


class IdentityHandler(APIHandler):
"""Get the current user's identity model"""
"""Get or patch the current user's identity model"""

@web.authenticated
async def get(self):
Expand Down Expand Up @@ -110,9 +111,30 @@ async def get(self):
model = {
"identity": identity,
"permissions": permissions,
"updatable_fields": self.identity_provider.updatable_fields,
}
self.write(json.dumps(model))

@web.authenticated
async def patch(self):
"""Update user information."""
user_data = cast(dict[UpdatableField, str], self.get_json_body())
if not user_data:
raise web.HTTPError(400, "Invalid or missing JSON body")

# Update user information
identity_provider = self.settings["identity_provider"]
if not isinstance(identity_provider, IdentityProvider):
raise web.HTTPError(500, "Identity provider not configured properly")

try:
updated_user = identity_provider.update_user(self, user_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably pass self.current_user here rather than self, so it's not part of the update_user API how we store the current user on the Handler. Then pass `self on the end for IdentityProviders that need to use Handler for persistence (some won't, e.g. JupyterHub).

self.write(
{"status": "success", "identity": identity_provider.identity_model(updated_user)}
)
except ValueError as e:
raise web.HTTPError(400, str(e)) from e


default_handlers = [
(r"/api/spec.yaml", APISpecHandler),
Expand Down
86 changes: 84 additions & 2 deletions tests/services/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def test_identity(jp_fetch, identity, expected, identity_provider):

assert r.code == 200
response = json.loads(r.body.decode())
assert set(response.keys()) == {"identity", "permissions"}
assert set(response.keys()) == {"identity", "permissions", "updatable_fields"}
identity_model = response["identity"]
print(identity_model)
for key, value in expected.items():
Expand All @@ -117,6 +117,88 @@ async def test_identity(jp_fetch, identity, expected, identity_provider):
assert set(identity_model.keys()) == set(User.__dataclass_fields__)


@pytest.mark.parametrize("identity", [{"username": "user.username"}])
async def test_update_user_success(jp_fetch, identity, identity_provider):
"""Test successful user update."""
identity_provider.mock_user = MockUser(**identity)
payload = {
"color": "#000000",
}
r = await jp_fetch(
"/api/me",
method="PATCH",
body=json.dumps(payload),
headers={"Content-Type": "application/json"},
)
assert r.code == 200
response = json.loads(r.body.decode())
assert response["status"] == "success"
assert response["identity"]["color"] == "#000000"


@pytest.mark.parametrize("identity", [{"username": "user.username"}])
async def test_update_user_raise(jp_fetch, identity, identity_provider):
"""Test failing user update."""
identity_provider.mock_user = MockUser(**identity)
payload = {
"name": "Updated Name",
"color": "#000000",
}
with pytest.raises(HTTPError) as exc:
await jp_fetch(
"/api/me",
method="PATCH",
body=json.dumps(payload),
headers={"Content-Type": "application/json"},
)


@pytest.mark.parametrize(
"identity, expected",
[
(
{"username": "user.username"},
{
"username": "user.username",
"name": "Updated Name",
"display_name": "Updated Display Name",
"color": "#000000",
},
)
],
)
async def test_update_user_success_custom_updatable_fields(
jp_fetch, identity, expected, identity_provider
):
"""Test successful user update."""
identity_provider.mock_user = MockUser(**identity)
identity_provider.updatable_fields = ["name", "display_name", "color"]
payload = {
"name": expected["name"],
"display_name": expected["display_name"],
"color": expected["color"],
}
r = await jp_fetch(
"/api/me",
method="PATCH",
body=json.dumps(payload),
headers={"Content-Type": "application/json"},
)
assert r.code == 200
response = json.loads(r.body.decode())
identity_model = response["identity"]
for key, value in expected.items():
assert identity_model[key] == value

# Test GET request to ensure the updated fields are returned
r = await jp_fetch("api/me")
assert r.code == 200
response = json.loads(r.body.decode())
identity_model = response["identity"]
for key, value in expected.items():
assert identity_model[key] == value


@pytest.mark.parametrize(
"have_permissions, check_permissions, expected",
[
Expand Down Expand Up @@ -157,7 +239,7 @@ async def test_identity_permissions(
assert r is not None
assert r.code == 200
response = json.loads(r.body.decode())
assert set(response.keys()) == {"identity", "permissions"}
assert set(response.keys()) == {"identity", "permissions", "updatable_fields"}
assert response["permissions"] == expected


Expand Down