From 2088e1b7aa6419dec58800bc1d0cb24f5808affe Mon Sep 17 00:00:00 2001
From: Ferenc Szontágh <szf@fsociety.hu>
Date: Sat, 03 Feb 2024 21:19:43 +0000
Subject: [PATCH] added queue handler against simple start
---
ui/MainWindowUI.cpp | 331 ++++++++++++++++++++++++++++++++-----------------------
1 files changed, 193 insertions(+), 138 deletions(-)
diff --git a/ui/MainWindowUI.cpp b/ui/MainWindowUI.cpp
index 09a1143..a6b39c7 100644
--- a/ui/MainWindowUI.cpp
+++ b/ui/MainWindowUI.cpp
@@ -6,6 +6,7 @@
this->ini_path = wxStandardPaths::Get().GetUserConfigDir() + wxFileName::GetPathSeparator() + "sd.ui.config.ini";
this->sd_params = new sd_gui_utils::SDParams;
this->notification = new wxNotificationMessage();
+ this->JobTableItems = new std::map<int, QM::QueueItem>();
this->notification->SetParent(this);
@@ -15,7 +16,9 @@
this->m_joblist->AppendTextColumn("Id");
this->m_joblist->AppendTextColumn("Created");
- this->m_joblist->AppendTextColumn("Update");
+ this->m_joblist->AppendTextColumn("Model");
+ this->m_joblist->AppendTextColumn("Sampler");
+ this->m_joblist->AppendTextColumn("Seed");
this->m_joblist->AppendTextColumn("Status");
this->SetTitle(this->GetTitle() + SD_GUI_VERSION);
@@ -97,12 +100,6 @@
// add the selected vae
this->sd_params->vae_path = this->VaeFiles.at(selection.ToStdString());
}
- // just select the vae file and add the live paramaeters, do not load the model, just add job to the queu on press the queue button
- /* if (this->m_generate->IsEnabled() == true)
- {
- wxPostEvent(this->m_model, event);
- }
- */
}
void MainWindowUI::onSamplerSelect(wxCommandEvent &event)
@@ -146,10 +143,28 @@
void MainWindowUI::onGenerate(wxCommandEvent &event)
{
- // this->StartGeneration();
- // dont start the generation, just add the job to the queue, to handle it the queue manager
- // so it will possible to add more job afterall
- this->qmanager->AddItem(this->sd_params);
+ // prepare params
+ this->sd_params->model_path = this->ModelFiles.at(this->m_model->GetStringSelection().ToStdString());
+ this->sd_params->lora_model_dir = this->cfg->lora;
+ this->sd_params->embeddings_path = this->cfg->embedding;
+
+ this->sd_params->prompt = this->m_prompt->GetValue().ToStdString();
+ this->sd_params->negative_prompt = this->m_neg_prompt->GetValue().ToStdString();
+
+ this->sd_params->cfg_scale = static_cast<float>(this->m_cfg->GetValue());
+ this->sd_params->seed = this->m_seed->GetValue();
+ this->sd_params->clip_skip = this->m_clip_skip->GetValue();
+ this->sd_params->sample_steps = this->m_steps->GetValue();
+
+ this->sd_params->sample_method = (sample_method_t)this->m_sampler->GetCurrentSelection();
+
+ this->sd_params->batch_count = this->m_batch_count->GetValue();
+
+ this->sd_params->width = this->m_width->GetValue();
+ this->sd_params->height = this->m_height->GetValue();
+
+ // add the queue item
+ auto id = this->qmanager->AddItem(this->sd_params);
}
void MainWindowUI::onSavePreset(wxCommandEvent &event)
@@ -243,29 +258,30 @@
this->LoadPresets();
}
}
+/*
+ this->m_joblist->AppendTextColumn("Id");
+ this->m_joblist->AppendTextColumn("Created");
+ this->m_joblist->AppendTextColumn("Model");
+ this->m_joblist->AppendTextColumn("Sampler");
+ this->m_joblist->AppendTextColumn("Seed");
+ this->m_joblist->AppendTextColumn("Status");
+*/
void MainWindowUI::OnQueueItemManagerItemAdded(QM::QueueItem item)
{
wxVector<wxVariant> data;
- data.push_back(wxVariant(item.id));
- data.push_back(wxVariant(item.created_at));
- data.push_back(wxVariant(item.updated_at));
+ data.push_back(wxVariant(std::to_string(item.id)));
+ data.push_back(wxVariant(std::to_string(item.created_at)));
+ data.push_back(wxVariant(item.params.model_path));
+ data.push_back(wxVariant(sd_gui_utils::sample_method_str[(int)item.params.sample_method]));
+ data.push_back(wxVariant(std::to_string(item.params.seed)));
data.push_back(wxVariant(QM::QueueStatus_str[item.status]));
this->m_joblist->AppendItem(data);
- /*
- for (auto model : this->ModelFiles)
- {
- // auto size = sd_gui_utils::HumanReadable{std::filesystem::file_size(model.second)};
- auto size = std::filesystem::file_size(model.second);
- wxVector<wxVariant> data;
- data.push_back(wxVariant(model.first));
- data.push_back(wxVariant(std::to_string(size)));
- this->m_data_model_list->AppendItem(data);
- }
- this->m_data_model_list->Refresh();*/
+ auto store = this->m_joblist->GetStore();
+ (*this->JobTableItems)[store->GetCount() - 1] = item;
}
void MainWindowUI::LoadFileList(sd_gui_utils::DirTypes type)
@@ -396,22 +412,22 @@
void MainWindowUI::OnThreadMessage(wxThreadEvent &e)
{
+ e.Skip();
auto msg = e.GetString().ToStdString();
std::string token = msg.substr(0, msg.find(":"));
std::string content = msg.substr(msg.find(":") + 1);
- // this->logs->AppendText(fmt::format("Got thread message: {}\n", e.GetString().ToStdString()));
+ // this->logs->AppendText(fmt::format("Got thread message: {}\n", e.GetString().ToStdString()));
if (token == "QUEUE")
{
-
- m_statusBar166->SetStatusText("got QUEUE cmd: " + msg);
// only numbers here...
QM::QueueEvents event = (QM::QueueEvents)std::stoi(content);
// only handle the QUEUE messages, what this class generate
// alway QM::EueueItem the payload, with the new data
- auto payload = e.GetPayload<QM::QueueItem>();
+ QM::QueueItem payload;
+ payload = e.GetPayload<QM::QueueItem>();
switch (event)
{
// new item added
@@ -420,11 +436,14 @@
break;
// item status changed
case QM::QueueEvents::ITEM_STATUS_CHANGED:
- this->OnQueueItemManagerItemAdded(payload);
+ this->OnQueueItemManagerItemStatusChanged(payload);
break;
// item updated... ? ? ?
case QM::QueueEvents::ITEM_UPDATED:
- this->OnQueueItemManagerItemAdded(payload);
+ this->OnQueueItemManagerItemUpdated(payload);
+ break;
+ case QM::QueueEvents::ITEM_START:
+ this->StartGeneration(payload);
break;
default:
@@ -433,12 +452,13 @@
}
if (token == "MODEL_LOAD_DONE")
{
- this->m_generate->Enable();
- this->m_model->Enable();
- this->m_vae->Enable();
- this->m_refresh->Enable();
+ // this->m_generate->Enable();
+ // this->m_model->Enable();
+ // this->m_vae->Enable();
+ // this->m_refresh->Enable();
- this->logs->AppendText(fmt::format("Model loaded: {}\n", content));
+ // this->logs->AppendText(fmt::format("Model loaded: {}\n", content));
+ this->modelLoaded = true;
this->sd_ctx = e.GetPayload<sd_ctx_t *>();
if (!this->IsShownOnScreen())
{
@@ -450,19 +470,20 @@
}
if (token == "MODEL_LOAD_START")
{
- this->m_generate->Disable();
- this->m_model->Disable();
- this->m_vae->Disable();
- this->m_refresh->Disable();
+ // this->m_generate->Disable();
+ // this->m_model->Disable();
+ // this->m_vae->Disable();
+ // this->m_refresh->Disable();
this->logs->AppendText(fmt::format("Model load start: {}\n", content));
}
if (token == "MODEL_LOAD_ERROR")
{
- this->m_generate->Disable();
- this->m_model->Enable();
- this->m_vae->Disable();
- this->m_refresh->Enable();
+ // this->m_generate->Disable();
+ // this->m_model->Enable();
+ // this->m_vae->Disable();
+ // this->m_refresh->Enable();
this->logs->AppendText(fmt::format("Model load error: {}\n", content));
+ this->modelLoaded = false;
if (!this->IsShownOnScreen())
{
this->notification->SetFlags(wxICON_ERROR);
@@ -474,62 +495,62 @@
if (token == "GENERATION_START")
{
- sd_gui_utils::SDParams *params = e.GetPayload<sd_gui_utils::SDParams *>();
+ auto myjob = e.GetPayload<QM::QueueItem>();
- this->m_generate->Disable();
- this->m_model->Disable();
- this->m_vae->Disable();
- this->m_refresh->Disable();
- this->logs->AppendText(fmt::format("Difusion started. Seed: {} Batch: {} {}x{}px Cfg: {} Steps: {}\n",
- params->seed,
- params->batch_count,
- params->width,
- params->height,
- params->cfg_scale,
- params->sample_steps));
+ // this->m_generate->Disable();
+ // this->m_model->Disable();
+ // this->m_vae->Disable();
+ // this->m_refresh->Disable();
+ this->logs->AppendText(fmt::format("Diffusion started. Seed: {} Batch: {} {}x{}px Cfg: {} Steps: {}\n",
+ myjob.params.seed,
+ myjob.params.batch_count,
+ myjob.params.width,
+ myjob.params.height,
+ myjob.params.cfg_scale,
+ myjob.params.sample_steps));
}
// never, not implemented in sd.cpp
if (token == "GENERATION_PROGRESS")
{
- this->m_generate->Disable();
- this->logs->AppendText(fmt::format("Generation progress: {}\n", content));
+ // this->m_generate->Disable();
+ // this->logs->AppendText(fmt::format("Generation progress: {}\n", content));
}
if (token == "GENERATION_DONE")
{
- this->m_generate->Enable();
- this->m_model->Enable();
- this->m_vae->Enable();
- this->m_refresh->Enable();
- sd_image_t *results = e.GetPayload<sd_image_t *>();
+ // this->m_generate->Enable();
+ // this->m_model->Enable();
+ // this->m_vae->Enable();
+ // this->m_refresh->Enable();
+ // sd_image_t *results = e.GetPayload<sd_image_t *>();
// show images in new window...
- for (int i = 0; i < this->sd_params->batch_count; i++)
- {
- MainWindowImageViewer *imgWindow = new MainWindowImageViewer(this);
- // wxBitmap *img = new wxBitmap(results[i].data, (int)results[i].width, (int)results[i].height, (int)results[i].channel);
- wxImage img(results[i].width, results[i].height, results[i].data);
+ /* for (int i = 0; i < this->sd_params->batch_count; i++)
+ {
+ MainWindowImageViewer *imgWindow = new MainWindowImageViewer(this);
+ // wxBitmap *img = new wxBitmap(results[i].data, (int)results[i].width, (int)results[i].height, (int)results[i].channel);
+ wxImage img(results[i].width, results[i].height, results[i].data);
- wxBitmapBundle wxBmapB(img);
- imgWindow->m_bitmap->SetBitmap(wxBmapB);
- imgWindow->m_bitmap->SetSize(results[i].width, results[i].height);
- imgWindow->SetSize(results[i].width + 200, results[i].height);
+ wxBitmapBundle wxBmapB(img);
+ imgWindow->m_bitmap->SetBitmap(wxBmapB);
+ imgWindow->m_bitmap->SetSize(results[i].width, results[i].height);
+ imgWindow->SetSize(results[i].width + 200, results[i].height);
- std::string details = fmt::format("Prompt:\n\n{}\n\nNegative prompt: \n\n{}\n\nSeed: {} \nCfg scale: {}\nClip skip: {}\nSampler: {}\nSteps: {}\nWidth: {} Height: {}",
- this->sd_params->prompt, this->sd_params->negative_prompt,
- this->sd_params->seed + i, this->sd_params->cfg_scale,
- this->sd_params->clip_skip, sd_gui_utils::sample_method_str[this->sd_params->sample_method], this->sd_params->sample_steps,
- results[i].width, results[i].height);
- imgWindow->m_textCtrl4->AppendText(wxString(details));
- imgWindow->Show();
+ std::string details = fmt::format("Prompt:\n\n{}\n\nNegative prompt: \n\n{}\n\nSeed: {} \nCfg scale: {}\nClip skip: {}\nSampler: {}\nSteps: {}\nWidth: {} Height: {}",
+ this->sd_params->prompt, this->sd_params->negative_prompt,
+ this->sd_params->seed + i, this->sd_params->cfg_scale,
+ this->sd_params->clip_skip, sd_gui_utils::sample_method_str[this->sd_params->sample_method], this->sd_params->sample_steps,
+ results[i].width, results[i].height);
+ imgWindow->m_textCtrl4->AppendText(wxString(details));
+ imgWindow->Show();
- // imgWindow->m_bitmap->SetBitmap(img);
- /// imgWindow->m_bitmap->Set
- }
+ // imgWindow->m_bitmap->SetBitmap(img);
+ /// imgWindow->m_bitmap->Set
+ }*/
}
if (token == "GENERATION_ERROR")
{
- this->m_generate->Enable();
- this->m_model->Enable();
- this->m_vae->Enable();
+ // this->m_generate->Enable();
+ // this->m_model->Enable();
+ // this->m_vae->Enable();
this->logs->AppendText(fmt::format("Generation error: {}\n", content));
if (!this->IsShownOnScreen())
{
@@ -560,10 +581,11 @@
}
}
-void MainWindowUI::LoadModel(wxEvtHandler *eventHandler)
+void MainWindowUI::LoadModel(wxEvtHandler *eventHandler, QM::QueueItem myItem)
{
wxThreadEvent *e = new wxThreadEvent();
e->SetString(wxString::Format("MODEL_LOAD_START:%s", this->sd_params->model_path));
+ e->SetPayload(myItem);
wxQueueEvent(eventHandler, e);
/*this->sd_ctx = new_sd_ctx(this->sd_params->model_path.c_str(),
@@ -580,22 +602,25 @@
this->sd_params->rng_type,
this->sd_params->schedule,
this->sd_params->control_net_cpu);*/
- this->sd_ctx = new_sd_ctx(this->sd_params->model_path.c_str(), this->sd_params->vae_path.c_str(), this->sd_params->taesd_path.c_str(), this->sd_params->lora_model_dir.c_str(), true, false, false, this->sd_params->n_threads, this->sd_params->wtype, this->sd_params->rng_type, this->sd_params->schedule);
+ // this->sd_ctx = new_sd_ctx(this->sd_params->model_path.c_str(), this->sd_params->vae_path.c_str(), this->sd_params->taesd_path.c_str(), this->sd_params->lora_model_dir.c_str(), true, false, false, this->sd_params->n_threads, this->sd_params->wtype, this->sd_params->rng_type, this->sd_params->schedule);
+ this->sd_ctx = new_sd_ctx(myItem.params.model_path.c_str(), myItem.params.vae_path.c_str(), myItem.params.taesd_path.c_str(), myItem.params.lora_model_dir.c_str(), true, false, false, myItem.params.n_threads, myItem.params.wtype, myItem.params.rng_type, myItem.params.schedule);
if (this->sd_ctx == NULL)
{
wxThreadEvent *c = new wxThreadEvent();
c->SetString(wxString::Format("MODEL_LOAD_ERROR:%s", this->sd_params->model_path));
+ c->SetPayload(myItem);
wxQueueEvent(eventHandler, c);
+ this->modelLoaded = false;
return;
}
else
{
wxThreadEvent *c = new wxThreadEvent();
c->SetString(wxString::Format("MODEL_LOAD_DONE:%s", this->sd_params->model_path));
- // c->SetEventObject(this->sd_ctx);
c->SetPayload(this->sd_ctx);
wxQueueEvent(eventHandler, c);
+ this->modelLoaded = true;
return;
}
@@ -639,7 +664,7 @@
// populate data from sd_params as default...
- if (!this->m_generate->IsEnabled())
+ if (!this->modelLoaded)
{
this->m_cfg->SetValue(static_cast<double>(this->sd_params->cfg_scale));
this->m_seed->SetValue(static_cast<int>(this->sd_params->seed));
@@ -649,38 +674,11 @@
this->m_height->SetValue(this->sd_params->height);
this->m_batch_count->SetValue(this->sd_params->batch_count);
}
- // hide unusable configs...
- /*
- if (SD_CPP_VERSION == "c6071fa") {
- // .. nope, configs in another window...
- }*/
}
-void MainWindowUI::StartGeneration()
+void MainWindowUI::StartGeneration(QM::QueueItem myJob)
{
-
- // prepare params
- this->sd_params->model_path = this->ModelFiles.at(this->m_model->GetStringSelection().ToStdString());
- this->sd_params->lora_model_dir = this->cfg->lora;
- this->sd_params->embeddings_path = this->cfg->embedding;
-
- this->sd_params->prompt = this->m_prompt->GetValue().ToStdString();
- this->sd_params->negative_prompt = this->m_neg_prompt->GetValue().ToStdString();
-
- this->sd_params->cfg_scale = static_cast<float>(this->m_cfg->GetValue());
- this->sd_params->seed = this->m_seed->GetValue();
- this->sd_params->clip_skip = this->m_clip_skip->GetValue();
- this->sd_params->sample_steps = this->m_steps->GetValue();
-
- /* sample method */
- this->sd_params->sample_method = (sample_method_t)this->m_sampler->GetCurrentSelection();
- /* sample method */
- this->sd_params->batch_count = this->m_batch_count->GetValue();
-
- this->sd_params->width = this->m_width->GetValue();
- this->sd_params->height = this->m_height->GetValue();
-
- this->threads.push_back(std::thread(std::bind(&MainWindowUI::Generate, this, this->GetEventHandler())));
+ this->threads.push_back(std::thread(std::bind(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob)));
}
void MainWindowUI::OnCloseSettings(wxCloseEvent &event)
@@ -717,31 +715,46 @@
this->m_data_model_list->Refresh();
}
-void MainWindowUI::Generate(wxEvtHandler *eventHandler)
+void MainWindowUI::Generate(wxEvtHandler *eventHandler, QM::QueueItem myItem)
{
- // calculate time
+ // @brief model loading is done in the same thread which generating stuffs... no need new thread
+ // @brief to load the model if no model loaded, or the loaded model is different from the job's model...
+ // @brief if all second job have different model, it will be slower than same model on all jobs
+
+ // TODO: abort job if model can not be loaded...
+ if (!this->modelLoaded)
+ {
+ this->LoadModel(eventHandler, myItem);
+ this->currentModel = myItem.params.model_path;
+ }
+ else
+ {
+ if (myItem.params.model_path != this->currentModel)
+ {
+ free_sd_ctx(this->sd_ctx);
+ this->LoadModel(eventHandler, myItem);
+ }
+ }
+ if (!this->modelLoaded)
+ {
+ wxThreadEvent *f = new wxThreadEvent();
+ f->SetString("GENERATION_ERROR:Model load failed...");
+ f->SetPayload(myItem);
+ wxQueueEvent(eventHandler, f);
+ return;
+ }
+
auto start = std::chrono::system_clock::now();
wxThreadEvent *e = new wxThreadEvent();
e->SetString(wxString::Format("GENERATION_START:%s", this->sd_params->model_path));
- e->SetPayload(this->sd_params);
+ e->SetPayload(myItem);
wxQueueEvent(eventHandler, e);
sd_image_t *control_image = NULL;
sd_image_t *results;
- /*results = txt2img(sd_ctx,
- this->sd_params->prompt.c_str(),
- this->sd_params->negative_prompt.c_str(),
- this->sd_params->clip_skip,
- this->sd_params->cfg_scale,
- this->sd_params->width,
- this->sd_params->height,
- this->sd_params->sample_method,
- this->sd_params->sample_steps,
- this->sd_params->seed,
- this->sd_params->batch_count,
- control_image,
- this->sd_params->control_strength);*/
+
+
results = txt2img(this->sd_ctx,
this->sd_params->prompt.c_str(), this->sd_params->negative_prompt.c_str(), this->sd_params->clip_skip, this->sd_params->cfg_scale, this->sd_params->width, this->sd_params->height, this->sd_params->sample_method, this->sd_params->sample_steps, this->sd_params->seed, this->sd_params->batch_count);
@@ -749,6 +762,7 @@
{
wxThreadEvent *f = new wxThreadEvent();
f->SetString("GENERATION_ERROR:Something wrong happened at image generation...");
+ f->SetPayload(myItem);
wxQueueEvent(eventHandler, f);
return;
}
@@ -785,7 +799,12 @@
{
wxThreadEvent *g = new wxThreadEvent();
g->SetString(wxString::Format("GENERATION_ERROR:Failed to save image into %s", filename));
+ g->SetPayload(myItem);
wxQueueEvent(eventHandler, g);
+ }
+ else
+ {
+ myItem.images.emplace_back(filename);
}
// handle data??
@@ -798,18 +817,29 @@
wxThreadEvent *h = new wxThreadEvent();
auto msg = fmt::format("MESSAGE:Image generation done in {}s. Saved into {}", elapsed_seconds.count(), this->cfg->output);
h->SetString(wxString(msg.c_str()));
+ h->SetPayload(myItem);
wxQueueEvent(eventHandler, h);
- // send to reset the buttons
wxThreadEvent *i = new wxThreadEvent();
i->SetString(wxString::Format("GENERATION_DONE:ok"));
i->SetPayload(results);
wxQueueEvent(eventHandler, i);
+
+ // send to the queue manager
+ wxThreadEvent *j = new wxThreadEvent();
+ j->SetString(wxString::Format("QUEUE:%d", QM::QueueEvents::ITEM_FINISHED));
+ j->SetPayload(myItem);
+ wxQueueEvent(eventHandler, j);
+
return;
}
void MainWindowUI::HandleSDLog(sd_log_level_t level, const char *text, void *data)
{
+ if (level != sd_log_level_t::SD_LOG_INFO)
+ {
+ return;
+ }
// wxEvtHandler *eventHandler = (wxEvtHandler *)data;
MainWindowUI *ui = (MainWindowUI *)data;
wxEvtHandler *eventHandler = ui->GetEventHandler();
@@ -821,12 +851,37 @@
void MainWindowUI::OnQueueItemManagerItemStatusChanged(QM::QueueItem item)
{
+ /*
+ this->m_joblist->AppendTextColumn("Id");
+ this->m_joblist->AppendTextColumn("Created");
+ this->m_joblist->AppendTextColumn("Model");
+ this->m_joblist->AppendTextColumn("Sampler");
+ this->m_joblist->AppendTextColumn("Seed");
+ this->m_joblist->AppendTextColumn("Status");
+
+ */
+ // TODO: how to update an item in the table...
+ // auto item = this->m_joblist->RowToItem();
+ auto store = this->m_joblist->GetStore();
+
+ for (auto it = this->JobTableItems->begin(); it != this->JobTableItems->end(); ++it)
+ {
+ if (it->second.id == item.id)
+ {
+ // auto store = *this->m_joblist()->GetStore();
+ int lastCol = this->m_joblist->GetColumnCount();
+ // always update the last col, so the last col is always need to be the status
+ store->SetValueByRow(wxVariant(QM::QueueStatus_str[item.status]), it->first, lastCol - 1);
+ this->m_joblist->Refresh();
+ break;
+ }
+ }
}
MainWindowUI::~MainWindowUI()
{
// clean up things...
- if (this->m_generate->IsEnabled())
+ if (this->modelLoaded)
{
free_sd_ctx(this->sd_ctx);
}
@@ -861,7 +916,7 @@
free_sd_ctx(this->sd_ctx);
return;
}
- if (this->m_generate->IsEnabled())
+ if (this->modelLoaded)
{
free_sd_ctx(this->sd_ctx);
}
@@ -871,5 +926,5 @@
this->m_refresh->Disable();
this->sd_params->model_path = this->ModelFiles.at(selection.ToStdString());
this->sd_params->lora_model_dir = this->fileConfig->Read("/paths/lora", "").ToStdString();
- this->threads.push_back(std::thread(std::bind(&MainWindowUI::LoadModel, this, this->GetEventHandler())));
+ // this->threads.push_back(std::thread(std::bind(&MainWindowUI::LoadModel, this, this->GetEventHandler())));
}
\ No newline at end of file
--
Gitblit v1.9.3