From 67e873db898349152f2d9daee9effb1d794ab4da Mon Sep 17 00:00:00 2001
From: fszontagh <szf@fsociety.hu>
Date: Sat, 24 Feb 2024 14:15:35 +0000
Subject: [PATCH] progress handler
---
util.cpp | 11 +++++++++++
stable-diffusion.h | 2 ++
2 files changed, 13 insertions(+), 0 deletions(-)
diff --git a/stable-diffusion.h b/stable-diffusion.h
index 01ba332..984e28f 100644
--- a/stable-diffusion.h
+++ b/stable-diffusion.h
@@ -89,8 +89,10 @@
};
typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
+typedef void (*sd_progress_cb_t)(int step,int steps,float time, void* data);
SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
+SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
SD_API int32_t get_num_physical_cores();
SD_API const char* sd_get_system_info();
diff --git a/util.cpp b/util.cpp
index f68607f..16d8118 100644
--- a/util.cpp
+++ b/util.cpp
@@ -160,6 +160,9 @@
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
}
+static sd_progress_cb_t sd_progress_cb = NULL;
+void* sd_progress_cb_data = NULL;
+
std::u32string utf8_to_utf32(const std::string& utf8_str) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
return converter.from_bytes(utf8_str);
@@ -207,6 +210,10 @@
if (step == 0) {
return;
}
+if (sd_progress_cb) {
+ sd_progress_cb(step,steps,time, sd_progress_cb_data);
+ return;
+ }
std::string progress = " |";
int max_progress = 50;
int32_t current = (int32_t)(step * 1.f * max_progress / steps);
@@ -285,6 +292,10 @@
sd_log_cb = cb;
sd_log_cb_data = data;
}
+void sd_set_progress_callback(sd_progress_cb_t cb, void* data) {
+ sd_progress_cb = cb;
+ sd_progress_cb_data = data;
+}
const char* sd_get_system_info() {
static char buffer[1024];
--
Gitblit v1.9.3