diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 0c5a8f00c..143f6d393 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -294,13 +294,16 @@ def create_pipeline(task): ensure_device(task) convert_dtype(task) model_name = task.get("model", None) + model_type = None + if "model_type" in task: + model_type = task["model_type"] if model_name: lower = model_name.lower() else: lower = None if lower and ("-ggml" in lower or "-gguf" in lower): pipe = GGMLPipeline(model_name, **task) - elif lower and "-gptq" in lower: + elif lower and "-gptq" in lower and not (model_type == "mistral" or model_type == "llama"): pipe = GPTQPipeline(model_name, **task) else: try: