Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/migrations.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4

- name: Install mariadb connectors
run: |
sudo apt-get update
sudo apt-get install -y libmariadb3 libmariadb-dev

- name: Install uv
uses: astral-sh/setup-uv@v6.7.0

Expand Down
1 change: 0 additions & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ jobs:
uses: actions/checkout@v4.3.0

- name: Install mariadb connectors
if: matrix.db == 'mariadb'
run: |
sudo apt-get update
sudo apt-get install -y libmariadb3 libmariadb-dev
Expand Down
25 changes: 13 additions & 12 deletions backend/alembic/versions/0073_sync_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
def upgrade() -> None:
connection = op.get_bind()
if is_postgresql(connection):
rom_user_status_enum = ENUM(
sync_session_status_enum = ENUM(
"PENDING",
"IN_PROGRESS",
"COMPLETED",
Expand All @@ -30,9 +30,9 @@ def upgrade() -> None:
name="syncsessionstatus",
create_type=False,
)
rom_user_status_enum.create(connection, checkfirst=False)
sync_session_status_enum.create(connection, checkfirst=True)
else:
rom_user_status_enum = sa.Enum(
sync_session_status_enum = sa.Enum(
"PENDING",
"IN_PROGRESS",
"COMPLETED",
Expand All @@ -48,14 +48,7 @@ def upgrade() -> None:
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column(
"status",
sa.Enum(
"PENDING",
"IN_PROGRESS",
"COMPLETED",
"FAILED",
"CANCELLED",
name="syncsessionstatus",
),
sync_session_status_enum,
nullable=False,
server_default="PENDING",
),
Expand Down Expand Up @@ -86,7 +79,11 @@ def upgrade() -> None:
"updated_at",
sa.TIMESTAMP(timezone=True),
nullable=False,
server_default=sa.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
server_default=(
sa.text("CURRENT_TIMESTAMP")
if is_postgresql(connection)
else sa.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")
),
),
sa.ForeignKeyConstraint(["device_id"], ["devices.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
Expand All @@ -107,3 +104,7 @@ def downgrade() -> None:
op.drop_index("ix_sync_sessions_device_id", table_name="sync_sessions")

op.drop_table("sync_sessions")

connection = op.get_bind()
if is_postgresql(connection):
ENUM(name="syncsessionstatus").drop(connection, checkfirst=True)
16 changes: 15 additions & 1 deletion backend/endpoints/responses/device.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from pydantic import ConfigDict
from typing import Any

from pydantic import ConfigDict, field_serializer

from models.device import SyncMode

from .base import BaseModel, UTCDatetime

SENSITIVE_SYNC_CONFIG_KEYS = {"ssh_password", "ssh_key_path"}


class DeviceSyncSchema(BaseModel):
model_config = ConfigDict(from_attributes=True)
Expand Down Expand Up @@ -34,6 +38,16 @@ class DeviceSchema(BaseModel):
created_at: UTCDatetime
updated_at: UTCDatetime

@field_serializer("sync_config")
@classmethod
def mask_sensitive_fields(cls, v: dict | None) -> dict[str, Any] | None:
if not v:
return v
return {
k: "********" if k in SENSITIVE_SYNC_CONFIG_KEYS else val
for k, val in v.items()
}


class DeviceCreateResponse(BaseModel):
device_id: str
Expand Down
5 changes: 3 additions & 2 deletions backend/endpoints/responses/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ def from_orm_with_request(
if not db_user:
return None

db_user.current_device_id = request.session.get("device_id") # type: ignore
return cls.model_validate(db_user)
schema = cls.model_validate(db_user)
schema.current_device_id = request.session.get("device_id")
return schema


class InviteLinkSchema(BaseModel):
Expand Down
33 changes: 11 additions & 22 deletions backend/endpoints/saves.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ def _resolve_device(
return device


def _increment_session_counter(session_id: int) -> None:
try:
db_sync_session_handler.increment_operations_completed(
session_id=session_id,
)
except Exception:
log.warning(f"Failed to update sync session {session_id}", exc_info=True)


router = APIRouter(
prefix="/saves",
tags=["saves"],
Expand Down Expand Up @@ -247,17 +256,7 @@ async def add_save(
db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id)

if session_id:
try:
session = db_sync_session_handler.get_session(
session_id=session_id, user_id=request.user.id
)
if session:
db_sync_session_handler.update_session(
session_id=session_id,
data={"operations_completed": session.operations_completed + 1},
)
except Exception:
log.warning(f"Failed to update sync session {session_id}")
_increment_session_counter(session_id)

if slot and autocleanup:
slot_saves = db_save_handler.get_saves(
Expand Down Expand Up @@ -454,17 +453,7 @@ def download_save(
db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id)

if session_id:
try:
session = db_sync_session_handler.get_session(
session_id=session_id, user_id=request.user.id
)
if session:
db_sync_session_handler.update_session(
session_id=session_id,
data={"operations_completed": session.operations_completed + 1},
)
except Exception:
log.warning(f"Failed to update sync session {session_id}")
_increment_session_counter(session_id)

return FileResponse(path=str(file_path), filename=save.file_name)

Expand Down
3 changes: 2 additions & 1 deletion backend/endpoints/sync.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from datetime import datetime

from fastapi import HTTPException, Request, status
from pydantic import BaseModel

from config import TASK_TIMEOUT
from decorators.auth import protected_route
from endpoints.responses.base import BaseModel
from endpoints.responses.sync import (
SyncNegotiateResponse,
SyncOperationSchema,
Expand Down Expand Up @@ -359,6 +359,7 @@ def trigger_push_pull(
high_prio_queue.enqueue(
"tasks.sync_push_pull_task.run_push_pull_sync",
device_id=device.id,
session_id=sync_session.id,
force=True,
job_timeout=TASK_TIMEOUT,
meta={
Expand Down
31 changes: 28 additions & 3 deletions backend/handler/database/sync_sessions_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import datetime, timezone

from sqlalchemy import select, update
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session

from decorators.database import begin_session
Expand Down Expand Up @@ -75,7 +76,25 @@ def update_session(
.values(**data)
.execution_options(synchronize_session="evaluate")
)
return session.query(SyncSession).filter_by(id=session_id).one()
result = session.scalar(select(SyncSession).filter_by(id=session_id))
if not result:
raise NoResultFound(f"SyncSession {session_id} not found after update")
return result

@begin_session
def increment_operations_completed(
self,
session_id: int,
session: Session = None, # type: ignore
) -> None:
session.execute(
update(SyncSession)
.where(SyncSession.id == session_id)
.values(
operations_completed=SyncSession.operations_completed + 1,
)
.execution_options(synchronize_session="evaluate")
)

@begin_session
def complete_session(
Expand All @@ -96,7 +115,10 @@ def complete_session(
)
.execution_options(synchronize_session="evaluate")
)
return session.query(SyncSession).filter_by(id=session_id).one()
result = session.scalar(select(SyncSession).filter_by(id=session_id))
if not result:
raise NoResultFound(f"SyncSession {session_id} not found after complete")
return result

@begin_session
def fail_session(
Expand All @@ -115,7 +137,10 @@ def fail_session(
)
.execution_options(synchronize_session="evaluate")
)
return session.query(SyncSession).filter_by(id=session_id).one()
result = session.scalar(select(SyncSession).filter_by(id=session_id))
if not result:
raise NoResultFound(f"SyncSession {session_id} not found after fail")
return result

@begin_session
def cancel_active_sessions(
Expand Down
14 changes: 5 additions & 9 deletions backend/handler/filesystem/sync_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,30 @@ def __init__(self) -> None:
def build_incoming_path(
self, device_id: str, platform_slug: str | None = None
) -> str:
"""Build the relative incoming path for a device (and optional platform)."""
parts = [self.base_path, device_id, "incoming"]
parts = [device_id, "incoming"]
if platform_slug:
parts.append(platform_slug)
return os.path.join(*parts)

def build_outgoing_path(
self, device_id: str, platform_slug: str | None = None
) -> str:
"""Build the relative outgoing path for a device (and optional platform)."""
parts = [self.base_path, device_id, "outgoing"]
parts = [device_id, "outgoing"]
if platform_slug:
parts.append(platform_slug)
return os.path.join(*parts)

def build_conflicts_path(
self, device_id: str, platform_slug: str | None = None
) -> str:
"""Build the relative conflicts path for a device (and optional platform)."""
parts = [self.base_path, device_id, "conflicts"]
parts = [device_id, "conflicts"]
if platform_slug:
parts.append(platform_slug)
return os.path.join(*parts)

def ensure_device_directories(self, device_id: str) -> None:
"""Create incoming/outgoing directory structure for a device."""
incoming = Path(self.build_incoming_path(device_id))
outgoing = Path(self.build_outgoing_path(device_id))
incoming = self.base_path / self.build_incoming_path(device_id)
outgoing = self.base_path / self.build_outgoing_path(device_id)

incoming.mkdir(parents=True, exist_ok=True)
outgoing.mkdir(parents=True, exist_ok=True)
Expand Down
4 changes: 4 additions & 0 deletions backend/handler/sync/ssh_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ async def connect(
"provide ssh_key_path/ssh_password in sync_config."
)

log.warning(
f"SSH host key verification disabled for {host} -- "
"connection is vulnerable to MITM attacks"
)
log.info(f"Connecting to {username}@{host}:{port}")
return await asyncssh.connect(**connect_kwargs)

Expand Down
18 changes: 6 additions & 12 deletions backend/sync_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,32 +46,26 @@
def _extract_device_and_platform(path: str) -> tuple[str, str, str] | None:
"""Extract device_id, platform_slug, and filename from a sync incoming path.

Expected path format: {build_incoming_path(device_id, platform_slug)}/filename.ext
i.e. {SYNC_BASE_PATH}/{device_id}/incoming/{platform_slug}/filename.ext
Expected path format: {SYNC_BASE_PATH}/{device_id}/incoming/{platform_slug}/filename.ext
"""
try:
rel_path = os.path.relpath(path)
rel_path = os.path.relpath(path, start=str(fs_sync_handler.base_path))
parts = rel_path.split(os.sep)
# Minimum: device_id / incoming / platform_slug / filename
if len(parts) < 4 or parts[1] != "incoming":
return None
device_id = parts[0]
platform_slug = parts[2]
filename = parts[-1]

# Validate path matches the canonical incoming path structure
expected_prefix = fs_sync_handler.build_incoming_path(device_id, platform_slug)
if not rel_path.startswith(expected_prefix):
return None

return (device_id, platform_slug, filename)
except (ValueError, IndexError):
return None


def _ensure_conflicts_dir(device_id: str, platform_slug: str) -> str:
"""Ensure the conflicts directory exists and return its path."""
conflicts_dir = fs_sync_handler.build_conflicts_path(device_id, platform_slug)
conflicts_dir = str(
fs_sync_handler.base_path
/ fs_sync_handler.build_conflicts_path(device_id, platform_slug)
)
os.makedirs(conflicts_dir, exist_ok=True)
return conflicts_dir

Expand Down
Loading
Loading