From 67e873db898349152f2d9daee9effb1d794ab4da Mon Sep 17 00:00:00 2001
From: fszontagh <szf@fsociety.hu>
Date: Sat, 24 Feb 2024 14:15:35 +0000
Subject: [PATCH] progress handler

---
 util.cpp           |   11 +++++++++++
 stable-diffusion.h |    2 ++
 2 files changed, 13 insertions(+), 0 deletions(-)

diff --git a/stable-diffusion.h b/stable-diffusion.h
index 01ba332..984e28f 100644
--- a/stable-diffusion.h
+++ b/stable-diffusion.h
@@ -89,8 +89,10 @@
 };
 
 typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
+typedef void (*sd_progress_cb_t)(int step,int steps,float time, void* data);
 
 SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
+SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
 SD_API int32_t get_num_physical_cores();
 SD_API const char* sd_get_system_info();
 
diff --git a/util.cpp b/util.cpp
index f68607f..16d8118 100644
--- a/util.cpp
+++ b/util.cpp
@@ -160,6 +160,9 @@
     return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
 }
 
+static sd_progress_cb_t sd_progress_cb  = NULL;
+void* sd_progress_cb_data               = NULL;
+
 std::u32string utf8_to_utf32(const std::string& utf8_str) {
     std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
     return converter.from_bytes(utf8_str);
@@ -207,6 +210,10 @@
     if (step == 0) {
         return;
     }
+if (sd_progress_cb) {
+        sd_progress_cb(step,steps,time, sd_progress_cb_data);
+        return;
+    }    
     std::string progress = "  |";
     int max_progress     = 50;
     int32_t current      = (int32_t)(step * 1.f * max_progress / steps);
@@ -285,6 +292,10 @@
     sd_log_cb      = cb;
     sd_log_cb_data = data;
 }
+void sd_set_progress_callback(sd_progress_cb_t cb, void* data) {
+    sd_progress_cb      = cb;
+    sd_progress_cb_data = data;
+}
 
 const char* sd_get_system_info() {
     static char buffer[1024];

--
Gitblit v1.9.3