Skip to content

Commit d7efd46

Browse files
feat(langchain): Support BaseCallbackManager
While implementing #4479, I noticed that our Langchain integration lacks support for the `local_callbacks` having type `BaseCallbackManager`, which according to the type hint is possible. This change adds support for this case.
1 parent 7804260 commit d7efd46

File tree

2 files changed

+173
-16
lines changed

2 files changed

+173
-16
lines changed

sentry_sdk/integrations/langchain.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from langchain_core.callbacks import (
2424
manager,
2525
BaseCallbackHandler,
26+
BaseCallbackManager,
2627
Callbacks,
2728
)
2829
from langchain_core.agents import AgentAction, AgentFinish
@@ -434,12 +435,47 @@ def new_configure(
434435
**kwargs,
435436
)
436437

437-
callbacks_list = local_callbacks or []
438+
# Lambda for lazy initialization of the SentryLangchainCallback
439+
sentry_handler_factory = lambda: SentryLangchainCallback(
440+
integration.max_spans,
441+
integration.include_prompts,
442+
integration.tiktoken_encoding_name,
443+
)
444+
445+
local_callbacks = local_callbacks or []
446+
447+
# Handle each possible type of local_callbacks. For each type, we
448+
# extract the list of callbacks to check for SentryLangchainCallback,
449+
# and define a function that would add the SentryLangchainCallback
450+
# to the existing callbacks list.
451+
if isinstance(local_callbacks, BaseCallbackManager):
452+
callbacks_list = local_callbacks.handlers
453+
manager = local_callbacks
454+
455+
# For BaseCallbackManager, we want to copy the manager and add the
456+
# SentryLangchainCallback to the copy.
457+
def local_callbacks_with_sentry():
458+
# type: () -> Union[BaseCallbackManager, list[BaseCallbackHandler]]
459+
new_manager = manager.copy()
460+
new_manager.handlers = [*new_manager.handlers, sentry_handler_factory()]
461+
return new_manager
462+
463+
elif isinstance(local_callbacks, BaseCallbackHandler):
464+
callbacks_list = [local_callbacks]
438465

439-
if isinstance(callbacks_list, BaseCallbackHandler):
440-
callbacks_list = [callbacks_list]
441-
elif not isinstance(callbacks_list, list):
442-
logger.debug("Unknown callback type: %s", callbacks_list)
466+
def local_callbacks_with_sentry():
467+
# type: () -> Union[BaseCallbackManager, list[BaseCallbackHandler]]
468+
return [*callbacks_list, sentry_handler_factory()]
469+
470+
elif isinstance(local_callbacks, list):
471+
callbacks_list = local_callbacks
472+
473+
def local_callbacks_with_sentry():
474+
# type: () -> Union[BaseCallbackManager, list[BaseCallbackHandler]]
475+
return [*callbacks_list, sentry_handler_factory()]
476+
477+
else:
478+
logger.debug("Unknown callback type: %s", local_callbacks)
443479
# Just proceed with original function call
444480
return f(
445481
callback_manager_cls,
@@ -457,20 +493,12 @@ def new_configure(
457493
isinstance(cb, SentryLangchainCallback)
458494
for cb in itertools.chain(callbacks_list, inheritable_callbacks_list)
459495
):
460-
# Avoid mutating the existing callbacks list
461-
callbacks_list = [
462-
*callbacks_list,
463-
SentryLangchainCallback(
464-
integration.max_spans,
465-
integration.include_prompts,
466-
integration.tiktoken_encoding_name,
467-
),
468-
]
496+
local_callbacks = local_callbacks_with_sentry()
469497

470498
return f(
471499
callback_manager_cls,
472500
inheritable_callbacks,
473-
callbacks_list,
501+
local_callbacks,
474502
*args,
475503
**kwargs,
476504
)

tests/integrations/langchain/test_langchain.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Optional, Any, Iterator
2+
from unittest import mock
23
from unittest.mock import Mock
34

45
import pytest
@@ -12,7 +13,7 @@
1213
# Langchain < 0.2
1314
from langchain_community.chat_models import ChatOpenAI
1415

15-
from langchain_core.callbacks import CallbackManagerForLLMRun
16+
from langchain_core.callbacks import BaseCallbackManager, CallbackManagerForLLMRun
1617
from langchain_core.messages import BaseMessage, AIMessageChunk
1718
from langchain_core.outputs import ChatGenerationChunk, ChatResult
1819
from langchain_core.runnables import RunnableConfig
@@ -428,3 +429,131 @@ def test_span_map_is_instance_variable():
428429
assert (
429430
callback1.span_map is not callback2.span_map
430431
), "span_map should be an instance variable, not shared between instances"
432+
433+
434+
def test_langchain_callback_manager(sentry_init):
435+
sentry_init(
436+
integrations=[LangchainIntegration()],
437+
traces_sample_rate=1.0,
438+
)
439+
local_manager = BaseCallbackManager(handlers=[])
440+
441+
with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
442+
mock_configure = mock_manager_module._configure
443+
444+
# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
445+
LangchainIntegration.setup_once()
446+
447+
callback_manager_cls = Mock()
448+
449+
mock_manager_module._configure(
450+
callback_manager_cls, local_callbacks=local_manager
451+
)
452+
453+
assert mock_configure.call_count == 1
454+
455+
call_args = mock_configure.call_args
456+
assert call_args.args[0] is callback_manager_cls
457+
458+
passed_manager = call_args.args[2]
459+
assert passed_manager is not local_manager
460+
assert local_manager.handlers == []
461+
462+
[handler] = passed_manager.handlers
463+
assert isinstance(handler, SentryLangchainCallback)
464+
465+
466+
def test_langchain_callback_manager_with_sentry_callback(sentry_init):
467+
sentry_init(
468+
integrations=[LangchainIntegration()],
469+
traces_sample_rate=1.0,
470+
)
471+
sentry_callback = SentryLangchainCallback(0, False)
472+
local_manager = BaseCallbackManager(handlers=[sentry_callback])
473+
474+
with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
475+
mock_configure = mock_manager_module._configure
476+
477+
# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
478+
LangchainIntegration.setup_once()
479+
480+
callback_manager_cls = Mock()
481+
482+
mock_manager_module._configure(
483+
callback_manager_cls, local_callbacks=local_manager
484+
)
485+
486+
assert mock_configure.call_count == 1
487+
488+
call_args = mock_configure.call_args
489+
assert call_args.args[0] is callback_manager_cls
490+
491+
passed_manager = call_args.args[2]
492+
assert passed_manager is local_manager
493+
494+
[handler] = passed_manager.handlers
495+
assert handler is sentry_callback
496+
497+
498+
def test_langchain_callback_list(sentry_init):
499+
sentry_init(
500+
integrations=[LangchainIntegration()],
501+
traces_sample_rate=1.0,
502+
)
503+
local_callbacks = []
504+
505+
with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
506+
mock_configure = mock_manager_module._configure
507+
508+
# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
509+
LangchainIntegration.setup_once()
510+
511+
callback_manager_cls = Mock()
512+
513+
mock_manager_module._configure(
514+
callback_manager_cls, local_callbacks=local_callbacks
515+
)
516+
517+
assert mock_configure.call_count == 1
518+
519+
call_args = mock_configure.call_args
520+
assert call_args.args[0] is callback_manager_cls
521+
522+
passed_callbacks = call_args.args[2]
523+
assert passed_callbacks is not local_callbacks
524+
assert local_callbacks == []
525+
526+
[handler] = passed_callbacks
527+
assert isinstance(handler, SentryLangchainCallback)
528+
529+
530+
def test_langchain_callback_list_existing_callback(sentry_init):
531+
sentry_init(
532+
integrations=[LangchainIntegration()],
533+
traces_sample_rate=1.0,
534+
)
535+
sentry_callback = SentryLangchainCallback(0, False)
536+
local_callbacks = [sentry_callback]
537+
538+
with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
539+
mock_configure = mock_manager_module._configure
540+
541+
# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
542+
LangchainIntegration.setup_once()
543+
544+
callback_manager_cls = Mock()
545+
546+
mock_manager_module._configure(
547+
callback_manager_cls, local_callbacks=local_callbacks
548+
)
549+
550+
assert mock_configure.call_count == 1
551+
552+
call_args = mock_configure.call_args
553+
assert call_args.args[0] is callback_manager_cls
554+
555+
passed_callbacks = call_args.args[2]
556+
assert passed_callbacks is local_callbacks
557+
558+
[handler] = passed_callbacks
559+
assert handler is sentry_callback

0 commit comments

Comments
 (0)