From 4be0319a1ab0628a4a067493ddee7fd3a5f498be Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 16 Nov 2023 11:52:51 -0800 Subject: [PATCH] Update extension for cloud streaming use --- pgml-extension/src/bindings/transformers/transformers.py | 4 +++- pgml-extension/src/bindings/transformers/transformers.rs | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 2117cb9f6..0c5a8f00c 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -312,7 +312,7 @@ def create_pipeline(task): return pipe -def transform_using(pipeline, args, inputs): +def transform_using(pipeline, args, inputs, stream=False): args = orjson.loads(args) inputs = orjson.loads(inputs) @@ -320,6 +320,8 @@ def transform_using(pipeline, args, inputs): inputs = [orjson.loads(input) for input in inputs] convert_eos_token(pipeline.tokenizer, args) + if stream: + return pipeline.stream(inputs, **args) return orjson.dumps(pipeline(inputs, **args), default=orjson_default).decode() diff --git a/pgml-extension/src/bindings/transformers/transformers.rs b/pgml-extension/src/bindings/transformers/transformers.rs index 6b89dd2a8..a03c0d751 100644 --- a/pgml-extension/src/bindings/transformers/transformers.rs +++ b/pgml-extension/src/bindings/transformers/transformers.rs @@ -12,7 +12,7 @@ pub struct TransformStreamIterator { } impl TransformStreamIterator { - fn new(python_iter: Py) -> Self { + pub fn new(python_iter: Py) -> Self { let locals = Python::with_gil(|py| -> Result, PyErr> { Ok([("python_iter", python_iter)].into_py_dict(py).into()) })