🌐 AI搜索 & 代理 主页
Skip to content

Commit 8f05f5b

Browse files
rmatifleejet
andauthored
feat: add support for custom scheduler (#694)
--------- Co-authored-by: leejet <leejet714@gmail.com>
1 parent 15d0f82 commit 8f05f5b

File tree

6 files changed

+101
-9
lines changed

6 files changed

+101
-9
lines changed

examples/cli/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ Generation Options:
121121
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
122122
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm],
123123
default: discrete
124+
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
124125
--skip-layers layers to skip for SLG steps (default: [7,8,9])
125126
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
126127
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)

examples/cli/main.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,15 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam
258258
parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(ctx_params.sampler_rng_type)) + ", ";
259259
}
260260
parameter_string += "Sampler: " + std::string(sd_sample_method_name(gen_params.sample_params.sample_method));
261-
if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) {
261+
if (!gen_params.custom_sigmas.empty()) {
262+
parameter_string += ", Custom Sigmas: [";
263+
for (size_t i = 0; i < gen_params.custom_sigmas.size(); ++i) {
264+
std::ostringstream oss;
265+
oss << std::fixed << std::setprecision(4) << gen_params.custom_sigmas[i];
266+
parameter_string += oss.str() + (i == gen_params.custom_sigmas.size() - 1 ? "" : ", ");
267+
}
268+
parameter_string += "]";
269+
} else if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) { // Only show schedule if not using custom sigmas
262270
parameter_string += " " + std::string(sd_scheduler_name(gen_params.sample_params.scheduler));
263271
}
264272
parameter_string += ", ";
@@ -806,4 +814,4 @@ int main(int argc, const char* argv[]) {
806814
release_all_resources();
807815

808816
return 0;
809-
}
817+
}

examples/common/common.hpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,8 @@ struct SDGenerationParams {
883883
std::vector<int> high_noise_skip_layers = {7, 8, 9};
884884
sd_sample_params_t high_noise_sample_params;
885885

886+
std::vector<float> custom_sigmas;
887+
886888
std::string easycache_option;
887889
sd_easycache_params_t easycache_params;
888890

@@ -1201,6 +1203,43 @@ struct SDGenerationParams {
12011203
return 1;
12021204
};
12031205

1206+
auto on_sigmas_arg = [&](int argc, const char** argv, int index) {
1207+
if (++index >= argc) {
1208+
return -1;
1209+
}
1210+
std::string sigmas_str = argv[index];
1211+
if (!sigmas_str.empty() && sigmas_str.front() == '[') {
1212+
sigmas_str.erase(0, 1);
1213+
}
1214+
if (!sigmas_str.empty() && sigmas_str.back() == ']') {
1215+
sigmas_str.pop_back();
1216+
}
1217+
1218+
std::stringstream ss(sigmas_str);
1219+
std::string item;
1220+
while (std::getline(ss, item, ',')) {
1221+
item.erase(0, item.find_first_not_of(" \t\n\r\f\v"));
1222+
item.erase(item.find_last_not_of(" \t\n\r\f\v") + 1);
1223+
if (!item.empty()) {
1224+
try {
1225+
custom_sigmas.push_back(std::stof(item));
1226+
} catch (const std::invalid_argument& e) {
1227+
fprintf(stderr, "error: invalid float value '%s' in --sigmas\n", item.c_str());
1228+
return -1;
1229+
} catch (const std::out_of_range& e) {
1230+
fprintf(stderr, "error: float value '%s' out of range in --sigmas\n", item.c_str());
1231+
return -1;
1232+
}
1233+
}
1234+
}
1235+
1236+
if (custom_sigmas.empty() && !sigmas_str.empty()) {
1237+
fprintf(stderr, "error: could not parse any sigma values from '%s'\n", argv[index]);
1238+
return -1;
1239+
}
1240+
return 1;
1241+
};
1242+
12041243
auto on_ref_image_arg = [&](int argc, const char** argv, int index) {
12051244
if (++index >= argc) {
12061245
return -1;
@@ -1260,6 +1299,10 @@ struct SDGenerationParams {
12601299
"--scheduler",
12611300
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete",
12621301
on_scheduler_arg},
1302+
{"",
1303+
"--sigmas",
1304+
"custom sigma values for the sampler, comma-separated (e.g., \"14.61,7.8,3.5,0.0\").",
1305+
on_sigmas_arg},
12631306
{"",
12641307
"--skip-layers",
12651308
"layers to skip for SLG steps (default: [7,8,9])",
@@ -1512,6 +1555,8 @@ struct SDGenerationParams {
15121555

15131556
sample_params.guidance.slg.layers = skip_layers.data();
15141557
sample_params.guidance.slg.layer_count = skip_layers.size();
1558+
sample_params.custom_sigmas = custom_sigmas.data();
1559+
sample_params.custom_sigmas_count = static_cast<int>(custom_sigmas.size());
15151560
high_noise_sample_params.guidance.slg.layers = high_noise_skip_layers.data();
15161561
high_noise_sample_params.guidance.slg.layer_count = high_noise_skip_layers.size();
15171562

@@ -1606,6 +1651,7 @@ struct SDGenerationParams {
16061651
<< " sample_params: " << sample_params_str << ",\n"
16071652
<< " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n"
16081653
<< " high_noise_sample_params: " << high_noise_sample_params_str << ",\n"
1654+
<< " custom_sigmas: " << vec_to_string(custom_sigmas) << ",\n"
16091655
<< " easycache_option: \"" << easycache_option << "\",\n"
16101656
<< " easycache: "
16111657
<< (easycache_params.enabled ? "enabled" : "disabled")

examples/server/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ Default Generation Options:
115115
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
116116
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm],
117117
default: discrete
118+
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
118119
--skip-layers layers to skip for SLG steps (default: [7,8,9])
119120
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
120121
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)

stable-diffusion.cpp

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2600,6 +2600,8 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
26002600
sample_params->scheduler = SCHEDULER_COUNT;
26012601
sample_params->sample_method = SAMPLE_METHOD_COUNT;
26022602
sample_params->sample_steps = 20;
2603+
sample_params->custom_sigmas = nullptr;
2604+
sample_params->custom_sigmas_count = 0;
26032605
}
26042606

26052607
char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
@@ -3194,11 +3196,21 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
31943196
}
31953197
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
31963198

3197-
int sample_steps = sd_img_gen_params->sample_params.sample_steps;
3198-
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps,
3199-
sd_ctx->sd->get_image_seq_len(height, width),
3200-
sd_img_gen_params->sample_params.scheduler,
3201-
sd_ctx->sd->version);
3199+
int sample_steps = sd_img_gen_params->sample_params.sample_steps;
3200+
std::vector<float> sigmas;
3201+
if (sd_img_gen_params->sample_params.custom_sigmas_count > 0) {
3202+
sigmas = std::vector<float>(sd_img_gen_params->sample_params.custom_sigmas,
3203+
sd_img_gen_params->sample_params.custom_sigmas + sd_img_gen_params->sample_params.custom_sigmas_count);
3204+
if (sample_steps != sigmas.size() - 1) {
3205+
sample_steps = static_cast<int>(sigmas.size()) - 1;
3206+
LOG_WARN("sample_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps);
3207+
}
3208+
} else {
3209+
sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps,
3210+
sd_ctx->sd->get_image_seq_len(height, width),
3211+
sd_img_gen_params->sample_params.scheduler,
3212+
sd_ctx->sd->version);
3213+
}
32023214

32033215
ggml_tensor* init_latent = nullptr;
32043216
ggml_tensor* concat_latent = nullptr;
@@ -3461,7 +3473,29 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
34613473
if (high_noise_sample_steps > 0) {
34623474
total_steps += high_noise_sample_steps;
34633475
}
3464-
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps, 0, sd_vid_gen_params->sample_params.scheduler, sd_ctx->sd->version);
3476+
3477+
std::vector<float> sigmas;
3478+
if (sd_vid_gen_params->sample_params.custom_sigmas_count > 0) {
3479+
sigmas = std::vector<float>(sd_vid_gen_params->sample_params.custom_sigmas,
3480+
sd_vid_gen_params->sample_params.custom_sigmas + sd_vid_gen_params->sample_params.custom_sigmas_count);
3481+
if (total_steps != sigmas.size() - 1) {
3482+
total_steps = static_cast<int>(sigmas.size()) - 1;
3483+
LOG_WARN("total_steps != custom_sigmas_count - 1, set total_steps to %d", total_steps);
3484+
if (sample_steps >= total_steps) {
3485+
sample_steps = total_steps;
3486+
LOG_WARN("total_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps);
3487+
}
3488+
if (high_noise_sample_steps > 0) {
3489+
high_noise_sample_steps = total_steps - sample_steps;
3490+
LOG_WARN("total_steps != custom_sigmas_count - 1, set high_noise_sample_steps to %d", high_noise_sample_steps);
3491+
}
3492+
}
3493+
} else {
3494+
sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps,
3495+
0,
3496+
sd_vid_gen_params->sample_params.scheduler,
3497+
sd_ctx->sd->version);
3498+
}
34653499

34663500
if (high_noise_sample_steps < 0) {
34673501
// timesteps ∝ sigmas for Flow models (like wan2.2 a14b)
@@ -3841,4 +3875,4 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
38413875
LOG_INFO("generate_video completed in %.2fs", (t5 - t0) * 1.0f / 1000);
38423876

38433877
return result_images;
3844-
}
3878+
}

stable-diffusion.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ typedef struct {
225225
int sample_steps;
226226
float eta;
227227
int shifted_timestep;
228+
float* custom_sigmas;
229+
int custom_sigmas_count;
228230
} sd_sample_params_t;
229231

230232
typedef struct {

0 commit comments

Comments
 (0)