@@ -1484,7 +1484,7 @@ struct SDGenerationParams {
14841484 on_cache_mode_arg},
14851485 {" " ,
14861486 " --cache-option" ,
1487- " cache params - legacy: \" threshold,start,end[ ,decay,relative] \" , cache-dit: \" Fn ,Bn,threshold,warmup\" (default: 8,0,0.08,8) " ,
1487+ " named cache params: easycache/ucache: threshold= ,start= ,end= ,decay= ,relative= | cache-dit: Fn= ,Bn= ,threshold= ,warmup= " ,
14881488 on_cache_option_arg},
14891489 {" " ,
14901490 " --scm-mask" ,
@@ -1613,88 +1613,125 @@ struct SDGenerationParams {
16131613 cache_params.mode = SD_CACHE_DISABLED;
16141614
16151615 if (!cache_mode.empty ()) {
1616- std::string option_str = cache_option;
1617- if (option_str.empty ()) {
1618- if (cache_mode == " easycache" ) {
1619- option_str = " 0.2,0.15,0.95" ;
1620- } else if (cache_mode == " ucache" ) {
1621- option_str = " 1.0,0.15,0.95" ;
1622- } else if (cache_mode == " dbcache" || cache_mode == " taylorseer" || cache_mode == " cache-dit" ) {
1623- option_str = " 8,0,0.08,8" ;
1624- }
1625- }
1626-
1627- float values[5 ] = {0 .0f , 0 .0f , 0 .0f , 1 .0f , 1 .0f };
1628- std::stringstream ss (option_str);
1629- std::string token;
1630- int idx = 0 ;
16311616 auto trim = [](std::string& s) {
16321617 const char * whitespace = " \t\r\n " ;
1633- auto start = s.find_first_not_of (whitespace);
1634- if (start == std::string::npos) {
1635- s.clear ();
1636- return ;
1637- }
1618+ auto start = s.find_first_not_of (whitespace);
1619+ if (start == std::string::npos) { s.clear (); return ; }
16381620 auto end = s.find_last_not_of (whitespace);
1639- s = s.substr (start, end - start + 1 );
1621+ s = s.substr (start, end - start + 1 );
16401622 };
1641- while (std::getline (ss, token, ' ,' )) {
1642- trim (token);
1643- if (token.empty ()) {
1644- fprintf (stderr, " error: invalid cache option '%s'\n " , option_str.c_str ());
1645- return false ;
1646- }
1647- if (idx >= 5 ) {
1648- fprintf (stderr, " error: cache option expects 3-5 comma-separated values (threshold,start,end[,decay,relative])\n " );
1649- return false ;
1623+
1624+ auto parse_named_params = [&](const std::string& opt_str) -> bool {
1625+ std::stringstream ss (opt_str);
1626+ std::string token;
1627+ while (std::getline (ss, token, ' ,' )) {
1628+ trim (token);
1629+ if (token.empty ()) continue ;
1630+
1631+ size_t eq_pos = token.find (' =' );
1632+ if (eq_pos == std::string::npos) {
1633+ fprintf (stderr, " error: invalid named parameter '%s', expected key=value\n " , token.c_str ());
1634+ return false ;
1635+ }
1636+
1637+ std::string key = token.substr (0 , eq_pos);
1638+ std::string val = token.substr (eq_pos + 1 );
1639+ trim (key);
1640+ trim (val);
1641+
1642+ if (key.empty () || val.empty ()) {
1643+ fprintf (stderr, " error: invalid named parameter '%s'\n " , token.c_str ());
1644+ return false ;
1645+ }
1646+
1647+ try {
1648+ if (key == " threshold" ) {
1649+ if (cache_mode == " easycache" || cache_mode == " ucache" ) {
1650+ cache_params.reuse_threshold = std::stof (val);
1651+ } else {
1652+ cache_params.residual_diff_threshold = std::stof (val);
1653+ }
1654+ } else if (key == " start" ) {
1655+ cache_params.start_percent = std::stof (val);
1656+ } else if (key == " end" ) {
1657+ cache_params.end_percent = std::stof (val);
1658+ } else if (key == " decay" ) {
1659+ cache_params.error_decay_rate = std::stof (val);
1660+ } else if (key == " relative" ) {
1661+ cache_params.use_relative_threshold = (std::stof (val) != 0 .0f );
1662+ } else if (key == " Fn" || key == " fn" ) {
1663+ cache_params.Fn_compute_blocks = std::stoi (val);
1664+ } else if (key == " Bn" || key == " bn" ) {
1665+ cache_params.Bn_compute_blocks = std::stoi (val);
1666+ } else if (key == " warmup" ) {
1667+ cache_params.max_warmup_steps = std::stoi (val);
1668+ } else {
1669+ fprintf (stderr, " error: unknown cache parameter '%s'\n " , key.c_str ());
1670+ return false ;
1671+ }
1672+ } catch (const std::exception&) {
1673+ fprintf (stderr, " error: invalid value '%s' for parameter '%s'\n " , val.c_str (), key.c_str ());
1674+ return false ;
1675+ }
16501676 }
1651- try {
1652- values[idx] = std::stof (token);
1653- } catch (const std::exception&) {
1654- fprintf (stderr, " error: invalid cache option value '%s'\n " , token.c_str ());
1677+ return true ;
1678+ };
1679+
1680+ if (cache_mode == " easycache" ) {
1681+ cache_params.mode = SD_CACHE_EASYCACHE;
1682+ cache_params.reuse_threshold = 0 .2f ;
1683+ cache_params.start_percent = 0 .15f ;
1684+ cache_params.end_percent = 0 .95f ;
1685+ cache_params.error_decay_rate = 1 .0f ;
1686+ cache_params.use_relative_threshold = true ;
1687+ } else if (cache_mode == " ucache" ) {
1688+ cache_params.mode = SD_CACHE_UCACHE;
1689+ cache_params.reuse_threshold = 1 .0f ;
1690+ cache_params.start_percent = 0 .15f ;
1691+ cache_params.end_percent = 0 .95f ;
1692+ cache_params.error_decay_rate = 1 .0f ;
1693+ cache_params.use_relative_threshold = true ;
1694+ } else if (cache_mode == " dbcache" ) {
1695+ cache_params.mode = SD_CACHE_DBCACHE;
1696+ cache_params.Fn_compute_blocks = 8 ;
1697+ cache_params.Bn_compute_blocks = 0 ;
1698+ cache_params.residual_diff_threshold = 0 .08f ;
1699+ cache_params.max_warmup_steps = 8 ;
1700+ } else if (cache_mode == " taylorseer" ) {
1701+ cache_params.mode = SD_CACHE_TAYLORSEER;
1702+ cache_params.Fn_compute_blocks = 8 ;
1703+ cache_params.Bn_compute_blocks = 0 ;
1704+ cache_params.residual_diff_threshold = 0 .08f ;
1705+ cache_params.max_warmup_steps = 8 ;
1706+ } else if (cache_mode == " cache-dit" ) {
1707+ cache_params.mode = SD_CACHE_CACHE_DIT;
1708+ cache_params.Fn_compute_blocks = 8 ;
1709+ cache_params.Bn_compute_blocks = 0 ;
1710+ cache_params.residual_diff_threshold = 0 .08f ;
1711+ cache_params.max_warmup_steps = 8 ;
1712+ } else {
1713+ fprintf (stderr, " error: unknown cache mode '%s'\n " , cache_mode.c_str ());
1714+ return false ;
1715+ }
1716+
1717+ if (!cache_option.empty ()) {
1718+ if (!parse_named_params (cache_option)) {
16551719 return false ;
16561720 }
1657- idx++;
16581721 }
1722+
16591723 if (cache_mode == " easycache" || cache_mode == " ucache" ) {
1660- if (idx < 3 ) {
1661- fprintf (stderr, " error: cache option expects at least 3 comma-separated values (threshold,start,end)\n " );
1662- return false ;
1663- }
1664- if (values[0 ] < 0 .0f ) {
1724+ if (cache_params.reuse_threshold < 0 .0f ) {
16651725 fprintf (stderr, " error: cache threshold must be non-negative\n " );
16661726 return false ;
16671727 }
1668- if (values[1 ] < 0 .0f || values[1 ] >= 1 .0f || values[2 ] <= 0 .0f || values[2 ] > 1 .0f || values[1 ] >= values[2 ]) {
1728+ if (cache_params.start_percent < 0 .0f || cache_params.start_percent >= 1 .0f ||
1729+ cache_params.end_percent <= 0 .0f || cache_params.end_percent > 1 .0f ||
1730+ cache_params.start_percent >= cache_params.end_percent ) {
16691731 fprintf (stderr, " error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0\n " );
16701732 return false ;
16711733 }
16721734 }
1673-
1674- if (cache_mode == " easycache" || cache_mode == " ucache" ) {
1675- cache_params.reuse_threshold = values[0 ];
1676- cache_params.start_percent = values[1 ];
1677- cache_params.end_percent = values[2 ];
1678- cache_params.error_decay_rate = values[3 ];
1679- cache_params.use_relative_threshold = (values[4 ] != 0 .0f );
1680- if (cache_mode == " easycache" ) {
1681- cache_params.mode = SD_CACHE_EASYCACHE;
1682- } else {
1683- cache_params.mode = SD_CACHE_UCACHE;
1684- }
1685- } else {
1686- cache_params.Fn_compute_blocks = (idx >= 1 ) ? static_cast <int >(values[0 ]) : 8 ;
1687- cache_params.Bn_compute_blocks = (idx >= 2 ) ? static_cast <int >(values[1 ]) : 0 ;
1688- cache_params.residual_diff_threshold = (idx >= 3 ) ? values[2 ] : 0 .08f ;
1689- cache_params.max_warmup_steps = (idx >= 4 ) ? static_cast <int >(values[3 ]) : 8 ;
1690- if (cache_mode == " dbcache" ) {
1691- cache_params.mode = SD_CACHE_DBCACHE;
1692- } else if (cache_mode == " taylorseer" ) {
1693- cache_params.mode = SD_CACHE_TAYLORSEER;
1694- } else {
1695- cache_params.mode = SD_CACHE_CACHE_DIT;
1696- }
1697- }
16981735 }
16991736
17001737 if (cache_params.mode == SD_CACHE_DBCACHE ||
0 commit comments