🌐 AI搜索 & 代理 主页
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/cli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ Generation Options:
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm],
default: discrete
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
--skip-layers layers to skip for SLG steps (default: [7,8,9])
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
Expand Down
12 changes: 10 additions & 2 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,15 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam
parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(ctx_params.sampler_rng_type)) + ", ";
}
parameter_string += "Sampler: " + std::string(sd_sample_method_name(gen_params.sample_params.sample_method));
if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) {
if (!gen_params.custom_sigmas.empty()) {
parameter_string += ", Custom Sigmas: [";
for (size_t i = 0; i < gen_params.custom_sigmas.size(); ++i) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(4) << gen_params.custom_sigmas[i];
parameter_string += oss.str() + (i == gen_params.custom_sigmas.size() - 1 ? "" : ", ");
}
parameter_string += "]";
} else if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) { // Only show schedule if not using custom sigmas
parameter_string += " " + std::string(sd_scheduler_name(gen_params.sample_params.scheduler));
}
parameter_string += ", ";
Expand Down Expand Up @@ -806,4 +814,4 @@ int main(int argc, const char* argv[]) {
release_all_resources();

return 0;
}
}
46 changes: 46 additions & 0 deletions examples/common/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,8 @@ struct SDGenerationParams {
std::vector<int> high_noise_skip_layers = {7, 8, 9};
sd_sample_params_t high_noise_sample_params;

std::vector<float> custom_sigmas;

std::string easycache_option;
sd_easycache_params_t easycache_params;

Expand Down Expand Up @@ -1201,6 +1203,43 @@ struct SDGenerationParams {
return 1;
};

auto on_sigmas_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
std::string sigmas_str = argv[index];
if (!sigmas_str.empty() && sigmas_str.front() == '[') {
sigmas_str.erase(0, 1);
}
if (!sigmas_str.empty() && sigmas_str.back() == ']') {
sigmas_str.pop_back();
}

std::stringstream ss(sigmas_str);
std::string item;
while (std::getline(ss, item, ',')) {
item.erase(0, item.find_first_not_of(" \t\n\r\f\v"));
item.erase(item.find_last_not_of(" \t\n\r\f\v") + 1);
if (!item.empty()) {
try {
custom_sigmas.push_back(std::stof(item));
} catch (const std::invalid_argument& e) {
fprintf(stderr, "error: invalid float value '%s' in --sigmas\n", item.c_str());
return -1;
} catch (const std::out_of_range& e) {
fprintf(stderr, "error: float value '%s' out of range in --sigmas\n", item.c_str());
return -1;
}
}
}

if (custom_sigmas.empty() && !sigmas_str.empty()) {
fprintf(stderr, "error: could not parse any sigma values from '%s'\n", argv[index]);
return -1;
}
return 1;
};

auto on_ref_image_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
Expand Down Expand Up @@ -1260,6 +1299,10 @@ struct SDGenerationParams {
"--scheduler",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete",
on_scheduler_arg},
{"",
"--sigmas",
"custom sigma values for the sampler, comma-separated (e.g., \"14.61,7.8,3.5,0.0\").",
on_sigmas_arg},
{"",
"--skip-layers",
"layers to skip for SLG steps (default: [7,8,9])",
Expand Down Expand Up @@ -1509,6 +1552,8 @@ struct SDGenerationParams {

sample_params.guidance.slg.layers = skip_layers.data();
sample_params.guidance.slg.layer_count = skip_layers.size();
sample_params.custom_sigmas = custom_sigmas.data();
sample_params.custom_sigmas_count = static_cast<int>(custom_sigmas.size());
high_noise_sample_params.guidance.slg.layers = high_noise_skip_layers.data();
high_noise_sample_params.guidance.slg.layer_count = high_noise_skip_layers.size();

Expand Down Expand Up @@ -1603,6 +1648,7 @@ struct SDGenerationParams {
<< " sample_params: " << sample_params_str << ",\n"
<< " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n"
<< " high_noise_sample_params: " << high_noise_sample_params_str << ",\n"
<< " custom_sigmas: " << vec_to_string(custom_sigmas) << ",\n"
<< " easycache_option: \"" << easycache_option << "\",\n"
<< " easycache: "
<< (easycache_params.enabled ? "enabled" : "disabled")
Expand Down
1 change: 1 addition & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ Default Generation Options:
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm],
default: discrete
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
--skip-layers layers to skip for SLG steps (default: [7,8,9])
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
Expand Down
48 changes: 41 additions & 7 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2600,6 +2600,8 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
sample_params->scheduler = SCHEDULER_COUNT;
sample_params->sample_method = SAMPLE_METHOD_COUNT;
sample_params->sample_steps = 20;
sample_params->custom_sigmas = nullptr;
sample_params->custom_sigmas_count = 0;
}

char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
Expand Down Expand Up @@ -3194,11 +3196,21 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
}
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);

int sample_steps = sd_img_gen_params->sample_params.sample_steps;
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps,
sd_ctx->sd->get_image_seq_len(height, width),
sd_img_gen_params->sample_params.scheduler,
sd_ctx->sd->version);
int sample_steps = sd_img_gen_params->sample_params.sample_steps;
std::vector<float> sigmas;
if (sd_img_gen_params->sample_params.custom_sigmas_count > 0) {
sigmas = std::vector<float>(sd_img_gen_params->sample_params.custom_sigmas,
sd_img_gen_params->sample_params.custom_sigmas + sd_img_gen_params->sample_params.custom_sigmas_count);
if (sample_steps != sigmas.size() - 1) {
sample_steps = static_cast<int>(sigmas.size()) - 1;
LOG_WARN("sample_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps);
}
} else {
sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps,
sd_ctx->sd->get_image_seq_len(height, width),
sd_img_gen_params->sample_params.scheduler,
sd_ctx->sd->version);
}

ggml_tensor* init_latent = nullptr;
ggml_tensor* concat_latent = nullptr;
Expand Down Expand Up @@ -3461,7 +3473,29 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
if (high_noise_sample_steps > 0) {
total_steps += high_noise_sample_steps;
}
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps, 0, sd_vid_gen_params->sample_params.scheduler, sd_ctx->sd->version);

std::vector<float> sigmas;
if (sd_vid_gen_params->sample_params.custom_sigmas_count > 0) {
sigmas = std::vector<float>(sd_vid_gen_params->sample_params.custom_sigmas,
sd_vid_gen_params->sample_params.custom_sigmas + sd_vid_gen_params->sample_params.custom_sigmas_count);
if (total_steps != sigmas.size() - 1) {
total_steps = static_cast<int>(sigmas.size()) - 1;
LOG_WARN("total_steps != custom_sigmas_count - 1, set total_steps to %d", total_steps);
if (sample_steps >= total_steps) {
sample_steps = total_steps;
LOG_WARN("total_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps);
}
if (high_noise_sample_steps > 0) {
high_noise_sample_steps = total_steps - sample_steps;
LOG_WARN("total_steps != custom_sigmas_count - 1, set high_noise_sample_steps to %d", high_noise_sample_steps);
}
}
} else {
sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps,
0,
sd_vid_gen_params->sample_params.scheduler,
sd_ctx->sd->version);
}

if (high_noise_sample_steps < 0) {
// timesteps ∝ sigmas for Flow models (like wan2.2 a14b)
Expand Down Expand Up @@ -3841,4 +3875,4 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
LOG_INFO("generate_video completed in %.2fs", (t5 - t0) * 1.0f / 1000);

return result_images;
}
}
2 changes: 2 additions & 0 deletions stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ typedef struct {
int sample_steps;
float eta;
int shifted_timestep;
float* custom_sigmas;
int custom_sigmas_count;
} sd_sample_params_t;

typedef struct {
Expand Down
Loading