4141 PegasusTokenizer ,
4242 TrainingArguments ,
4343 Trainer ,
44- GPTQConfig
44+ GPTQConfig ,
45+ PegasusForConditionalGeneration ,
46+ PegasusTokenizer ,
4547)
4648import threading
4749
@@ -254,6 +256,8 @@ def __init__(self, model_name, **kwargs):
254256 if "use_auth_token" in kwargs :
255257 kwargs ["token" ] = kwargs .pop ("use_auth_token" )
256258
259+ self .model_name = model_name
260+
257261 if (
258262 "task" in kwargs
259263 and model_name is not None
@@ -278,29 +282,55 @@ def __init__(self, model_name, **kwargs):
278282 model_name , ** kwargs
279283 )
280284 elif self .task == "summarization" or self .task == "translation" :
281- self .model = AutoModelForSeq2SeqLM .from_pretrained (model_name , ** kwargs )
285+ if model_name == "google/pegasus-xsum" :
286+ # HF auto model doesn't detect GPUs
287+ self .model = PegasusForConditionalGeneration .from_pretrained (
288+ model_name
289+ )
290+ else :
291+ self .model = AutoModelForSeq2SeqLM .from_pretrained (
292+ model_name , ** kwargs
293+ )
282294 elif self .task == "text-generation" or self .task == "conversational" :
283295 # See: https://huggingface.co/docs/transformers/main/quantization
284296 if "quantization_config" in kwargs :
285297 quantization_config = kwargs .pop ("quantization_config" )
286298 quantization_config = GPTQConfig (** quantization_config )
287- self .model = AutoModelForCausalLM .from_pretrained (model_name , quantization_config = quantization_config , ** kwargs )
299+ self .model = AutoModelForCausalLM .from_pretrained (
300+ model_name , quantization_config = quantization_config , ** kwargs
301+ )
288302 else :
289- self .model = AutoModelForCausalLM .from_pretrained (model_name , ** kwargs )
303+ self .model = AutoModelForCausalLM .from_pretrained (
304+ model_name , ** kwargs
305+ )
290306 else :
291307 raise PgMLException (f"Unhandled task: { self .task } " )
292308
309+ if model_name == "google/pegasus-xsum" :
310+ kwargs .pop ("token" , None )
311+
293312 if "token" in kwargs :
294313 self .tokenizer = AutoTokenizer .from_pretrained (
295314 model_name , token = kwargs ["token" ]
296315 )
297316 else :
298- self .tokenizer = AutoTokenizer .from_pretrained (model_name )
317+ if model_name == "google/pegasus-xsum" :
318+ self .tokenizer = PegasusTokenizer .from_pretrained (model_name )
319+ else :
320+ self .tokenizer = AutoTokenizer .from_pretrained (model_name )
321+
322+ pipe_kwargs = {
323+ "model" : self .model ,
324+ "tokenizer" : self .tokenizer ,
325+ }
326+
327+ # https://huggingface.co/docs/transformers/en/model_doc/pegasus
328+ if model_name == "google/pegasus-xsum" :
329+ pipe_kwargs ["device" ] = kwargs .get ("device" , "cpu" )
299330
300331 self .pipe = transformers .pipeline (
301332 self .task ,
302- model = self .model ,
303- tokenizer = self .tokenizer ,
333+ ** pipe_kwargs ,
304334 )
305335 else :
306336 self .pipe = transformers .pipeline (** kwargs )
@@ -320,7 +350,7 @@ def stream(self, input, timeout=None, **kwargs):
320350 self .tokenizer ,
321351 timeout = timeout ,
322352 skip_prompt = True ,
323- skip_special_tokens = True
353+ skip_special_tokens = True ,
324354 )
325355 if "chat_template" in kwargs :
326356 input = self .tokenizer .apply_chat_template (
@@ -343,9 +373,7 @@ def stream(self, input, timeout=None, **kwargs):
343373 )
344374 else :
345375 streamer = TextIteratorStreamer (
346- self .tokenizer ,
347- timeout = timeout ,
348- skip_special_tokens = True
376+ self .tokenizer , timeout = timeout , skip_special_tokens = True
349377 )
350378 input = self .tokenizer (input , return_tensors = "pt" , padding = True ).to (
351379 self .model .device
@@ -496,7 +524,6 @@ def embed(transformer, inputs, kwargs):
496524 return embed_using (model , transformer , inputs , kwargs )
497525
498526
499-
500527def clear_gpu_cache (memory_usage : None ):
501528 if not torch .cuda .is_available ():
502529 raise PgMLException (f"No GPU available" )
0 commit comments