@@ -1061,10 +1061,9 @@ struct SDGenerationParams {
10611061 std::vector<int > high_noise_skip_layers = {7 , 8 , 9 };
10621062 sd_sample_params_t high_noise_sample_params;
10631063
1064- std::string easycache_option;
1064+ std::string cache_mode;
1065+ std::string cache_option;
10651066 sd_easycache_params_t easycache_params;
1066-
1067- std::string ucache_option;
10681067 sd_ucache_params_t ucache_params;
10691068
10701069 float moe_boundary = 0 .875f ;
@@ -1385,68 +1384,24 @@ struct SDGenerationParams {
13851384 return 1 ;
13861385 };
13871386
1388- auto on_easycache_arg = [&](int argc, const char ** argv, int index) {
1389- const std::string default_values = " 0.2,0.15,0.95" ;
1390- auto looks_like_value = [](const std::string& token) {
1391- if (token.empty ()) {
1392- return false ;
1393- }
1394- if (token[0 ] != ' -' ) {
1395- return true ;
1396- }
1397- if (token.size () == 1 ) {
1398- return false ;
1399- }
1400- unsigned char next = static_cast <unsigned char >(token[1 ]);
1401- return std::isdigit (next) || token[1 ] == ' .' ;
1402- };
1403-
1404- std::string option_value;
1405- int consumed = 0 ;
1406- if (index + 1 < argc) {
1407- std::string next_arg = argv[index + 1 ];
1408- if (looks_like_value (next_arg)) {
1409- option_value = argv_to_utf8 (index + 1 , argv);
1410- consumed = 1 ;
1411- }
1387+ auto on_cache_mode_arg = [&](int argc, const char ** argv, int index) {
1388+ if (++index >= argc) {
1389+ return -1 ;
14121390 }
1413- if (option_value.empty ()) {
1414- option_value = default_values;
1391+ cache_mode = argv_to_utf8 (index, argv);
1392+ if (cache_mode != " easycache" && cache_mode != " ucache" ) {
1393+ fprintf (stderr, " error: invalid cache mode '%s', must be 'easycache' or 'ucache'\n " , cache_mode.c_str ());
1394+ return -1 ;
14151395 }
1416- easycache_option = option_value;
1417- return consumed;
1396+ return 1 ;
14181397 };
14191398
1420- auto on_ucache_arg = [&](int argc, const char ** argv, int index) {
1421- const std::string default_values = " 1.0,0.15,0.95" ;
1422- auto looks_like_value = [](const std::string& token) {
1423- if (token.empty ()) {
1424- return false ;
1425- }
1426- if (token[0 ] != ' -' ) {
1427- return true ;
1428- }
1429- if (token.size () == 1 ) {
1430- return false ;
1431- }
1432- unsigned char next = static_cast <unsigned char >(token[1 ]);
1433- return std::isdigit (next) || token[1 ] == ' .' ;
1434- };
1435-
1436- std::string option_value;
1437- int consumed = 0 ;
1438- if (index + 1 < argc) {
1439- std::string next_arg = argv[index + 1 ];
1440- if (looks_like_value (next_arg)) {
1441- option_value = argv_to_utf8 (index + 1 , argv);
1442- consumed = 1 ;
1443- }
1444- }
1445- if (option_value.empty ()) {
1446- option_value = default_values;
1399+ auto on_cache_option_arg = [&](int argc, const char ** argv, int index) {
1400+ if (++index >= argc) {
1401+ return -1 ;
14471402 }
1448- ucache_option = option_value ;
1449- return consumed ;
1403+ cache_option = argv_to_utf8 (index, argv) ;
1404+ return 1 ;
14501405 };
14511406
14521407 options.manual_options = {
@@ -1481,13 +1436,13 @@ struct SDGenerationParams {
14811436 " reference image for Flux Kontext models (can be used multiple times)" ,
14821437 on_ref_image_arg},
14831438 {" " ,
1484- " --easycache " ,
1485- " enable EasyCache for DiT models with optional \" threshold,start_percent,end_percent \" (default: 0.2,0.15,0.95 )" ,
1486- on_easycache_arg },
1439+ " --cache-mode " ,
1440+ " caching method: 'easycache' for DiT models, 'ucache' for UNET models (SD1.x/SD2.x/SDXL )" ,
1441+ on_cache_mode_arg },
14871442 {" " ,
1488- " --ucache " ,
1489- " enable UCache for UNET models (SD1.x/SD2.x/SDXL) with optional \" threshold,start_percent,end_percent\" (default: 1.0,0.15,0.95)" ,
1490- on_ucache_arg },
1443+ " --cache-option " ,
1444+ " cache parameters \" threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache )" ,
1445+ on_cache_option_arg },
14911446
14921447 };
14931448
@@ -1600,62 +1555,21 @@ struct SDGenerationParams {
16001555 return false ;
16011556 }
16021557
1603- if (!easycache_option.empty ()) {
1604- float values[3 ] = {0 .0f , 0 .0f , 0 .0f };
1605- std::stringstream ss (easycache_option);
1606- std::string token;
1607- int idx = 0 ;
1608- while (std::getline (ss, token, ' ,' )) {
1609- auto trim = [](std::string& s) {
1610- const char * whitespace = " \t\r\n " ;
1611- auto start = s.find_first_not_of (whitespace);
1612- if (start == std::string::npos) {
1613- s.clear ();
1614- return ;
1615- }
1616- auto end = s.find_last_not_of (whitespace);
1617- s = s.substr (start, end - start + 1 );
1618- };
1619- trim (token);
1620- if (token.empty ()) {
1621- fprintf (stderr, " error: invalid easycache option '%s'\n " , easycache_option.c_str ());
1622- return false ;
1623- }
1624- if (idx >= 3 ) {
1625- fprintf (stderr, " error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n " );
1626- return false ;
1627- }
1628- try {
1629- values[idx] = std::stof (token);
1630- } catch (const std::exception&) {
1631- fprintf (stderr, " error: invalid easycache value '%s'\n " , token.c_str ());
1632- return false ;
1558+ easycache_params.enabled = false ;
1559+ ucache_params.enabled = false ;
1560+
1561+ if (!cache_mode.empty ()) {
1562+ std::string option_str = cache_option;
1563+ if (option_str.empty ()) {
1564+ if (cache_mode == " easycache" ) {
1565+ option_str = " 0.2,0.15,0.95" ;
1566+ } else {
1567+ option_str = " 1.0,0.15,0.95" ;
16331568 }
1634- idx++;
1635- }
1636- if (idx != 3 ) {
1637- fprintf (stderr, " error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n " );
1638- return false ;
16391569 }
1640- if (values[0 ] < 0 .0f ) {
1641- fprintf (stderr, " error: easycache threshold must be non-negative\n " );
1642- return false ;
1643- }
1644- if (values[1 ] < 0 .0f || values[1 ] >= 1 .0f || values[2 ] <= 0 .0f || values[2 ] > 1 .0f || values[1 ] >= values[2 ]) {
1645- fprintf (stderr, " error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n " );
1646- return false ;
1647- }
1648- easycache_params.enabled = true ;
1649- easycache_params.reuse_threshold = values[0 ];
1650- easycache_params.start_percent = values[1 ];
1651- easycache_params.end_percent = values[2 ];
1652- } else {
1653- easycache_params.enabled = false ;
1654- }
16551570
1656- if (!ucache_option.empty ()) {
16571571 float values[3 ] = {0 .0f , 0 .0f , 0 .0f };
1658- std::stringstream ss (ucache_option );
1572+ std::stringstream ss (option_str );
16591573 std::string token;
16601574 int idx = 0 ;
16611575 while (std::getline (ss, token, ' ,' )) {
@@ -1671,39 +1585,45 @@ struct SDGenerationParams {
16711585 };
16721586 trim (token);
16731587 if (token.empty ()) {
1674- fprintf (stderr, " error: invalid ucache option '%s'\n " , ucache_option .c_str ());
1588+ fprintf (stderr, " error: invalid cache option '%s'\n " , option_str .c_str ());
16751589 return false ;
16761590 }
16771591 if (idx >= 3 ) {
1678- fprintf (stderr, " error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n " );
1592+ fprintf (stderr, " error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n " );
16791593 return false ;
16801594 }
16811595 try {
16821596 values[idx] = std::stof (token);
16831597 } catch (const std::exception&) {
1684- fprintf (stderr, " error: invalid ucache value '%s'\n " , token.c_str ());
1598+ fprintf (stderr, " error: invalid cache option value '%s'\n " , token.c_str ());
16851599 return false ;
16861600 }
16871601 idx++;
16881602 }
16891603 if (idx != 3 ) {
1690- fprintf (stderr, " error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n " );
1604+ fprintf (stderr, " error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n " );
16911605 return false ;
16921606 }
16931607 if (values[0 ] < 0 .0f ) {
1694- fprintf (stderr, " error: ucache threshold must be non-negative\n " );
1608+ fprintf (stderr, " error: cache threshold must be non-negative\n " );
16951609 return false ;
16961610 }
16971611 if (values[1 ] < 0 .0f || values[1 ] >= 1 .0f || values[2 ] <= 0 .0f || values[2 ] > 1 .0f || values[1 ] >= values[2 ]) {
1698- fprintf (stderr, " error: ucache start/end percents must satisfy 0.0 <= start < end <= 1.0\n " );
1612+ fprintf (stderr, " error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0\n " );
16991613 return false ;
17001614 }
1701- ucache_params.enabled = true ;
1702- ucache_params.reuse_threshold = values[0 ];
1703- ucache_params.start_percent = values[1 ];
1704- ucache_params.end_percent = values[2 ];
1705- } else {
1706- ucache_params.enabled = false ;
1615+
1616+ if (cache_mode == " easycache" ) {
1617+ easycache_params.enabled = true ;
1618+ easycache_params.reuse_threshold = values[0 ];
1619+ easycache_params.start_percent = values[1 ];
1620+ easycache_params.end_percent = values[2 ];
1621+ } else {
1622+ ucache_params.enabled = true ;
1623+ ucache_params.reuse_threshold = values[0 ];
1624+ ucache_params.start_percent = values[1 ];
1625+ ucache_params.end_percent = values[2 ];
1626+ }
17071627 }
17081628
17091629 sample_params.guidance .slg .layers = skip_layers.data ();
@@ -1798,12 +1718,18 @@ struct SDGenerationParams {
17981718 << " sample_params: " << sample_params_str << " ,\n "
17991719 << " high_noise_skip_layers: " << vec_to_string (high_noise_skip_layers) << " ,\n "
18001720 << " high_noise_sample_params: " << high_noise_sample_params_str << " ,\n "
1801- << " easycache_option: \" " << easycache_option << " \" ,\n "
1721+ << " cache_mode: \" " << cache_mode << " \" ,\n "
1722+ << " cache_option: \" " << cache_option << " \" ,\n "
18021723 << " easycache: "
18031724 << (easycache_params.enabled ? " enabled" : " disabled" )
18041725 << " (threshold=" << easycache_params.reuse_threshold
18051726 << " , start=" << easycache_params.start_percent
18061727 << " , end=" << easycache_params.end_percent << " ),\n "
1728+ << " ucache: "
1729+ << (ucache_params.enabled ? " enabled" : " disabled" )
1730+ << " (threshold=" << ucache_params.reuse_threshold
1731+ << " , start=" << ucache_params.start_percent
1732+ << " , end=" << ucache_params.end_percent << " ),\n "
18071733 << " moe_boundary: " << moe_boundary << " ,\n "
18081734 << " video_frames: " << video_frames << " ,\n "
18091735 << " fps: " << fps << " ,\n "
0 commit comments