29
29
InferenceClient ,
30
30
InferenceEndpoint ,
31
31
InferenceEndpointTimeoutError ,
32
+ TextGenerationOutput ,
32
33
create_inference_endpoint ,
33
34
get_inference_endpoint ,
34
35
)
35
- from huggingface_hub .inference ._text_generation import TextGenerationResponse
36
36
from torch .utils .data import DataLoader
37
37
from tqdm import tqdm
38
38
from transformers import AutoTokenizer
@@ -148,7 +148,7 @@ def max_length(self):
148
148
149
149
def __async_process_request (
150
150
self , context : str , stop_tokens : list [str ], max_tokens : int
151
- ) -> Coroutine [None , list [TextGenerationResponse ], str ]:
151
+ ) -> Coroutine [None , list [TextGenerationOutput ], str ]:
152
152
# Todo: add an option to launch with conversational instead for chat prompts
153
153
# https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational
154
154
generated_text = self .async_client .text_generation (
@@ -162,7 +162,7 @@ def __async_process_request(
162
162
163
163
return generated_text
164
164
165
- def __process_request (self , context : str , stop_tokens : list [str ], max_tokens : int ) -> TextGenerationResponse :
165
+ def __process_request (self , context : str , stop_tokens : list [str ], max_tokens : int ) -> TextGenerationOutput :
166
166
# Todo: add an option to launch with conversational instead for chat prompts
167
167
# https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational
168
168
generated_text = self .client .text_generation (
@@ -179,7 +179,7 @@ def __process_request(self, context: str, stop_tokens: list[str], max_tokens: in
179
179
async def __async_process_batch_generate (
180
180
self ,
181
181
requests : list [GreedyUntilRequest | GreedyUntilWithLogitsRequest ],
182
- ) -> list [TextGenerationResponse ]:
182
+ ) -> list [TextGenerationOutput ]:
183
183
return await asyncio .gather (
184
184
* [
185
185
self .__async_process_request (
@@ -194,7 +194,7 @@ async def __async_process_batch_generate(
194
194
def __process_batch_generate (
195
195
self ,
196
196
requests : list [GreedyUntilRequest | GreedyUntilWithLogitsRequest ],
197
- ) -> list [TextGenerationResponse ]:
197
+ ) -> list [TextGenerationOutput ]:
198
198
return [
199
199
self .__process_request (
200
200
context = request .context ,
@@ -206,7 +206,7 @@ def __process_batch_generate(
206
206
207
207
async def __async_process_batch_logprob (
208
208
self , requests : list [LoglikelihoodRequest ], rolling : bool = False
209
- ) -> list [TextGenerationResponse ]:
209
+ ) -> list [TextGenerationOutput ]:
210
210
return await asyncio .gather (
211
211
* [
212
212
self .__async_process_request (
@@ -220,7 +220,7 @@ async def __async_process_batch_logprob(
220
220
221
221
def __process_batch_logprob (
222
222
self , requests : list [LoglikelihoodRequest ], rolling : bool = False
223
- ) -> list [TextGenerationResponse ]:
223
+ ) -> list [TextGenerationOutput ]:
224
224
return [
225
225
self .__process_request (
226
226
context = request .context if rolling else request .context + request .choice ,
0 commit comments