🌐 AI搜索 & 代理 主页
Skip to content

Commit 96c3e64

Browse files
authored
refactor: optimize the handling of embedding (leejet#1068)
* optimize the handling of embedding * support case-insensitive embedding names
1 parent 0392273 commit 96c3e64

File tree

5 files changed

+165
-82
lines changed

5 files changed

+165
-82
lines changed

clip.hpp

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "ggml_extend.hpp"
55
#include "model.h"
6+
#include "tokenize_util.h"
67

78
/*================================================== CLIPTokenizer ===================================================*/
89

@@ -72,6 +73,8 @@ class CLIPTokenizer {
7273
int encoder_len;
7374
int bpe_len;
7475

76+
std::vector<std::string> special_tokens;
77+
7578
public:
7679
const std::string UNK_TOKEN = "<|endoftext|>";
7780
const std::string BOS_TOKEN = "<|startoftext|>";
@@ -117,6 +120,15 @@ class CLIPTokenizer {
117120
return pairs;
118121
}
119122

123+
bool is_special_token(const std::string& token) {
124+
for (auto& special_token : special_tokens) {
125+
if (special_token == token) {
126+
return true;
127+
}
128+
}
129+
return false;
130+
}
131+
120132
public:
121133
CLIPTokenizer(int pad_token_id = 49407, const std::string& merges_utf8_str = "")
122134
: PAD_TOKEN_ID(pad_token_id) {
@@ -125,6 +137,8 @@ class CLIPTokenizer {
125137
} else {
126138
load_from_merges(ModelLoader::load_merges());
127139
}
140+
add_special_token("<|startoftext|>");
141+
add_special_token("<|endoftext|>");
128142
}
129143

130144
void load_from_merges(const std::string& merges_utf8_str) {
@@ -201,6 +215,10 @@ class CLIPTokenizer {
201215
}
202216
}
203217

218+
void add_special_token(const std::string& token) {
219+
special_tokens.push_back(token);
220+
}
221+
204222
std::u32string bpe(const std::u32string& token) {
205223
std::vector<std::u32string> word;
206224

@@ -379,25 +397,54 @@ class CLIPTokenizer {
379397
return trim(text);
380398
}
381399

400+
std::vector<std::string> token_split(const std::string& text) {
401+
std::regex pat(R"('s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)",
402+
std::regex::icase);
403+
std::sregex_iterator iter(text.begin(), text.end(), pat);
404+
std::sregex_iterator end;
405+
406+
std::vector<std::string> result;
407+
for (; iter != end; ++iter) {
408+
result.emplace_back(iter->str());
409+
}
410+
411+
return result;
412+
}
413+
382414
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb) {
383415
std::string original_text = text;
384416
std::vector<int32_t> bpe_tokens;
385417
text = whitespace_clean(text);
386418
std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); });
387419

388-
std::regex pat(R"(<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)",
389-
std::regex::icase);
390-
391-
std::smatch matches;
392420
std::string str = text;
393421
std::vector<std::string> token_strs;
394-
while (std::regex_search(str, matches, pat)) {
395-
bool skip = on_new_token_cb(str, bpe_tokens);
396-
if (skip) {
422+
423+
auto splited_texts = split_with_special_tokens(text, special_tokens);
424+
425+
for (auto& splited_text : splited_texts) {
426+
LOG_DEBUG("token %s", splited_text.c_str());
427+
if (is_special_token(splited_text)) {
428+
LOG_DEBUG("special %s", splited_text.c_str());
429+
bool skip = on_new_token_cb(splited_text, bpe_tokens);
430+
if (skip) {
431+
token_strs.push_back(splited_text);
432+
continue;
433+
}
397434
continue;
398435
}
399-
for (auto& token : matches) {
400-
std::string token_str = token.str();
436+
437+
auto tokens = token_split(splited_text);
438+
for (auto& token : tokens) {
439+
if (on_new_token_cb != nullptr) {
440+
bool skip = on_new_token_cb(token, bpe_tokens);
441+
if (skip) {
442+
token_strs.push_back(token);
443+
continue;
444+
}
445+
}
446+
447+
std::string token_str = token;
401448
std::u32string utf32_token;
402449
for (int i = 0; i < token_str.length(); i++) {
403450
unsigned char b = token_str[i];
@@ -417,14 +464,13 @@ class CLIPTokenizer {
417464
bpe_tokens.push_back(encoder[bpe_str]);
418465
token_strs.push_back(utf32_to_utf8(bpe_str));
419466
}
420-
str = matches.suffix();
421-
}
422-
std::stringstream ss;
423-
ss << "[";
424-
for (auto token : token_strs) {
425-
ss << "\"" << token << "\", ";
426467
}
427-
ss << "]";
468+
// std::stringstream ss;
469+
// ss << "[";
470+
// for (auto token : token_strs) {
471+
// ss << "\"" << token << "\", ";
472+
// }
473+
// ss << "]";
428474
// LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
429475
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
430476
return bpe_tokens;

conditioner.hpp

Lines changed: 30 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
5656
std::shared_ptr<CLIPTextModelRunner> text_model2;
5757

5858
std::string trigger_word = "img"; // should be user settable
59-
std::string embd_dir;
59+
std::map<std::string, std::string> embedding_map;
6060
int32_t num_custom_embeddings = 0;
6161
int32_t num_custom_embeddings_2 = 0;
6262
std::vector<uint8_t> token_embed_custom;
@@ -65,11 +65,17 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
6565
FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
6666
bool offload_params_to_cpu,
6767
const String2TensorStorage& tensor_storage_map,
68-
const std::string& embd_dir,
68+
const std::map<std::string, std::string>& orig_embedding_map,
6969
SDVersion version = VERSION_SD1,
7070
PMVersion pv = PM_VERSION_1)
71-
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) {
72-
bool force_clip_f32 = embd_dir.size() > 0;
71+
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407) {
72+
for (const auto& kv : orig_embedding_map) {
73+
std::string name = kv.first;
74+
std::transform(name.begin(), name.end(), name.begin(), [](unsigned char c) { return std::tolower(c); });
75+
embedding_map[name] = kv.second;
76+
tokenizer.add_special_token(name);
77+
}
78+
bool force_clip_f32 = !embedding_map.empty();
7379
if (sd_version_is_sd1(version)) {
7480
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, true, force_clip_f32);
7581
} else if (sd_version_is_sd2(version)) {
@@ -196,25 +202,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
196202

197203
std::vector<int> convert_token_to_id(std::string text) {
198204
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
199-
size_t word_end = str.find(",");
200-
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
201-
embd_name = trim(embd_name);
202-
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
203-
if (embd_path.size() == 0) {
204-
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
205+
auto iter = embedding_map.find(str);
206+
if (iter == embedding_map.end()) {
207+
return false;
205208
}
206-
if (embd_path.size() == 0) {
207-
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
208-
}
209-
if (embd_path.size() > 0) {
210-
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
211-
if (word_end != std::string::npos) {
212-
str = str.substr(word_end);
213-
} else {
214-
str = "";
215-
}
216-
return true;
217-
}
209+
std::string embedding_path = iter->second;
210+
if (load_embedding(str, embedding_path, bpe_tokens)) {
211+
return true;
218212
}
219213
return false;
220214
};
@@ -245,25 +239,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
245239
}
246240

247241
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
248-
size_t word_end = str.find(",");
249-
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
250-
embd_name = trim(embd_name);
251-
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
252-
if (embd_path.size() == 0) {
253-
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
254-
}
255-
if (embd_path.size() == 0) {
256-
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
242+
auto iter = embedding_map.find(str);
243+
if (iter == embedding_map.end()) {
244+
return false;
257245
}
258-
if (embd_path.size() > 0) {
259-
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
260-
if (word_end != std::string::npos) {
261-
str = str.substr(word_end);
262-
} else {
263-
str = "";
264-
}
265-
return true;
266-
}
246+
std::string embedding_path = iter->second;
247+
if (load_embedding(str, embedding_path, bpe_tokens)) {
248+
return true;
267249
}
268250
return false;
269251
};
@@ -376,25 +358,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
376358
}
377359

378360
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
379-
size_t word_end = str.find(",");
380-
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
381-
embd_name = trim(embd_name);
382-
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
383-
if (embd_path.size() == 0) {
384-
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
385-
}
386-
if (embd_path.size() == 0) {
387-
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
361+
auto iter = embedding_map.find(str);
362+
if (iter == embedding_map.end()) {
363+
return false;
388364
}
389-
if (embd_path.size() > 0) {
390-
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
391-
if (word_end != std::string::npos) {
392-
str = str.substr(word_end);
393-
} else {
394-
str = "";
395-
}
396-
return true;
397-
}
365+
std::string embedding_path = iter->second;
366+
if (load_embedding(str, embedding_path, bpe_tokens)) {
367+
return true;
398368
}
399369
return false;
400370
};
@@ -1728,7 +1698,7 @@ struct LLMEmbedder : public Conditioner {
17281698
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
17291699
std::pair<int, int> prompt_attn_range;
17301700
int prompt_template_encode_start_idx = 34;
1731-
int max_length = 0;
1701+
int max_length = 0;
17321702
std::set<int> out_layers;
17331703
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
17341704
LOG_INFO("QwenImageEditPlusPipeline");
@@ -1828,7 +1798,7 @@ struct LLMEmbedder : public Conditioner {
18281798
prompt += "[/INST]";
18291799
} else if (version == VERSION_OVIS_IMAGE) {
18301800
prompt_template_encode_start_idx = 28;
1831-
max_length = prompt_template_encode_start_idx + 256;
1801+
max_length = prompt_template_encode_start_idx + 256;
18321802

18331803
prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:";
18341804

examples/cli/main.cpp

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,9 @@ struct SDContextParams {
501501
std::string tensor_type_rules;
502502
std::string lora_model_dir;
503503

504+
std::map<std::string, std::string> embedding_map;
505+
std::vector<sd_embedding_t> embedding_array;
506+
504507
rng_type_t rng_type = CUDA_RNG;
505508
rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
506509
bool offload_params_to_cpu = false;
@@ -828,6 +831,37 @@ struct SDContextParams {
828831
return options;
829832
}
830833

834+
void build_embedding_map() {
835+
static const std::vector<std::string> valid_ext = {".pt", ".safetensors", ".gguf"};
836+
837+
if (!fs::exists(embedding_dir) || !fs::is_directory(embedding_dir)) {
838+
return;
839+
}
840+
841+
for (auto& p : fs::directory_iterator(embedding_dir)) {
842+
if (!p.is_regular_file())
843+
continue;
844+
845+
auto path = p.path();
846+
std::string ext = path.extension().string();
847+
848+
bool valid = false;
849+
for (auto& e : valid_ext) {
850+
if (ext == e) {
851+
valid = true;
852+
break;
853+
}
854+
}
855+
if (!valid)
856+
continue;
857+
858+
std::string key = path.stem().string();
859+
std::string value = path.string();
860+
861+
embedding_map[key] = value;
862+
}
863+
}
864+
831865
bool process_and_check(SDMode mode) {
832866
if (mode != UPSCALE && model_path.length() == 0 && diffusion_model_path.length() == 0) {
833867
fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n");
@@ -845,10 +879,24 @@ struct SDContextParams {
845879
n_threads = sd_get_num_physical_cores();
846880
}
847881

882+
build_embedding_map();
883+
848884
return true;
849885
}
850886

851887
std::string to_string() const {
888+
std::ostringstream emb_ss;
889+
emb_ss << "{\n";
890+
for (auto it = embedding_map.begin(); it != embedding_map.end(); ++it) {
891+
emb_ss << " \"" << it->first << "\": \"" << it->second << "\"";
892+
if (std::next(it) != embedding_map.end()) {
893+
emb_ss << ",";
894+
}
895+
emb_ss << "\n";
896+
}
897+
emb_ss << " }";
898+
899+
std::string embeddings_str = emb_ss.str();
852900
std::ostringstream oss;
853901
oss << "SDContextParams {\n"
854902
<< " n_threads: " << n_threads << ",\n"
@@ -866,6 +914,7 @@ struct SDContextParams {
866914
<< " esrgan_path: \"" << esrgan_path << "\",\n"
867915
<< " control_net_path: \"" << control_net_path << "\",\n"
868916
<< " embedding_dir: \"" << embedding_dir << "\",\n"
917+
<< " embeddings: " << embeddings_str << "\n"
869918
<< " wtype: " << sd_type_name(wtype) << ",\n"
870919
<< " tensor_type_rules: \"" << tensor_type_rules << "\",\n"
871920
<< " lora_model_dir: \"" << lora_model_dir << "\",\n"
@@ -898,6 +947,15 @@ struct SDContextParams {
898947
}
899948

900949
sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview) {
950+
embedding_array.clear();
951+
embedding_array.reserve(embedding_map.size());
952+
for (const auto& kv : embedding_map) {
953+
sd_embedding_t item;
954+
item.name = kv.first.c_str();
955+
item.path = kv.second.c_str();
956+
embedding_array.emplace_back(item);
957+
}
958+
901959
sd_ctx_params_t sd_ctx_params = {
902960
model_path.c_str(),
903961
clip_l_path.c_str(),
@@ -912,7 +970,8 @@ struct SDContextParams {
912970
taesd_path.c_str(),
913971
control_net_path.c_str(),
914972
lora_model_dir.c_str(),
915-
embedding_dir.c_str(),
973+
embedding_array.data(),
974+
static_cast<uint32_t>(embedding_array.size()),
916975
photo_maker_path.c_str(),
917976
tensor_type_rules.c_str(),
918977
vae_decode_only,

0 commit comments

Comments
 (0)