@@ -184,28 +184,45 @@ def get_model_from(task):
184184 return model [ty ][0 ]
185185
186186
187+ def create_pipeline (task ):
188+ if isinstance (task , str ):
189+ task = orjson .loads (task )
190+ ensure_device (task )
191+ convert_dtype (task )
192+ model_name = task .get ("model" , None )
193+ if model_name and "-ggml" in model_name :
194+ pipe = GGMLPipeline (model_name , ** task )
195+ elif model_name and "-gptq" in model_name :
196+ pipe = GPTQPipeline (model_name , ** task )
197+ else :
198+ try :
199+ pipe = StandardPipeline (model_name , ** task )
200+ except TypeError :
201+ # some models fail when given "device" kwargs, remove and try again
202+ task .pop ("device" )
203+ pipe = StandardPipeline (model_name , ** task )
204+ return pipe
205+
206+
207+ def transform_using (pipeline , args , inputs ):
208+ args = orjson .loads (args )
209+ inputs = orjson .loads (inputs )
210+
211+ if pipeline .task == "question-answering" :
212+ inputs = [orjson .loads (input ) for input in inputs ]
213+ convert_eos_token (pipeline .tokenizer , args )
214+
215+ return orjson .dumps (pipeline (inputs , ** args ), default = orjson_default ).decode ()
216+
217+
187218def transform (task , args , inputs ):
188219 task = orjson .loads (task )
189220 args = orjson .loads (args )
190221 inputs = orjson .loads (inputs )
191222
192223 key = "," .join ([f"{ key } :{ val } " for (key , val ) in sorted (task .items ())])
193224 if key not in __cache_transform_pipeline_by_task :
194- ensure_device (task )
195- convert_dtype (task )
196- model_name = task .get ("model" , None )
197- if model_name and "-ggml" in model_name :
198- pipe = GGMLPipeline (model_name , ** task )
199- elif model_name and "-gptq" in model_name :
200- pipe = GPTQPipeline (model_name , ** task )
201- else :
202- try :
203- pipe = StandardPipeline (model_name , ** task )
204- except TypeError :
205- # some models fail when given "device" kwargs, remove and try again
206- task .pop ("device" )
207- pipe = StandardPipeline (model_name , ** task )
208-
225+ pipe = create_pipeline (task )
209226 __cache_transform_pipeline_by_task [key ] = pipe
210227
211228 pipe = __cache_transform_pipeline_by_task [key ]
0 commit comments