Skip to content

Commit 16cc0be

Browse files
authored
Merge branch 'main' into patch-1
2 parents b266df8 + ac13894 commit 16cc0be

File tree

57 files changed

+961
-114
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+961
-114
lines changed

ddtrace/_trace/trace_handlers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,7 @@ def _set_span_pointer(span: "Span", span_pointer_description: _SpanPointerDescri
823823
)
824824

825825

826-
def _set_azure_function_tags(span, azure_functions_config, function_name, trigger, span_kind=SpanKind.INTERNAL):
826+
def _set_azure_function_tags(span, azure_functions_config, function_name, trigger, span_kind):
827827
span.set_tag_str(COMPONENT, azure_functions_config.integration_name)
828828
span.set_tag_str(SPAN_KIND, span_kind)
829829
span.set_tag_str("aas.function.name", function_name) # codespell:ignore
@@ -857,9 +857,9 @@ def _on_azure_functions_start_response(ctx, azure_functions_config, res, functio
857857
)
858858

859859

860-
def _on_azure_functions_trigger_span_modifier(ctx, azure_functions_config, function_name, trigger):
860+
def _on_azure_functions_trigger_span_modifier(ctx, azure_functions_config, function_name, trigger, span_kind):
861861
span = ctx.get_item("trigger_span")
862-
_set_azure_function_tags(span, azure_functions_config, function_name, trigger)
862+
_set_azure_function_tags(span, azure_functions_config, function_name, trigger, span_kind)
863863

864864

865865
def listen():
@@ -968,6 +968,7 @@ def listen():
968968
"rq.job.perform",
969969
"rq.job.fetch_many",
970970
"azure.functions.patched_route_request",
971+
"azure.functions.patched_service_bus",
971972
"azure.functions.patched_timer",
972973
):
973974
core.on(f"context.started.start_span.{context_name}", _start_span)

ddtrace/_trace/tracer.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from contextlib import contextmanager
22
import functools
3+
import inspect
34
from inspect import iscoroutinefunction
45
from itertools import chain
56
import logging
@@ -775,6 +776,56 @@ def flush(self):
775776
"""Flush the buffer of the trace writer. This does nothing if an unbuffered trace writer is used."""
776777
self._span_aggregator.writer.flush_queue()
777778

779+
def _wrap_generator(
780+
self,
781+
f: AnyCallable,
782+
span_name: str,
783+
service: Optional[str] = None,
784+
resource: Optional[str] = None,
785+
span_type: Optional[str] = None,
786+
) -> AnyCallable:
787+
"""Wrap a generator function with tracing."""
788+
789+
@functools.wraps(f)
790+
def func_wrapper(*args, **kwargs):
791+
if getattr(self, "_wrap_executor", None):
792+
return self._wrap_executor(
793+
self,
794+
f,
795+
args,
796+
kwargs,
797+
span_name,
798+
service=service,
799+
resource=resource,
800+
span_type=span_type,
801+
)
802+
803+
with self.trace(span_name, service=service, resource=resource, span_type=span_type):
804+
gen = f(*args, **kwargs)
805+
for value in gen:
806+
yield value
807+
808+
return func_wrapper
809+
810+
def _wrap_generator_async(
811+
self,
812+
f: AnyCallable,
813+
span_name: str,
814+
service: Optional[str] = None,
815+
resource: Optional[str] = None,
816+
span_type: Optional[str] = None,
817+
) -> AnyCallable:
818+
"""Wrap a generator function with tracing."""
819+
820+
@functools.wraps(f)
821+
async def func_wrapper(*args, **kwargs):
822+
with self.trace(span_name, service=service, resource=resource, span_type=span_type):
823+
agen = f(*args, **kwargs)
824+
async for value in agen:
825+
yield value
826+
827+
return func_wrapper
828+
778829
def wrap(
779830
self,
780831
name: Optional[str] = None,
@@ -812,6 +863,15 @@ async def coroutine():
812863
def coroutine():
813864
return 'executed'
814865
866+
>>> # or use it on generators
867+
@tracer.wrap()
868+
def gen():
869+
yield 'executed'
870+
871+
>>> @tracer.wrap()
872+
async def gen():
873+
yield 'executed'
874+
815875
You can access the current span using `tracer.current_span()` to set
816876
tags:
817877
@@ -825,10 +885,26 @@ def wrap_decorator(f: AnyCallable) -> AnyCallable:
825885
# FIXME[matt] include the class name for methods.
826886
span_name = name if name else "%s.%s" % (f.__module__, f.__name__)
827887

828-
# detect if the the given function is a coroutine to use the
829-
# right decorator; this initial check ensures that the
888+
# detect if the the given function is a coroutine and/or a generator
889+
# to use the right decorator; this initial check ensures that the
830890
# evaluation is done only once for each @tracer.wrap
831-
if iscoroutinefunction(f):
891+
if inspect.isgeneratorfunction(f):
892+
func_wrapper = self._wrap_generator(
893+
f,
894+
span_name,
895+
service=service,
896+
resource=resource,
897+
span_type=span_type,
898+
)
899+
elif inspect.isasyncgenfunction(f):
900+
func_wrapper = self._wrap_generator_async(
901+
f,
902+
span_name,
903+
service=service,
904+
resource=resource,
905+
span_type=span_type,
906+
)
907+
elif iscoroutinefunction(f):
832908
# create an async wrapper that awaits the coroutine and traces it
833909
@functools.wraps(f)
834910
async def func_wrapper(*args, **kwargs):

ddtrace/appsec/_iast/_patch_modules.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ddtrace.appsec._iast.secure_marks.sanitizers import path_traversal_sanitizer
77
from ddtrace.appsec._iast.secure_marks.sanitizers import sqli_sanitizer
88
from ddtrace.appsec._iast.secure_marks.sanitizers import xss_sanitizer
9+
from ddtrace.appsec._iast.secure_marks.validators import ssrf_validator
910
from ddtrace.appsec._iast.secure_marks.validators import unvalidated_redirect_validator
1011

1112

@@ -32,6 +33,7 @@ def patch_iast(patch_modules=IAST_PATCH):
3233
for module in (m for m, e in patch_modules.items() if e):
3334
when_imported("hashlib")(_on_import_factory(module, "ddtrace.appsec._iast.taint_sinks.%s", raise_errors=False))
3435

36+
# CMDI sanitizers
3537
when_imported("shlex")(
3638
lambda _: try_wrap_function_wrapper(
3739
"shlex",
@@ -40,6 +42,11 @@ def patch_iast(patch_modules=IAST_PATCH):
4042
)
4143
)
4244

45+
# SSRF
46+
when_imported("django.utils.http")(
47+
lambda _: try_wrap_function_wrapper("django.utils.http", "url_has_allowed_host_and_scheme", ssrf_validator)
48+
)
49+
4350
# SQL sanitizers
4451
when_imported("mysql.connector.conversion")(
4552
lambda _: try_wrap_function_wrapper(

ddtrace/appsec/_iast/secure_marks/validators.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,18 @@ def unvalidated_redirect_validator(wrapped: Callable, instance: Any, args: Seque
9393
True if validation passed, False otherwise
9494
"""
9595
return create_validator(VulnerabilityType.UNVALIDATED_REDIRECT, wrapped, instance, args, kwargs)
96+
97+
98+
def ssrf_validator(wrapped: Callable, instance: Any, args: Sequence, kwargs: dict) -> bool:
99+
"""Validator for ssrf functions.
100+
101+
Args:
102+
wrapped: The original validator function
103+
instance: The instance the function is bound to (if any)
104+
args: Positional arguments
105+
kwargs: Keyword arguments
106+
107+
Returns:
108+
True if validation passed, False otherwise
109+
"""
110+
return create_validator(VulnerabilityType.SSRF, wrapped, instance, args, kwargs)

ddtrace/appsec/_iast/taint_sinks/ssrf.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,19 @@
22

33
from ddtrace.appsec._constants import IAST_SPAN_TAGS
44
from ddtrace.appsec._iast._logs import iast_error
5+
from ddtrace.appsec._iast._logs import iast_propagation_sink_point_debug_log
56
from ddtrace.appsec._iast._metrics import _set_metric_iast_executed_sink
67
from ddtrace.appsec._iast._span_metrics import increment_iast_span_metric
78
from ddtrace.appsec._iast._taint_tracking import VulnerabilityType
9+
from ddtrace.appsec._iast._taint_tracking import get_ranges
810
from ddtrace.appsec._iast.constants import VULN_SSRF
911
from ddtrace.appsec._iast.taint_sinks._base import VulnerabilityBase
10-
from ddtrace.internal.logger import get_logger
1112
from ddtrace.internal.utils import ArgumentError
1213
from ddtrace.internal.utils import get_argument_value
1314
from ddtrace.internal.utils.importlib import func_name
1415
from ddtrace.settings.asm import config as asm_config
1516

1617

17-
log = get_logger(__name__)
18-
19-
2018
class SSRF(VulnerabilityBase):
2119
vulnerability_type = VULN_SSRF
2220
secure_mark = VulnerabilityType.SSRF
@@ -33,24 +31,41 @@ class SSRF(VulnerabilityBase):
3331

3432

3533
def _iast_report_ssrf(func: Callable, *args, **kwargs):
34+
"""
35+
Check and report potential SSRF (Server-Side Request Forgery) vulnerabilities in function calls.
36+
37+
This function analyzes calls to URL-handling functions to detect potential SSRF vulnerabilities.
38+
It checks if the URL argument is tainted (user-controlled) and reports it if conditions are met.
39+
URL fragments (parts after #) are handled specially - if all tainted parts are in the fragment,
40+
no vulnerability is reported.
41+
"""
3642
func_key = func_name(func)
3743
arg_pos, kwarg_name = _FUNC_TO_URL_ARGUMENT.get(func_key, (None, None))
3844
if arg_pos is None:
39-
log.debug("%s not found in list of functions supported for SSRF", func_key)
45+
iast_propagation_sink_point_debug_log("%s not found in list of functions supported for SSRF", func_key)
4046
return
4147

4248
try:
4349
kw = kwarg_name if kwarg_name else ""
4450
report_ssrf = get_argument_value(list(args), kwargs, arg_pos, kw)
4551
except ArgumentError:
46-
log.debug("Failed to get URL argument from _FUNC_TO_URL_ARGUMENT dict for function %s", func_key)
52+
iast_propagation_sink_point_debug_log(
53+
"Failed to get URL argument from _FUNC_TO_URL_ARGUMENT dict for function %s", func_key
54+
)
4755
return
48-
4956
if report_ssrf:
5057
if asm_config.is_iast_request_enabled:
5158
try:
5259
if SSRF.has_quota() and SSRF.is_tainted_pyobject(report_ssrf):
53-
SSRF.report(evidence_value=report_ssrf)
60+
valid_to_report = True
61+
fragment_start = report_ssrf.find("#")
62+
taint_ranges = get_ranges(report_ssrf)
63+
if fragment_start != -1:
64+
# If all taint ranges are in the fragment, do not report
65+
if all(r.start >= fragment_start for r in taint_ranges):
66+
valid_to_report = False
67+
if valid_to_report:
68+
SSRF.report(evidence_value=report_ssrf)
5469

5570
# Reports Span Metrics
5671
_set_metric_iast_executed_sink(SSRF.vulnerability_type)

0 commit comments

Comments
 (0)