🌐 AI搜索 & 代理 主页
Skip to content

Commit debd9ae

Browse files
Separate embedding kwargs into init kwargs and encode kwargs (#1555)
Co-authored-by: Montana Low <montana.low@gmail.com>
1 parent fec164a commit debd9ae

File tree

4 files changed

+30
-6
lines changed

4 files changed

+30
-6
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
\timing on
2+
3+
SELECT pgml.embed('Alibaba-NLP/gte-base-en-v1.5', 'hi mom', '{"trust_remote_code": true}');
4+
SELECT pgml.embed('Alibaba-NLP/gte-base-en-v1.5', 'hi mom', '{"device": "cuda", "trust_remote_code": true}');
5+
SELECT pgml.embed('Alibaba-NLP/gte-base-en-v1.5', 'hi mom', '{"device": "cpu", "trust_remote_code": true}');
6+
SELECT pgml.embed('hkunlp/instructor-xl', 'hi mom', '{"instruction": "Encode it with love"}');
7+
SELECT pgml.embed('mixedbread-ai/mxbai-embed-large-v1', 'test', '{"prompt": "test prompt: "}');

pgml-extension/examples/image_classification.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', hyperpara
6666

6767
-- runtimes
6868
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear', runtime => 'python');
69-
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear', runtime => 'rust');
69+
--SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear', runtime => 'rust');
7070

7171
--SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', runtime => 'python', hyperparams => '{"n_estimators": 10}'); -- too slow
7272
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', runtime => 'rust', hyperparams => '{"n_estimators": 10}');

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,8 @@ def rank(transformer, query, documents, kwargs):
527527
return rank_using(model, query, documents, kwargs)
528528

529529

530-
def create_embedding(transformer):
531-
return SentenceTransformer(transformer)
530+
def create_embedding(transformer, kwargs):
531+
return SentenceTransformer(transformer, **kwargs)
532532

533533

534534
def embed_using(model, transformer, inputs, kwargs):
@@ -545,16 +545,32 @@ def embed_using(model, transformer, inputs, kwargs):
545545

546546
def embed(transformer, inputs, kwargs):
547547
kwargs = orjson.loads(kwargs)
548-
549548
ensure_device(kwargs)
550549

550+
init_kwarg_keys = [
551+
"device",
552+
"trust_remote_code",
553+
"revision",
554+
"model_kwargs",
555+
"tokenizer_kwargs",
556+
"config_kwargs",
557+
"truncate_dim",
558+
"token",
559+
]
560+
init_kwargs = {
561+
key: value for key, value in kwargs.items() if key in init_kwarg_keys
562+
}
563+
encode_kwargs = {
564+
key: value for key, value in kwargs.items() if key not in init_kwarg_keys
565+
}
566+
551567
if transformer not in __cache_sentence_transformer_by_name:
552568
__cache_sentence_transformer_by_name[transformer] = create_embedding(
553-
transformer
569+
transformer, init_kwargs
554570
)
555571
model = __cache_sentence_transformer_by_name[transformer]
556572

557-
return embed_using(model, transformer, inputs, kwargs)
573+
return embed_using(model, transformer, inputs, encode_kwargs)
558574

559575

560576
def clear_gpu_cache(memory_usage: None):

pgml-extension/tests/test.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,6 @@ SELECT pgml.load_dataset('wine');
3131
\i examples/vectors.sql
3232
\i examples/chunking.sql
3333
\i examples/preprocessing.sql
34+
\i examples/embedding.sql
3435
-- transformers are generally too slow to run in the test suite
3536
--\i examples/transformers.sql

0 commit comments

Comments
 (0)