|
3 | 3 | import shutil |
4 | 4 | import time |
5 | 5 | import queue |
6 | | -import sys |
7 | 6 |
|
8 | 7 | import datasets |
9 | 8 | from InstructorEmbedding import INSTRUCTOR |
|
42 | 41 | Trainer, |
43 | 42 | ) |
44 | 43 | from threading import Thread |
45 | | -from typing import Optional |
46 | 44 |
|
47 | 45 | __cache_transformer_by_model_id = {} |
48 | 46 | __cache_sentence_transformer_by_name = {} |
@@ -393,42 +391,28 @@ def transform(task, args, inputs, stream=False): |
393 | 391 | return orjson.dumps(pipe(inputs, **args), default=orjson_default).decode() |
394 | 392 |
|
395 | 393 |
|
396 | | -def create_embedding(transformer): |
| 394 | +def embed(transformer, inputs, kwargs): |
| 395 | + kwargs = orjson.loads(kwargs) |
| 396 | + ensure_device(kwargs) |
397 | 397 | instructor = transformer.startswith("hkunlp/instructor") |
398 | | - klass = INSTRUCTOR if instructor else SentenceTransformer |
399 | | - return klass(transformer) |
400 | 398 |
|
| 399 | + # Cache the model |
| 400 | + if transformer not in __cache_sentence_transformer_by_name: |
| 401 | + klass = INSTRUCTOR if instructor else SentenceTransformer |
| 402 | + __cache_sentence_transformer_by_name[transformer] = klass(transformer) |
| 403 | + model = __cache_sentence_transformer_by_name[transformer] |
401 | 404 |
|
402 | | -def embed_using(model, transformer, inputs, kwargs): |
403 | | - if isinstance(kwargs, str): |
404 | | - kwargs = orjson.loads(kwargs) |
405 | | - |
406 | | - instructor = transformer.startswith("hkunlp/instructor") |
| 405 | + # Handle instruction encoding |
407 | 406 | if instructor: |
408 | 407 | texts_with_instructions = [] |
409 | 408 | instruction = kwargs.pop("instruction") |
410 | 409 | for text in inputs: |
411 | 410 | texts_with_instructions.append([instruction, text]) |
412 | | - |
413 | 411 | inputs = texts_with_instructions |
414 | 412 |
|
415 | 413 | return model.encode(inputs, **kwargs) |
416 | 414 |
|
417 | 415 |
|
418 | | -def embed(transformer, inputs, kwargs): |
419 | | - kwargs = orjson.loads(kwargs) |
420 | | - |
421 | | - ensure_device(kwargs) |
422 | | - |
423 | | - if transformer not in __cache_sentence_transformer_by_name: |
424 | | - __cache_sentence_transformer_by_name[transformer] = create_embedding( |
425 | | - transformer |
426 | | - ) |
427 | | - model = __cache_sentence_transformer_by_name[transformer] |
428 | | - |
429 | | - return embed_using(model, transformer, inputs, kwargs) |
430 | | - |
431 | | - |
432 | 416 | def clear_gpu_cache(memory_usage: None): |
433 | 417 | if not torch.cuda.is_available(): |
434 | 418 | raise PgMLException(f"No GPU available") |
|
0 commit comments