Skip to content

Python: provide methods to register single native function to the kernel #2390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

33 changes: 33 additions & 0 deletions python/semantic_kernel/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,39 @@ def register_semantic_function(

return function

def register_native_function(
self,
skill_name: Optional[str],
sk_function: Callable,
) -> SKFunctionBase:
if not hasattr(sk_function, "__sk_function__"):
raise KernelException(
KernelException.ErrorCodes.InvalidFunctionType,
"sk_function argument must be decorated with @sk_function",
)
function_name = sk_function.__sk_function_name__

if skill_name is None or skill_name == "":
skill_name = SkillCollection.GLOBAL_SKILL
assert skill_name is not None # for type checker

validate_skill_name(skill_name)
validate_function_name(function_name)

function = SKFunction.from_native_method(sk_function, skill_name, self.logger)

if self.skills.has_function(skill_name, function_name):
raise KernelException(
KernelException.ErrorCodes.FunctionOverloadNotSupported,
"Overloaded functions are not supported, "
"please differentiate function names.",
)

function.set_default_skill_collection(self.skills)
self._skill_collection.add_native_function(function)

return function

async def run_stream_async(
self,
*functions: Any,
Expand Down
61 changes: 61 additions & 0 deletions python/tests/unit/kernel_extensions/test_register_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Microsoft. All rights reserved.


import pytest

from semantic_kernel import Kernel
from semantic_kernel.kernel_exception import KernelException
from semantic_kernel.orchestration.sk_function_base import SKFunctionBase
from semantic_kernel.skill_definition.sk_function_decorator import sk_function
from semantic_kernel.skill_definition.skill_collection import SkillCollection


def not_decorated_native_function(arg1: str) -> str:
return "test"


@sk_function(name="getLightStatus")
def decorated_native_function(arg1: str) -> str:
return "test"


def test_register_valid_native_function():
kernel = Kernel()

registered_func = kernel.register_native_function(
"TestSkill", decorated_native_function
)

assert isinstance(registered_func, SKFunctionBase)
assert (
kernel.skills.get_native_function("TestSkill", "getLightStatus")
== registered_func
)
assert registered_func.invoke("testtest").result == "test"


def test_register_undecorated_native_function():
kernel = Kernel()

with pytest.raises(KernelException):
kernel.register_native_function("TestSkill", not_decorated_native_function)


def test_register_with_none_skill_name():
kernel = Kernel()

registered_func = kernel.register_native_function(None, decorated_native_function)
assert registered_func.skill_name == SkillCollection.GLOBAL_SKILL


def test_register_overloaded_native_function():
kernel = Kernel()

kernel.register_native_function("TestSkill", decorated_native_function)

with pytest.raises(KernelException):
kernel.register_native_function("TestSkill", decorated_native_function)


if __name__ == "__main__":
pytest.main([__file__])