diff --git a/.github/workflows/migrations.yml b/.github/workflows/migrations.yml index 5adb3bc7ae..22bea85660 100644 --- a/.github/workflows/migrations.yml +++ b/.github/workflows/migrations.yml @@ -31,6 +31,11 @@ jobs: - name: Checkout code uses: actions/checkout@v4 + - name: Install mariadb connectors + run: | + sudo apt-get update + sudo apt-get install -y libmariadb3 libmariadb-dev + - name: Install uv uses: astral-sh/setup-uv@v6.7.0 diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 78c9879073..301e558d72 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -53,7 +53,6 @@ jobs: uses: actions/checkout@v4.3.0 - name: Install mariadb connectors - if: matrix.db == 'mariadb' run: | sudo apt-get update sudo apt-get install -y libmariadb3 libmariadb-dev diff --git a/backend/alembic/versions/0073_sync_sessions.py b/backend/alembic/versions/0073_sync_sessions.py index cd0b703124..79ccb10b82 100644 --- a/backend/alembic/versions/0073_sync_sessions.py +++ b/backend/alembic/versions/0073_sync_sessions.py @@ -21,7 +21,7 @@ def upgrade() -> None: connection = op.get_bind() if is_postgresql(connection): - rom_user_status_enum = ENUM( + sync_session_status_enum = ENUM( "PENDING", "IN_PROGRESS", "COMPLETED", @@ -30,9 +30,9 @@ def upgrade() -> None: name="syncsessionstatus", create_type=False, ) - rom_user_status_enum.create(connection, checkfirst=False) + sync_session_status_enum.create(connection, checkfirst=True) else: - rom_user_status_enum = sa.Enum( + sync_session_status_enum = sa.Enum( "PENDING", "IN_PROGRESS", "COMPLETED", @@ -48,14 +48,7 @@ def upgrade() -> None: sa.Column("user_id", sa.Integer(), nullable=False), sa.Column( "status", - sa.Enum( - "PENDING", - "IN_PROGRESS", - "COMPLETED", - "FAILED", - "CANCELLED", - name="syncsessionstatus", - ), + sync_session_status_enum, nullable=False, server_default="PENDING", ), @@ -86,7 +79,11 @@ def upgrade() -> None: "updated_at", sa.TIMESTAMP(timezone=True), nullable=False, - server_default=sa.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"), + server_default=( + sa.text("CURRENT_TIMESTAMP") + if is_postgresql(connection) + else sa.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + ), ), sa.ForeignKeyConstraint(["device_id"], ["devices.id"], ondelete="CASCADE"), sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), @@ -107,3 +104,7 @@ def downgrade() -> None: op.drop_index("ix_sync_sessions_device_id", table_name="sync_sessions") op.drop_table("sync_sessions") + + connection = op.get_bind() + if is_postgresql(connection): + ENUM(name="syncsessionstatus").drop(connection, checkfirst=True) diff --git a/backend/endpoints/responses/device.py b/backend/endpoints/responses/device.py index 22870ff52d..9ac0226fd6 100644 --- a/backend/endpoints/responses/device.py +++ b/backend/endpoints/responses/device.py @@ -1,9 +1,13 @@ -from pydantic import ConfigDict +from typing import Any + +from pydantic import ConfigDict, field_serializer from models.device import SyncMode from .base import BaseModel, UTCDatetime +SENSITIVE_SYNC_CONFIG_KEYS = {"ssh_password", "ssh_key_path"} + class DeviceSyncSchema(BaseModel): model_config = ConfigDict(from_attributes=True) @@ -34,6 +38,16 @@ class DeviceSchema(BaseModel): created_at: UTCDatetime updated_at: UTCDatetime + @field_serializer("sync_config") + @classmethod + def mask_sensitive_fields(cls, v: dict | None) -> dict[str, Any] | None: + if not v: + return v + return { + k: "********" if k in SENSITIVE_SYNC_CONFIG_KEYS else val + for k, val in v.items() + } + class DeviceCreateResponse(BaseModel): device_id: str diff --git a/backend/endpoints/responses/identity.py b/backend/endpoints/responses/identity.py index 674b2a8d7e..29a9b504ed 100644 --- a/backend/endpoints/responses/identity.py +++ b/backend/endpoints/responses/identity.py @@ -42,8 +42,9 @@ def from_orm_with_request( if not db_user: return None - db_user.current_device_id = request.session.get("device_id") # type: ignore - return cls.model_validate(db_user) + schema = cls.model_validate(db_user) + schema.current_device_id = request.session.get("device_id") + return schema class InviteLinkSchema(BaseModel): diff --git a/backend/endpoints/saves.py b/backend/endpoints/saves.py index 66eb616e30..07c4e1d042 100644 --- a/backend/endpoints/saves.py +++ b/backend/endpoints/saves.py @@ -96,6 +96,15 @@ def _resolve_device( return device +def _increment_session_counter(session_id: int) -> None: + try: + db_sync_session_handler.increment_operations_completed( + session_id=session_id, + ) + except Exception: + log.warning(f"Failed to update sync session {session_id}", exc_info=True) + + router = APIRouter( prefix="/saves", tags=["saves"], @@ -247,17 +256,7 @@ async def add_save( db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id) if session_id: - try: - session = db_sync_session_handler.get_session( - session_id=session_id, user_id=request.user.id - ) - if session: - db_sync_session_handler.update_session( - session_id=session_id, - data={"operations_completed": session.operations_completed + 1}, - ) - except Exception: - log.warning(f"Failed to update sync session {session_id}") + _increment_session_counter(session_id) if slot and autocleanup: slot_saves = db_save_handler.get_saves( @@ -454,17 +453,7 @@ def download_save( db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id) if session_id: - try: - session = db_sync_session_handler.get_session( - session_id=session_id, user_id=request.user.id - ) - if session: - db_sync_session_handler.update_session( - session_id=session_id, - data={"operations_completed": session.operations_completed + 1}, - ) - except Exception: - log.warning(f"Failed to update sync session {session_id}") + _increment_session_counter(session_id) return FileResponse(path=str(file_path), filename=save.file_name) diff --git a/backend/endpoints/sync.py b/backend/endpoints/sync.py index ff8cd35edf..bd659db7ce 100644 --- a/backend/endpoints/sync.py +++ b/backend/endpoints/sync.py @@ -1,10 +1,10 @@ from datetime import datetime from fastapi import HTTPException, Request, status -from pydantic import BaseModel from config import TASK_TIMEOUT from decorators.auth import protected_route +from endpoints.responses.base import BaseModel from endpoints.responses.sync import ( SyncNegotiateResponse, SyncOperationSchema, @@ -359,6 +359,7 @@ def trigger_push_pull( high_prio_queue.enqueue( "tasks.sync_push_pull_task.run_push_pull_sync", device_id=device.id, + session_id=sync_session.id, force=True, job_timeout=TASK_TIMEOUT, meta={ diff --git a/backend/handler/database/sync_sessions_handler.py b/backend/handler/database/sync_sessions_handler.py index 8fed624d56..cf9e666442 100644 --- a/backend/handler/database/sync_sessions_handler.py +++ b/backend/handler/database/sync_sessions_handler.py @@ -2,6 +2,7 @@ from datetime import datetime, timezone from sqlalchemy import select, update +from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import Session from decorators.database import begin_session @@ -75,7 +76,25 @@ def update_session( .values(**data) .execution_options(synchronize_session="evaluate") ) - return session.query(SyncSession).filter_by(id=session_id).one() + result = session.scalar(select(SyncSession).filter_by(id=session_id)) + if not result: + raise NoResultFound(f"SyncSession {session_id} not found after update") + return result + + @begin_session + def increment_operations_completed( + self, + session_id: int, + session: Session = None, # type: ignore + ) -> None: + session.execute( + update(SyncSession) + .where(SyncSession.id == session_id) + .values( + operations_completed=SyncSession.operations_completed + 1, + ) + .execution_options(synchronize_session="evaluate") + ) @begin_session def complete_session( @@ -96,7 +115,10 @@ def complete_session( ) .execution_options(synchronize_session="evaluate") ) - return session.query(SyncSession).filter_by(id=session_id).one() + result = session.scalar(select(SyncSession).filter_by(id=session_id)) + if not result: + raise NoResultFound(f"SyncSession {session_id} not found after complete") + return result @begin_session def fail_session( @@ -115,7 +137,10 @@ def fail_session( ) .execution_options(synchronize_session="evaluate") ) - return session.query(SyncSession).filter_by(id=session_id).one() + result = session.scalar(select(SyncSession).filter_by(id=session_id)) + if not result: + raise NoResultFound(f"SyncSession {session_id} not found after fail") + return result @begin_session def cancel_active_sessions( diff --git a/backend/handler/filesystem/sync_handler.py b/backend/handler/filesystem/sync_handler.py index bf53b6c120..3aa4d9dbf7 100644 --- a/backend/handler/filesystem/sync_handler.py +++ b/backend/handler/filesystem/sync_handler.py @@ -17,8 +17,7 @@ def __init__(self) -> None: def build_incoming_path( self, device_id: str, platform_slug: str | None = None ) -> str: - """Build the relative incoming path for a device (and optional platform).""" - parts = [self.base_path, device_id, "incoming"] + parts = [device_id, "incoming"] if platform_slug: parts.append(platform_slug) return os.path.join(*parts) @@ -26,8 +25,7 @@ def build_incoming_path( def build_outgoing_path( self, device_id: str, platform_slug: str | None = None ) -> str: - """Build the relative outgoing path for a device (and optional platform).""" - parts = [self.base_path, device_id, "outgoing"] + parts = [device_id, "outgoing"] if platform_slug: parts.append(platform_slug) return os.path.join(*parts) @@ -35,16 +33,14 @@ def build_outgoing_path( def build_conflicts_path( self, device_id: str, platform_slug: str | None = None ) -> str: - """Build the relative conflicts path for a device (and optional platform).""" - parts = [self.base_path, device_id, "conflicts"] + parts = [device_id, "conflicts"] if platform_slug: parts.append(platform_slug) return os.path.join(*parts) def ensure_device_directories(self, device_id: str) -> None: - """Create incoming/outgoing directory structure for a device.""" - incoming = Path(self.build_incoming_path(device_id)) - outgoing = Path(self.build_outgoing_path(device_id)) + incoming = self.base_path / self.build_incoming_path(device_id) + outgoing = self.base_path / self.build_outgoing_path(device_id) incoming.mkdir(parents=True, exist_ok=True) outgoing.mkdir(parents=True, exist_ok=True) diff --git a/backend/handler/sync/ssh_handler.py b/backend/handler/sync/ssh_handler.py index ae2c5ce285..74bdabd56c 100644 --- a/backend/handler/sync/ssh_handler.py +++ b/backend/handler/sync/ssh_handler.py @@ -106,6 +106,10 @@ async def connect( "provide ssh_key_path/ssh_password in sync_config." ) + log.warning( + f"SSH host key verification disabled for {host} -- " + "connection is vulnerable to MITM attacks" + ) log.info(f"Connecting to {username}@{host}:{port}") return await asyncssh.connect(**connect_kwargs) diff --git a/backend/sync_watcher.py b/backend/sync_watcher.py index b83027897f..827cb7cf66 100644 --- a/backend/sync_watcher.py +++ b/backend/sync_watcher.py @@ -46,32 +46,26 @@ def _extract_device_and_platform(path: str) -> tuple[str, str, str] | None: """Extract device_id, platform_slug, and filename from a sync incoming path. - Expected path format: {build_incoming_path(device_id, platform_slug)}/filename.ext - i.e. {SYNC_BASE_PATH}/{device_id}/incoming/{platform_slug}/filename.ext + Expected path format: {SYNC_BASE_PATH}/{device_id}/incoming/{platform_slug}/filename.ext """ try: - rel_path = os.path.relpath(path) + rel_path = os.path.relpath(path, start=str(fs_sync_handler.base_path)) parts = rel_path.split(os.sep) - # Minimum: device_id / incoming / platform_slug / filename if len(parts) < 4 or parts[1] != "incoming": return None device_id = parts[0] platform_slug = parts[2] filename = parts[-1] - - # Validate path matches the canonical incoming path structure - expected_prefix = fs_sync_handler.build_incoming_path(device_id, platform_slug) - if not rel_path.startswith(expected_prefix): - return None - return (device_id, platform_slug, filename) except (ValueError, IndexError): return None def _ensure_conflicts_dir(device_id: str, platform_slug: str) -> str: - """Ensure the conflicts directory exists and return its path.""" - conflicts_dir = fs_sync_handler.build_conflicts_path(device_id, platform_slug) + conflicts_dir = str( + fs_sync_handler.base_path + / fs_sync_handler.build_conflicts_path(device_id, platform_slug) + ) os.makedirs(conflicts_dir, exist_ok=True) return conflicts_dir diff --git a/backend/tasks/sync_push_pull_task.py b/backend/tasks/sync_push_pull_task.py index 952aafcd45..88cc018ef0 100644 --- a/backend/tasks/sync_push_pull_task.py +++ b/backend/tasks/sync_push_pull_task.py @@ -26,7 +26,11 @@ from tasks.tasks import PeriodicTask, TaskType -async def run_push_pull_sync(device_id: str | None = None, force: bool = False) -> dict: +async def run_push_pull_sync( + device_id: str | None = None, + session_id: int | None = None, + force: bool = False, +) -> dict: """Execute push-pull sync for one or all push_pull devices.""" if not ENABLE_SYNC_PUSH_PULL and not force: log.info("Push-pull sync not enabled, skipping") @@ -50,13 +54,13 @@ async def run_push_pull_sync(device_id: str | None = None, force: bool = False) for device in devices: if not device.sync_enabled: continue - result = await _sync_device(device) + result = await _sync_device(device, session_id=session_id) results.append(result) return {"status": "completed", "device_results": results} -async def _sync_device(device: Device) -> dict: +async def _sync_device(device: Device, session_id: int | None = None) -> dict: """Perform push-pull sync for a single device.""" sync_config = device.sync_config or {} if not sync_config.get("ssh_host"): @@ -71,10 +75,21 @@ async def _sync_device(device: Device) -> dict: emit_sync_started, ) - # Create sync session - sync_session = db_sync_session_handler.create_session( - device_id=device.id, user_id=device.user_id - ) + if session_id: + sync_session = db_sync_session_handler.get_session( + session_id=session_id, user_id=device.user_id + ) + if not sync_session: + log.warning( + f"Push-pull: session {session_id} not found, creating new session" + ) + sync_session = db_sync_session_handler.create_session( + device_id=device.id, user_id=device.user_id + ) + else: + sync_session = db_sync_session_handler.create_session( + device_id=device.id, user_id=device.user_id + ) await emit_sync_started( user_id=device.user_id, @@ -110,7 +125,6 @@ async def _sync_device(device: Device) -> dict: db_sync_session_handler.complete_session(session_id=sync_session.id) return {"device_id": device.id, "status": "no_directories"} - # List remote saves remote_saves = await ssh_sync_handler.list_remote_saves(conn, save_directories) log.info( f"Push-pull: found {len(remote_saves)} remote saves on device {device.id}" @@ -126,7 +140,6 @@ async def _sync_device(device: Device) -> dict: operations_planned = len(remote_saves) - # Process each remote save for remote_save in remote_saves: try: action = await _process_remote_save(device, conn, remote_save) @@ -158,14 +171,11 @@ async def _sync_device(device: Device) -> dict: current_file=remote_save.file_name, ) - # Check for server saves that need to be pushed to the device push_count = await _push_missing_saves( device, conn, remote_saves, save_directories ) completed += push_count - conn.close() - except Exception as e: log.error(f"Push-pull sync failed for device {device.id}: {e}", exc_info=True) db_sync_session_handler.fail_session( @@ -178,6 +188,8 @@ async def _sync_device(device: Device) -> dict: error_message=str(e), ) return {"device_id": device.id, "status": "failed", "error": str(e)} + finally: + conn.close() db_sync_session_handler.complete_session( session_id=sync_session.id, @@ -227,21 +239,11 @@ async def _process_remote_save( break if not matched_save: - # New save from device - download it - local_path, content_hash = await ssh_sync_handler.download_save( - conn, remote_save.path + log.info( + f"Push-pull: remote save {hl(remote_save.file_name)} " + f"on platform {remote_save.platform_slug} - no matching server save, skipping" ) - try: - # We have the file locally, but we need a ROM to attach it to. - # Without a clear ROM match, skip for now. - log.info( - f"Push-pull: new remote save {hl(remote_save.file_name)} " - f"on platform {remote_save.platform_slug} - no matching server save" - ) - return "skipped" - finally: - if os.path.exists(local_path): - os.unlink(local_path) + return "skipped" # Compare with existing save device_sync = db_device_save_sync_handler.get_sync( diff --git a/backend/tests/endpoints/test_device.py b/backend/tests/endpoints/test_device.py index 9ec724ac6c..85ef64eb59 100644 --- a/backend/tests/endpoints/test_device.py +++ b/backend/tests/endpoints/test_device.py @@ -475,3 +475,82 @@ def test_hostname_only_no_conflict_without_platform( ) assert response.status_code == status.HTTP_201_CREATED + + +class TestSyncConfigMasking: + def test_ssh_credentials_masked_in_response( + self, client, access_token: str, admin_user: User + ): + db_device_handler.add_device( + Device( + id="mask-dev-1", + user_id=admin_user.id, + name="SSH Device", + sync_config={ + "ssh_host": "192.168.1.100", + "ssh_port": 22, + "ssh_username": "retro", + "ssh_password": "supersecret123", + "ssh_key_path": "/home/retro/.ssh/id_rsa", + "save_directories": [ + {"platform_slug": "gba", "path": "/saves/gba"} + ], + }, + ) + ) + + response = client.get( + "/api/devices/mask-dev-1", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + config = response.json()["sync_config"] + assert config["ssh_host"] == "192.168.1.100" + assert config["ssh_port"] == 22 + assert config["ssh_username"] == "retro" + assert config["ssh_password"] == "********" + assert config["ssh_key_path"] == "********" + assert config["save_directories"] == [ + {"platform_slug": "gba", "path": "/saves/gba"} + ] + + def test_null_sync_config_passes_through( + self, client, access_token: str, admin_user: User + ): + db_device_handler.add_device( + Device( + id="mask-dev-2", + user_id=admin_user.id, + name="No Config Device", + ) + ) + + response = client.get( + "/api/devices/mask-dev-2", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + assert response.json()["sync_config"] is None + + def test_sync_config_without_sensitive_fields( + self, client, access_token: str, admin_user: User + ): + db_device_handler.add_device( + Device( + id="mask-dev-3", + user_id=admin_user.id, + sync_config={"ssh_host": "10.0.0.1", "ssh_port": 2222}, + ) + ) + + response = client.get( + "/api/devices/mask-dev-3", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + config = response.json()["sync_config"] + assert config["ssh_host"] == "10.0.0.1" + assert config["ssh_port"] == 2222 diff --git a/backend/tests/endpoints/test_sync.py b/backend/tests/endpoints/test_sync.py index 0ff9e2cb32..8b80b782b1 100644 --- a/backend/tests/endpoints/test_sync.py +++ b/backend/tests/endpoints/test_sync.py @@ -391,3 +391,142 @@ def test_trigger_push_pull_device_not_found(self, client, access_token: str): headers={"Authorization": f"Bearer {access_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_trigger_push_pull_passes_session_id( + self, client, access_token: str, admin_user: User + ): + device = db_device_handler.add_device( + Device( + id="pp-dev-sid", + user_id=admin_user.id, + sync_mode=SyncMode.PUSH_PULL, + sync_enabled=True, + ) + ) + + with mock.patch("endpoints.sync.high_prio_queue") as mock_queue: + response = client.post( + f"/api/sync/devices/{device.id}/push-pull", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + call_kwargs = mock_queue.enqueue.call_args + assert "session_id" in call_kwargs.kwargs + + +class TestNegotiateAdvanced: + def test_negotiate_untracked_save_returns_noop( + self, client, access_token: str, admin_user: User, save: Save + ): + device = db_device_handler.add_device( + Device(id="neg-untrack-dev", user_id=admin_user.id, sync_enabled=True) + ) + db_device_save_sync_handler.set_untracked( + device_id=device.id, save_id=save.id, untracked=True + ) + + response = client.post( + "/api/sync/negotiate", + json={ + "device_id": device.id, + "saves": [ + { + "rom_id": save.rom_id, + "file_name": save.file_name, + "content_hash": "different_hash", + "updated_at": "2026-03-01T00:00:00Z", + "file_size_bytes": 100, + } + ], + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + noop_ops = [op for op in data["operations"] if op["action"] == "no_op"] + assert len(noop_ops) >= 1 + + def test_negotiate_server_save_not_mentioned_by_client( + self, client, access_token: str, admin_user: User, save: Save + ): + device = db_device_handler.add_device( + Device(id="neg-miss-dev", user_id=admin_user.id, sync_enabled=True) + ) + + response = client.post( + "/api/sync/negotiate", + json={"device_id": device.id, "saves": []}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + download_ops = [op for op in data["operations"] if op["action"] == "download"] + assert len(download_ops) >= 1 + assert any(op["save_id"] == save.id for op in download_ops) + + def test_negotiate_deleted_by_client_skipped( + self, client, access_token: str, admin_user: User, save: Save + ): + device = db_device_handler.add_device( + Device(id="neg-del-dev", user_id=admin_user.id, sync_enabled=True) + ) + db_device_save_sync_handler.upsert_sync( + device_id=device.id, + save_id=save.id, + synced_at=datetime.now(timezone.utc), + ) + + response = client.post( + "/api/sync/negotiate", + json={"device_id": device.id, "saves": []}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + ops_for_save = [op for op in data["operations"] if op.get("save_id") == save.id] + assert len(ops_for_save) == 0 + + def test_complete_failed_session_rejected( + self, client, access_token: str, admin_user: User + ): + device = db_device_handler.add_device( + Device(id="sess-failed-dev", user_id=admin_user.id) + ) + sync_session = db_sync_session_handler.create_session( + device_id=device.id, user_id=admin_user.id + ) + db_sync_session_handler.fail_session(sync_session.id, error_message="test") + + response = client.post( + f"/api/sync/sessions/{sync_session.id}/complete", + json={"operations_completed": 0, "operations_failed": 0}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_complete_cancelled_session_rejected( + self, client, access_token: str, admin_user: User + ): + device = db_device_handler.add_device( + Device(id="sess-cancel-dev", user_id=admin_user.id) + ) + db_sync_session_handler.create_session( + device_id=device.id, user_id=admin_user.id + ) + db_sync_session_handler.cancel_active_sessions(device.id, admin_user.id) + + sessions = db_sync_session_handler.get_sessions( + admin_user.id, device_id=device.id + ) + cancelled = sessions[0] + + response = client.post( + f"/api/sync/sessions/{cancelled.id}/complete", + json={"operations_completed": 0, "operations_failed": 0}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST diff --git a/backend/tests/handler/database/test_sync_sessions_handler.py b/backend/tests/handler/database/test_sync_sessions_handler.py index c91b85b3e9..303c047e66 100644 --- a/backend/tests/handler/database/test_sync_sessions_handler.py +++ b/backend/tests/handler/database/test_sync_sessions_handler.py @@ -1,3 +1,6 @@ +import pytest +from sqlalchemy.exc import NoResultFound + from handler.database import db_device_handler, db_sync_session_handler from models.device import Device from models.sync_session import SyncSessionStatus @@ -126,6 +129,41 @@ def test_marks_failed_without_error(self, admin_user: User): assert result.error_message is None +class TestIncrementOperationsCompleted: + def test_increments_counter(self, admin_user: User): + device = db_device_handler.add_device( + Device(id="inc-dev-1", user_id=admin_user.id) + ) + session = db_sync_session_handler.create_session(device.id, admin_user.id) + assert session.operations_completed == 0 + + db_sync_session_handler.increment_operations_completed(session.id) + db_sync_session_handler.increment_operations_completed(session.id) + db_sync_session_handler.increment_operations_completed(session.id) + + result = db_sync_session_handler.get_session(session.id, admin_user.id) + assert result.operations_completed == 3 + + def test_noop_on_nonexistent_session(self, admin_user: User): + db_sync_session_handler.increment_operations_completed(999999) + + +class TestNoResultFoundOnMissingSession: + def test_update_session_raises(self, admin_user: User): + with pytest.raises(NoResultFound): + db_sync_session_handler.update_session( + 999999, {"status": SyncSessionStatus.IN_PROGRESS} + ) + + def test_complete_session_raises(self, admin_user: User): + with pytest.raises(NoResultFound): + db_sync_session_handler.complete_session(999999) + + def test_fail_session_raises(self, admin_user: User): + with pytest.raises(NoResultFound): + db_sync_session_handler.fail_session(999999, error_message="test") + + class TestCancelActiveSessions: def test_cancels_active_sessions(self, admin_user: User): device = db_device_handler.add_device( diff --git a/backend/tests/handler/filesystem/test_sync_handler.py b/backend/tests/handler/filesystem/test_sync_handler.py index 78ffcaa786..b6a156b711 100644 --- a/backend/tests/handler/filesystem/test_sync_handler.py +++ b/backend/tests/handler/filesystem/test_sync_handler.py @@ -27,36 +27,28 @@ def patch_base_path(self, handler: FSSyncHandler, temp_dir): def test_build_incoming_path(self, handler: FSSyncHandler): path = handler.build_incoming_path("device-1") - assert "device-1" in path - assert "incoming" in path + assert path == os.path.join("device-1", "incoming") def test_build_incoming_path_with_platform(self, handler): path = handler.build_incoming_path("device-1", "gba") - assert "device-1" in path - assert "incoming" in path - assert "gba" in path + assert path == os.path.join("device-1", "incoming", "gba") def test_build_outgoing_path(self, handler: FSSyncHandler): path = handler.build_outgoing_path("device-1") - assert "device-1" in path - assert "outgoing" in path + assert path == os.path.join("device-1", "outgoing") def test_build_outgoing_path_with_platform(self, handler: FSSyncHandler): path = handler.build_outgoing_path("device-1", "snes") - assert "device-1" in path - assert "outgoing" in path - assert "snes" in path + assert path == os.path.join("device-1", "outgoing", "snes") def test_build_conflicts_path(self, handler: FSSyncHandler): path = handler.build_conflicts_path("device-1", "gba") - assert "device-1" in path - assert "conflicts" in path - assert "gba" in path + assert path == os.path.join("device-1", "conflicts", "gba") def test_ensure_device_directories(self, handler: FSSyncHandler, temp_dir): handler.ensure_device_directories("test-device") - incoming = handler.build_incoming_path("test-device") - outgoing = handler.build_outgoing_path("test-device") + incoming = handler.base_path / handler.build_incoming_path("test-device") + outgoing = handler.base_path / handler.build_outgoing_path("test-device") assert os.path.isdir(incoming) assert os.path.isdir(outgoing) @@ -65,9 +57,10 @@ def test_list_incoming_files_empty(self, handler: FSSyncHandler): assert result == [] def test_list_incoming_files(self, handler: FSSyncHandler, temp_dir): - # Set up: create incoming/platform/file structure handler.ensure_device_directories("dev-1") - incoming_path = handler.build_incoming_path("dev-1", "gba") + incoming_path = str( + handler.base_path / handler.build_incoming_path("dev-1", "gba") + ) os.makedirs(incoming_path, exist_ok=True) test_file = os.path.join(incoming_path, "save.sav") with open(test_file, "wb") as f: @@ -114,7 +107,7 @@ def test_write_outgoing_file(self, handler: FSSyncHandler, temp_dir): def test_remove_incoming_file(self, handler: FSSyncHandler, temp_dir): handler.ensure_device_directories("dev-1") - incoming = handler.build_incoming_path("dev-1", "gba") + incoming = str(handler.base_path / handler.build_incoming_path("dev-1", "gba")) os.makedirs(incoming, exist_ok=True) test_file = os.path.join(incoming, "to_remove.sav") with open(test_file, "wb") as f: diff --git a/backend/tests/test_sync_watcher.py b/backend/tests/test_sync_watcher.py new file mode 100644 index 0000000000..e62326b6a6 --- /dev/null +++ b/backend/tests/test_sync_watcher.py @@ -0,0 +1,112 @@ +import os +import shutil +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest + +from handler.filesystem.sync_handler import FSSyncHandler + + +class TestExtractDeviceAndPlatform: + @pytest.fixture + def temp_dir(self): + d = tempfile.mkdtemp() + yield d + shutil.rmtree(d, ignore_errors=True) + + @pytest.fixture + def handler(self): + return FSSyncHandler.__new__(FSSyncHandler) + + @pytest.fixture(autouse=True) + def patch_base_path(self, handler: FSSyncHandler, temp_dir): + handler.base_path = Path(temp_dir) + with patch("sync_watcher.fs_sync_handler", handler): + yield + + def test_extract_valid_incoming_path(self, temp_dir): + from sync_watcher import _extract_device_and_platform + + path = os.path.join(temp_dir, "device-1", "incoming", "gba", "save.sav") + result = _extract_device_and_platform(path) + assert result == ("device-1", "gba", "save.sav") + + def test_extract_non_incoming_path_returns_none(self, temp_dir): + from sync_watcher import _extract_device_and_platform + + path = os.path.join(temp_dir, "device-1", "outgoing", "gba", "save.sav") + result = _extract_device_and_platform(path) + assert result is None + + def test_extract_too_few_parts_returns_none(self, temp_dir): + from sync_watcher import _extract_device_and_platform + + path = os.path.join(temp_dir, "device-1", "incoming") + result = _extract_device_and_platform(path) + assert result is None + + def test_extract_deeply_nested_returns_leaf_filename(self, temp_dir): + from sync_watcher import _extract_device_and_platform + + path = os.path.join( + temp_dir, "device-1", "incoming", "gba", "subdir", "save.sav" + ) + result = _extract_device_and_platform(path) + assert result == ("device-1", "gba", "save.sav") + + def test_extract_path_outside_base_returns_none(self): + from sync_watcher import _extract_device_and_platform + + result = _extract_device_and_platform("/totally/different/path") + assert result is None + + +class TestEnsureConflictsDir: + @pytest.fixture + def temp_dir(self): + d = tempfile.mkdtemp() + yield d + shutil.rmtree(d, ignore_errors=True) + + @pytest.fixture + def handler(self): + return FSSyncHandler.__new__(FSSyncHandler) + + @pytest.fixture(autouse=True) + def patch_base_path(self, handler: FSSyncHandler, temp_dir): + handler.base_path = Path(temp_dir) + with patch("sync_watcher.fs_sync_handler", handler): + yield + + def test_creates_directory_and_returns_path(self, temp_dir): + from sync_watcher import _ensure_conflicts_dir + + result = _ensure_conflicts_dir("device-1", "gba") + expected = os.path.join(temp_dir, "device-1", "conflicts", "gba") + assert result == expected + assert os.path.isdir(expected) + + def test_idempotent_no_error_on_second_call(self, temp_dir): + from sync_watcher import _ensure_conflicts_dir + + _ensure_conflicts_dir("device-1", "gba") + result = _ensure_conflicts_dir("device-1", "gba") + expected = os.path.join(temp_dir, "device-1", "conflicts", "gba") + assert result == expected + assert os.path.isdir(expected) + + +class TestProcessSyncChanges: + def test_empty_changes_returns_immediately(self): + with patch("sync_watcher.ENABLE_SYNC_FOLDER_WATCHER", True): + from sync_watcher import process_sync_changes + + process_sync_changes([]) + + def test_disabled_watcher_returns_immediately(self): + with patch("sync_watcher.ENABLE_SYNC_FOLDER_WATCHER", False): + from sync_watcher import process_sync_changes + + process_sync_changes([("added", "/some/path/file.sav")]) diff --git a/backend/tests/test_utils_auth.py b/backend/tests/test_utils_auth.py new file mode 100644 index 0000000000..3d1e0a5867 --- /dev/null +++ b/backend/tests/test_utils_auth.py @@ -0,0 +1,73 @@ +from dataclasses import replace +from unittest.mock import MagicMock + +from ua_parser import parse as parse_ua + +from handler.database import db_device_handler +from models.user import User +from utils.auth import _get_device_name, create_or_find_web_device + +CHROME_MAC_UA = ( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " + "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" +) + + +def _make_request(user_agent: str = CHROME_MAC_UA, forwarded_for: str = "1.2.3.4"): + request = MagicMock() + request.headers = {"user-agent": user_agent, "x-forwarded-for": forwarded_for} + request.client.host = "127.0.0.1" + return request + + +class TestGetDeviceName: + def test_browser_and_os(self): + result = parse_ua(CHROME_MAC_UA) + assert _get_device_name(result) == "Chrome on Mac OS X" + + def test_browser_only(self): + result = replace(parse_ua(CHROME_MAC_UA), os=None) + assert _get_device_name(result) == "Chrome" + + def test_os_only(self): + result = replace(parse_ua(CHROME_MAC_UA), user_agent=None) + assert _get_device_name(result) == "Mac OS X" + + def test_neither(self): + result = replace(parse_ua(CHROME_MAC_UA), user_agent=None, os=None) + assert _get_device_name(result) == "Web Browser" + + +class TestCreateOrFindWebDevice: + def test_creates_new_device(self, admin_user: User): + request = _make_request() + device = create_or_find_web_device(request, admin_user) + + assert device.id is not None + assert len(device.id) == 36 + assert device.user_id == admin_user.id + assert device.name == "Chrome on Mac OS X" + assert device.platform == "Web" + assert device.client == "web" + assert device.ip_address == "1.2.3.4" + assert device.hostname == "127.0.0.1" + assert device.last_seen is not None + + def test_returns_existing_device_on_matching_fingerprint(self, admin_user: User): + request = _make_request() + first = create_or_find_web_device(request, admin_user) + second = create_or_find_web_device(request, admin_user) + + assert second.id == first.id + + def test_updates_last_seen_on_existing_device(self, admin_user: User): + request = _make_request() + first = create_or_find_web_device(request, admin_user) + + second = create_or_find_web_device(request, admin_user) + refreshed = db_device_handler.get_device( + device_id=second.id, user_id=admin_user.id + ) + + assert refreshed.last_seen is not None + assert first.id == refreshed.id