diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 2117cb9f6..0c5a8f00c 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -312,7 +312,7 @@ def create_pipeline(task): return pipe -def transform_using(pipeline, args, inputs): +def transform_using(pipeline, args, inputs, stream=False): args = orjson.loads(args) inputs = orjson.loads(inputs) @@ -320,6 +320,8 @@ def transform_using(pipeline, args, inputs): inputs = [orjson.loads(input) for input in inputs] convert_eos_token(pipeline.tokenizer, args) + if stream: + return pipeline.stream(inputs, **args) return orjson.dumps(pipeline(inputs, **args), default=orjson_default).decode() diff --git a/pgml-extension/src/bindings/transformers/transformers.rs b/pgml-extension/src/bindings/transformers/transformers.rs index 6b89dd2a8..a03c0d751 100644 --- a/pgml-extension/src/bindings/transformers/transformers.rs +++ b/pgml-extension/src/bindings/transformers/transformers.rs @@ -12,7 +12,7 @@ pub struct TransformStreamIterator { } impl TransformStreamIterator { - fn new(python_iter: Py) -> Self { + pub fn new(python_iter: Py) -> Self { let locals = Python::with_gil(|py| -> Result, PyErr> { Ok([("python_iter", python_iter)].into_py_dict(py).into()) })