Skip to content

Commit dbe7ff1

Browse files
committed
fix: add validation to avoid path-traversal vulnerabilities
1 parent ffe766f commit dbe7ff1

File tree

1 file changed

+58
-11
lines changed
  • libs/ktem/ktem/index/file

1 file changed

+58
-11
lines changed

libs/ktem/ktem/index/file/ui.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import os
44
import shutil
5+
import stat
56
import tempfile
67
import zipfile
78
from copy import deepcopy
@@ -1059,17 +1060,53 @@ def _may_extract_zip(self, files, zip_dir: str):
10591060
"""Handle zip files"""
10601061
zip_files = [file for file in files if file.endswith(".zip")]
10611062
remaining_files = [file for file in files if not file.endswith("zip")]
1063+
errors = []
10621064

10631065
# Clean-up <zip_dir> before unzip to remove old files
10641066
shutil.rmtree(zip_dir, ignore_errors=True)
10651067

1068+
# Unzip
1069+
unsafe_zip_files = []
10661070
for zip_file in zip_files:
10671071
# Prepare new zip output dir, separated for each files
10681072
basename = os.path.splitext(os.path.basename(zip_file))[0]
10691073
zip_out_dir = os.path.join(zip_dir, basename)
10701074
os.makedirs(zip_out_dir, exist_ok=True)
1075+
10711076
with zipfile.ZipFile(zip_file, "r") as zip_ref:
1072-
zip_ref.extractall(zip_out_dir)
1077+
# Check for symlinks and path traversal attacks at zip level
1078+
is_safe = False
1079+
1080+
for member in zip_ref.infolist():
1081+
# Disallow symlinks
1082+
if stat.S_ISLNK(member.external_attr >> 16):
1083+
# Skipping zip file with symlink
1084+
is_safe = False
1085+
break
1086+
1087+
# Check for path traversal attacks
1088+
target_path = os.path.join(zip_out_dir, member.filename)
1089+
abs_zip_out_dir = os.path.abspath(zip_out_dir)
1090+
abs_target_path = os.path.abspath(target_path)
1091+
1092+
if not (
1093+
abs_target_path.startswith(abs_zip_out_dir + os.sep)
1094+
or abs_target_path == abs_zip_out_dir
1095+
):
1096+
# Skip zip file with path traversal file
1097+
is_safe = False
1098+
break
1099+
1100+
if is_safe:
1101+
zip_ref.extractall(zip_out_dir)
1102+
else:
1103+
unsafe_zip_files.append(zip_file)
1104+
1105+
if unsafe_zip_files:
1106+
str_error = ", ".join(unsafe_zip_files)
1107+
errors.append(
1108+
f"Unsafe zip files (contains symlinks or path traversal): {str_error}"
1109+
)
10731110

10741111
n_zip_file = 0
10751112
for root, dirs, files in os.walk(zip_dir):
@@ -1084,7 +1121,7 @@ def _may_extract_zip(self, files, zip_dir: str):
10841121
if n_zip_file > 0:
10851122
print(f"Update zip files: {n_zip_file}")
10861123

1087-
return remaining_files
1124+
return remaining_files, errors
10881125

10891126
def index_fn(
10901127
self, files, urls, reindex: bool, settings, user_id
@@ -1100,20 +1137,22 @@ def index_fn(
11001137
"""
11011138
if urls:
11021139
files = [it.strip() for it in urls.split("\n")]
1103-
errors = []
1140+
errors = self.validate_urls(files)
11041141
else:
11051142
if not files:
11061143
gr.Info("No uploaded file")
11071144
yield "", ""
11081145
return
1146+
files, unzip_errors = self._may_extract_zip(
1147+
files, flowsettings.KH_ZIP_INPUT_DIR
1148+
)
1149+
errors = self.validate_files(files)
1150+
errors.extend(unzip_errors)
11091151

1110-
files = self._may_extract_zip(files, flowsettings.KH_ZIP_INPUT_DIR)
1111-
1112-
errors = self.validate(files)
1113-
if errors:
1114-
gr.Warning(", ".join(errors))
1115-
yield "", ""
1116-
return
1152+
if errors:
1153+
gr.Warning(", ".join(errors))
1154+
yield "", ""
1155+
return
11171156

11181157
gr.Info(f"Start indexing {len(files)} files...")
11191158

@@ -1569,7 +1608,7 @@ def interact_group_list(self, list_groups, ev: gr.SelectData):
15691608
selected_item["files"],
15701609
)
15711610

1572-
def validate(self, files: list[str]):
1611+
def validate_files(self, files: list[str]):
15731612
"""Validate if the files are valid"""
15741613
paths = [Path(file) for file in files]
15751614
errors = []
@@ -1598,6 +1637,14 @@ def validate(self, files: list[str]):
15981637

15991638
return errors
16001639

1640+
def validate_urls(self, urls: list[str]):
1641+
"""Validate if the urls are valid"""
1642+
errors = []
1643+
for url in urls:
1644+
if not url.startswith("http") and not url.startswith("https"):
1645+
errors.append(f"Invalid url `{url}`")
1646+
return errors
1647+
16011648

16021649
class FileSelector(BasePage):
16031650
"""File selector UI in the Chat page"""

0 commit comments

Comments
 (0)