Skip to content

Commit 1e4a19d

Browse files
committed
Added generator support
1 parent fc0c82e commit 1e4a19d

File tree

6 files changed

+69
-73
lines changed

6 files changed

+69
-73
lines changed

evaluations/raw.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22

3-
from verbalizer.nlp import LlamaModelParaphrase, ChatGptModelParaphrase
3+
from verbalizer.nlp import ChatGptModelParaphrase
44

55
examples = [
66
"""
@@ -556,9 +556,9 @@
556556
]
557557

558558
if __name__ == '__main__':
559-
llama_model = LlamaModelParaphrase('http://localhost:11434/v1', temperature=0.1)
559+
# llama_model = LlamaModelParaphrase('http://localhost:11434/v1', temperature=0.1)
560560
openai_model = ChatGptModelParaphrase(api_key=os.getenv('OPENAI_API_KEY'), model='gpt-4o', temperature=0.7)
561-
models = [openai_model, llama_model]
561+
models = [openai_model]
562562

563563
for model in models:
564564
print(f'Running on {model.name}:')

playground.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import os
33

4-
from verbalizer.nlp import ChatGptModelParaphrase, LlamaModelParaphrase
4+
from verbalizer.nlp import ChatGptModelParaphrase
55
from verbalizer.process import Processor
66
from verbalizer.sampler import Sampler
77
from verbalizer.verbalizer import Verbalizer
@@ -83,9 +83,8 @@
8383
}
8484

8585
if __name__ == '__main__':
86-
llama_model = LlamaModelParaphrase('http://localhost:11434/v1', temperature=0.1)
8786
openai_model = ChatGptModelParaphrase(api_key=os.getenv('OPENAI_API_KEY'), model='gpt-4o', temperature=0.7)
88-
models = [openai_model, llama_model]
87+
models = [openai_model]
8988

9089
sampler = Sampler(sample_n=100, seed=42)
9190

pyproject.toml

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "ontology-verbalizer"
3-
version = "1.1.0"
3+
version = "1.1.1"
44
description = "A Python package for ontology verbalization"
55
authors = ["Antonio Zaitoun <[email protected]>"]
66
license = "MIT"
@@ -12,7 +12,6 @@ repository = "https://github.com/Minitour/ontology-verbalizer"
1212
[tool.poetry.dependencies]
1313
python = "^3.12"
1414
rdflib = "~7.0.0"
15-
openai = "~1.12.0"
1615
pandas = "~2.2.0"
1716
tqdm = "~4.66.2"
1817

tests/test_verbalization.py

+14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import types
12
import unittest
23

34
from rdflib import Graph
@@ -64,3 +65,16 @@ def test_verbalization_with_sampler(self):
6465

6566
# although we sampled 10, only 7 were applicable.
6667
self.assertEqual(7, len(results))
68+
69+
def test_verbalization_with_generator(self):
70+
ontology = Processor.from_file('./data/foaf.owl')
71+
72+
# create vocabulary
73+
vocab = Vocabulary(ontology, ignore=ignore_iri, rephrased=rename_iri)
74+
75+
# create verbalizer
76+
verbalizer = Verbalizer(vocab)
77+
78+
results = Processor.verbalize_with(verbalizer, namespace='foaf', as_generator=True)
79+
self.assertTrue(isinstance(results, types.GeneratorType))
80+
self.assertEqual(12, len(list(results)))

verbalizer/nlp.py

+37-61
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,40 @@
22
from abc import ABC, abstractmethod
33
from typing import Optional
44

5-
from openai import OpenAI
5+
try:
6+
from openai import OpenAI
67

7-
logging.getLogger("openai").setLevel(logging.ERROR)
8-
logging.getLogger("httpx").setLevel(logging.ERROR)
8+
logging.getLogger("openai").setLevel(logging.ERROR)
9+
logging.getLogger("httpx").setLevel(logging.ERROR)
10+
except ModuleNotFoundError as err:
11+
OpenAI = None
12+
13+
14+
class ParaphraseLanguageModel(ABC):
15+
16+
@abstractmethod
17+
def pseudo_to_text(self, pseudo_text: str, extra: str = None) -> str:
18+
"""
19+
Given a pseudo text or controlled natural language, return a rephrased version of that same text.
20+
:param pseudo_text: The CNL set of statements,
21+
:param extra: Additional context to include as part of the prompt.
22+
:return: Paraphrased text.
23+
"""
24+
return pseudo_text
25+
26+
@property
27+
def cost(self) -> float:
28+
"""
29+
The usage cost so far of the model.
30+
"""
31+
return 0.0
32+
33+
@property
34+
def name(self) -> str:
35+
"""
36+
The name of the model used.
37+
"""
38+
return 'Unknown'
939

1040

1141
def get_messages(pseudo_text: str, extra_context: Optional[str] = None):
@@ -50,33 +80,6 @@ def get_messages(pseudo_text: str, extra_context: Optional[str] = None):
5080
]
5181

5282

53-
class ParaphraseLanguageModel(ABC):
54-
55-
@abstractmethod
56-
def pseudo_to_text(self, pseudo_text: str, extra: str = None) -> str:
57-
"""
58-
Given a pseudo text or controlled natural language, return a rephrased version of that same text.
59-
:param pseudo_text: The CNL set of statements,
60-
:param extra: Additional context to include as part of the prompt.
61-
:return: Paraphrased text.
62-
"""
63-
return pseudo_text
64-
65-
@property
66-
def cost(self) -> float:
67-
"""
68-
The usage cost so far of the model.
69-
"""
70-
return 0.0
71-
72-
@property
73-
def name(self) -> str:
74-
"""
75-
The name of the model used.
76-
"""
77-
return 'Unknown'
78-
79-
8083
class ChatGptModelParaphrase(ParaphraseLanguageModel):
8184
"""
8285
OpenAI wrapper implementation.
@@ -138,6 +141,9 @@ class ChatGptModelParaphrase(ParaphraseLanguageModel):
138141
}
139142

140143
def __init__(self, api_key: str, model: str = 'gpt-3.5-turbo-0613', temperature=0.5):
144+
if not OpenAI:
145+
raise ModuleNotFoundError("OpenAI is not installed. Please install it with `pip install openai`")
146+
141147
self.model = model
142148
self.temperature = temperature
143149
self.client = OpenAI(api_key=api_key)
@@ -156,7 +162,7 @@ def pseudo_to_text(self, pseudo_text: str, extra: str = None) -> str:
156162

157163
@property
158164
def cost(self) -> float:
159-
model_pricing = self.models.get(self.model)
165+
model_pricing = self.models.get(self.model) or {'input': 0.0, 'output': 0.0}
160166

161167
in_tokens = self._in_token_usage / 1000
162168
out_tokens = self._out_token_usage / 1000
@@ -166,33 +172,3 @@ def cost(self) -> float:
166172
@property
167173
def name(self) -> str:
168174
return self.model
169-
170-
171-
class LlamaModelParaphrase(ParaphraseLanguageModel):
172-
"""
173-
Llama model wrapper implementation.
174-
"""
175-
176-
def __init__(self, base_url, model='llama3', temperature=0.5):
177-
self.temperature = temperature
178-
self.model = model
179-
self.client = OpenAI(
180-
base_url=base_url,
181-
api_key="sk-no-key-required"
182-
)
183-
184-
def pseudo_to_text(self, pseudo_text: str, extra: str = None) -> str:
185-
response = self.client.chat.completions.create(
186-
model=self.model,
187-
messages=get_messages(pseudo_text, extra),
188-
temperature=self.temperature
189-
)
190-
return response.choices[0].message.content.strip()
191-
192-
@property
193-
def cost(self) -> float:
194-
return 0.0
195-
196-
@property
197-
def name(self) -> str:
198-
return self.model

verbalizer/process.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,16 @@ def verbalize_with(cls,
2727
namespace: str,
2828
output_dir: Optional[str] = None,
2929
chunk_size: int = 1000,
30-
sampler: Optional[Sampler] = None):
30+
sampler: Optional[Sampler] = None,
31+
as_generator: bool = False):
3132
"""
3233
Start the verbalization process.
3334
:param verbalizer: The verbalizer to use.
3435
:param namespace: Name of the directory to create under the output directory.
3536
:param output_dir: Name of the output directory.
3637
:param chunk_size: Number of entries (rows) per file. default = 1000
3738
:param sampler: A sampling configuration, use to sample large ontologies.
39+
:param as_generator: If True, returns a generator instead of a list.
3840
"""
3941

4042
# current timestamp
@@ -67,7 +69,7 @@ def verbalize_with(cls,
6769
if stats.statements == 0:
6870
continue
6971

70-
chunk_dataset.append({
72+
element = {
7173
'ontology': namespace,
7274
'root': entry,
7375
'fragment': fragment,
@@ -79,7 +81,12 @@ def verbalize_with(cls,
7981
'unique_relationships': len(stats.relationship_counter),
8082
'total_relationships': sum(stats.relationship_counter.values()),
8183
**stats.relationship_counter
82-
})
84+
}
85+
86+
chunk_dataset.append(element)
87+
88+
if as_generator:
89+
yield element
8390

8491
if len(chunk_dataset) != chunk_size:
8592
continue
@@ -104,7 +111,8 @@ def verbalize_with(cls,
104111
if llm:
105112
logger.info(f'LLM usage cost: ${llm.cost}')
106113

107-
return full_dataset
114+
if not as_generator:
115+
return full_dataset
108116

109117
@staticmethod
110118
def _get_classes(graph):

0 commit comments

Comments
 (0)