diff --git a/examples/cli/README.md b/examples/cli/README.md index f6a427851..02650f703 100644 --- a/examples/cli/README.md +++ b/examples/cli/README.md @@ -121,6 +121,7 @@ Generation Options: ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete + --sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0"). --skip-layers layers to skip for SLG steps (default: [7,8,9]) --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) -r, --ref-image reference image for Flux Kontext models (can be used multiple times) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index eaa2591e6..417d211aa 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -258,7 +258,15 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(ctx_params.sampler_rng_type)) + ", "; } parameter_string += "Sampler: " + std::string(sd_sample_method_name(gen_params.sample_params.sample_method)); - if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) { + if (!gen_params.custom_sigmas.empty()) { + parameter_string += ", Custom Sigmas: ["; + for (size_t i = 0; i < gen_params.custom_sigmas.size(); ++i) { + std::ostringstream oss; + oss << std::fixed << std::setprecision(4) << gen_params.custom_sigmas[i]; + parameter_string += oss.str() + (i == gen_params.custom_sigmas.size() - 1 ? "" : ", "); + } + parameter_string += "]"; + } else if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) { // Only show schedule if not using custom sigmas parameter_string += " " + std::string(sd_scheduler_name(gen_params.sample_params.scheduler)); } parameter_string += ", "; @@ -806,4 +814,4 @@ int main(int argc, const char* argv[]) { release_all_resources(); return 0; -} +} \ No newline at end of file diff --git a/examples/common/common.hpp b/examples/common/common.hpp index 558817eea..9d9a5b685 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -883,6 +883,8 @@ struct SDGenerationParams { std::vector high_noise_skip_layers = {7, 8, 9}; sd_sample_params_t high_noise_sample_params; + std::vector custom_sigmas; + std::string easycache_option; sd_easycache_params_t easycache_params; @@ -1201,6 +1203,43 @@ struct SDGenerationParams { return 1; }; + auto on_sigmas_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + std::string sigmas_str = argv[index]; + if (!sigmas_str.empty() && sigmas_str.front() == '[') { + sigmas_str.erase(0, 1); + } + if (!sigmas_str.empty() && sigmas_str.back() == ']') { + sigmas_str.pop_back(); + } + + std::stringstream ss(sigmas_str); + std::string item; + while (std::getline(ss, item, ',')) { + item.erase(0, item.find_first_not_of(" \t\n\r\f\v")); + item.erase(item.find_last_not_of(" \t\n\r\f\v") + 1); + if (!item.empty()) { + try { + custom_sigmas.push_back(std::stof(item)); + } catch (const std::invalid_argument& e) { + fprintf(stderr, "error: invalid float value '%s' in --sigmas\n", item.c_str()); + return -1; + } catch (const std::out_of_range& e) { + fprintf(stderr, "error: float value '%s' out of range in --sigmas\n", item.c_str()); + return -1; + } + } + } + + if (custom_sigmas.empty() && !sigmas_str.empty()) { + fprintf(stderr, "error: could not parse any sigma values from '%s'\n", argv[index]); + return -1; + } + return 1; + }; + auto on_ref_image_arg = [&](int argc, const char** argv, int index) { if (++index >= argc) { return -1; @@ -1260,6 +1299,10 @@ struct SDGenerationParams { "--scheduler", "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete", on_scheduler_arg}, + {"", + "--sigmas", + "custom sigma values for the sampler, comma-separated (e.g., \"14.61,7.8,3.5,0.0\").", + on_sigmas_arg}, {"", "--skip-layers", "layers to skip for SLG steps (default: [7,8,9])", @@ -1509,6 +1552,8 @@ struct SDGenerationParams { sample_params.guidance.slg.layers = skip_layers.data(); sample_params.guidance.slg.layer_count = skip_layers.size(); + sample_params.custom_sigmas = custom_sigmas.data(); + sample_params.custom_sigmas_count = static_cast(custom_sigmas.size()); high_noise_sample_params.guidance.slg.layers = high_noise_skip_layers.data(); high_noise_sample_params.guidance.slg.layer_count = high_noise_skip_layers.size(); @@ -1603,6 +1648,7 @@ struct SDGenerationParams { << " sample_params: " << sample_params_str << ",\n" << " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n" << " high_noise_sample_params: " << high_noise_sample_params_str << ",\n" + << " custom_sigmas: " << vec_to_string(custom_sigmas) << ",\n" << " easycache_option: \"" << easycache_option << "\",\n" << " easycache: " << (easycache_params.enabled ? "enabled" : "disabled") diff --git a/examples/server/README.md b/examples/server/README.md index 6393d841d..43c5d5f57 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -115,6 +115,7 @@ Default Generation Options: ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete + --sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0"). --skip-layers layers to skip for SLG steps (default: [7,8,9]) --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) -r, --ref-image reference image for Flux Kontext models (can be used multiple times) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 1ef851247..2cb588213 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -2600,6 +2600,8 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) { sample_params->scheduler = SCHEDULER_COUNT; sample_params->sample_method = SAMPLE_METHOD_COUNT; sample_params->sample_steps = 20; + sample_params->custom_sigmas = nullptr; + sample_params->custom_sigmas_count = 0; } char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) { @@ -3194,11 +3196,21 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g } LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); - int sample_steps = sd_img_gen_params->sample_params.sample_steps; - std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps, - sd_ctx->sd->get_image_seq_len(height, width), - sd_img_gen_params->sample_params.scheduler, - sd_ctx->sd->version); + int sample_steps = sd_img_gen_params->sample_params.sample_steps; + std::vector sigmas; + if (sd_img_gen_params->sample_params.custom_sigmas_count > 0) { + sigmas = std::vector(sd_img_gen_params->sample_params.custom_sigmas, + sd_img_gen_params->sample_params.custom_sigmas + sd_img_gen_params->sample_params.custom_sigmas_count); + if (sample_steps != sigmas.size() - 1) { + sample_steps = static_cast(sigmas.size()) - 1; + LOG_WARN("sample_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps); + } + } else { + sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps, + sd_ctx->sd->get_image_seq_len(height, width), + sd_img_gen_params->sample_params.scheduler, + sd_ctx->sd->version); + } ggml_tensor* init_latent = nullptr; ggml_tensor* concat_latent = nullptr; @@ -3461,7 +3473,29 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s if (high_noise_sample_steps > 0) { total_steps += high_noise_sample_steps; } - std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps, 0, sd_vid_gen_params->sample_params.scheduler, sd_ctx->sd->version); + + std::vector sigmas; + if (sd_vid_gen_params->sample_params.custom_sigmas_count > 0) { + sigmas = std::vector(sd_vid_gen_params->sample_params.custom_sigmas, + sd_vid_gen_params->sample_params.custom_sigmas + sd_vid_gen_params->sample_params.custom_sigmas_count); + if (total_steps != sigmas.size() - 1) { + total_steps = static_cast(sigmas.size()) - 1; + LOG_WARN("total_steps != custom_sigmas_count - 1, set total_steps to %d", total_steps); + if (sample_steps >= total_steps) { + sample_steps = total_steps; + LOG_WARN("total_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps); + } + if (high_noise_sample_steps > 0) { + high_noise_sample_steps = total_steps - sample_steps; + LOG_WARN("total_steps != custom_sigmas_count - 1, set high_noise_sample_steps to %d", high_noise_sample_steps); + } + } + } else { + sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps, + 0, + sd_vid_gen_params->sample_params.scheduler, + sd_ctx->sd->version); + } if (high_noise_sample_steps < 0) { // timesteps ∝ sigmas for Flow models (like wan2.2 a14b) @@ -3841,4 +3875,4 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s LOG_INFO("generate_video completed in %.2fs", (t5 - t0) * 1.0f / 1000); return result_images; -} +} \ No newline at end of file diff --git a/stable-diffusion.h b/stable-diffusion.h index 2da70bd77..e4abc8dcd 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -225,6 +225,8 @@ typedef struct { int sample_steps; float eta; int shifted_timestep; + float* custom_sigmas; + int custom_sigmas_count; } sd_sample_params_t; typedef struct {