Skip to content

Commit a875154

Browse files
committed
feat: add retriever and query engine implementations
1 parent c6ad791 commit a875154

File tree

6 files changed

+178
-35
lines changed

6 files changed

+178
-35
lines changed

samples/simple-hitl-agent/uv.lock

Lines changed: 61 additions & 30 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .context_grounding_query_engine import ContextGroundingQueryEngine
2+
3+
__all__ = ["ContextGroundingQueryEngine"]
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from typing import Optional
2+
3+
from llama_index.core.query_engine import CustomQueryEngine
4+
from llama_index.core.response_synthesizers import BaseSynthesizer
5+
from uipath import UiPath
6+
7+
from uipath_llamaindex.retrievers import ContextGroundingRetriever
8+
9+
10+
class ContextGroundingQueryEngine(CustomQueryEngine):
11+
"""RAG Query Engine."""
12+
13+
def __init__(
14+
self,
15+
response_synthesizer: BaseSynthesizer,
16+
index_name: str,
17+
folder_path: Optional[str] = None,
18+
folder_key: Optional[str] = None,
19+
uipath: Optional[UiPath] = None,
20+
number_of_results: Optional[int] = 10,
21+
**kwargs,
22+
):
23+
super().__init__()
24+
self._retriever = ContextGroundingRetriever(
25+
index_name=index_name,
26+
folder_path=folder_path,
27+
folder_key=folder_key,
28+
number_of_results=number_of_results,
29+
uipath=uipath,
30+
**kwargs,
31+
)
32+
self._response_synthesizer = response_synthesizer
33+
34+
def custom_query(self, query_str: str):
35+
nodes = self._retriever.retrieve(query_str)
36+
response_obj = self._response_synthesizer.synthesize(query_str, nodes)
37+
return response_obj
38+
39+
async def acustom_query(self, query_str: str):
40+
nodes = await self._retriever.aretrieve(query_str)
41+
response_obj = self._response_synthesizer.synthesize(query_str, nodes)
42+
return response_obj
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .context_grounding_retriever import ContextGroundingRetriever
2+
3+
__all__ = ["ContextGroundingRetriever"]
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from typing import List, Optional
2+
3+
from llama_index.core.retrievers import (
4+
BaseRetriever,
5+
)
6+
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
7+
from uipath import UiPath
8+
from uipath.models import ContextGroundingQueryResponse
9+
10+
11+
class ContextGroundingRetriever(BaseRetriever):
12+
def __init__(
13+
self,
14+
index_name: str,
15+
folder_path: Optional[str] = None,
16+
folder_key: Optional[str] = None,
17+
uipath: Optional[UiPath] = None,
18+
number_of_results: Optional[int] = 10,
19+
**kwargs,
20+
):
21+
super().__init__()
22+
self._index_name = index_name
23+
self._folder_path = folder_path
24+
self._folder_key = folder_key
25+
self._uipath = uipath or UiPath()
26+
self._number_of_results = number_of_results
27+
self._results: list[ContextGroundingQueryResponse] = []
28+
29+
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
30+
self._results = self._uipath.context_grounding.search(
31+
self._index_name,
32+
query_bundle.query_str,
33+
self._number_of_results,
34+
folder_path=self._folder_path,
35+
folder_key=self._folder_key,
36+
)
37+
38+
return self._to_nodes_with_scores()
39+
40+
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
41+
self._results = await self._uipath.context_grounding.search_async(
42+
self._index_name,
43+
query_bundle.query_str,
44+
self._number_of_results,
45+
folder_path=self._folder_path,
46+
folder_key=self._folder_key,
47+
)
48+
49+
return self._to_nodes_with_scores()
50+
51+
def _to_nodes_with_scores(self) -> List[NodeWithScore]:
52+
nodes_with_scores = []
53+
for chunk in self._results:
54+
node = TextNode(
55+
text=chunk.content,
56+
metadata={
57+
"source_document_id": chunk.source_document_id,
58+
"source": chunk.source,
59+
"page_number": chunk.page_number,
60+
},
61+
)
62+
nodes_with_scores.append(NodeWithScore(node=node, score=chunk.score))
63+
return nodes_with_scores

uv.lock

Lines changed: 6 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)