Skip to content

Commit 5440c03

Browse files
fix(app): directory traversal when deleting images
1 parent 358dbdb commit 5440c03

File tree

2 files changed

+68
-6
lines changed

2 files changed

+68
-6
lines changed

invokeai/app/services/image_files/image_files_disk.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,26 @@ def delete(self, image_name: str) -> None:
110110
except Exception as e:
111111
raise ImageFileDeleteException from e
112112

113-
# TODO: make this a bit more flexible for e.g. cloud storage
114113
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
115-
path = self.__output_folder / image_name
114+
base_folder = self.__thumbnails_folder if thumbnail else self.__output_folder
115+
filename = get_thumbnail_name(image_name) if thumbnail else image_name
116116

117-
if thumbnail:
118-
thumbnail_name = get_thumbnail_name(image_name)
119-
path = self.__thumbnails_folder / thumbnail_name
117+
# Strip any path information from the filename
118+
basename = Path(filename).name
119+
120+
if basename != filename:
121+
raise ValueError("Invalid image name, potential directory traversal detected")
122+
123+
image_path = base_folder / basename
124+
125+
# Ensure the image path is within the base folder to prevent directory traversal
126+
resolved_base = base_folder.resolve()
127+
resolved_image_path = image_path.resolve()
128+
129+
if not resolved_image_path.is_relative_to(resolved_base):
130+
raise ValueError("Image path outside outputs folder, potential directory traversal detected")
120131

121-
return path
132+
return resolved_image_path
122133

123134
def validate_path(self, path: Union[str, Path]) -> bool:
124135
"""Validates the path given for an image or thumbnail."""
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import platform
2+
from pathlib import Path
3+
4+
import pytest
5+
6+
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
7+
8+
9+
@pytest.fixture
10+
def image_names() -> list[str]:
11+
# Determine the platform and return a path that matches its format
12+
if platform.system() == "Windows":
13+
return [
14+
# Relative paths
15+
"folder\\evil.txt",
16+
"folder\\..\\evil.txt",
17+
# Absolute paths
18+
"\\folder\\evil.txt",
19+
"C:\\folder\\..\\evil.txt",
20+
]
21+
else:
22+
return [
23+
# Relative paths
24+
"folder/evil.txt",
25+
"folder/../evil.txt",
26+
# Absolute paths
27+
"/folder/evil.txt",
28+
"/folder/../evil.txt",
29+
]
30+
31+
32+
def test_directory_traversal_protection(tmp_path: Path, image_names: list[str]):
33+
"""Test that the image file storage prevents directory traversal attacks.
34+
35+
There are two safeguards in the `DiskImageFileStorage.get_path` method:
36+
1. Check if the image name contains any directory traversal characters
37+
2. Check if the resulting path is relative to the base folder
38+
39+
This test checks the first safeguard. I'd like to check the second but I cannot figure out a test case that would
40+
pass the first check but fail the second check.
41+
"""
42+
image_files_disk = DiskImageFileStorage(tmp_path)
43+
for name in image_names:
44+
with pytest.raises(ValueError, match="Invalid image name, potential directory traversal detected"):
45+
image_files_disk.get_path(name)
46+
47+
48+
def test_image_paths_relative_to_storage_dir(tmp_path: Path):
49+
image_files_disk = DiskImageFileStorage(tmp_path)
50+
path = image_files_disk.get_path("foo.png")
51+
assert path.is_relative_to(tmp_path)

0 commit comments

Comments
 (0)