From 1683c8a090c3efc51c43107d5ede0dcd5d506e3b Mon Sep 17 00:00:00 2001
From: Ferenc Szontágh <szf@fsociety.hu>
Date: Sun, 04 Feb 2024 20:45:08 +0000
Subject: [PATCH] added better queue manager, some clean-up and new feature: model management (wip), progressbar
---
ui/MainWindowUI.cpp | 1161 ++++++++++++++++++++++++++++++---------------------------
1 files changed, 612 insertions(+), 549 deletions(-)
diff --git a/ui/MainWindowUI.cpp b/ui/MainWindowUI.cpp
index a6b39c7..9788cfe 100644
--- a/ui/MainWindowUI.cpp
+++ b/ui/MainWindowUI.cpp
@@ -6,20 +6,23 @@
this->ini_path = wxStandardPaths::Get().GetUserConfigDir() + wxFileName::GetPathSeparator() + "sd.ui.config.ini";
this->sd_params = new sd_gui_utils::SDParams;
this->notification = new wxNotificationMessage();
- this->JobTableItems = new std::map<int, QM::QueueItem>();
+ // this->JobTableItems = new std::map<int, QM::QueueItem>();
this->notification->SetParent(this);
// prepare data list views
this->m_data_model_list->AppendTextColumn("Name", wxDATAVIEW_CELL_INERT, 200);
this->m_data_model_list->AppendTextColumn("Size");
+ this->m_data_model_list->AppendTextColumn("Hash");
this->m_joblist->AppendTextColumn("Id");
- this->m_joblist->AppendTextColumn("Created");
+ this->m_joblist->AppendTextColumn("Created at");
this->m_joblist->AppendTextColumn("Model");
this->m_joblist->AppendTextColumn("Sampler");
this->m_joblist->AppendTextColumn("Seed");
- this->m_joblist->AppendTextColumn("Status");
+ this->m_joblist->AppendProgressColumn("Progress"); // progressbar
+ this->m_joblist->AppendTextColumn("Speed"); // speed
+ this->m_joblist->AppendTextColumn("Status"); // status
this->SetTitle(this->GetTitle() + SD_GUI_VERSION);
@@ -30,7 +33,7 @@
this->qmanager = new QM::QueueManager(this->GetEventHandler(), this->cfg->jobs);
// set SD logger
- sd_set_log_callback(MainWindowUI::HandleSDLog, (void *)this);
+ sd_set_log_callback(MainWindowUI::HandleSDLog, (void *)this->GetEventHandler());
// load
this->LoadPresets();
@@ -102,11 +105,6 @@
}
}
-void MainWindowUI::onSamplerSelect(wxCommandEvent &event)
-{
- this->sd_params->sample_method = (sample_method_t)this->m_sampler->GetSelection();
-}
-
void MainWindowUI::onResolutionSwap(wxCommandEvent &event)
{
auto oldW = this->m_width->GetValue();
@@ -136,6 +134,16 @@
// TODO: Implement onJoblistItemActivated
}
+void MainWindowUI::onContextMenu(wxDataViewEvent &event)
+{
+
+ wxMenu menu;
+
+ menu.Append(0, "Calculate Hash");
+ menu.Append(1, "Download info from CivitAi.com");
+ PopupMenu(&menu);
+}
+
void MainWindowUI::onJoblistSelectionChanged(wxDataViewEvent &event)
{
// TODO: Implement onJoblistSelectionChanged
@@ -162,9 +170,22 @@
this->sd_params->width = this->m_width->GetValue();
this->sd_params->height = this->m_height->GetValue();
+ QM::QueueItem item;
+ item.params = *this->sd_params;
+ item.model = this->m_model->GetStringSelection().ToStdString();
+
+ if (item.params.seed == -1)
+ {
+ item.params.seed = sd_gui_utils::generateRandomInt(100000000, 999999999);
+ }
// add the queue item
- auto id = this->qmanager->AddItem(this->sd_params);
+ auto id = this->qmanager->AddItem(item);
+}
+
+void MainWindowUI::onSamplerSelect(wxCommandEvent &event)
+{
+ this->sd_params->sample_method = (sample_method_t)this->m_sampler->GetSelection();
}
void MainWindowUI::onSavePreset(wxCommandEvent &event)
@@ -258,30 +279,596 @@
this->LoadPresets();
}
}
-/*
- this->m_joblist->AppendTextColumn("Id");
- this->m_joblist->AppendTextColumn("Created");
- this->m_joblist->AppendTextColumn("Model");
- this->m_joblist->AppendTextColumn("Sampler");
- this->m_joblist->AppendTextColumn("Seed");
- this->m_joblist->AppendTextColumn("Status");
-*/
+void MainWindowUI::HandleSDLog(sd_log_level_t level, const char *text, void *data)
+{
+ if (level == sd_log_level_t::SD_LOG_INFO || level == sd_log_level_t::SD_LOG_ERROR)
+ {
+ auto *eventHandler = (wxEvtHandler *)data;
+ wxThreadEvent *e = new wxThreadEvent();
+ e->SetString(wxString::Format("SD_MESSAGE:%s", text));
+ e->SetPayload(level);
+ wxQueueEvent(eventHandler, e);
+ }
+}
+
+void MainWindowUI::OnQueueItemManagerItemStatusChanged(QM::QueueItem item)
+{
+ auto store = this->m_joblist->GetStore();
+
+ int lastCol = this->m_joblist->GetColumnCount() - 1;
+
+ for (unsigned int i = 0; i < store->GetItemCount(); i++)
+ {
+ auto _item = store->GetItem(i);
+ auto _item_data = store->GetItemData(_item);
+ auto *_qitem = reinterpret_cast<QM::QueueItem *>(_item_data);
+ if (_qitem->id == item.id)
+ {
+ store->SetValueByRow(wxVariant(QM::QueueStatus_str[item.status]), i, lastCol);
+ this->m_joblist->Refresh();
+ break;
+ }
+ }
+}
+
+void MainWindowUI::loadModelList()
+{
+ this->m_sampler->Clear();
+ for (auto sampler : sd_gui_utils::sample_method_str)
+ {
+ int _u = this->m_sampler->Append(sampler);
+
+ if (sampler == sd_gui_utils::sample_method_str[this->sd_params->sample_method])
+ {
+ this->m_sampler->Select(_u);
+ }
+ }
+
+ this->LoadFileList(sd_gui_utils::DirTypes::CHECKPOINT);
+
+ for (auto model : this->ModelFiles)
+ {
+ // auto size = sd_gui_utils::HumanReadable{std::filesystem::file_size(model.second)};
+ uintmax_t size = std::filesystem::file_size(model.second);
+ auto humanSize = sd_gui_utils::humanReadableFileSize(size);
+ auto hs = wxString::Format("%.1f %s", humanSize.first, humanSize.second);
+ wxVector<wxVariant> data;
+ data.push_back(wxVariant(model.first));
+ data.push_back(hs);
+ data.push_back("--");
+ this->m_data_model_list->AppendItem(data);
+ }
+ this->m_data_model_list->Refresh();
+}
+
+void MainWindowUI::StartGeneration(QM::QueueItem myJob)
+{
+
+ // here starts the trhead
+ // this->threads.push_back(std::thread(std::bind(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob)));
+
+ // this->threads.emplace_back(std::thread(std::bind(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob)));
+ // this->threads.emplace_back(std::thread(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob));
+ std::thread(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob);
+}
+
+void MainWindowUI::HandleSDProgress(int step, int steps, float time, void *data)
+{
+ sd_gui_utils::VoidHolder *objs = (sd_gui_utils::VoidHolder *)data;
+ wxEvtHandler *eventHandler = (wxEvtHandler *)objs->p1;
+ QM::QueueItem *myItem = (QM::QueueItem *)objs->p2;
+ myItem->step = step;
+ myItem->steps = steps;
+ myItem->time = time;
+ /*
+ format it/s
+ time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s",
+ progress.c_str(), step, steps,
+ time > 1.0f || time == 0 ? time : (1.0f / time)
+ */
+
+ wxThreadEvent *e = new wxThreadEvent();
+ e->SetString(wxString::Format("GENERATION_PROGRESS:%d/%d", step, steps));
+
+ e->SetPayload(myItem);
+ wxQueueEvent(eventHandler, e);
+}
+
+void MainWindowUI::Generate(wxEvtHandler *eventHandler, QM::QueueItem myItem)
+{
+ sd_gui_utils::VoidHolder *vparams = new sd_gui_utils::VoidHolder;
+ vparams->p1 = (void *)this->GetEventHandler();
+ vparams->p2 = (void *)&myItem;
+
+ sd_set_progress_callback(MainWindowUI::HandleSDProgress, (void *)vparams);
+
+ if (!this->modelLoaded)
+ {
+ this->sd_ctx = this->LoadModelv2(eventHandler, myItem);
+ this->currentModel = myItem.params.model_path;
+ }
+ else
+ {
+ if (myItem.params.model_path != this->currentModel)
+ {
+ free_sd_ctx(this->sd_ctx);
+ this->sd_ctx = this->LoadModelv2(eventHandler, myItem);
+ }
+ }
+ if (!this->modelLoaded || this->sd_ctx == nullptr)
+ {
+ wxThreadEvent *f = new wxThreadEvent();
+ f->SetString("GENERATION_ERROR:Model load failed...");
+ f->SetPayload(myItem);
+ wxQueueEvent(eventHandler, f);
+ return;
+ }
+
+ auto start = std::chrono::system_clock::now();
+
+ wxThreadEvent *e = new wxThreadEvent();
+ e->SetString(wxString::Format("GENERATION_START:%s", this->sd_params->model_path));
+ e->SetPayload(myItem);
+ wxQueueEvent(eventHandler, e);
+
+ sd_image_t *control_image = NULL;
+ sd_image_t *results;
+
+ // std::lock_guard<std::mutex> guard(this->sdMutex);
+ results = txt2img(this->sd_ctx,
+ myItem.params.prompt.c_str(),
+ myItem.params.negative_prompt.c_str(),
+ myItem.params.clip_skip,
+ myItem.params.cfg_scale,
+ myItem.params.width,
+ myItem.params.height,
+ myItem.params.sample_method,
+ myItem.params.sample_steps,
+ myItem.params.seed,
+ myItem.params.batch_count,
+ control_image,
+ myItem.params.control_strength);
+
+ if (results == NULL)
+ {
+ wxThreadEvent *f = new wxThreadEvent();
+ f->SetString("GENERATION_ERROR:Something wrong happened at image generation...");
+ f->SetPayload(myItem);
+ wxQueueEvent(eventHandler, f);
+ return;
+ }
+ if (!std::filesystem::exists(this->cfg->output))
+ {
+ std::filesystem::create_directories(this->cfg->output);
+ }
+ /* save image(s) */
+
+ const auto p1 = std::chrono::system_clock::now();
+ auto ctime = std::chrono::duration_cast<std::chrono::seconds>(p1.time_since_epoch()).count();
+
+ for (int i = 0; i < this->sd_params->batch_count; i++)
+ {
+ if (results[i].data == NULL)
+ {
+ continue;
+ }
+
+ // handle data??
+ wxImage *img = new wxImage(results[i].width, results[i].height, results[i].data);
+ std::string filename = this->cfg->output;
+ std::string extension = ".png";
+
+ if (this->sd_params->batch_count > 1)
+ {
+ filename = filename + wxFileName::GetPathSeparator() + std::to_string(ctime) + "_" + std::to_string(i) + extension;
+ }
+ else
+ {
+ filename = filename + wxFileName::GetPathSeparator() + std::to_string(ctime) + extension;
+ }
+ if (!img->SaveFile(filename))
+ {
+ wxThreadEvent *g = new wxThreadEvent();
+ g->SetString(wxString::Format("GENERATION_ERROR:Failed to save image into %s", filename));
+ g->SetPayload(myItem);
+ wxQueueEvent(eventHandler, g);
+ }
+ else
+ {
+ myItem.images.emplace_back(filename);
+ }
+ // handle data??
+ }
+
+ auto end = std::chrono::system_clock::now();
+ std::chrono::duration<double> elapsed_seconds = end - start;
+
+ // send to notify the user...
+ wxThreadEvent *h = new wxThreadEvent();
+ auto msg = fmt::format("MESSAGE:Image generation done in {}s. Saved into {}", elapsed_seconds.count(), this->cfg->output);
+ h->SetString(wxString(msg.c_str()));
+ // h->SetPayload(myItem);
+ wxQueueEvent(eventHandler, h);
+
+ wxThreadEvent *i = new wxThreadEvent();
+ i->SetString(wxString::Format("GENERATION_DONE:ok"));
+ i->SetPayload(results);
+ wxQueueEvent(eventHandler, i);
+
+ // send to the queue manager
+ wxThreadEvent *j = new wxThreadEvent();
+ j->SetString(wxString::Format("QUEUE:%d", QM::QueueEvents::ITEM_FINISHED));
+ j->SetPayload(myItem);
+ wxQueueEvent(eventHandler, j);
+}
+
+void MainWindowUI::initConfig()
+{
+ wxString datapath = wxStandardPaths::Get().GetUserDataDir() + wxFileName::GetPathSeparator() + "sd_ui_data" + wxFileName::GetPathSeparator();
+ wxString imagespath = wxStandardPaths::Get().GetDocumentsDir() + wxFileName::GetPathSeparator() + "sd_ui_output" + wxFileName::GetPathSeparator();
+
+ wxString model_path = datapath;
+ model_path.append("checkpoints");
+
+ wxString vae_path = datapath;
+ vae_path.append("vae");
+
+ wxString lora_path = datapath;
+ lora_path.append("lora");
+
+ wxString embedding_path = datapath;
+ embedding_path.append("embedding");
+
+ wxString presets_path = datapath;
+ presets_path.append("presets");
+
+ wxString jobs_path = datapath;
+ jobs_path.append("queue_jobs");
+
+ this->cfg->lora = this->fileConfig->Read("/paths/lora", lora_path).ToStdString();
+ this->cfg->model = this->fileConfig->Read("/paths/model", model_path).ToStdString();
+ this->cfg->vae = this->fileConfig->Read("/paths/vae", vae_path).ToStdString();
+ this->cfg->embedding = this->fileConfig->Read("/paths/embedding", embedding_path).ToStdString();
+ this->cfg->presets = this->fileConfig->Read("/paths/presets", presets_path).ToStdString();
+
+ this->cfg->jobs = this->fileConfig->Read("/paths/presets", jobs_path).ToStdString();
+
+ this->cfg->output = this->fileConfig->Read("/paths/output", imagespath).ToStdString();
+ this->cfg->keep_model_in_memory = this->fileConfig->Read("/keep_model_in_memory", this->cfg->keep_model_in_memory);
+ this->cfg->save_all_image = this->fileConfig->Read("/save_all_image", this->cfg->save_all_image);
+
+ // populate data from sd_params as default...
+
+ if (!this->modelLoaded)
+ {
+ this->m_cfg->SetValue(static_cast<double>(this->sd_params->cfg_scale));
+ this->m_seed->SetValue(static_cast<int>(this->sd_params->seed));
+ this->m_clip_skip->SetValue(this->sd_params->clip_skip);
+ this->m_steps->SetValue(this->sd_params->sample_steps);
+ this->m_width->SetValue(this->sd_params->width);
+ this->m_height->SetValue(this->sd_params->height);
+ this->m_batch_count->SetValue(this->sd_params->batch_count);
+ }
+}
+
+void MainWindowUI::OnCloseSettings(wxCloseEvent &event)
+{
+ this->initConfig();
+ this->settingsWindow->Destroy();
+}
+
+void MainWindowUI::OnThreadMessage(wxThreadEvent &e)
+{
+ if (e.GetSkipped() == false)
+ {
+ e.Skip();
+ }
+ auto msg = e.GetString().ToStdString();
+
+ std::string token = msg.substr(0, msg.find(":"));
+ std::string content = msg.substr(msg.find(":") + 1);
+
+ // this->logs->AppendText(fmt::format("Got thread message: {}\n", e.GetString().ToStdString()));
+ if (token == "QUEUE")
+ {
+ // only numbers here...
+ QM::QueueEvents event = (QM::QueueEvents)std::stoi(content);
+ // only handle the QUEUE messages, what this class generate
+ // alway QM::EueueItem the payload, with the new data
+ QM::QueueItem payload;
+ payload = e.GetPayload<QM::QueueItem>();
+ switch (event)
+ {
+ // new item added
+ case QM::QueueEvents::ITEM_ADDED:
+ this->OnQueueItemManagerItemAdded(payload);
+ break;
+ // item status changed
+ case QM::QueueEvents::ITEM_STATUS_CHANGED:
+ this->OnQueueItemManagerItemStatusChanged(payload);
+ break;
+ // item updated... ? ? ?
+ case QM::QueueEvents::ITEM_UPDATED:
+ this->OnQueueItemManagerItemUpdated(payload);
+ break;
+ case QM::QueueEvents::ITEM_START:
+ this->StartGeneration(payload);
+ break;
+
+ default:
+ break;
+ }
+ }
+ if (token == "MODEL_LOAD_DONE")
+ {
+ // this->m_generate->Enable();
+ // this->m_model->Enable();
+ // this->m_vae->Enable();
+ // this->m_refresh->Enable();
+
+ // this->logs->AppendText(fmt::format("Model loaded: {}\n", content));
+ this->modelLoaded = true;
+ // std::lock_guard<std::mutex> guard(this->sdMutex);
+ // this->sd_ctx = e.GetPayload<sd_ctx_t *>();
+ if (!this->IsShownOnScreen())
+ {
+ this->notification->SetFlags(wxICON_INFORMATION);
+ this->notification->SetTitle("SD Gui");
+ this->notification->SetMessage(content);
+ this->notification->Show(5000);
+ }
+ }
+ if (token == "MODEL_LOAD_START")
+ {
+ // this->m_generate->Disable();
+ // this->m_model->Disable();
+ // this->m_vae->Disable();
+ // this->m_refresh->Disable();
+ this->logs->AppendText(fmt::format("Model load start: {}\n", content));
+ }
+ if (token == "MODEL_LOAD_ERROR")
+ {
+ // this->m_generate->Disable();
+ // this->m_model->Enable();
+ // this->m_vae->Disable();
+ // this->m_refresh->Enable();
+ this->logs->AppendText(fmt::format("Model load error: {}\n", content));
+ this->modelLoaded = false;
+ if (!this->IsShownOnScreen())
+ {
+ this->notification->SetFlags(wxICON_ERROR);
+ this->notification->SetTitle("SD Gui - error");
+ this->notification->SetMessage(content);
+ this->notification->Show(5000);
+ }
+ }
+
+ if (token == "GENERATION_START")
+ {
+ auto myjob = e.GetPayload<QM::QueueItem>();
+
+ // this->m_generate->Disable();
+ // this->m_model->Disable();
+ // this->m_vae->Disable();
+ // this->m_refresh->Disable();
+ this->logs->AppendText(fmt::format("Diffusion started. Seed: {} Batch: {} {}x{}px Cfg: {} Steps: {}\n",
+ myjob.params.seed,
+ myjob.params.batch_count,
+ myjob.params.width,
+ myjob.params.height,
+ myjob.params.cfg_scale,
+ myjob.params.sample_steps));
+ }
+ // in the original SD.cpp the progress callback is not implemented... :(
+ if (token == "GENERATION_PROGRESS")
+ {
+ QM::QueueItem *myjob = e.GetPayload<QM::QueueItem *>();
+ // update column
+ auto store = this->m_joblist->GetStore();
+ // -1 the last (status)
+ // -2 ... (speed)
+ // -3 ... (progressbar)
+
+ // it/s format
+ /*
+ time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s",
+ progress.c_str(), step, steps,
+ time > 1.0f || time == 0 ? time : (1.0f / time)
+ */
+ wxString speed = wxString::Format(myjob->time > 1.0f ? "%.2fs/it" : "%.2fit/s", myjob->time > 1.0f || myjob->time == 0 ? myjob->time : (1.0f / myjob->time));
+ int progressCol = this->m_joblist->GetColumnCount() - 3;
+ int speedCol = this->m_joblist->GetColumnCount() - 2;
+ float current_progress = 100.f * (static_cast<float>(myjob->step) / static_cast<float>(myjob->steps));
+ if (current_progress < 2.f)
+ {
+ return;
+ }
+
+ for (unsigned int i = 0; i < store->GetItemCount(); i++)
+ {
+ auto _item = store->GetItem(i);
+ auto _item_data = store->GetItemData(_item);
+ auto *_qitem = reinterpret_cast<QM::QueueItem *>(_item_data);
+ if (_qitem->id == myjob->id)
+ {
+ store->SetValueByRow(static_cast<int>(current_progress), i, progressCol);
+ store->SetValueByRow(speed, i, speedCol);
+ this->m_joblist->Refresh();
+ break;
+ }
+ }
+
+ return;
+ for (auto it = this->JobTableItems.begin(); it != this->JobTableItems.end(); ++it)
+ {
+ if (it->second->id == myjob->id)
+ {
+ // store->SetValueByRow(wxVariant(QM::QueueStatus_str[item.status]), it->first, progressCol);
+ store->SetValueByRow(static_cast<int>(current_progress), it->first, progressCol);
+ store->SetValueByRow(speed, it->first, speedCol);
+ this->m_joblist->Refresh();
+ break;
+ }
+ }
+ // update column
+ }
+ if (token == "GENERATION_DONE")
+ {
+ // this->m_generate->Enable();
+ // this->m_model->Enable();
+ // this->m_vae->Enable();
+ // this->m_refresh->Enable();
+ // sd_image_t *results = e.GetPayload<sd_image_t *>();
+ // show images in new window...
+ /* for (int i = 0; i < this->sd_params->batch_count; i++)
+ {
+ MainWindowImageViewer *imgWindow = new MainWindowImageViewer(this);
+ // wxBitmap *img = new wxBitmap(results[i].data, (int)results[i].width, (int)results[i].height, (int)results[i].channel);
+ wxImage img(results[i].width, results[i].height, results[i].data);
+
+ wxBitmapBundle wxBmapB(img);
+ imgWindow->m_bitmap->SetBitmap(wxBmapB);
+ imgWindow->m_bitmap->SetSize(results[i].width, results[i].height);
+ imgWindow->SetSize(results[i].width + 200, results[i].height);
+
+ std::string details = fmt::format("Prompt:\n\n{}\n\nNegative prompt: \n\n{}\n\nSeed: {} \nCfg scale: {}\nClip skip: {}\nSampler: {}\nSteps: {}\nWidth: {} Height: {}",
+ this->sd_params->prompt, this->sd_params->negative_prompt,
+ this->sd_params->seed + i, this->sd_params->cfg_scale,
+ this->sd_params->clip_skip, sd_gui_utils::sample_method_str[this->sd_params->sample_method], this->sd_params->sample_steps,
+ results[i].width, results[i].height);
+ imgWindow->m_textCtrl4->AppendText(wxString(details));
+ imgWindow->Show();
+
+ // imgWindow->m_bitmap->SetBitmap(img);
+ /// imgWindow->m_bitmap->Set
+ }*/
+ }
+ if (token == "GENERATION_ERROR")
+ {
+ this->logs->AppendText(fmt::format("Generation error: {}\n", content));
+ if (!this->IsShownOnScreen())
+ {
+ this->notification->SetFlags(wxICON_ERROR);
+ this->notification->SetTitle("SD Gui - error");
+ this->notification->SetMessage(content);
+ this->notification->Show(5000);
+ }
+ }
+ if (token == "SD_MESSAGE")
+ {
+ if (content.length() < 1)
+ {
+ return;
+ }
+ this->logs->AppendText(fmt::format("{}", content));
+ }
+ if (token == "MESSAGE")
+ {
+ this->logs->AppendText(fmt::format("{}\n", content));
+ if (!this->IsShownOnScreen())
+ {
+ this->notification->SetFlags(wxICON_INFORMATION);
+ this->notification->SetTitle("SD Gui");
+ this->notification->SetMessage(content);
+ this->notification->Show(5000);
+ }
+ }
+}
+
+sd_ctx_t *MainWindowUI::LoadModelv2(wxEvtHandler *eventHandler, QM::QueueItem myItem)
+{
+ wxThreadEvent *e = new wxThreadEvent();
+ e->SetString(wxString::Format("MODEL_LOAD_START:%s", myItem.params.model_path));
+ e->SetPayload(myItem);
+ wxQueueEvent(eventHandler, e);
+
+ // std::lock_guard<std::mutex> guard(this->sdMutex);
+ sd_ctx_t *sd_ctx_ = new_sd_ctx(
+ myItem.params.model_path.c_str(),
+ myItem.params.vae_path.c_str(),
+ myItem.params.taesd_path.c_str(),
+ myItem.params.controlnet_path.c_str(),
+ myItem.params.lora_model_dir.c_str(),
+ myItem.params.embeddings_path.c_str(),
+ true, false, false,
+ myItem.params.n_threads,
+ myItem.params.wtype,
+ myItem.params.rng_type,
+ myItem.params.schedule, false);
+
+ if (sd_ctx_ == NULL)
+ {
+ wxThreadEvent *c = new wxThreadEvent();
+ c->SetString(wxString::Format("MODEL_LOAD_ERROR:%s", myItem.params.model_path));
+ c->SetPayload(myItem);
+ wxQueueEvent(eventHandler, c);
+ this->modelLoaded = false;
+ return nullptr;
+ }
+ else
+ {
+ wxThreadEvent *c = new wxThreadEvent();
+ c->SetString(wxString::Format("MODEL_LOAD_DONE:%s", myItem.params.model_path));
+ wxQueueEvent(eventHandler, c);
+ this->modelLoaded = true;
+ this->currentModel = myItem.params.model_path;
+ }
+ return sd_ctx_;
+}
+
+void MainWindowUI::LoadPresets()
+{
+ this->LoadFileList(sd_gui_utils::DirTypes::PRESETS);
+}
+
+void MainWindowUI::OnQueueItemManagerItemUpdated(QM::QueueItem item)
+{
+}
+
+void MainWindowUI::loadVaeList()
+{
+ if (!std::filesystem::exists(this->cfg->vae))
+ {
+ std::filesystem::create_directories(this->cfg->vae);
+ }
+ this->LoadFileList(sd_gui_utils::DirTypes::VAE);
+}
+
+MainWindowUI::~MainWindowUI()
+{
+ // this->Hide();
+ for (int i = 0; i < this->threads.size(); i++)
+ {
+ if (this->threads.at(i).joinable())
+ {
+ this->threads.at(i).join();
+ }
+ }
+}
+
void MainWindowUI::OnQueueItemManagerItemAdded(QM::QueueItem item)
{
wxVector<wxVariant> data;
+ auto created_at = sd_gui_utils::formatUnixTimestampToDate(item.created_at);
+
data.push_back(wxVariant(std::to_string(item.id)));
- data.push_back(wxVariant(std::to_string(item.created_at)));
- data.push_back(wxVariant(item.params.model_path));
+ data.push_back(wxVariant(created_at));
+ data.push_back(wxVariant(item.model));
data.push_back(wxVariant(sd_gui_utils::sample_method_str[(int)item.params.sample_method]));
data.push_back(wxVariant(std::to_string(item.params.seed)));
- data.push_back(wxVariant(QM::QueueStatus_str[item.status]));
-
- this->m_joblist->AppendItem(data);
+ data.push_back(item.status == QM::QueueStatus::DONE ? 100 : 1); // progressbar
+ data.push_back(wxString("-.--it/s")); // speed
+ data.push_back(wxVariant(QM::QueueStatus_str[item.status])); // status
auto store = this->m_joblist->GetStore();
- (*this->JobTableItems)[store->GetCount() - 1] = item;
+
+ QM::QueueItem *nItem = new QM::QueueItem(item);
+
+ this->JobTableItems[item.id] = nItem;
+ // store->AppendItem(data, wxUIntPtr(this->JobTableItems[item.id]));
+ store->PrependItem(data, wxUIntPtr(this->JobTableItems[item.id]));
}
void MainWindowUI::LoadFileList(sd_gui_utils::DirTypes type)
@@ -334,7 +921,7 @@
if (type == sd_gui_utils::DirTypes::CHECKPOINT || type == sd_gui_utils::DirTypes::VAE)
{
- if (ext != ".safetensors" && ext != ".cptk")
+ if (ext != ".safetensors" && ext != ".ckpt")
{
continue;
}
@@ -403,528 +990,4 @@
this->m_preset_list->Enable();
}
}
-}
-
-void MainWindowUI::LoadPresets()
-{
- this->LoadFileList(sd_gui_utils::DirTypes::PRESETS);
-}
-
-void MainWindowUI::OnThreadMessage(wxThreadEvent &e)
-{
- e.Skip();
-
- auto msg = e.GetString().ToStdString();
-
- std::string token = msg.substr(0, msg.find(":"));
- std::string content = msg.substr(msg.find(":") + 1);
-
- // this->logs->AppendText(fmt::format("Got thread message: {}\n", e.GetString().ToStdString()));
- if (token == "QUEUE")
- {
- // only numbers here...
- QM::QueueEvents event = (QM::QueueEvents)std::stoi(content);
- // only handle the QUEUE messages, what this class generate
- // alway QM::EueueItem the payload, with the new data
- QM::QueueItem payload;
- payload = e.GetPayload<QM::QueueItem>();
- switch (event)
- {
- // new item added
- case QM::QueueEvents::ITEM_ADDED:
- this->OnQueueItemManagerItemAdded(payload);
- break;
- // item status changed
- case QM::QueueEvents::ITEM_STATUS_CHANGED:
- this->OnQueueItemManagerItemStatusChanged(payload);
- break;
- // item updated... ? ? ?
- case QM::QueueEvents::ITEM_UPDATED:
- this->OnQueueItemManagerItemUpdated(payload);
- break;
- case QM::QueueEvents::ITEM_START:
- this->StartGeneration(payload);
- break;
-
- default:
- break;
- }
- }
- if (token == "MODEL_LOAD_DONE")
- {
- // this->m_generate->Enable();
- // this->m_model->Enable();
- // this->m_vae->Enable();
- // this->m_refresh->Enable();
-
- // this->logs->AppendText(fmt::format("Model loaded: {}\n", content));
- this->modelLoaded = true;
- this->sd_ctx = e.GetPayload<sd_ctx_t *>();
- if (!this->IsShownOnScreen())
- {
- this->notification->SetFlags(wxICON_INFORMATION);
- this->notification->SetTitle("SD Gui");
- this->notification->SetMessage(content);
- this->notification->Show(5000);
- }
- }
- if (token == "MODEL_LOAD_START")
- {
- // this->m_generate->Disable();
- // this->m_model->Disable();
- // this->m_vae->Disable();
- // this->m_refresh->Disable();
- this->logs->AppendText(fmt::format("Model load start: {}\n", content));
- }
- if (token == "MODEL_LOAD_ERROR")
- {
- // this->m_generate->Disable();
- // this->m_model->Enable();
- // this->m_vae->Disable();
- // this->m_refresh->Enable();
- this->logs->AppendText(fmt::format("Model load error: {}\n", content));
- this->modelLoaded = false;
- if (!this->IsShownOnScreen())
- {
- this->notification->SetFlags(wxICON_ERROR);
- this->notification->SetTitle("SD Gui - error");
- this->notification->SetMessage(content);
- this->notification->Show(5000);
- }
- }
-
- if (token == "GENERATION_START")
- {
- auto myjob = e.GetPayload<QM::QueueItem>();
-
- // this->m_generate->Disable();
- // this->m_model->Disable();
- // this->m_vae->Disable();
- // this->m_refresh->Disable();
- this->logs->AppendText(fmt::format("Diffusion started. Seed: {} Batch: {} {}x{}px Cfg: {} Steps: {}\n",
- myjob.params.seed,
- myjob.params.batch_count,
- myjob.params.width,
- myjob.params.height,
- myjob.params.cfg_scale,
- myjob.params.sample_steps));
- }
- // never, not implemented in sd.cpp
- if (token == "GENERATION_PROGRESS")
- {
- // this->m_generate->Disable();
- // this->logs->AppendText(fmt::format("Generation progress: {}\n", content));
- }
- if (token == "GENERATION_DONE")
- {
- // this->m_generate->Enable();
- // this->m_model->Enable();
- // this->m_vae->Enable();
- // this->m_refresh->Enable();
- // sd_image_t *results = e.GetPayload<sd_image_t *>();
- // show images in new window...
- /* for (int i = 0; i < this->sd_params->batch_count; i++)
- {
- MainWindowImageViewer *imgWindow = new MainWindowImageViewer(this);
- // wxBitmap *img = new wxBitmap(results[i].data, (int)results[i].width, (int)results[i].height, (int)results[i].channel);
- wxImage img(results[i].width, results[i].height, results[i].data);
-
- wxBitmapBundle wxBmapB(img);
- imgWindow->m_bitmap->SetBitmap(wxBmapB);
- imgWindow->m_bitmap->SetSize(results[i].width, results[i].height);
- imgWindow->SetSize(results[i].width + 200, results[i].height);
-
- std::string details = fmt::format("Prompt:\n\n{}\n\nNegative prompt: \n\n{}\n\nSeed: {} \nCfg scale: {}\nClip skip: {}\nSampler: {}\nSteps: {}\nWidth: {} Height: {}",
- this->sd_params->prompt, this->sd_params->negative_prompt,
- this->sd_params->seed + i, this->sd_params->cfg_scale,
- this->sd_params->clip_skip, sd_gui_utils::sample_method_str[this->sd_params->sample_method], this->sd_params->sample_steps,
- results[i].width, results[i].height);
- imgWindow->m_textCtrl4->AppendText(wxString(details));
- imgWindow->Show();
-
- // imgWindow->m_bitmap->SetBitmap(img);
- /// imgWindow->m_bitmap->Set
- }*/
- }
- if (token == "GENERATION_ERROR")
- {
- // this->m_generate->Enable();
- // this->m_model->Enable();
- // this->m_vae->Enable();
- this->logs->AppendText(fmt::format("Generation error: {}\n", content));
- if (!this->IsShownOnScreen())
- {
- this->notification->SetFlags(wxICON_ERROR);
- this->notification->SetTitle("SD Gui - error");
- this->notification->SetMessage(content);
- this->notification->Show(5000);
- }
- }
- if (token == "SD_MESSAGE")
- {
- if (content.length() < 1)
- {
- return;
- }
- this->logs->AppendText(fmt::format("{}", content));
- }
- if (token == "MESSAGE")
- {
- this->logs->AppendText(fmt::format("{}\n", content));
- if (!this->IsShownOnScreen())
- {
- this->notification->SetFlags(wxICON_INFORMATION);
- this->notification->SetTitle("SD Gui");
- this->notification->SetMessage(content);
- this->notification->Show(5000);
- }
- }
-}
-
-void MainWindowUI::LoadModel(wxEvtHandler *eventHandler, QM::QueueItem myItem)
-{
- wxThreadEvent *e = new wxThreadEvent();
- e->SetString(wxString::Format("MODEL_LOAD_START:%s", this->sd_params->model_path));
- e->SetPayload(myItem);
- wxQueueEvent(eventHandler, e);
-
- /*this->sd_ctx = new_sd_ctx(this->sd_params->model_path.c_str(),
- this->sd_params->vae_path.c_str(),
- this->sd_params->taesd_path.c_str(),
- this->sd_params->controlnet_path.c_str(),
- std::string(this->sd_params->lora_model_dir + "\\").c_str(),
- this->sd_params->embeddings_path.c_str(),
- true, // vae decode only
- true ,
- !this->cfg->keep_model_in_memory,
- this->sd_params->n_threads,
- this->sd_params->wtype ,
- this->sd_params->rng_type,
- this->sd_params->schedule,
- this->sd_params->control_net_cpu);*/
- // this->sd_ctx = new_sd_ctx(this->sd_params->model_path.c_str(), this->sd_params->vae_path.c_str(), this->sd_params->taesd_path.c_str(), this->sd_params->lora_model_dir.c_str(), true, false, false, this->sd_params->n_threads, this->sd_params->wtype, this->sd_params->rng_type, this->sd_params->schedule);
- this->sd_ctx = new_sd_ctx(myItem.params.model_path.c_str(), myItem.params.vae_path.c_str(), myItem.params.taesd_path.c_str(), myItem.params.lora_model_dir.c_str(), true, false, false, myItem.params.n_threads, myItem.params.wtype, myItem.params.rng_type, myItem.params.schedule);
-
- if (this->sd_ctx == NULL)
- {
- wxThreadEvent *c = new wxThreadEvent();
- c->SetString(wxString::Format("MODEL_LOAD_ERROR:%s", this->sd_params->model_path));
- c->SetPayload(myItem);
- wxQueueEvent(eventHandler, c);
- this->modelLoaded = false;
- return;
- }
- else
- {
- wxThreadEvent *c = new wxThreadEvent();
- c->SetString(wxString::Format("MODEL_LOAD_DONE:%s", this->sd_params->model_path));
- c->SetPayload(this->sd_ctx);
- wxQueueEvent(eventHandler, c);
- this->modelLoaded = true;
- return;
- }
-
- return;
-}
-
-void MainWindowUI::initConfig()
-{
- wxString datapath = wxStandardPaths::Get().GetUserDataDir() + wxFileName::GetPathSeparator() + "sd_ui_data" + wxFileName::GetPathSeparator();
- wxString imagespath = wxStandardPaths::Get().GetDocumentsDir() + wxFileName::GetPathSeparator() + "sd_ui_output" + wxFileName::GetPathSeparator();
-
- wxString model_path = datapath;
- model_path.append("checkpoints");
-
- wxString vae_path = datapath;
- vae_path.append("vae");
-
- wxString lora_path = datapath;
- lora_path.append("lora");
-
- wxString embedding_path = datapath;
- embedding_path.append("embedding");
-
- wxString presets_path = datapath;
- presets_path.append("presets");
-
- wxString jobs_path = datapath;
- jobs_path.append("queue_jobs");
-
- this->cfg->lora = this->fileConfig->Read("/paths/lora", lora_path).ToStdString();
- this->cfg->model = this->fileConfig->Read("/paths/model", model_path).ToStdString();
- this->cfg->vae = this->fileConfig->Read("/paths/vae", vae_path).ToStdString();
- this->cfg->embedding = this->fileConfig->Read("/paths/embedding", embedding_path).ToStdString();
- this->cfg->presets = this->fileConfig->Read("/paths/presets", presets_path).ToStdString();
-
- this->cfg->jobs = this->fileConfig->Read("/paths/presets", jobs_path).ToStdString();
-
- this->cfg->output = this->fileConfig->Read("/paths/output", imagespath).ToStdString();
- this->cfg->keep_model_in_memory = this->fileConfig->Read("/keep_model_in_memory", this->cfg->keep_model_in_memory);
- this->cfg->save_all_image = this->fileConfig->Read("/save_all_image", this->cfg->save_all_image);
-
- // populate data from sd_params as default...
-
- if (!this->modelLoaded)
- {
- this->m_cfg->SetValue(static_cast<double>(this->sd_params->cfg_scale));
- this->m_seed->SetValue(static_cast<int>(this->sd_params->seed));
- this->m_clip_skip->SetValue(this->sd_params->clip_skip);
- this->m_steps->SetValue(this->sd_params->sample_steps);
- this->m_width->SetValue(this->sd_params->width);
- this->m_height->SetValue(this->sd_params->height);
- this->m_batch_count->SetValue(this->sd_params->batch_count);
- }
-}
-
-void MainWindowUI::StartGeneration(QM::QueueItem myJob)
-{
- this->threads.push_back(std::thread(std::bind(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob)));
-}
-
-void MainWindowUI::OnCloseSettings(wxCloseEvent &event)
-{
- this->initConfig();
- this->settingsWindow->Destroy();
-}
-
-void MainWindowUI::loadModelList()
-{
- this->m_sampler->Clear();
- for (auto sampler : sd_gui_utils::sample_method_str)
- {
- int _u = this->m_sampler->Append(sampler);
-
- if (sampler == sd_gui_utils::sample_method_str[this->sd_params->sample_method])
- {
- this->m_sampler->Select(_u);
- }
- }
-
- this->LoadFileList(sd_gui_utils::DirTypes::CHECKPOINT);
-
- for (auto model : this->ModelFiles)
- {
- // auto size = sd_gui_utils::HumanReadable{std::filesystem::file_size(model.second)};
- auto size = std::filesystem::file_size(model.second);
-
- wxVector<wxVariant> data;
- data.push_back(wxVariant(model.first));
- data.push_back(wxVariant(std::to_string(size)));
- this->m_data_model_list->AppendItem(data);
- }
- this->m_data_model_list->Refresh();
-}
-
-void MainWindowUI::Generate(wxEvtHandler *eventHandler, QM::QueueItem myItem)
-{
- // @brief model loading is done in the same thread which generating stuffs... no need new thread
- // @brief to load the model if no model loaded, or the loaded model is different from the job's model...
- // @brief if all second job have different model, it will be slower than same model on all jobs
-
- // TODO: abort job if model can not be loaded...
- if (!this->modelLoaded)
- {
- this->LoadModel(eventHandler, myItem);
- this->currentModel = myItem.params.model_path;
- }
- else
- {
- if (myItem.params.model_path != this->currentModel)
- {
- free_sd_ctx(this->sd_ctx);
- this->LoadModel(eventHandler, myItem);
- }
- }
- if (!this->modelLoaded)
- {
- wxThreadEvent *f = new wxThreadEvent();
- f->SetString("GENERATION_ERROR:Model load failed...");
- f->SetPayload(myItem);
- wxQueueEvent(eventHandler, f);
- return;
- }
-
- auto start = std::chrono::system_clock::now();
-
- wxThreadEvent *e = new wxThreadEvent();
- e->SetString(wxString::Format("GENERATION_START:%s", this->sd_params->model_path));
- e->SetPayload(myItem);
- wxQueueEvent(eventHandler, e);
-
- sd_image_t *control_image = NULL;
- sd_image_t *results;
-
-
- results = txt2img(this->sd_ctx,
- this->sd_params->prompt.c_str(), this->sd_params->negative_prompt.c_str(), this->sd_params->clip_skip, this->sd_params->cfg_scale, this->sd_params->width, this->sd_params->height, this->sd_params->sample_method, this->sd_params->sample_steps, this->sd_params->seed, this->sd_params->batch_count);
-
- if (results == NULL)
- {
- wxThreadEvent *f = new wxThreadEvent();
- f->SetString("GENERATION_ERROR:Something wrong happened at image generation...");
- f->SetPayload(myItem);
- wxQueueEvent(eventHandler, f);
- return;
- }
- if (!std::filesystem::exists(this->cfg->output))
- {
- std::filesystem::create_directories(this->cfg->output);
- }
- /* save image(s) */
-
- const auto p1 = std::chrono::system_clock::now();
- auto ctime = std::chrono::duration_cast<std::chrono::seconds>(p1.time_since_epoch()).count();
-
- for (int i = 0; i < this->sd_params->batch_count; i++)
- {
- if (results[i].data == NULL)
- {
- continue;
- }
-
- // handle data??
- wxImage *img = new wxImage(results[i].width, results[i].height, results[i].data);
- std::string filename = this->cfg->output;
- std::string extension = ".png";
-
- if (this->sd_params->batch_count > 1)
- {
- filename = filename + wxFileName::GetPathSeparator() + std::to_string(ctime) + "_" + std::to_string(i) + extension;
- }
- else
- {
- filename = filename + wxFileName::GetPathSeparator() + std::to_string(ctime) + extension;
- }
- if (!img->SaveFile(filename))
- {
- wxThreadEvent *g = new wxThreadEvent();
- g->SetString(wxString::Format("GENERATION_ERROR:Failed to save image into %s", filename));
- g->SetPayload(myItem);
- wxQueueEvent(eventHandler, g);
- }
- else
- {
- myItem.images.emplace_back(filename);
- }
-
- // handle data??
- }
-
- auto end = std::chrono::system_clock::now();
- std::chrono::duration<double> elapsed_seconds = end - start;
-
- // send to notify the user...
- wxThreadEvent *h = new wxThreadEvent();
- auto msg = fmt::format("MESSAGE:Image generation done in {}s. Saved into {}", elapsed_seconds.count(), this->cfg->output);
- h->SetString(wxString(msg.c_str()));
- h->SetPayload(myItem);
- wxQueueEvent(eventHandler, h);
-
- wxThreadEvent *i = new wxThreadEvent();
- i->SetString(wxString::Format("GENERATION_DONE:ok"));
- i->SetPayload(results);
- wxQueueEvent(eventHandler, i);
-
- // send to the queue manager
- wxThreadEvent *j = new wxThreadEvent();
- j->SetString(wxString::Format("QUEUE:%d", QM::QueueEvents::ITEM_FINISHED));
- j->SetPayload(myItem);
- wxQueueEvent(eventHandler, j);
-
- return;
-}
-
-void MainWindowUI::HandleSDLog(sd_log_level_t level, const char *text, void *data)
-{
- if (level != sd_log_level_t::SD_LOG_INFO)
- {
- return;
- }
- // wxEvtHandler *eventHandler = (wxEvtHandler *)data;
- MainWindowUI *ui = (MainWindowUI *)data;
- wxEvtHandler *eventHandler = ui->GetEventHandler();
- wxThreadEvent *e = new wxThreadEvent();
- e->SetString(wxString::Format("SD_MESSAGE:%s", text));
- e->SetPayload(level);
- wxQueueEvent(eventHandler, e);
-}
-
-void MainWindowUI::OnQueueItemManagerItemStatusChanged(QM::QueueItem item)
-{
- /*
- this->m_joblist->AppendTextColumn("Id");
- this->m_joblist->AppendTextColumn("Created");
- this->m_joblist->AppendTextColumn("Model");
- this->m_joblist->AppendTextColumn("Sampler");
- this->m_joblist->AppendTextColumn("Seed");
- this->m_joblist->AppendTextColumn("Status");
-
- */
- // TODO: how to update an item in the table...
- // auto item = this->m_joblist->RowToItem();
- auto store = this->m_joblist->GetStore();
-
- for (auto it = this->JobTableItems->begin(); it != this->JobTableItems->end(); ++it)
- {
- if (it->second.id == item.id)
- {
- // auto store = *this->m_joblist()->GetStore();
- int lastCol = this->m_joblist->GetColumnCount();
- // always update the last col, so the last col is always need to be the status
- store->SetValueByRow(wxVariant(QM::QueueStatus_str[item.status]), it->first, lastCol - 1);
- this->m_joblist->Refresh();
- break;
- }
- }
-}
-
-MainWindowUI::~MainWindowUI()
-{
- // clean up things...
- if (this->modelLoaded)
- {
- free_sd_ctx(this->sd_ctx);
- }
- for (int i = 0; i < this->threads.size(); i++)
- {
- this->threads.at(i).join();
- }
- this->threads.clear();
- this->Destroy();
- exit(0);
-}
-
-void MainWindowUI::loadVaeList()
-{
- if (!std::filesystem::exists(this->cfg->vae))
- {
- std::filesystem::create_directories(this->cfg->vae);
- }
- this->LoadFileList(sd_gui_utils::DirTypes::VAE);
-}
-
-void MainWindowUI::OnQueueItemManagerItemUpdated(QM::QueueItem item)
-{
-}
-
-void MainWindowUI::StartLoadModel()
-{
- // prepare
- auto selection = this->m_model->GetStringSelection();
- if (selection == "-none-")
- {
- free_sd_ctx(this->sd_ctx);
- return;
- }
- if (this->modelLoaded)
- {
- free_sd_ctx(this->sd_ctx);
- }
- // disable ui
- this->m_model->Disable();
- this->m_vae->Disable();
- this->m_refresh->Disable();
- this->sd_params->model_path = this->ModelFiles.at(selection.ToStdString());
- this->sd_params->lora_model_dir = this->fileConfig->Read("/paths/lora", "").ToStdString();
- // this->threads.push_back(std::thread(std::bind(&MainWindowUI::LoadModel, this, this->GetEventHandler())));
}
\ No newline at end of file
--
Gitblit v1.9.3