Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions backend/alembic/versions/0073_sync_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
def upgrade() -> None:
connection = op.get_bind()
if is_postgresql(connection):
rom_user_status_enum = ENUM(
sync_session_status_enum = ENUM(
"PENDING",
"IN_PROGRESS",
"COMPLETED",
Expand All @@ -30,9 +30,9 @@ def upgrade() -> None:
name="syncsessionstatus",
create_type=False,
)
rom_user_status_enum.create(connection, checkfirst=False)
sync_session_status_enum.create(connection, checkfirst=False)
else:
rom_user_status_enum = sa.Enum(
sync_session_status_enum = sa.Enum(
"PENDING",
"IN_PROGRESS",
"COMPLETED",
Expand Down Expand Up @@ -86,7 +86,11 @@ def upgrade() -> None:
"updated_at",
sa.TIMESTAMP(timezone=True),
nullable=False,
server_default=sa.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
server_default=(
sa.text("CURRENT_TIMESTAMP")
if is_postgresql(connection)
else sa.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")
),
),
sa.ForeignKeyConstraint(["device_id"], ["devices.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
Expand All @@ -107,3 +111,7 @@ def downgrade() -> None:
op.drop_index("ix_sync_sessions_device_id", table_name="sync_sessions")

op.drop_table("sync_sessions")

connection = op.get_bind()
if is_postgresql(connection):
ENUM(name="syncsessionstatus").drop(connection, checkfirst=True)
16 changes: 15 additions & 1 deletion backend/endpoints/responses/device.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from pydantic import ConfigDict
from typing import Any

from pydantic import ConfigDict, field_serializer

from models.device import SyncMode

from .base import BaseModel, UTCDatetime

SENSITIVE_SYNC_CONFIG_KEYS = {"ssh_password", "ssh_key_path"}


class DeviceSyncSchema(BaseModel):
model_config = ConfigDict(from_attributes=True)
Expand Down Expand Up @@ -34,6 +38,16 @@ class DeviceSchema(BaseModel):
created_at: UTCDatetime
updated_at: UTCDatetime

@field_serializer("sync_config")
@classmethod
def mask_sensitive_fields(cls, v: dict | None) -> dict[str, Any] | None:
if not v:
return v
return {
k: "********" if k in SENSITIVE_SYNC_CONFIG_KEYS else val
for k, val in v.items()
}


class DeviceCreateResponse(BaseModel):
device_id: str
Expand Down
5 changes: 3 additions & 2 deletions backend/endpoints/responses/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ def from_orm_with_request(
if not db_user:
return None

db_user.current_device_id = request.session.get("device_id") # type: ignore
return cls.model_validate(db_user)
schema = cls.model_validate(db_user)
schema.current_device_id = request.session.get("device_id")
return schema


class InviteLinkSchema(BaseModel):
Expand Down
33 changes: 11 additions & 22 deletions backend/endpoints/saves.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ def _resolve_device(
return device


def _increment_session_counter(session_id: int) -> None:
try:
db_sync_session_handler.increment_operations_completed(
session_id=session_id,
)
except Exception:
log.warning(f"Failed to update sync session {session_id}", exc_info=True)


router = APIRouter(
prefix="/saves",
tags=["saves"],
Expand Down Expand Up @@ -247,17 +256,7 @@ async def add_save(
db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id)

if session_id:
try:
session = db_sync_session_handler.get_session(
session_id=session_id, user_id=request.user.id
)
if session:
db_sync_session_handler.update_session(
session_id=session_id,
data={"operations_completed": session.operations_completed + 1},
)
except Exception:
log.warning(f"Failed to update sync session {session_id}")
_increment_session_counter(session_id)

if slot and autocleanup:
slot_saves = db_save_handler.get_saves(
Expand Down Expand Up @@ -454,17 +453,7 @@ def download_save(
db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id)

if session_id:
try:
session = db_sync_session_handler.get_session(
session_id=session_id, user_id=request.user.id
)
if session:
db_sync_session_handler.update_session(
session_id=session_id,
data={"operations_completed": session.operations_completed + 1},
)
except Exception:
log.warning(f"Failed to update sync session {session_id}")
_increment_session_counter(session_id)

return FileResponse(path=str(file_path), filename=save.file_name)

Expand Down
3 changes: 2 additions & 1 deletion backend/endpoints/sync.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from datetime import datetime

from fastapi import HTTPException, Request, status
from pydantic import BaseModel

from config import TASK_TIMEOUT
from decorators.auth import protected_route
from endpoints.responses.base import BaseModel
from endpoints.responses.sync import (
SyncNegotiateResponse,
SyncOperationSchema,
Expand Down Expand Up @@ -359,6 +359,7 @@ def trigger_push_pull(
high_prio_queue.enqueue(
"tasks.sync_push_pull_task.run_push_pull_sync",
device_id=device.id,
session_id=sync_session.id,
force=True,
job_timeout=TASK_TIMEOUT,
meta={
Expand Down
31 changes: 28 additions & 3 deletions backend/handler/database/sync_sessions_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import datetime, timezone

from sqlalchemy import select, update
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session

from decorators.database import begin_session
Expand Down Expand Up @@ -75,7 +76,25 @@ def update_session(
.values(**data)
.execution_options(synchronize_session="evaluate")
)
return session.query(SyncSession).filter_by(id=session_id).one()
result = session.scalar(select(SyncSession).filter_by(id=session_id))
if not result:
raise NoResultFound(f"SyncSession {session_id} not found after update")
return result

@begin_session
def increment_operations_completed(
self,
session_id: int,
session: Session = None, # type: ignore
) -> None:
session.execute(
update(SyncSession)
.where(SyncSession.id == session_id)
.values(
operations_completed=SyncSession.operations_completed + 1,
)
.execution_options(synchronize_session="evaluate")
)

@begin_session
def complete_session(
Expand All @@ -96,7 +115,10 @@ def complete_session(
)
.execution_options(synchronize_session="evaluate")
)
return session.query(SyncSession).filter_by(id=session_id).one()
result = session.scalar(select(SyncSession).filter_by(id=session_id))
if not result:
raise NoResultFound(f"SyncSession {session_id} not found after complete")
return result

@begin_session
def fail_session(
Expand All @@ -115,7 +137,10 @@ def fail_session(
)
.execution_options(synchronize_session="evaluate")
)
return session.query(SyncSession).filter_by(id=session_id).one()
result = session.scalar(select(SyncSession).filter_by(id=session_id))
if not result:
raise NoResultFound(f"SyncSession {session_id} not found after fail")
return result

@begin_session
def cancel_active_sessions(
Expand Down
14 changes: 5 additions & 9 deletions backend/handler/filesystem/sync_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,30 @@ def __init__(self) -> None:
def build_incoming_path(
self, device_id: str, platform_slug: str | None = None
) -> str:
"""Build the relative incoming path for a device (and optional platform)."""
parts = [self.base_path, device_id, "incoming"]
parts = [device_id, "incoming"]
if platform_slug:
parts.append(platform_slug)
return os.path.join(*parts)

def build_outgoing_path(
self, device_id: str, platform_slug: str | None = None
) -> str:
"""Build the relative outgoing path for a device (and optional platform)."""
parts = [self.base_path, device_id, "outgoing"]
parts = [device_id, "outgoing"]
if platform_slug:
parts.append(platform_slug)
return os.path.join(*parts)

def build_conflicts_path(
self, device_id: str, platform_slug: str | None = None
) -> str:
"""Build the relative conflicts path for a device (and optional platform)."""
parts = [self.base_path, device_id, "conflicts"]
parts = [device_id, "conflicts"]
if platform_slug:
parts.append(platform_slug)
return os.path.join(*parts)

def ensure_device_directories(self, device_id: str) -> None:
"""Create incoming/outgoing directory structure for a device."""
incoming = Path(self.build_incoming_path(device_id))
outgoing = Path(self.build_outgoing_path(device_id))
incoming = self.base_path / self.build_incoming_path(device_id)
outgoing = self.base_path / self.build_outgoing_path(device_id)

incoming.mkdir(parents=True, exist_ok=True)
outgoing.mkdir(parents=True, exist_ok=True)
Expand Down
4 changes: 4 additions & 0 deletions backend/handler/sync/ssh_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ async def connect(
"provide ssh_key_path/ssh_password in sync_config."
)

log.warning(
f"SSH host key verification disabled for {host} -- "
"connection is vulnerable to MITM attacks"
)
log.info(f"Connecting to {username}@{host}:{port}")
return await asyncssh.connect(**connect_kwargs)

Expand Down
18 changes: 6 additions & 12 deletions backend/sync_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,32 +46,26 @@
def _extract_device_and_platform(path: str) -> tuple[str, str, str] | None:
"""Extract device_id, platform_slug, and filename from a sync incoming path.

Expected path format: {build_incoming_path(device_id, platform_slug)}/filename.ext
i.e. {SYNC_BASE_PATH}/{device_id}/incoming/{platform_slug}/filename.ext
Expected path format: {SYNC_BASE_PATH}/{device_id}/incoming/{platform_slug}/filename.ext
"""
try:
rel_path = os.path.relpath(path)
rel_path = os.path.relpath(path, start=str(fs_sync_handler.base_path))
parts = rel_path.split(os.sep)
# Minimum: device_id / incoming / platform_slug / filename
if len(parts) < 4 or parts[1] != "incoming":
return None
device_id = parts[0]
platform_slug = parts[2]
filename = parts[-1]

# Validate path matches the canonical incoming path structure
expected_prefix = fs_sync_handler.build_incoming_path(device_id, platform_slug)
if not rel_path.startswith(expected_prefix):
return None

return (device_id, platform_slug, filename)
except (ValueError, IndexError):
return None


def _ensure_conflicts_dir(device_id: str, platform_slug: str) -> str:
"""Ensure the conflicts directory exists and return its path."""
conflicts_dir = fs_sync_handler.build_conflicts_path(device_id, platform_slug)
conflicts_dir = str(
fs_sync_handler.base_path
/ fs_sync_handler.build_conflicts_path(device_id, platform_slug)
)
os.makedirs(conflicts_dir, exist_ok=True)
return conflicts_dir

Expand Down
Loading
Loading