forked from rommapp/romm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathssh_handler.py
More file actions
223 lines (182 loc) · 7.66 KB
/
ssh_handler.py
File metadata and controls
223 lines (182 loc) · 7.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
"""SSH/SFTP handler for Push-Pull sync mode.
Provides methods to connect to remote devices via SSH, list remote save files,
and perform bidirectional file transfers using SFTP.
SSH keys are expected to be pre-mounted on the server (e.g. via Docker volume)
at the path configured by SYNC_SSH_KEYS_PATH. Keys are looked up by device_id
({SYNC_SSH_KEYS_PATH}/{device_id}.pem) or via an explicit ssh_key_path in the
device's sync_config.
"""
from __future__ import annotations
import hashlib
import os
import tempfile
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import asyncssh
from config import SYNC_SSH_KEYS_PATH
from logger.logger import log
@dataclass
class RemoteSaveInfo:
"""Information about a save file on a remote device."""
path: str
file_name: str
platform_slug: str
file_size: int
mtime: datetime
content_hash: str | None = None
class SSHSyncHandler:
"""Handles SSH/SFTP operations for push-pull sync mode.
SSH keys are expected to be pre-mounted on the server filesystem at
SYNC_SSH_KEYS_PATH. The handler looks up keys by device_id convention
({keys_path}/{device_id}.pem) or uses an explicit path from sync_config.
"""
def __init__(self) -> None:
self.keys_path = Path(SYNC_SSH_KEYS_PATH)
self.keys_path.mkdir(parents=True, exist_ok=True)
def _resolve_key_path(self, device_id: str, sync_config: dict) -> str | None:
"""Resolve the SSH key path for a device.
Checks, in order:
1. Explicit ssh_key_path in sync_config
2. Convention-based path: {SYNC_SSH_KEYS_PATH}/{device_id}.pem
"""
explicit = sync_config.get("ssh_key_path")
if explicit and os.path.isfile(explicit):
return explicit
convention_path = self.keys_path / f"{device_id}.pem"
if convention_path.is_file():
return str(convention_path)
return None
async def connect(
self, sync_config: dict, device_id: str | None = None
) -> asyncssh.SSHClientConnection:
"""Establish an SSH connection using device sync_config.
SSH keys should be pre-mounted on the server. The handler resolves
the key by checking sync_config.ssh_key_path first, then falls back
to the convention-based path {SYNC_SSH_KEYS_PATH}/{device_id}.pem.
sync_config should contain:
- ssh_host: hostname or IP
- ssh_port: port (default 22)
- ssh_username: username
- ssh_key_path: explicit path to private key (optional)
- ssh_password: password (optional, fallback if no key found)
"""
host = sync_config["ssh_host"]
port = sync_config.get("ssh_port", 22)
username = sync_config.get("ssh_username", "root")
connect_kwargs: dict[str, Any] = {
"host": host,
"port": port,
"username": username,
"known_hosts": None, # Accept all host keys (TODO: make configurable)
}
# Resolve key path (explicit or convention-based)
key_path = self._resolve_key_path(device_id or "", sync_config)
if key_path:
connect_kwargs["client_keys"] = [key_path]
elif sync_config.get("ssh_password"):
connect_kwargs["password"] = sync_config["ssh_password"]
else:
raise ValueError(
f"No SSH authentication method available for {host}. "
f"Mount a key at {self.keys_path}/{{device_id}}.pem or "
"provide ssh_key_path/ssh_password in sync_config."
)
log.warning(
f"SSH host key verification disabled for {host} -- "
"connection is vulnerable to MITM attacks"
)
log.info(f"Connecting to {username}@{host}:{port}")
return await asyncssh.connect(**connect_kwargs)
async def list_remote_saves(
self,
conn: asyncssh.SSHClientConnection,
save_directories: list[dict],
) -> list[RemoteSaveInfo]:
"""List save files on a remote device.
save_directories is a list of dicts with keys:
- platform_slug: str
- path: str (remote directory path)
- extension: str (optional, file extension filter, e.g. ".srm")
"""
results: list[RemoteSaveInfo] = []
async with conn.start_sftp_client() as sftp:
for dir_config in save_directories:
platform_slug = dir_config["platform_slug"]
remote_path = dir_config["path"]
extension = dir_config.get("extension", "")
try:
entries = await sftp.listdir(remote_path)
except asyncssh.SFTPNoSuchFile:
log.warning(f"Remote directory not found: {remote_path}")
continue
for entry in entries:
if extension and not entry.endswith(extension):
continue
full_remote_path = f"{remote_path}/{entry}"
try:
attrs = await sftp.stat(full_remote_path)
if not attrs.type == asyncssh.constants.FILEXFER_TYPE_REGULAR:
continue
mtime = datetime.fromtimestamp(
attrs.mtime or 0, tz=timezone.utc
)
results.append(
RemoteSaveInfo(
path=full_remote_path,
file_name=entry,
platform_slug=platform_slug,
file_size=attrs.size or 0,
mtime=mtime,
)
)
except asyncssh.SFTPError as e:
log.warning(f"Failed to stat {full_remote_path}: {e}")
return results
async def download_save(
self,
conn: asyncssh.SSHClientConnection,
remote_path: str,
local_path: str | None = None,
) -> tuple[str, str]:
"""Download a save file from a remote device.
Returns (local_temp_path, content_hash).
"""
if local_path is None:
fd, local_path = tempfile.mkstemp(prefix="romm_sync_")
os.close(fd)
async with conn.start_sftp_client() as sftp:
await sftp.get(remote_path, local_path)
# Compute hash
hash_obj = hashlib.md5(usedforsecurity=False)
with open(local_path, "rb") as f:
while chunk := f.read(8192):
hash_obj.update(chunk)
return local_path, hash_obj.hexdigest()
async def upload_save(
self,
conn: asyncssh.SSHClientConnection,
local_path: str,
remote_path: str,
) -> None:
"""Upload a save file to a remote device."""
async with conn.start_sftp_client() as sftp:
# Ensure remote directory exists
remote_dir = os.path.dirname(remote_path)
try:
await sftp.mkdir(remote_dir)
except asyncssh.SFTPError:
pass # Directory likely already exists
await sftp.put(local_path, remote_path)
log.info(f"Uploaded {local_path} -> {remote_path}")
async def delete_remote_save(
self,
conn: asyncssh.SSHClientConnection,
remote_path: str,
) -> None:
"""Delete a save file from a remote device."""
async with conn.start_sftp_client() as sftp:
await sftp.remove(remote_path)
log.info(f"Deleted remote file: {remote_path}")
ssh_sync_handler = SSHSyncHandler()