🌐 AI搜索 & 代理 主页
Skip to content

Commit c8cc665

Browse files
committed
add cache-mode and cache-option
1 parent 24304f9 commit c8cc665

File tree

2 files changed

+59
-133
lines changed

2 files changed

+59
-133
lines changed

examples/cli/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,6 @@ Generation Options:
124124
--skip-layers layers to skip for SLG steps (default: [7,8,9])
125125
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
126126
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
127-
--easycache enable EasyCache for DiT models with optional "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95)
128-
--ucache enable UCache for UNET models with optional "threshold,start_percent,end_percent" (default: 1,0.15,0.95)
127+
--cache-mode caching method: 'easycache' for DiT models, 'ucache' for UNET models (SD1.x/SD2.x/SDXL)
128+
--cache-option cache parameters "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache)
129129
```

examples/cli/main.cpp

Lines changed: 57 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,10 +1061,9 @@ struct SDGenerationParams {
10611061
std::vector<int> high_noise_skip_layers = {7, 8, 9};
10621062
sd_sample_params_t high_noise_sample_params;
10631063

1064-
std::string easycache_option;
1064+
std::string cache_mode;
1065+
std::string cache_option;
10651066
sd_easycache_params_t easycache_params;
1066-
1067-
std::string ucache_option;
10681067
sd_ucache_params_t ucache_params;
10691068

10701069
float moe_boundary = 0.875f;
@@ -1385,68 +1384,24 @@ struct SDGenerationParams {
13851384
return 1;
13861385
};
13871386

1388-
auto on_easycache_arg = [&](int argc, const char** argv, int index) {
1389-
const std::string default_values = "0.2,0.15,0.95";
1390-
auto looks_like_value = [](const std::string& token) {
1391-
if (token.empty()) {
1392-
return false;
1393-
}
1394-
if (token[0] != '-') {
1395-
return true;
1396-
}
1397-
if (token.size() == 1) {
1398-
return false;
1399-
}
1400-
unsigned char next = static_cast<unsigned char>(token[1]);
1401-
return std::isdigit(next) || token[1] == '.';
1402-
};
1403-
1404-
std::string option_value;
1405-
int consumed = 0;
1406-
if (index + 1 < argc) {
1407-
std::string next_arg = argv[index + 1];
1408-
if (looks_like_value(next_arg)) {
1409-
option_value = argv_to_utf8(index + 1, argv);
1410-
consumed = 1;
1411-
}
1387+
auto on_cache_mode_arg = [&](int argc, const char** argv, int index) {
1388+
if (++index >= argc) {
1389+
return -1;
14121390
}
1413-
if (option_value.empty()) {
1414-
option_value = default_values;
1391+
cache_mode = argv_to_utf8(index, argv);
1392+
if (cache_mode != "easycache" && cache_mode != "ucache") {
1393+
fprintf(stderr, "error: invalid cache mode '%s', must be 'easycache' or 'ucache'\n", cache_mode.c_str());
1394+
return -1;
14151395
}
1416-
easycache_option = option_value;
1417-
return consumed;
1396+
return 1;
14181397
};
14191398

1420-
auto on_ucache_arg = [&](int argc, const char** argv, int index) {
1421-
const std::string default_values = "1.0,0.15,0.95";
1422-
auto looks_like_value = [](const std::string& token) {
1423-
if (token.empty()) {
1424-
return false;
1425-
}
1426-
if (token[0] != '-') {
1427-
return true;
1428-
}
1429-
if (token.size() == 1) {
1430-
return false;
1431-
}
1432-
unsigned char next = static_cast<unsigned char>(token[1]);
1433-
return std::isdigit(next) || token[1] == '.';
1434-
};
1435-
1436-
std::string option_value;
1437-
int consumed = 0;
1438-
if (index + 1 < argc) {
1439-
std::string next_arg = argv[index + 1];
1440-
if (looks_like_value(next_arg)) {
1441-
option_value = argv_to_utf8(index + 1, argv);
1442-
consumed = 1;
1443-
}
1444-
}
1445-
if (option_value.empty()) {
1446-
option_value = default_values;
1399+
auto on_cache_option_arg = [&](int argc, const char** argv, int index) {
1400+
if (++index >= argc) {
1401+
return -1;
14471402
}
1448-
ucache_option = option_value;
1449-
return consumed;
1403+
cache_option = argv_to_utf8(index, argv);
1404+
return 1;
14501405
};
14511406

14521407
options.manual_options = {
@@ -1481,13 +1436,13 @@ struct SDGenerationParams {
14811436
"reference image for Flux Kontext models (can be used multiple times)",
14821437
on_ref_image_arg},
14831438
{"",
1484-
"--easycache",
1485-
"enable EasyCache for DiT models with optional \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95)",
1486-
on_easycache_arg},
1439+
"--cache-mode",
1440+
"caching method: 'easycache' for DiT models, 'ucache' for UNET models (SD1.x/SD2.x/SDXL)",
1441+
on_cache_mode_arg},
14871442
{"",
1488-
"--ucache",
1489-
"enable UCache for UNET models (SD1.x/SD2.x/SDXL) with optional \"threshold,start_percent,end_percent\" (default: 1.0,0.15,0.95)",
1490-
on_ucache_arg},
1443+
"--cache-option",
1444+
"cache parameters \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache)",
1445+
on_cache_option_arg},
14911446

14921447
};
14931448

@@ -1600,62 +1555,21 @@ struct SDGenerationParams {
16001555
return false;
16011556
}
16021557

1603-
if (!easycache_option.empty()) {
1604-
float values[3] = {0.0f, 0.0f, 0.0f};
1605-
std::stringstream ss(easycache_option);
1606-
std::string token;
1607-
int idx = 0;
1608-
while (std::getline(ss, token, ',')) {
1609-
auto trim = [](std::string& s) {
1610-
const char* whitespace = " \t\r\n";
1611-
auto start = s.find_first_not_of(whitespace);
1612-
if (start == std::string::npos) {
1613-
s.clear();
1614-
return;
1615-
}
1616-
auto end = s.find_last_not_of(whitespace);
1617-
s = s.substr(start, end - start + 1);
1618-
};
1619-
trim(token);
1620-
if (token.empty()) {
1621-
fprintf(stderr, "error: invalid easycache option '%s'\n", easycache_option.c_str());
1622-
return false;
1623-
}
1624-
if (idx >= 3) {
1625-
fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
1626-
return false;
1627-
}
1628-
try {
1629-
values[idx] = std::stof(token);
1630-
} catch (const std::exception&) {
1631-
fprintf(stderr, "error: invalid easycache value '%s'\n", token.c_str());
1632-
return false;
1558+
easycache_params.enabled = false;
1559+
ucache_params.enabled = false;
1560+
1561+
if (!cache_mode.empty()) {
1562+
std::string option_str = cache_option;
1563+
if (option_str.empty()) {
1564+
if (cache_mode == "easycache") {
1565+
option_str = "0.2,0.15,0.95";
1566+
} else {
1567+
option_str = "1.0,0.15,0.95";
16331568
}
1634-
idx++;
1635-
}
1636-
if (idx != 3) {
1637-
fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
1638-
return false;
16391569
}
1640-
if (values[0] < 0.0f) {
1641-
fprintf(stderr, "error: easycache threshold must be non-negative\n");
1642-
return false;
1643-
}
1644-
if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) {
1645-
fprintf(stderr, "error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
1646-
return false;
1647-
}
1648-
easycache_params.enabled = true;
1649-
easycache_params.reuse_threshold = values[0];
1650-
easycache_params.start_percent = values[1];
1651-
easycache_params.end_percent = values[2];
1652-
} else {
1653-
easycache_params.enabled = false;
1654-
}
16551570

1656-
if (!ucache_option.empty()) {
16571571
float values[3] = {0.0f, 0.0f, 0.0f};
1658-
std::stringstream ss(ucache_option);
1572+
std::stringstream ss(option_str);
16591573
std::string token;
16601574
int idx = 0;
16611575
while (std::getline(ss, token, ',')) {
@@ -1671,39 +1585,45 @@ struct SDGenerationParams {
16711585
};
16721586
trim(token);
16731587
if (token.empty()) {
1674-
fprintf(stderr, "error: invalid ucache option '%s'\n", ucache_option.c_str());
1588+
fprintf(stderr, "error: invalid cache option '%s'\n", option_str.c_str());
16751589
return false;
16761590
}
16771591
if (idx >= 3) {
1678-
fprintf(stderr, "error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n");
1592+
fprintf(stderr, "error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n");
16791593
return false;
16801594
}
16811595
try {
16821596
values[idx] = std::stof(token);
16831597
} catch (const std::exception&) {
1684-
fprintf(stderr, "error: invalid ucache value '%s'\n", token.c_str());
1598+
fprintf(stderr, "error: invalid cache option value '%s'\n", token.c_str());
16851599
return false;
16861600
}
16871601
idx++;
16881602
}
16891603
if (idx != 3) {
1690-
fprintf(stderr, "error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n");
1604+
fprintf(stderr, "error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n");
16911605
return false;
16921606
}
16931607
if (values[0] < 0.0f) {
1694-
fprintf(stderr, "error: ucache threshold must be non-negative\n");
1608+
fprintf(stderr, "error: cache threshold must be non-negative\n");
16951609
return false;
16961610
}
16971611
if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) {
1698-
fprintf(stderr, "error: ucache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
1612+
fprintf(stderr, "error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
16991613
return false;
17001614
}
1701-
ucache_params.enabled = true;
1702-
ucache_params.reuse_threshold = values[0];
1703-
ucache_params.start_percent = values[1];
1704-
ucache_params.end_percent = values[2];
1705-
} else {
1706-
ucache_params.enabled = false;
1615+
1616+
if (cache_mode == "easycache") {
1617+
easycache_params.enabled = true;
1618+
easycache_params.reuse_threshold = values[0];
1619+
easycache_params.start_percent = values[1];
1620+
easycache_params.end_percent = values[2];
1621+
} else {
1622+
ucache_params.enabled = true;
1623+
ucache_params.reuse_threshold = values[0];
1624+
ucache_params.start_percent = values[1];
1625+
ucache_params.end_percent = values[2];
1626+
}
17071627
}
17081628

17091629
sample_params.guidance.slg.layers = skip_layers.data();
@@ -1798,12 +1718,18 @@ struct SDGenerationParams {
17981718
<< " sample_params: " << sample_params_str << ",\n"
17991719
<< " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n"
18001720
<< " high_noise_sample_params: " << high_noise_sample_params_str << ",\n"
1801-
<< " easycache_option: \"" << easycache_option << "\",\n"
1721+
<< " cache_mode: \"" << cache_mode << "\",\n"
1722+
<< " cache_option: \"" << cache_option << "\",\n"
18021723
<< " easycache: "
18031724
<< (easycache_params.enabled ? "enabled" : "disabled")
18041725
<< " (threshold=" << easycache_params.reuse_threshold
18051726
<< ", start=" << easycache_params.start_percent
18061727
<< ", end=" << easycache_params.end_percent << "),\n"
1728+
<< " ucache: "
1729+
<< (ucache_params.enabled ? "enabled" : "disabled")
1730+
<< " (threshold=" << ucache_params.reuse_threshold
1731+
<< ", start=" << ucache_params.start_percent
1732+
<< ", end=" << ucache_params.end_percent << "),\n"
18071733
<< " moe_boundary: " << moe_boundary << ",\n"
18081734
<< " video_frames: " << video_frames << ",\n"
18091735
<< " fps: " << fps << ",\n"

0 commit comments

Comments
 (0)