Skip to content

[backport -> release/3.6.x] feat(plugins): ai-transformer plugins #12426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ plugins/ai-prompt-template:
- changed-files:
- any-glob-to-any-file: kong/plugins/ai-prompt-template/**/*

plugins/ai-request-transformer:
- changed-files:
- any-glob-to-any-file: ['kong/plugins/ai-request-transformer/**/*', 'kong/llm/**/*']

plugins/ai-response-transformer:
- changed-files:
- any-glob-to-any-file: ['kong/plugins/ai-response-transformer/**/*', 'kong/llm/**/*']

plugins/aws-lambda:
- changed-files:
- any-glob-to-any-file: kong/plugins/aws-lambda/**/*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: Introduced the new **AI Request Transformer** plugin that enables passing mid-flight consumer requests to an LLM for transformation or sanitization.
type: feature
scope: Plugin
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: Introduced the new **AI Response Transformer** plugin that enables passing mid-flight upstream responses to an LLM for transformation or sanitization.
type: feature
scope: Plugin
6 changes: 6 additions & 0 deletions kong-3.6.0-0.rockspec
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,12 @@ build = {
["kong.plugins.ai-proxy.handler"] = "kong/plugins/ai-proxy/handler.lua",
["kong.plugins.ai-proxy.schema"] = "kong/plugins/ai-proxy/schema.lua",

["kong.plugins.ai-request-transformer.handler"] = "kong/plugins/ai-request-transformer/handler.lua",
["kong.plugins.ai-request-transformer.schema"] = "kong/plugins/ai-request-transformer/schema.lua",

["kong.plugins.ai-response-transformer.handler"] = "kong/plugins/ai-response-transformer/handler.lua",
["kong.plugins.ai-response-transformer.schema"] = "kong/plugins/ai-response-transformer/schema.lua",

["kong.llm"] = "kong/llm/init.lua",
["kong.llm.drivers.shared"] = "kong/llm/drivers/shared.lua",
["kong.llm.drivers.openai"] = "kong/llm/drivers/openai.lua",
Expand Down
2 changes: 2 additions & 0 deletions kong/constants.lua
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ local plugins = {
"ai-proxy",
"ai-prompt-decorator",
"ai-prompt-template",
"ai-request-transformer",
"ai-response-transformer",
}

local plugin_map = {}
Expand Down
6 changes: 4 additions & 2 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
error("body must be table or string")
end

local url = fmt(
-- may be overridden
local url = (conf.model.options and conf.model.options.upstream_url)
or fmt(
"%s%s",
ai_shared.upstream_url_format[DRIVER_NAME],
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
Expand Down Expand Up @@ -241,7 +243,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
local body = res.body

if status > 299 then
return body, res.status, "status code not 2xx"
return body, res.status, "status code " .. status
end

return body, res.status, nil
Expand Down
14 changes: 9 additions & 5 deletions kong/llm/drivers/azure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
end

-- azure has non-standard URL format
local url = fmt(
"%s%s",
local url = (conf.model.options and conf.model.options.upstream_url)
or fmt(
"%s%s?api-version=%s",
ai_shared.upstream_url_format[DRIVER_NAME]:format(conf.model.options.azure_instance, conf.model.options.azure_deployment_id),
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path,
conf.model.options.azure_api_version or "2023-05-15"
)

local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method
Expand Down Expand Up @@ -71,7 +73,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
local body = res.body

if status > 299 then
return body, res.status, "status code not 2xx"
return body, res.status, "status code " .. status
end

return body, res.status, nil
Expand Down Expand Up @@ -111,7 +113,9 @@ function _M.configure_request(conf)
end

local query_table = kong.request.get_query()
query_table["api-version"] = conf.model.options.azure_api_version

-- technically min supported version
query_table["api-version"] = conf.model.options and conf.model.options.azure_api_version or "2023-05-15"

if auth_param_name and auth_param_value and auth_param_location == "query" then
query_table[auth_param_name] = auth_param_value
Expand Down
53 changes: 4 additions & 49 deletions kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ local cjson = require("cjson.safe")
local fmt = string.format
local ai_shared = require("kong.llm.drivers.shared")
local socket_url = require "socket.url"
local http = require("resty.http")
local table_new = require("table.new")
--

Expand Down Expand Up @@ -290,52 +289,6 @@ function _M.to_format(request_table, model_info, route_type)
return response_object, content_type, nil
end

function _M.subrequest(body_table, route_type, auth)
local body_string, err = cjson.encode(body_table)
if err then
return nil, nil, "failed to parse body to json: " .. err
end

local httpc = http.new()

local request_url = fmt(
"%s%s",
ai_shared.upstream_url_format[DRIVER_NAME],
ai_shared.operation_map[DRIVER_NAME][route_type].path
)

local headers = {
["Accept"] = "application/json",
["Content-Type"] = "application/json",
}

if auth and auth.header_name then
headers[auth.header_name] = auth.header_value
end

local res, err = httpc:request_uri(
request_url,
{
method = "POST",
body = body_string,
headers = headers,
})
if not res then
return nil, "request failed: " .. err
end

-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
local status = res.status
local body = res.body

if status ~= 200 then
return body, "status code not 200"
end

return body, res.status, nil
end

function _M.header_filter_hooks(body)
-- nothing to parse in header_filter phase
end
Expand Down Expand Up @@ -372,7 +325,9 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
return nil, nil, "body must be table or string"
end

local url = fmt(
-- may be overridden
local url = (conf.model.options and conf.model.options.upstream_url)
or fmt(
"%s%s",
ai_shared.upstream_url_format[DRIVER_NAME],
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
Expand Down Expand Up @@ -403,7 +358,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
local body = res.body

if status > 299 then
return body, res.status, "status code not 2xx"
return body, res.status, "status code " .. status
end

return body, res.status, nil
Expand Down
4 changes: 2 additions & 2 deletions kong/llm/drivers/llama2.lua
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ function _M.to_format(request_table, model_info, route_type)
model_info
)
if err or (not ok) then
return nil, nil, fmt("error transforming to %s://%s", model_info.provider, route_type)
return nil, nil, fmt("error transforming to %s://%s/%s", model_info.provider, route_type, model_info.options.llama2_format)
end

return response_object, content_type, nil
Expand Down Expand Up @@ -231,7 +231,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
local body = res.body

if status > 299 then
return body, res.status, "status code not 2xx"
return body, res.status, "status code " .. status
end

return body, res.status, nil
Expand Down
6 changes: 3 additions & 3 deletions kong/llm/drivers/mistral.lua
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ function _M.subrequest(body, conf, http_opts, return_res_table)

local url = conf.model.options.upstream_url

local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method
local method = "POST"

local headers = {
["Accept"] = "application/json",
["Content-Type"] = "application/json",
["Content-Type"] = "application/json"
}

if conf.auth and conf.auth.header_name then
Expand All @@ -118,7 +118,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
local body = res.body

if status > 299 then
return body, res.status, "status code not 2xx"
return body, res.status, "status code " .. status
end

return body, res.status, nil
Expand Down
2 changes: 1 addition & 1 deletion kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
local body = res.body

if status > 299 then
return body, res.status, "status code not 2xx"
return body, res.status, "status code " .. status
end

return body, res.status, nil
Expand Down
8 changes: 4 additions & 4 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,20 @@ function _M.pre_request(conf, request_table)
end

-- if enabled AND request type is compatible, capture the input for analytics
if conf.logging.log_payloads then
if conf.logging and conf.logging.log_payloads then
kong.log.set_serialize_value(log_entry_keys.REQUEST_BODY, kong.request.get_raw_body())
end

return true, nil
end

function _M.post_request(conf, response_string)
if conf.logging.log_payloads then
if conf.logging and conf.logging.log_payloads then
kong.log.set_serialize_value(log_entry_keys.RESPONSE_BODY, response_string)
end

-- analytics and logging
if conf.logging.log_statistics then
if conf.logging and conf.logging.log_statistics then
-- check if we already have analytics in this context
local request_analytics = kong.ctx.shared.analytics

Expand Down Expand Up @@ -253,7 +253,7 @@ function _M.http_request(url, body, method, headers, http_opts)
method = method,
body = body,
headers = headers,
ssl_verify = http_opts.https_verify or true,
ssl_verify = http_opts.https_verify,
})
if not res then
return nil, "request failed: " .. err
Expand Down
22 changes: 15 additions & 7 deletions kong/llm/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ function _M:ai_introspect_body(request, system_prompt, http_opts, response_regex
local new_request_body = ai_response.choices
and #ai_response.choices > 0
and ai_response.choices[1]
and ai_response.choices[1].message
and ai_response.choices[1].message.content
if not new_request_body then
return nil, "no response choices received from upstream AI service"
Expand All @@ -327,16 +328,23 @@ function _M:ai_introspect_body(request, system_prompt, http_opts, response_regex
return new_request_body
end

function _M:parse_json_instructions(body_string)
local instructions, err = cjson.decode(body_string)
if err then
return nil, nil, nil, err
function _M:parse_json_instructions(in_body)
local err
if type(in_body) == "string" then
in_body, err = cjson.decode(in_body)
if err then
return nil, nil, nil, err
end
end

if type(in_body) ~= "table" then
return nil, nil, nil, "input not table or string"
end

return
instructions.headers,
instructions.body or body_string,
instructions.status or 200
in_body.headers,
in_body.body or in_body,
in_body.status or 200
end

function _M:new(conf, http_opts)
Expand Down
74 changes: 74 additions & 0 deletions kong/plugins/ai-request-transformer/handler.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
local _M = {}

-- imports
local kong_meta = require "kong.meta"
local fmt = string.format
local llm = require("kong.llm")
--

_M.PRIORITY = 777
_M.VERSION = kong_meta.version

local function bad_request(msg)
kong.log.info(msg)
return kong.response.exit(400, { error = { message = msg } })
end

local function internal_server_error(msg)
kong.log.err(msg)
return kong.response.exit(500, { error = { message = msg } })
end

local function create_http_opts(conf)
local http_opts = {}

if conf.http_proxy_host then -- port WILL be set via schema constraint
http_opts.proxy_opts = http_opts.proxy_opts or {}
http_opts.proxy_opts.http_proxy = fmt("http://%s:%d", conf.http_proxy_host, conf.http_proxy_port)
end

if conf.https_proxy_host then
http_opts.proxy_opts = http_opts.proxy_opts or {}
http_opts.proxy_opts.https_proxy = fmt("http://%s:%d", conf.https_proxy_host, conf.https_proxy_port)
end

http_opts.http_timeout = conf.http_timeout
http_opts.https_verify = conf.https_verify

return http_opts
end

function _M:access(conf)
kong.service.request.enable_buffering()
kong.ctx.shared.skip_response_transformer = true

-- first find the configured LLM interface and driver
local http_opts = create_http_opts(conf)
local ai_driver, err = llm:new(conf.llm, http_opts)

if not ai_driver then
return internal_server_error(err)
end

-- if asked, introspect the request before proxying
kong.log.debug("introspecting request with LLM")
local new_request_body, err = llm:ai_introspect_body(
kong.request.get_raw_body(),
conf.prompt,
http_opts,
conf.transformation_extract_pattern
)

if err then
return bad_request(err)
end

-- set the body for later plugins
kong.service.request.set_raw_body(new_request_body)

-- continue into other plugins including ai-response-transformer,
-- which may exit early with a sub-request
end


return _M
Loading