From 90f494351037bfd371854ddacc51f62f35d5fa7a Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 16 Nov 2023 13:24:01 -0800 Subject: [PATCH] Added support for llama and mistral GPTQ models --- pgml-extension/src/bindings/transformers/transformers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 0c5a8f00c..143f6d393 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -294,13 +294,16 @@ def create_pipeline(task): ensure_device(task) convert_dtype(task) model_name = task.get("model", None) + model_type = None + if "model_type" in task: + model_type = task["model_type"] if model_name: lower = model_name.lower() else: lower = None if lower and ("-ggml" in lower or "-gguf" in lower): pipe = GGMLPipeline(model_name, **task) - elif lower and "-gptq" in lower: + elif lower and "-gptq" in lower and not (model_type == "mistral" or model_type == "llama"): pipe = GPTQPipeline(model_name, **task) else: try: