diff --git a/pgml-dashboard/Cargo.lock b/pgml-dashboard/Cargo.lock index b4ad1de1d..a220f3368 100644 --- a/pgml-dashboard/Cargo.lock +++ b/pgml-dashboard/Cargo.lock @@ -2065,7 +2065,7 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "pgml-dashboard" -version = "2.7.4" +version = "2.7.6" dependencies = [ "aho-corasick 0.7.20", "anyhow", diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs index 417a2e28a..94a7668be 100644 --- a/pgml-extension/src/bindings/mod.rs +++ b/pgml-extension/src/bindings/mod.rs @@ -73,15 +73,18 @@ pub trait Bindings: Send + Sync + Debug { Self: Sized; } -trait TracebackError { +pub trait TracebackError { fn format_traceback(self, py: Python<'_>) -> Result; } impl TracebackError for PyResult { fn format_traceback(self, py: Python<'_>) -> Result { - self.map_err(|e| { - let traceback = e.traceback(py).unwrap().format().unwrap(); - anyhow!("{traceback} {e}") + self.map_err(|e| match e.traceback(py) { + Some(traceback) => match traceback.format() { + Ok(traceback) => anyhow!("{traceback} {e}"), + Err(format_e) => anyhow!("{e} {format_e}"), + }, + None => anyhow!("{e}"), }) } } diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index fe2f6b3e7..0359085f5 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -184,6 +184,37 @@ def get_model_from(task): return model[ty][0] +def create_pipeline(task): + if isinstance(task, str): + task = orjson.loads(task) + ensure_device(task) + convert_dtype(task) + model_name = task.get("model", None) + if model_name and "-ggml" in model_name: + pipe = GGMLPipeline(model_name, **task) + elif model_name and "-gptq" in model_name: + pipe = GPTQPipeline(model_name, **task) + else: + try: + pipe = StandardPipeline(model_name, **task) + except TypeError: + # some models fail when given "device" kwargs, remove and try again + task.pop("device") + pipe = StandardPipeline(model_name, **task) + return pipe + + +def transform_using(pipeline, args, inputs): + args = orjson.loads(args) + inputs = orjson.loads(inputs) + + if pipeline.task == "question-answering": + inputs = [orjson.loads(input) for input in inputs] + convert_eos_token(pipeline.tokenizer, args) + + return orjson.dumps(pipeline(inputs, **args), default=orjson_default).decode() + + def transform(task, args, inputs): task = orjson.loads(task) args = orjson.loads(args) @@ -191,21 +222,7 @@ def transform(task, args, inputs): key = ",".join([f"{key}:{val}" for (key, val) in sorted(task.items())]) if key not in __cache_transform_pipeline_by_task: - ensure_device(task) - convert_dtype(task) - model_name = task.get("model", None) - if model_name and "-ggml" in model_name: - pipe = GGMLPipeline(model_name, **task) - elif model_name and "-gptq" in model_name: - pipe = GPTQPipeline(model_name, **task) - else: - try: - pipe = StandardPipeline(model_name, **task) - except TypeError: - # some models fail when given "device" kwargs, remove and try again - task.pop("device") - pipe = StandardPipeline(model_name, **task) - + pipe = create_pipeline(task) __cache_transform_pipeline_by_task[key] = pipe pipe = __cache_transform_pipeline_by_task[key]