Skip to content

Commit 616c41c

Browse files
fix: include missing warn for Poetry (#758)
The previous warning implementation missed Poetry; this fix unifies the behavior.
1 parent 9459d60 commit 616c41c

File tree

5 files changed

+479
-65
lines changed

5 files changed

+479
-65
lines changed

safety/tool/mixins.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from typing import Any, List, Protocol, Tuple, Dict, Optional, runtime_checkable
2+
import typer
3+
from rich.padding import Padding
4+
5+
from .base import EnvironmentDiffTracker
6+
7+
from safety.console import main_console as console
8+
from safety.init.render import render_header, progressive_print
9+
from safety.models import ToolResult
10+
import logging
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
@runtime_checkable
16+
class AuditableCommand(Protocol):
17+
"""
18+
Protocol defining the contract for classes that can be audited for packages.
19+
"""
20+
21+
@property
22+
def _diff_tracker(self) -> "EnvironmentDiffTracker":
23+
"""
24+
Provides package tracking functionality.
25+
"""
26+
...
27+
28+
@property
29+
def _packages(self) -> List[Tuple[str, Optional[str]]]:
30+
"""
31+
Provides the target package list.
32+
"""
33+
...
34+
35+
36+
class InstallationAuditMixin:
37+
"""
38+
Mixin providing installation audit functionality for command classes.
39+
40+
This mixin can be used by any command class that needs to audit
41+
installation and show warnings.
42+
43+
Classes using this mixin should conform to the AuditableCommand protocol.
44+
"""
45+
46+
def render_installation_warnings(
47+
self, ctx: typer.Context, packages_audit: Dict[str, Any]
48+
):
49+
"""
50+
Render installation warnings based on package audit results.
51+
52+
Args:
53+
ctx: The typer context
54+
packages_audit: pre-fetched audit data
55+
"""
56+
57+
warning_messages = []
58+
for audited_package in packages_audit.get("audit", {}).get("packages", []):
59+
vulnerabilities = audited_package.get("vulnerabilities", {})
60+
critical_vulnerabilities = vulnerabilities.get("critical", 0)
61+
total_vulnerabilities = 0
62+
for count in vulnerabilities.values():
63+
total_vulnerabilities += count
64+
65+
if total_vulnerabilities == 0:
66+
continue
67+
68+
warning_message = f"[[yellow]Warning[/yellow]] {audited_package.get('package_specifier')} contains {total_vulnerabilities} vulnerabilities"
69+
if critical_vulnerabilities > 0:
70+
warning_message += f", including {critical_vulnerabilities} critical severity vulnerabilities"
71+
72+
warning_message += "."
73+
warning_messages.append(warning_message)
74+
75+
if len(warning_messages) > 0:
76+
console.print()
77+
render_header(" Safety Report")
78+
progressive_print(warning_messages)
79+
console.line()
80+
81+
def render_package_details(self: "AuditableCommand"):
82+
"""
83+
Render details for installed packages.
84+
"""
85+
for package_name, _ in self._packages:
86+
console.print(
87+
Padding(
88+
f"Learn more: [link]https://data.safetycli.com/packages/pypi/{package_name}/[/link]",
89+
(0, 0, 0, 1),
90+
),
91+
emoji=True,
92+
)
93+
94+
def audit_packages(self, ctx: typer.Context) -> Dict[str, Any]:
95+
"""
96+
Audit packages based on environment diff tracking.
97+
Override this method in your command class if needed.
98+
99+
Args:
100+
ctx: The typer context
101+
102+
Returns:
103+
Dict containing audit results
104+
"""
105+
try:
106+
# Check if the instance has a diff tracker and can get a diff
107+
# Using getattr to avoid lint errors
108+
diff_tracker = getattr(self, "_diff_tracker", None)
109+
if diff_tracker and hasattr(diff_tracker, "get_diff"):
110+
added, _, updated = diff_tracker.get_diff()
111+
packages = {**added, **updated}
112+
113+
if hasattr(ctx.obj, "auth") and hasattr(ctx.obj.auth, "client"):
114+
return ctx.obj.auth.client.audit_packages(
115+
[
116+
f"{package_name}=={version[-1] if isinstance(version, tuple) else version}"
117+
for (package_name, version) in packages.items()
118+
]
119+
)
120+
except Exception:
121+
logger.debug("Audit API failed with error", exc_info=True)
122+
123+
# Always return a dict to satisfy the return type
124+
return dict()
125+
126+
def handle_installation_audit(self, ctx: typer.Context, result: ToolResult):
127+
"""
128+
Handle installation audit and rendering warnings/details.
129+
This is an explicit method that can be called from a command's after method.
130+
131+
Usage example:
132+
def after(self, ctx, result):
133+
super().after(ctx, result)
134+
self.handle_installation_audit(ctx, result)
135+
136+
Args:
137+
ctx: The typer context
138+
result: The tool result
139+
"""
140+
141+
if not isinstance(self, AuditableCommand):
142+
raise TypeError(
143+
"handle_installation_audit can only be called on instances of AuditableCommand"
144+
)
145+
146+
packages_audit = self.audit_packages(ctx)
147+
self.render_installation_warnings(ctx, packages_audit)
148+
149+
# If command failed, show package details
150+
if not result.process or result.process.returncode != 0:
151+
self.render_package_details()

safety/tool/pip/command.py

Lines changed: 6 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
from pathlib import Path
33
import re
44
from tempfile import mkstemp
5-
from typing import TYPE_CHECKING, Any, List, Optional
5+
from typing import TYPE_CHECKING, List, Optional
66

77
import logging
8-
from rich.padding import Padding
98
import typer
109

1110
from safety.models import ToolResult
@@ -15,10 +14,9 @@
1514
from ..intents import ToolIntentionType
1615
from safety_schemas.models.events.types import ToolType
1716
from ..environment_diff import EnvironmentDiffTracker, PipEnvironmentDiffTracker
17+
from ..mixins import InstallationAuditMixin
1818
from ..utils import Pip
1919

20-
from safety.console import main_console as console
21-
from ...init.render import render_header, progressive_print
2220

2321
PIP_LOCK = "safety-pip.lock"
2422

@@ -80,10 +78,10 @@ class PipGenericCommand(PipCommand):
8078
pass
8179

8280

83-
class PipInstallCommand(PipCommand):
81+
class PipInstallCommand(PipCommand, InstallationAuditMixin):
8482
def __init__(self, *args, **kwargs) -> None:
8583
super().__init__(*args, **kwargs)
86-
self.__packages = []
84+
self._packages = []
8785
self.__index_url = None
8886

8987
def before(self, ctx: typer.Context):
@@ -92,7 +90,7 @@ def before(self, ctx: typer.Context):
9290

9391
if self._intention:
9492
for pkg in self._intention.packages:
95-
self.__packages.append((pkg.name, pkg.version_constraint))
93+
self._packages.append((pkg.name, pkg.version_constraint))
9694

9795
if index_opt := self._intention.options.get(
9896
"index-url"
@@ -137,65 +135,9 @@ def before(self, ctx: typer.Context):
137135

138136
def after(self, ctx: typer.Context, result: ToolResult):
139137
super().after(ctx, result)
140-
141-
self.__render_installation_warnings(ctx)
142-
143-
if not result.process or result.process.returncode != 0:
144-
self.__render_package_details()
138+
self.handle_installation_audit(ctx, result)
145139

146140
def env(self, ctx: typer.Context) -> dict:
147141
env = super().env(ctx)
148142
env["PIP_INDEX_URL"] = Pip.build_index_url(ctx, self.__index_url)
149143
return env
150-
151-
def __render_installation_warnings(self, ctx: typer.Context):
152-
packages_audit = self.__audit_packages(ctx)
153-
154-
warning_messages = []
155-
for audited_package in packages_audit.get("audit", {}).get("packages", []):
156-
vulnerabilities = audited_package.get("vulnerabilities", {})
157-
critical_vulnerabilities = vulnerabilities.get("critical", 0)
158-
total_vulnerabilities = 0
159-
for count in vulnerabilities.values():
160-
total_vulnerabilities += count
161-
162-
if total_vulnerabilities == 0:
163-
continue
164-
165-
warning_message = f"[[yellow]Warning[/yellow]] {audited_package.get('package_specifier')} contains {total_vulnerabilities} vulnerabilities"
166-
if critical_vulnerabilities > 0:
167-
warning_message += f", including {critical_vulnerabilities} critical severity vulnerabilities"
168-
169-
warning_message += "."
170-
warning_messages.append(warning_message)
171-
172-
if len(warning_messages) > 0:
173-
console.print()
174-
render_header(" Safety Report")
175-
progressive_print(warning_messages)
176-
177-
def __render_package_details(self):
178-
for package_name, version_specifier in self.__packages:
179-
console.print(
180-
Padding(
181-
f"Learn more: [link]https://data.safetycli.com/packages/pypi/{package_name}/[/link]",
182-
(0, 0, 0, 1),
183-
),
184-
emoji=True,
185-
)
186-
187-
def __audit_packages(self, ctx: typer.Context) -> Any:
188-
try:
189-
added, _, updated = self._diff_tracker.get_diff()
190-
packages = {**added, **updated}
191-
192-
return ctx.obj.auth.client.audit_packages(
193-
[
194-
f"{package_name}=={version[-1] if isinstance(version, tuple) else version}"
195-
for (package_name, version) in packages.items()
196-
]
197-
)
198-
except Exception:
199-
logger.debug("Audit API failed with error", exc_info=True)
200-
# do not propagate the error in case the audit failed
201-
return dict()

safety/tool/poetry/command.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111

1212
from ..auth import index_credentials
1313
from ..base import BaseCommand, ToolIntentionType
14+
from ..mixins import InstallationAuditMixin
1415
from ..environment_diff import EnvironmentDiffTracker, PipEnvironmentDiffTracker
1516
from safety_schemas.models.events.types import ToolType
1617

1718
from safety.console import main_console as console
19+
from safety.models import ToolResult
1820

1921
PO_LOCK = "safety-po.lock"
2022

@@ -161,7 +163,11 @@ class PoetryGenericCommand(PoetryCommand):
161163
pass
162164

163165

164-
class PoetryAddCommand(PoetryCommand):
166+
class PoetryAddCommand(PoetryCommand, InstallationAuditMixin):
167+
def __init__(self, *args, **kwargs) -> None:
168+
super().__init__(*args, **kwargs)
169+
self._packages = []
170+
165171
def patch_source_option(
166172
self, args: List[str], new_source: str = "safety"
167173
) -> Tuple[Optional[str], List[str]]:
@@ -197,3 +203,19 @@ def before(self, ctx: typer.Context):
197203

198204
_, modified_args = self.patch_source_option(self._args)
199205
self._args = modified_args
206+
207+
# Extract packages from intention for rendering later
208+
if self._intention and self._intention.packages:
209+
for pkg in self._intention.packages:
210+
self._packages.append((pkg.name, pkg.version_constraint))
211+
212+
def after(self, ctx: typer.Context, result: ToolResult):
213+
"""
214+
Run after the command execution. Handle installation audit via mixin.
215+
216+
Args:
217+
ctx: The typer context
218+
result: The tool result
219+
"""
220+
super().after(ctx, result)
221+
self.handle_installation_audit(ctx, result)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# type: ignore
2+
3+
import pytest
4+
from unittest.mock import MagicMock, patch
5+
6+
import typer
7+
8+
from safety.tool.pip.command import PipInstallCommand
9+
from safety.tool.uv.command import UvInstallCommand
10+
from safety.tool.poetry.command import PoetryAddCommand
11+
12+
13+
class TestInstallationCommandsAudit:
14+
"""
15+
Test suite for verifying installation audit functionality in command classes.
16+
"""
17+
18+
def setup_method(self):
19+
"""
20+
Set up test fixtures.
21+
"""
22+
self.ctx = MagicMock(spec=typer.Context)
23+
self.ctx.obj = MagicMock()
24+
self.result = MagicMock(duration_ms=100, process=MagicMock(returncode=0))
25+
26+
@pytest.mark.parametrize(
27+
"command_class,command_args",
28+
[
29+
(PipInstallCommand, ["install", "requests"]),
30+
(UvInstallCommand, ["pip", "install", "requests"]),
31+
(PoetryAddCommand, ["add", "requests"]),
32+
],
33+
)
34+
@patch("safety.tool.base.BaseCommand._handle_command_result")
35+
def test_installation_command_calls_audit(
36+
self, mock_handle_result, command_class, command_args
37+
):
38+
"""
39+
Test that all installation commands call handle_installation_audit in after().
40+
"""
41+
command = command_class(command_args)
42+
43+
with patch.object(
44+
command_class, "handle_installation_audit"
45+
) as mock_handle_audit:
46+
command.after(self.ctx, self.result)
47+
48+
mock_handle_audit.assert_called_once_with(self.ctx, self.result)

0 commit comments

Comments
 (0)