diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 83d953fdd..7bcecc8cc 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -41,6 +41,7 @@ PegasusTokenizer, TrainingArguments, Trainer, + GPTQConfig ) import threading @@ -279,7 +280,13 @@ def __init__(self, model_name, **kwargs): elif self.task == "summarization" or self.task == "translation": self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs) elif self.task == "text-generation" or self.task == "conversational": - self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) + # See: https://huggingface.co/docs/transformers/main/quantization + if "quantization_config" in kwargs: + quantization_config = kwargs.pop("quantization_config") + quantization_config = GPTQConfig(**quantization_config) + self.model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config, **kwargs) + else: + self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) else: raise PgMLException(f"Unhandled task: {self.task}") @@ -409,10 +416,13 @@ def create_pipeline(task): else: try: pipe = StandardPipeline(model_name, **task) - except TypeError: - # some models fail when given "device" kwargs, remove and try again - task.pop("device") - pipe = StandardPipeline(model_name, **task) + except TypeError as error: + if "device" in task: + # some models fail when given "device" kwargs, remove and try again + task.pop("device") + pipe = StandardPipeline(model_name, **task) + else: + raise error return pipe