Skip to content

Commit e37ccb9

Browse files
authored
Merge pull request #31 from jurasofish/draft-handling-of-complex-inputs
Handle complex inputs
2 parents 7bf0e1e + aa56d36 commit e37ccb9

File tree

7 files changed

+700
-16
lines changed

7 files changed

+700
-16
lines changed

Diff for: README.md

+21
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,27 @@ async def fetch_weather(city: str) -> str:
212212
return response.text
213213
```
214214

215+
Complex input handling example:
216+
```python
217+
from pydantic import BaseModel, Field
218+
from typing import Annotated
219+
220+
class ShrimpTank(BaseModel):
221+
class Shrimp(BaseModel):
222+
name: Annotated[str, Field(max_length=10)]
223+
224+
shrimp: list[Shrimp]
225+
226+
@mcp.tool()
227+
def name_shrimp(
228+
tank: ShrimpTank,
229+
# You can use pydantic Field in function signatures for validation.
230+
extra_names: Annotated[list[str], Field(max_length=10)],
231+
) -> list[str]:
232+
"""List all shrimp names in the tank"""
233+
return [shrimp.name for shrimp in tank.shrimp] + extra_names
234+
```
235+
215236
### Prompts
216237

217238
Prompts are reusable templates that help LLMs interact with your server effectively. They're like "best practices" encoded into your server. A prompt can be as simple as a string:

Diff for: examples/complex_inputs.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
FastMCP Complex inputs Example
3+
4+
Demonstrates validation via pydantic with complex models.
5+
"""
6+
7+
from pydantic import BaseModel, Field
8+
from typing import Annotated
9+
from fastmcp.server import FastMCP
10+
11+
mcp = FastMCP("Shrimp Tank")
12+
13+
14+
class ShrimpTank(BaseModel):
15+
class Shrimp(BaseModel):
16+
name: Annotated[str, Field(max_length=10)]
17+
18+
shrimp: list[Shrimp]
19+
20+
21+
@mcp.tool()
22+
def name_shrimp(
23+
tank: ShrimpTank,
24+
# You can use pydantic Field in function signatures for validation.
25+
extra_names: Annotated[list[str], Field(max_length=10)],
26+
) -> list[str]:
27+
"""List all shrimp names in the tank"""
28+
return [shrimp.name for shrimp in tank.shrimp] + extra_names

Diff for: src/fastmcp/exceptions.py

+4
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@ class ResourceError(FastMCPError):
1515

1616
class ToolError(FastMCPError):
1717
"""Error in tool operations."""
18+
19+
20+
class InvalidSignature(Exception):
21+
"""Invalid signature for use with FastMCP."""

Diff for: src/fastmcp/tools/base.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import fastmcp
22
from fastmcp.exceptions import ToolError
33

4-
5-
from pydantic import BaseModel, Field, TypeAdapter, validate_call
4+
from fastmcp.utilities.func_metadata import func_metadata, FuncMetadata
5+
from pydantic import BaseModel, Field
66

77

88
import inspect
@@ -19,6 +19,9 @@ class Tool(BaseModel):
1919
name: str = Field(description="Name of the tool")
2020
description: str = Field(description="Description of what the tool does")
2121
parameters: dict = Field(description="JSON schema for tool parameters")
22+
fn_metadata: FuncMetadata = Field(
23+
description="Metadata about the function including a pydantic model for tool arguments"
24+
)
2225
is_async: bool = Field(description="Whether the tool is async")
2326
context_kwarg: Optional[str] = Field(
2427
None, description="Name of the kwarg that should receive context"
@@ -41,9 +44,6 @@ def from_function(
4144
func_doc = description or fn.__doc__ or ""
4245
is_async = inspect.iscoroutinefunction(fn)
4346

44-
# Get schema from TypeAdapter - will fail if function isn't properly typed
45-
parameters = TypeAdapter(fn).json_schema()
46-
4747
# Find context parameter if it exists
4848
if context_kwarg is None:
4949
sig = inspect.signature(fn)
@@ -52,28 +52,32 @@ def from_function(
5252
context_kwarg = param_name
5353
break
5454

55-
# ensure the arguments are properly cast
56-
fn = validate_call(fn)
55+
func_arg_metadata = func_metadata(
56+
fn,
57+
skip_names=[context_kwarg] if context_kwarg is not None else [],
58+
)
59+
parameters = func_arg_metadata.arg_model.model_json_schema()
5760

5861
return cls(
5962
fn=fn,
6063
name=func_name,
6164
description=func_doc,
6265
parameters=parameters,
66+
fn_metadata=func_arg_metadata,
6367
is_async=is_async,
6468
context_kwarg=context_kwarg,
6569
)
6670

6771
async def run(self, arguments: dict, context: Optional["Context"] = None) -> Any:
6872
"""Run the tool with arguments."""
6973
try:
70-
# Inject context if needed
71-
if self.context_kwarg:
72-
arguments[self.context_kwarg] = context
73-
74-
# Call function with proper async handling
75-
if self.is_async:
76-
return await self.fn(**arguments)
77-
return self.fn(**arguments)
74+
return await self.fn_metadata.call_fn_with_arg_validation(
75+
self.fn,
76+
self.is_async,
77+
arguments,
78+
{self.context_kwarg: context}
79+
if self.context_kwarg is not None
80+
else None,
81+
)
7882
except Exception as e:
7983
raise ToolError(f"Error executing tool {self.name}: {e}") from e

Diff for: src/fastmcp/utilities/func_metadata.py

+200
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import inspect
2+
from collections.abc import Callable, Sequence, Awaitable
3+
from typing import (
4+
Annotated,
5+
Any,
6+
Dict,
7+
ForwardRef,
8+
)
9+
from pydantic import Field
10+
from fastmcp.exceptions import InvalidSignature
11+
from pydantic._internal._typing_extra import try_eval_type
12+
import json
13+
from pydantic import BaseModel
14+
from pydantic.fields import FieldInfo
15+
from pydantic import ConfigDict, create_model
16+
from pydantic import WithJsonSchema
17+
from pydantic_core import PydanticUndefined
18+
from fastmcp.utilities.logging import get_logger
19+
20+
21+
logger = get_logger(__name__)
22+
23+
24+
class ArgModelBase(BaseModel):
25+
"""A model representing the arguments to a function."""
26+
27+
def model_dump_one_level(self) -> dict[str, Any]:
28+
"""Return a dict of the model's fields, one level deep.
29+
30+
That is, sub-models etc are not dumped - they are kept as pydantic models.
31+
"""
32+
kwargs: dict[str, Any] = {}
33+
for field_name in self.model_fields.keys():
34+
kwargs[field_name] = getattr(self, field_name)
35+
return kwargs
36+
37+
model_config = ConfigDict(
38+
arbitrary_types_allowed=True,
39+
)
40+
41+
42+
class FuncMetadata(BaseModel):
43+
arg_model: Annotated[type[ArgModelBase], WithJsonSchema(None)]
44+
# We can add things in the future like
45+
# - Maybe some args are excluded from attempting to parse from JSON
46+
# - Maybe some args are special (like context) for dependency injection
47+
48+
async def call_fn_with_arg_validation(
49+
self,
50+
fn: Callable | Awaitable,
51+
fn_is_async: bool,
52+
arguments_to_validate: dict[str, Any],
53+
arguments_to_pass_directly: dict[str, Any] | None,
54+
) -> Any:
55+
"""Call the given function with arguments validated and injected.
56+
57+
Arguments are first attempted to be parsed from JSON, then validated against
58+
the argument model, before being passed to the function.
59+
"""
60+
arguments_pre_parsed = self.pre_parse_json(arguments_to_validate)
61+
arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed)
62+
arguments_parsed_dict = arguments_parsed_model.model_dump_one_level()
63+
64+
arguments_parsed_dict |= arguments_to_pass_directly or {}
65+
66+
if fn_is_async:
67+
return await fn(**arguments_parsed_dict)
68+
return fn(**arguments_parsed_dict)
69+
70+
def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
71+
"""Pre-parse data from JSON.
72+
73+
Return a dict with same keys as input but with values parsed from JSON
74+
if appropriate.
75+
76+
This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside
77+
a string rather than an actual list. Claude desktop is prone to this - in fact
78+
it seems incapable of NOT doing this. For sub-models, it tends to pass
79+
dicts (JSON objects) as JSON strings, which can be pre-parsed here.
80+
"""
81+
new_data = data.copy() # Shallow copy
82+
for field_name, field_info in self.arg_model.model_fields.items():
83+
if field_name not in data.keys():
84+
continue
85+
if isinstance(data[field_name], str):
86+
try:
87+
pre_parsed = json.loads(data[field_name])
88+
except json.JSONDecodeError:
89+
continue # Not JSON - skip
90+
if isinstance(pre_parsed, str):
91+
# This is likely that the raw value is e.g. `"hello"` which we
92+
# Should really be parsed as '"hello"' in Python - but if we parse
93+
# it as JSON it'll turn into just 'hello'. So we skip it.
94+
continue
95+
new_data[field_name] = pre_parsed
96+
assert new_data.keys() == data.keys()
97+
return new_data
98+
99+
model_config = ConfigDict(
100+
arbitrary_types_allowed=True,
101+
)
102+
103+
104+
def func_metadata(func: Callable, skip_names: Sequence[str] = ()) -> FuncMetadata:
105+
"""Given a function, return metadata including a pydantic model representing its signature.
106+
107+
The use case for this is
108+
```
109+
meta = func_to_pyd(func)
110+
validated_args = meta.arg_model.model_validate(some_raw_data_dict)
111+
return func(**validated_args.model_dump_one_level())
112+
```
113+
114+
**critically** it also provides pre-parse helper to attempt to parse things from JSON.
115+
116+
Args:
117+
func: The function to convert to a pydantic model
118+
skip_names: A list of parameter names to skip. These will not be included in
119+
the model.
120+
Returns:
121+
A pydantic model representing the function's signature.
122+
"""
123+
sig = _get_typed_signature(func)
124+
params = sig.parameters
125+
dynamic_pydantic_model_params: dict[str, Any] = {}
126+
for param in params.values():
127+
if param.name.startswith("_"):
128+
raise InvalidSignature(
129+
f"Parameter {param.name} of {func.__name__} may not start with an underscore"
130+
)
131+
if param.name in skip_names:
132+
continue
133+
annotation = param.annotation
134+
135+
# `x: None` / `x: None = None`
136+
if annotation is None:
137+
annotation = Annotated[
138+
None,
139+
Field(
140+
default=param.default
141+
if param.default is not inspect.Parameter.empty
142+
else PydanticUndefined
143+
),
144+
]
145+
146+
# Untyped field
147+
if annotation is inspect.Parameter.empty:
148+
annotation = Annotated[
149+
Any,
150+
Field(),
151+
# 🤷
152+
WithJsonSchema({"title": param.name, "type": "string"}),
153+
]
154+
155+
field_info = FieldInfo.from_annotated_attribute(
156+
annotation,
157+
param.default
158+
if param.default is not inspect.Parameter.empty
159+
else PydanticUndefined,
160+
)
161+
dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info)
162+
continue
163+
164+
arguments_model = create_model(
165+
f"{func.__name__}Arguments",
166+
**dynamic_pydantic_model_params,
167+
__base__=ArgModelBase,
168+
)
169+
resp = FuncMetadata(arg_model=arguments_model)
170+
return resp
171+
172+
173+
def _get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
174+
if isinstance(annotation, str):
175+
annotation = ForwardRef(annotation)
176+
annotation, status = try_eval_type(annotation, globalns, globalns)
177+
178+
# This check and raise could perhaps be skipped, and we (FastMCP) just call
179+
# model_rebuild right before using it 🤷
180+
if status is False:
181+
raise InvalidSignature(f"Unable to evaluate type annotation {annotation}")
182+
183+
return annotation
184+
185+
186+
def _get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
187+
"""Get function signature while evaluating forward references"""
188+
signature = inspect.signature(call)
189+
globalns = getattr(call, "__globals__", {})
190+
typed_params = [
191+
inspect.Parameter(
192+
name=param.name,
193+
kind=param.kind,
194+
default=param.default,
195+
annotation=_get_typed_annotation(param.annotation, globalns),
196+
)
197+
for param in signature.parameters.values()
198+
]
199+
typed_signature = inspect.Signature(typed_params)
200+
return typed_signature

0 commit comments

Comments
 (0)