Skip to content

Python: added num_records parameter to text memory skill #2236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 17, 2023
28 changes: 23 additions & 5 deletions python/semantic_kernel/core_skills/text_memory_skill.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ class TextMemorySkill(PydanticField):
COLLECTION_PARAM = "collection"
RELEVANCE_PARAM = "relevance"
KEY_PARAM = "key"
NUM_RECORDS_PARAM = "num_records"
DEFAULT_COLLECTION = "generic"
DEFAULT_RELEVANCE = 0.75
DEFAULT_NUM_RECORDS = 1

# @staticmethod
@sk_function(
Expand All @@ -28,6 +30,11 @@ class TextMemorySkill(PydanticField):
description="The relevance score, from 0.0 to 1.0; 1.0 means perfect match",
default_value=DEFAULT_RELEVANCE,
)
@sk_function_context_parameter(
name=NUM_RECORDS_PARAM,
description="The number of records to retrieve, default is 1.",
default_value=DEFAULT_NUM_RECORDS,
)
async def recall_async(self, ask: str, context: SKContext) -> str:
"""
Recall a fact from the long term memory.
Expand All @@ -39,10 +46,11 @@ async def recall_async(self, ask: str, context: SKContext) -> str:
Args:
ask -- The question to ask the memory
context -- Contains the 'collection' to search for information
and the 'relevance' score to use when searching
, the 'relevance' score to use when searching
and the 'num_records' to retrieve.

Returns:
The nearest item from the memory store
The nearest item from the memory store as a comma-separated string or empty string if not found.
"""
if context.variables is None:
raise ValueError("Context has no variables")
Expand All @@ -65,15 +73,25 @@ async def recall_async(self, ask: str, context: SKContext) -> str:
if relevance is None or str(relevance).strip() == "":
relevance = TextMemorySkill.DEFAULT_RELEVANCE

num_records = (
context.variables[TextMemorySkill.NUM_RECORDS_PARAM]
if context.variables.contains_key(TextMemorySkill.NUM_RECORDS_PARAM)
else TextMemorySkill.DEFAULT_NUM_RECORDS
)
if num_records is None or str(num_records).strip() == "":
num_records = TextMemorySkill.DEFAULT_NUM_RECORDS

results = await context.memory.search_async(
collection, ask, min_relevance_score=float(relevance)
collection=collection,
query=ask,
limit=int(num_records),
min_relevance_score=float(relevance),
)
if results is None or len(results) == 0:
if context.log is not None:
context.log.warning(f"Memory not found in collection: {collection}")
return ""

return results[0].text if results[0].text is not None else ""
return ", ".join([result.text for result in results if result.text is not None])

@sk_function(
description="Save information to semantic memory",
Expand Down