@@ -2600,6 +2600,8 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
26002600 sample_params->scheduler = SCHEDULER_COUNT;
26012601 sample_params->sample_method = SAMPLE_METHOD_COUNT;
26022602 sample_params->sample_steps = 20 ;
2603+ sample_params->custom_sigmas = nullptr ;
2604+ sample_params->custom_sigmas_count = 0 ;
26032605}
26042606
26052607char * sd_sample_params_to_str (const sd_sample_params_t * sample_params) {
@@ -3194,11 +3196,21 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
31943196 }
31953197 LOG_INFO (" sampling using %s method" , sampling_methods_str[sample_method]);
31963198
3197- int sample_steps = sd_img_gen_params->sample_params .sample_steps ;
3198- std::vector<float > sigmas = sd_ctx->sd ->denoiser ->get_sigmas (sample_steps,
3199- sd_ctx->sd ->get_image_seq_len (height, width),
3200- sd_img_gen_params->sample_params .scheduler ,
3201- sd_ctx->sd ->version );
3199+ int sample_steps = sd_img_gen_params->sample_params .sample_steps ;
3200+ std::vector<float > sigmas;
3201+ if (sd_img_gen_params->sample_params .custom_sigmas_count > 0 ) {
3202+ sigmas = std::vector<float >(sd_img_gen_params->sample_params .custom_sigmas ,
3203+ sd_img_gen_params->sample_params .custom_sigmas + sd_img_gen_params->sample_params .custom_sigmas_count );
3204+ if (sample_steps != sigmas.size () - 1 ) {
3205+ sample_steps = static_cast <int >(sigmas.size ()) - 1 ;
3206+ LOG_WARN (" sample_steps != custom_sigmas_count - 1, set sample_steps to %d" , sample_steps);
3207+ }
3208+ } else {
3209+ sigmas = sd_ctx->sd ->denoiser ->get_sigmas (sample_steps,
3210+ sd_ctx->sd ->get_image_seq_len (height, width),
3211+ sd_img_gen_params->sample_params .scheduler ,
3212+ sd_ctx->sd ->version );
3213+ }
32023214
32033215 ggml_tensor* init_latent = nullptr ;
32043216 ggml_tensor* concat_latent = nullptr ;
@@ -3461,7 +3473,29 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
34613473 if (high_noise_sample_steps > 0 ) {
34623474 total_steps += high_noise_sample_steps;
34633475 }
3464- std::vector<float > sigmas = sd_ctx->sd ->denoiser ->get_sigmas (total_steps, 0 , sd_vid_gen_params->sample_params .scheduler , sd_ctx->sd ->version );
3476+
3477+ std::vector<float > sigmas;
3478+ if (sd_vid_gen_params->sample_params .custom_sigmas_count > 0 ) {
3479+ sigmas = std::vector<float >(sd_vid_gen_params->sample_params .custom_sigmas ,
3480+ sd_vid_gen_params->sample_params .custom_sigmas + sd_vid_gen_params->sample_params .custom_sigmas_count );
3481+ if (total_steps != sigmas.size () - 1 ) {
3482+ total_steps = static_cast <int >(sigmas.size ()) - 1 ;
3483+ LOG_WARN (" total_steps != custom_sigmas_count - 1, set total_steps to %d" , total_steps);
3484+ if (sample_steps >= total_steps) {
3485+ sample_steps = total_steps;
3486+ LOG_WARN (" total_steps != custom_sigmas_count - 1, set sample_steps to %d" , sample_steps);
3487+ }
3488+ if (high_noise_sample_steps > 0 ) {
3489+ high_noise_sample_steps = total_steps - sample_steps;
3490+ LOG_WARN (" total_steps != custom_sigmas_count - 1, set high_noise_sample_steps to %d" , high_noise_sample_steps);
3491+ }
3492+ }
3493+ } else {
3494+ sigmas = sd_ctx->sd ->denoiser ->get_sigmas (total_steps,
3495+ 0 ,
3496+ sd_vid_gen_params->sample_params .scheduler ,
3497+ sd_ctx->sd ->version );
3498+ }
34653499
34663500 if (high_noise_sample_steps < 0 ) {
34673501 // timesteps ∝ sigmas for Flow models (like wan2.2 a14b)
@@ -3841,4 +3875,4 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
38413875 LOG_INFO (" generate_video completed in %.2fs" , (t5 - t0) * 1 .0f / 1000 );
38423876
38433877 return result_images;
3844- }
3878+ }
0 commit comments