Skip to content

Commit b8b6371

Browse files
authored
Fix TextGenerationResponse import from hfh (#129)
1 parent 9bf8cf4 commit b8b6371

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ keywords = ["evaluation", "nlp", "llm"]
5353
dependencies = [
5454
# Base dependencies
5555
"transformers>=4.38.0",
56-
"huggingface_hub>=0.21.2",
56+
"huggingface_hub>=0.22.0",
5757
"torch>=2.0",
5858
"GitPython>=3.1.41", # for logging
5959
"datasets>=2.14.0",

src/lighteval/models/endpoint_model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@
2929
InferenceClient,
3030
InferenceEndpoint,
3131
InferenceEndpointTimeoutError,
32+
TextGenerationOutput,
3233
create_inference_endpoint,
3334
get_inference_endpoint,
3435
)
35-
from huggingface_hub.inference._text_generation import TextGenerationResponse
3636
from torch.utils.data import DataLoader
3737
from tqdm import tqdm
3838
from transformers import AutoTokenizer
@@ -148,7 +148,7 @@ def max_length(self):
148148

149149
def __async_process_request(
150150
self, context: str, stop_tokens: list[str], max_tokens: int
151-
) -> Coroutine[None, list[TextGenerationResponse], str]:
151+
) -> Coroutine[None, list[TextGenerationOutput], str]:
152152
# Todo: add an option to launch with conversational instead for chat prompts
153153
# https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational
154154
generated_text = self.async_client.text_generation(
@@ -162,7 +162,7 @@ def __async_process_request(
162162

163163
return generated_text
164164

165-
def __process_request(self, context: str, stop_tokens: list[str], max_tokens: int) -> TextGenerationResponse:
165+
def __process_request(self, context: str, stop_tokens: list[str], max_tokens: int) -> TextGenerationOutput:
166166
# Todo: add an option to launch with conversational instead for chat prompts
167167
# https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational
168168
generated_text = self.client.text_generation(
@@ -179,7 +179,7 @@ def __process_request(self, context: str, stop_tokens: list[str], max_tokens: in
179179
async def __async_process_batch_generate(
180180
self,
181181
requests: list[GreedyUntilRequest | GreedyUntilWithLogitsRequest],
182-
) -> list[TextGenerationResponse]:
182+
) -> list[TextGenerationOutput]:
183183
return await asyncio.gather(
184184
*[
185185
self.__async_process_request(
@@ -194,7 +194,7 @@ async def __async_process_batch_generate(
194194
def __process_batch_generate(
195195
self,
196196
requests: list[GreedyUntilRequest | GreedyUntilWithLogitsRequest],
197-
) -> list[TextGenerationResponse]:
197+
) -> list[TextGenerationOutput]:
198198
return [
199199
self.__process_request(
200200
context=request.context,
@@ -206,7 +206,7 @@ def __process_batch_generate(
206206

207207
async def __async_process_batch_logprob(
208208
self, requests: list[LoglikelihoodRequest], rolling: bool = False
209-
) -> list[TextGenerationResponse]:
209+
) -> list[TextGenerationOutput]:
210210
return await asyncio.gather(
211211
*[
212212
self.__async_process_request(
@@ -220,7 +220,7 @@ async def __async_process_batch_logprob(
220220

221221
def __process_batch_logprob(
222222
self, requests: list[LoglikelihoodRequest], rolling: bool = False
223-
) -> list[TextGenerationResponse]:
223+
) -> list[TextGenerationOutput]:
224224
return [
225225
self.__process_request(
226226
context=request.context if rolling else request.context + request.choice,

0 commit comments

Comments
 (0)