Skip to content

Commit 25f11e1

Browse files
Merge branch 'main' into users/markwallace/handle_prompt_list_strings
2 parents dbed1f9 + 74ca404 commit 25f11e1

File tree

7 files changed

+1767
-802
lines changed

7 files changed

+1767
-802
lines changed

dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs

+1
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ private async Task<PromptRenderingResult> RenderPromptAsync(
533533
/// <summary>
534534
/// Captures usage details, including token information.
535535
/// </summary>
536+
[ExcludeFromCodeCoverage]
536537
private void CaptureUsageDetails(string? modelId, IReadOnlyDictionary<string, object?>? metadata, ILogger logger)
537538
{
538539
if (!logger.IsEnabled(LogLevel.Information) &&

dotnet/src/SemanticKernel.Core/KernelExtensions.cs

+1
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,7 @@ public static KernelPlugin CreatePluginFromPromptDirectory(
929929
/// <summary>Creates a plugin containing one function per child directory of the specified <paramref name="pluginDirectory"/>.</summary>
930930
[RequiresUnreferencedCode("Uses reflection to handle various aspects of the function creation and invocation, making it incompatible with AOT scenarios.")]
931931
[RequiresDynamicCode("Uses reflection to handle various aspects of the function creation and invocation, making it incompatible with AOT scenarios.")]
932+
[ExcludeFromCodeCoverage]
932933
private static KernelPlugin CreatePluginFromPromptDirectory(
933934
string pluginDirectory,
934935
string? pluginName = null,

dotnet/src/SemanticKernel.Core/Memory/MemoryBuilder.cs

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ namespace Microsoft.SemanticKernel.Memory;
1313
/// A builder for Memory plugin.
1414
/// </summary>
1515
[Experimental("SKEXP0001")]
16+
[ExcludeFromCodeCoverage]
1617
public sealed class MemoryBuilder
1718
{
1819
private Func<IMemoryStore>? _memoryStoreFactory = null;

dotnet/src/SemanticKernel.Core/Memory/SemanticTextMemory.cs

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ namespace Microsoft.SemanticKernel.Memory;
1616
/// in a semantic memory store.
1717
/// </summary>
1818
[Experimental("SKEXP0001")]
19+
[ExcludeFromCodeCoverage]
1920
public sealed class SemanticTextMemory : ISemanticTextMemory
2021
{
2122
private readonly ITextEmbeddingGenerationService _embeddingGenerator;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from unittest.mock import AsyncMock, MagicMock, patch
4+
5+
import pytest
6+
from pydantic import BaseModel, ValidationError
7+
from pydantic_core import InitErrorDetails
8+
from pymongo import AsyncMongoClient
9+
10+
import semantic_kernel.connectors.memory.azure_cosmos_db.azure_cosmos_db_mongodb_collection as cosmos_collection
11+
import semantic_kernel.connectors.memory.azure_cosmos_db.azure_cosmos_db_mongodb_settings as cosmos_settings
12+
from semantic_kernel.data.const import DistanceFunction, IndexKind
13+
from semantic_kernel.data.record_definition import (
14+
VectorStoreRecordDataField,
15+
VectorStoreRecordDefinition,
16+
VectorStoreRecordKeyField,
17+
VectorStoreRecordVectorField,
18+
)
19+
from semantic_kernel.exceptions import VectorStoreInitializationException
20+
21+
22+
async def test_constructor_with_mongo_client_provided() -> None:
23+
"""
24+
Test the constructor of AzureCosmosDBforMongoDBCollection when a mongo_client
25+
is directly provided. Expect that the class is successfully initialized
26+
and doesn't attempt to manage the client.
27+
"""
28+
mock_client = AsyncMock(spec=AsyncMongoClient)
29+
collection_name = "test_collection"
30+
fake_definition = VectorStoreRecordDefinition(
31+
fields={
32+
"id": VectorStoreRecordKeyField(),
33+
"content": VectorStoreRecordDataField(),
34+
"vector": VectorStoreRecordVectorField(),
35+
}
36+
)
37+
38+
collection = cosmos_collection.AzureCosmosDBforMongoDBCollection(
39+
collection_name=collection_name,
40+
data_model_type=dict,
41+
mongo_client=mock_client,
42+
data_model_definition=fake_definition,
43+
)
44+
45+
assert collection.mongo_client == mock_client
46+
assert collection.collection_name == collection_name
47+
assert not collection.managed_client, "Should not be managing client when provided"
48+
49+
50+
async def test_constructor_raises_exception_on_validation_error() -> None:
51+
"""
52+
Test that the constructor raises VectorStoreInitializationException when
53+
AzureCosmosDBforMongoDBSettings.create fails with ValidationError.
54+
"""
55+
56+
mock_data_model_definition = VectorStoreRecordDefinition(
57+
fields={
58+
"id": VectorStoreRecordKeyField(),
59+
"content": VectorStoreRecordDataField(),
60+
"vector": VectorStoreRecordVectorField(),
61+
}
62+
)
63+
64+
class DummyModel(BaseModel):
65+
connection_string: str
66+
67+
error = InitErrorDetails(
68+
type="missing",
69+
loc=("connection_string",),
70+
msg="Field required",
71+
input=None,
72+
) # type: ignore
73+
74+
validation_error = ValidationError.from_exception_data("DummyModel", [error])
75+
76+
with patch.object(
77+
cosmos_settings.AzureCosmosDBforMongoDBSettings,
78+
"create",
79+
side_effect=validation_error,
80+
):
81+
with pytest.raises(VectorStoreInitializationException) as exc_info:
82+
cosmos_collection.AzureCosmosDBforMongoDBCollection(
83+
collection_name="test_collection",
84+
data_model_type=dict,
85+
data_model_definition=mock_data_model_definition,
86+
database_name="",
87+
)
88+
assert "The Azure CosmosDB for MongoDB connection string is required." in str(exc_info.value)
89+
90+
91+
async def test_constructor_raises_exception_if_no_connection_string() -> None:
92+
"""
93+
Ensure that a VectorStoreInitializationException is raised if the
94+
AzureCosmosDBforMongoDBSettings.connection_string is None.
95+
"""
96+
# Mock settings without a connection string
97+
mock_settings = AsyncMock(spec=cosmos_settings.AzureCosmosDBforMongoDBSettings)
98+
mock_settings.connection_string = None
99+
mock_settings.database_name = "some_database"
100+
101+
with patch.object(cosmos_settings.AzureCosmosDBforMongoDBSettings, "create", return_value=mock_settings):
102+
with pytest.raises(VectorStoreInitializationException) as exc_info:
103+
cosmos_collection.AzureCosmosDBforMongoDBCollection(collection_name="test_collection", data_model_type=dict)
104+
assert "The Azure CosmosDB for MongoDB connection string is required." in str(exc_info.value)
105+
106+
107+
async def test_create_collection_calls_database_methods() -> None:
108+
"""
109+
Test create_collection to verify that it first creates a collection, then
110+
calls the appropriate command to create a vector index.
111+
"""
112+
# Setup
113+
mock_database = AsyncMock()
114+
mock_database.create_collection = AsyncMock()
115+
mock_database.command = AsyncMock()
116+
117+
mock_client = AsyncMock(spec=AsyncMongoClient)
118+
mock_client.get_database = MagicMock(return_value=mock_database)
119+
120+
mock_data_model_definition = AsyncMock(spec=VectorStoreRecordDefinition)
121+
# Simulate a data_model_definition with certain fields & vector_fields
122+
mock_field = AsyncMock(spec=VectorStoreRecordDataField)
123+
type(mock_field).name = "test_field"
124+
type(mock_field).is_filterable = True
125+
type(mock_field).is_full_text_searchable = True
126+
127+
type(mock_field).property_type = "str"
128+
129+
mock_vector_field = AsyncMock()
130+
type(mock_vector_field).dimensions = 128
131+
type(mock_vector_field).name = "embedding"
132+
type(mock_vector_field).distance_function = DistanceFunction.COSINE_SIMILARITY
133+
type(mock_vector_field).index_kind = IndexKind.IVF_FLAT
134+
type(mock_vector_field).property_type = "float"
135+
136+
mock_data_model_definition.fields = {"test_field": mock_field}
137+
mock_data_model_definition.vector_fields = [mock_vector_field]
138+
mock_data_model_definition.key_field = mock_field
139+
140+
# Instantiate
141+
collection = cosmos_collection.AzureCosmosDBforMongoDBCollection(
142+
collection_name="test_collection",
143+
data_model_type=dict,
144+
data_model_definition=mock_data_model_definition,
145+
mongo_client=mock_client,
146+
database_name="test_db",
147+
)
148+
149+
# Act
150+
await collection.create_collection(customArg="customValue")
151+
152+
# Assert
153+
mock_database.create_collection.assert_awaited_once_with("test_collection", customArg="customValue")
154+
mock_database.command.assert_awaited()
155+
command_args = mock_database.command.call_args.kwargs["command"]
156+
157+
assert command_args["createIndexes"] == "test_collection"
158+
assert len(command_args["indexes"]) == 2, "One for the data field, one for the vector field"
159+
# Check the data field index
160+
assert command_args["indexes"][0]["name"] == "test_field_"
161+
# Check the vector field index creation
162+
assert command_args["indexes"][1]["name"] == "embedding_"
163+
assert command_args["indexes"][1]["key"] == {"embedding": "cosmosSearch"}
164+
assert command_args["indexes"][1]["cosmosSearchOptions"]["kind"] == "vector-ivf"
165+
assert command_args["indexes"][1]["cosmosSearchOptions"]["similarity"] is not None
166+
assert command_args["indexes"][1]["cosmosSearchOptions"]["dimensions"] == 128
167+
168+
169+
async def test_context_manager_calls_aconnect_and_close_when_managed() -> None:
170+
"""
171+
Test that the context manager in AzureCosmosDBforMongoDBCollection calls 'aconnect' and
172+
'close' when the client is managed (i.e., created internally).
173+
"""
174+
mock_client = AsyncMock(spec=AsyncMongoClient)
175+
176+
mock_data_model_definition = VectorStoreRecordDefinition(
177+
fields={
178+
"id": VectorStoreRecordKeyField(),
179+
"content": VectorStoreRecordDataField(),
180+
"vector": VectorStoreRecordVectorField(),
181+
}
182+
)
183+
184+
with patch(
185+
"semantic_kernel.connectors.memory.azure_cosmos_db.azure_cosmos_db_mongodb_collection.AsyncMongoClient",
186+
return_value=mock_client,
187+
):
188+
collection = cosmos_collection.AzureCosmosDBforMongoDBCollection(
189+
collection_name="test_collection",
190+
data_model_type=dict,
191+
connection_string="mongodb://fake",
192+
data_model_definition=mock_data_model_definition,
193+
)
194+
195+
# "__aenter__" should call 'aconnect'
196+
async with collection as c:
197+
mock_client.aconnect.assert_awaited_once()
198+
assert c is collection
199+
200+
# "__aexit__" should call 'close' if managed
201+
mock_client.close.assert_awaited_once()
202+
203+
204+
async def test_context_manager_does_not_close_when_not_managed() -> None:
205+
"""
206+
Test that the context manager in AzureCosmosDBforMongoDBCollection does not call 'close'
207+
when the client is not managed (i.e., provided externally).
208+
"""
209+
mock_data_model_definition = VectorStoreRecordDefinition(
210+
fields={
211+
"id": VectorStoreRecordKeyField(),
212+
"content": VectorStoreRecordDataField(),
213+
"vector": VectorStoreRecordVectorField(),
214+
}
215+
)
216+
217+
external_client = AsyncMock(spec=AsyncMongoClient, name="external_client", value=None)
218+
external_client.aconnect = AsyncMock(name="aconnect")
219+
external_client.close = AsyncMock(name="close")
220+
221+
collection = cosmos_collection.AzureCosmosDBforMongoDBCollection(
222+
collection_name="test_collection",
223+
data_model_type=dict,
224+
mongo_client=external_client,
225+
data_model_definition=mock_data_model_definition,
226+
)
227+
228+
# "__aenter__" scenario
229+
async with collection as c:
230+
external_client.aconnect.assert_awaited()
231+
assert c is collection
232+
233+
# "__aexit__" should NOT call "close" when not managed
234+
external_client.close.assert_not_awaited()

0 commit comments

Comments
 (0)