#include "MainWindowUI.h"
|
|
MainWindowUI::MainWindowUI(wxWindow *parent)
|
: UI(parent)
|
{
|
this->ini_path = wxStandardPaths::Get().GetUserConfigDir() + wxFileName::GetPathSeparator() + "sd.ui.config.ini";
|
this->sd_params = new sd_gui_utils::SDParams;
|
this->notification = new wxNotificationMessage();
|
// this->JobTableItems = new std::map<int, QM::QueueItem>();
|
|
this->notification->SetParent(this);
|
|
// prepare data list views
|
this->m_data_model_list->AppendTextColumn("Name", wxDATAVIEW_CELL_INERT, 200);
|
this->m_data_model_list->AppendTextColumn("Size");
|
this->m_data_model_list->AppendTextColumn("Hash");
|
|
this->m_joblist->AppendTextColumn("Id");
|
this->m_joblist->AppendTextColumn("Created at");
|
this->m_joblist->AppendTextColumn("Model");
|
this->m_joblist->AppendTextColumn("Sampler");
|
this->m_joblist->AppendTextColumn("Seed");
|
this->m_joblist->AppendProgressColumn("Progress"); // progressbar
|
this->m_joblist->AppendTextColumn("Speed"); // speed
|
this->m_joblist->AppendTextColumn("Status"); // status
|
|
this->SetTitle(this->GetTitle() + SD_GUI_VERSION);
|
|
this->cfg = new sd_gui_utils::config;
|
this->fileConfig = new wxFileConfig("sd.cpp.ui", wxEmptyString, this->ini_path);
|
wxConfigBase::Set(fileConfig);
|
this->initConfig();
|
this->qmanager = new QM::QueueManager(this->GetEventHandler(), this->cfg->jobs);
|
|
// set SD logger
|
sd_set_log_callback(MainWindowUI::HandleSDLog, (void *)this->GetEventHandler());
|
|
// load
|
this->LoadPresets();
|
this->loadModelList();
|
this->loadVaeList();
|
if (this->ModelFiles.size() > 0)
|
{
|
this->m_model->Enable();
|
}
|
if (this->VaeFiles.size() > 0)
|
{
|
this->m_vae->Enable();
|
}
|
Bind(wxEVT_THREAD, &MainWindowUI::OnThreadMessage, this);
|
}
|
|
void MainWindowUI::onSettings(wxCommandEvent &event)
|
{
|
this->settingsWindow = new MainWindowSettings(this);
|
this->settingsWindow->Bind(wxEVT_CLOSE_WINDOW, &MainWindowUI::OnCloseSettings, this);
|
settingsWindow->Show();
|
}
|
|
void MainWindowUI::onModelsRefresh(wxCommandEvent &event)
|
{
|
this->loadModelList();
|
this->loadVaeList();
|
if (this->ModelFiles.size() > 0)
|
{
|
this->m_model->Enable();
|
}
|
if (this->VaeFiles.size() > 0)
|
{
|
this->m_vae->Enable();
|
}
|
}
|
|
void MainWindowUI::onModelSelect(wxCommandEvent &event)
|
{
|
// check if really selected a model, or just the first element, which is always exists...
|
auto name = this->m_model->GetStringSelection().ToStdString();
|
auto first = this->m_model->GetString(0);
|
if (name == first)
|
{
|
this->m_generate->Disable();
|
this->m_statusBar166->SetStatusText("Model: none");
|
return;
|
}
|
this->m_generate->Enable();
|
this->sd_params->model_path = this->ModelFiles.at(name);
|
this->m_statusBar166->SetStatusText("Model: " + this->sd_params->model_path);
|
// add the model to the params, but dont start load the model...
|
// the queue handle it...
|
// this->StartLoadModel();
|
}
|
|
void MainWindowUI::onVaeSelect(wxCommandEvent &event)
|
{
|
auto selection = this->m_vae->GetStringSelection();
|
if (selection == "-none-")
|
{
|
// remove the vae
|
this->sd_params->vae_path = std::string("");
|
}
|
else
|
{
|
// add the selected vae
|
this->sd_params->vae_path = this->VaeFiles.at(selection.ToStdString());
|
}
|
}
|
|
void MainWindowUI::onRandomGenerateButton(wxCommandEvent &event)
|
{
|
this->m_seed->SetValue(sd_gui_utils::generateRandomInt(100000000, 999999999));
|
}
|
|
void MainWindowUI::onResolutionSwap(wxCommandEvent &event)
|
{
|
auto oldW = this->m_width->GetValue();
|
auto oldH = this->m_height->GetValue();
|
|
this->m_height->SetValue(oldW);
|
this->m_width->SetValue(oldH);
|
}
|
|
void MainWindowUI::onJobsStart(wxCommandEvent &event)
|
{
|
// TODO: Implement onJobsStart
|
}
|
|
void MainWindowUI::onJobsPause(wxCommandEvent &event)
|
{
|
// TODO: Implement onJobsPause
|
}
|
|
void MainWindowUI::onJobsDelete(wxCommandEvent &event)
|
{
|
// TODO: Implement onJobsDelete
|
}
|
|
void MainWindowUI::onJoblistItemActivated(wxDataViewEvent &event)
|
{
|
// TODO: Implement onJoblistItemActivated
|
}
|
|
void MainWindowUI::onContextMenu(wxDataViewEvent &event)
|
{
|
|
auto *source = (wxDataViewListCtrl *)event.GetEventObject();
|
wxMenu menu;
|
|
menu.SetClientData((void *)source);
|
|
if (source == this->m_joblist)
|
{
|
menu.Append(1, "Copy and restart");
|
menu.Append(2, "Copy paramters to text2img");
|
menu.Append(3, "Copy paramters to img2img");
|
menu.Append(4, "Details...");
|
}
|
|
if (source == this->m_data_model_list)
|
{
|
|
menu.Append(1, "Calculate Hash");
|
menu.Append(2, "Download info from CivitAi.com");
|
}
|
|
menu.Connect(wxEVT_COMMAND_MENU_SELECTED, wxCommandEventHandler(MainWindowUI::OnPopupClick), NULL, this);
|
PopupMenu(&menu);
|
}
|
|
void MainWindowUI::onJoblistSelectionChanged(wxDataViewEvent &event)
|
{
|
// TODO: Implement onJoblistSelectionChanged
|
}
|
|
void MainWindowUI::onGenerate(wxCommandEvent &event)
|
{
|
// prepare params
|
this->sd_params->model_path = this->ModelFiles.at(this->m_model->GetStringSelection().ToStdString());
|
this->sd_params->lora_model_dir = this->cfg->lora;
|
this->sd_params->embeddings_path = this->cfg->embedding;
|
|
this->sd_params->prompt = this->m_prompt->GetValue().ToStdString();
|
this->sd_params->negative_prompt = this->m_neg_prompt->GetValue().ToStdString();
|
|
this->sd_params->cfg_scale = static_cast<float>(this->m_cfg->GetValue());
|
this->sd_params->seed = this->m_seed->GetValue();
|
this->sd_params->clip_skip = this->m_clip_skip->GetValue();
|
this->sd_params->sample_steps = this->m_steps->GetValue();
|
|
this->sd_params->sample_method = (sample_method_t)this->m_sampler->GetCurrentSelection();
|
|
this->sd_params->batch_count = this->m_batch_count->GetValue();
|
|
this->sd_params->width = this->m_width->GetValue();
|
this->sd_params->height = this->m_height->GetValue();
|
QM::QueueItem item;
|
item.params = *this->sd_params;
|
item.model = this->m_model->GetStringSelection().ToStdString();
|
|
if (item.params.seed == -1)
|
{
|
item.params.seed = sd_gui_utils::generateRandomInt(100000000, 999999999);
|
}
|
|
// add the queue item
|
auto id = this->qmanager->AddItem(item);
|
}
|
|
void MainWindowUI::onSamplerSelect(wxCommandEvent &event)
|
{
|
this->sd_params->sample_method = (sample_method_t)this->m_sampler->GetSelection();
|
}
|
|
void MainWindowUI::onSavePreset(wxCommandEvent &event)
|
{
|
wxTextEntryDialog dlg(this, "Please specify a name (only alphanumeric)");
|
dlg.SetTextValidator(wxFILTER_ALPHA | wxFILTER_DIGITS);
|
if (dlg.ShowModal() == wxID_OK)
|
{
|
sd_gui_utils::generator_preset preset;
|
|
wxString preset_name = dlg.GetValue();
|
preset.cfg = this->m_cfg->GetValue();
|
|
preset.seed = this->m_seed->GetValue();
|
preset.clip_skip = this->m_clip_skip->GetValue();
|
preset.steps = this->m_steps->GetValue();
|
preset.width = this->m_width->GetValue();
|
preset.height = this->m_height->GetValue();
|
preset.sampler = (sample_method_t)this->m_sampler->GetSelection();
|
preset.batch = this->m_batch_count->GetValue();
|
preset.name = preset_name.ToStdString();
|
preset.mode = "text2image";
|
nlohmann::json j(preset);
|
std::string presetfile = fmt::format("{}{}{}.json",
|
this->cfg->presets,
|
wxString(wxFileName::GetPathSeparator()).ToStdString(),
|
preset.name);
|
|
std::ofstream file(presetfile);
|
file << j;
|
file.close();
|
this->LoadPresets();
|
}
|
}
|
|
void MainWindowUI::onLoadPreset(wxCommandEvent &event)
|
{
|
/* auto preset_name = this->m_preset_list->GetString(this->m_preset_list->GetCurrentSelection());
|
|
this->m_cfg->SetValue(this->fileConfig->Read("/presets/" + preset_name + "/cfg", this->sd_params->cfg_scale));
|
this->m_seed->SetValue(this->fileConfig->Read("/presets/" + preset_name + "/seed", this->sd_params->seed));
|
this->m_clip_skip->SetValue(this->fileConfig->Read("/presets/" + preset_name + "/clip_skip", this->sd_params->clip_skip));
|
this->m_steps->SetValue(this->fileConfig->Read("/presets/" + preset_name + "/steps", this->sd_params->sample_steps));
|
this->m_width->SetValue(this->fileConfig->Read("/presets/" + preset_name + "/width", this->sd_params->width));
|
this->m_height->SetValue(this->fileConfig->Read("/presets/" + preset_name + "/height", this->sd_params->height));
|
this->m_sampler->Select(this->fileConfig->Read("/presets/" + preset_name + "/sampler", (int)this->sd_params->sample_method));
|
this->m_batch_count->SetValue(this->fileConfig->Read("/presets/" + preset_name + "/batch", this->sd_params->batch_count));
|
*/
|
auto selected = this->m_preset_list->GetCurrentSelection();
|
auto name = this->m_preset_list->GetString(selected);
|
|
for (auto preset : this->Presets)
|
{
|
if (preset.second.name == name)
|
{
|
this->m_cfg->SetValue(preset.second.cfg);
|
this->m_clip_skip->SetValue(preset.second.clip_skip);
|
this->m_seed->SetValue(preset.second.seed);
|
this->m_steps->SetValue(preset.second.steps);
|
this->m_width->SetValue(preset.second.width);
|
this->m_height->SetValue(preset.second.height);
|
this->m_sampler->SetSelection(preset.second.sampler);
|
this->m_batch_count->SetValue(preset.second.batch);
|
}
|
}
|
}
|
|
void MainWindowUI::onSelectPreset(wxCommandEvent &event)
|
{
|
if (this->m_preset_list->GetCurrentSelection() == 0)
|
{
|
this->m_load_preset->Disable();
|
this->m_delete_preset->Disable();
|
}
|
else
|
{
|
this->m_load_preset->Enable();
|
this->m_delete_preset->Enable();
|
}
|
}
|
|
void MainWindowUI::onDeletePreset(wxCommandEvent &event)
|
{
|
|
auto name = this->m_preset_list->GetStringSelection().ToStdString();
|
|
if (this->Presets.find(name) != this->Presets.end())
|
{
|
auto preset = this->Presets[name];
|
std::remove(preset.path.c_str());
|
this->LoadPresets();
|
}
|
}
|
|
void MainWindowUI::LoadFileList(sd_gui_utils::DirTypes type)
|
{
|
std::string basepath;
|
|
switch (type)
|
{
|
case sd_gui_utils::DirTypes::VAE:
|
this->VaeFiles.clear();
|
this->m_vae->Clear();
|
this->m_vae->Append("-none-");
|
this->m_vae->Select(0);
|
basepath = this->cfg->vae;
|
break;
|
case sd_gui_utils::DirTypes::LORA:
|
basepath = this->cfg->lora;
|
break;
|
case sd_gui_utils::DirTypes::CHECKPOINT:
|
this->ModelFiles.clear();
|
this->m_model->Clear();
|
this->m_model->Append("-none-");
|
this->m_model->Select(0);
|
basepath = this->cfg->model;
|
break;
|
case sd_gui_utils::DirTypes::PRESETS:
|
this->Presets.clear();
|
this->m_preset_list->Clear();
|
this->m_preset_list->Append("-none-");
|
this->m_preset_list->Select(0);
|
basepath = this->cfg->presets;
|
break;
|
}
|
if (!std::filesystem::exists(basepath))
|
{
|
std::filesystem::create_directories(basepath);
|
}
|
|
int i = 0;
|
for (auto const &dir_entry : std::filesystem::recursive_directory_iterator(basepath))
|
{
|
if (!dir_entry.exists() || !dir_entry.is_regular_file() || !dir_entry.path().has_extension())
|
{
|
continue;
|
}
|
|
std::filesystem::path path = dir_entry.path();
|
|
std::string ext = path.extension().string();
|
|
if (type == sd_gui_utils::DirTypes::CHECKPOINT || type == sd_gui_utils::DirTypes::VAE)
|
{
|
if (ext != ".safetensors" && ext != ".ckpt")
|
{
|
continue;
|
}
|
}
|
if (type == sd_gui_utils::DirTypes::PRESETS)
|
{
|
if (ext != ".json")
|
{
|
continue;
|
}
|
}
|
std::string name = path.filename().replace_extension("").string();
|
// prepend the subdirectory to the modelname
|
// // wxFileName::GetPathSeparator()
|
auto path_name = path.string();
|
sd_gui_utils::replace(path_name, basepath, "");
|
sd_gui_utils::replace(path_name, "//", "");
|
sd_gui_utils::replace(path_name, "\\\\", "");
|
sd_gui_utils::replace(path_name, ext, "");
|
|
name = path_name.substr(1);
|
|
if (type == sd_gui_utils::CHECKPOINT)
|
{
|
this->m_model->Append(name);
|
this->ModelFiles.emplace(name, dir_entry.path().string());
|
}
|
if (type == sd_gui_utils::VAE)
|
{
|
this->m_vae->Append(name);
|
this->VaeFiles.emplace(name, dir_entry.path().string());
|
}
|
if (type == sd_gui_utils::PRESETS)
|
{
|
sd_gui_utils::generator_preset preset;
|
std::ifstream f(path.string());
|
try
|
{
|
nlohmann::json data = nlohmann::json::parse(f);
|
preset = data;
|
preset.path = path.string();
|
this->m_preset_list->Append(preset.name);
|
this->Presets.emplace(preset.name, preset);
|
}
|
catch (const std::exception &e)
|
{
|
std::remove(path.string().c_str());
|
std::cerr << e.what() << '\n';
|
}
|
}
|
}
|
|
if (type == sd_gui_utils::CHECKPOINT)
|
{
|
this->logs->AppendText(fmt::format("Loaded checkpoints: {}\n", this->ModelFiles.size()));
|
}
|
if (type == sd_gui_utils::VAE)
|
{
|
this->logs->AppendText(fmt::format("Loaded vaes: {}\n", this->VaeFiles.size()));
|
}
|
if (type == sd_gui_utils::PRESETS)
|
{
|
this->logs->AppendText(fmt::format("Loaded presets: {}\n", this->Presets.size()));
|
if (this->Presets.size() > 0)
|
{
|
this->m_preset_list->Enable();
|
}
|
}
|
}
|
|
void MainWindowUI::OnQueueItemManagerItemAdded(QM::QueueItem item)
|
{
|
wxVector<wxVariant> data;
|
|
auto created_at = sd_gui_utils::formatUnixTimestampToDate(item.created_at);
|
|
data.push_back(wxVariant(std::to_string(item.id)));
|
data.push_back(wxVariant(created_at));
|
data.push_back(wxVariant(item.model));
|
data.push_back(wxVariant(sd_gui_utils::sample_method_str[(int)item.params.sample_method]));
|
data.push_back(wxVariant(std::to_string(item.params.seed)));
|
data.push_back(item.status == QM::QueueStatus::DONE ? 100 : 1); // progressbar
|
data.push_back(wxString("-.--it/s")); // speed
|
data.push_back(wxVariant(QM::QueueStatus_str[item.status])); // status
|
|
auto store = this->m_joblist->GetStore();
|
|
QM::QueueItem *nItem = new QM::QueueItem(item);
|
|
this->JobTableItems[item.id] = nItem;
|
// store->AppendItem(data, wxUIntPtr(this->JobTableItems[item.id]));
|
store->PrependItem(data, wxUIntPtr(this->JobTableItems[item.id]));
|
}
|
|
MainWindowUI::~MainWindowUI()
|
{
|
if (this->modelLoaded)
|
{
|
free_sd_ctx(this->sd_ctx);
|
}
|
for (auto &t : this->threads)
|
{
|
t->join();
|
}
|
}
|
|
void MainWindowUI::loadVaeList()
|
{
|
if (!std::filesystem::exists(this->cfg->vae))
|
{
|
std::filesystem::create_directories(this->cfg->vae);
|
}
|
this->LoadFileList(sd_gui_utils::DirTypes::VAE);
|
}
|
|
void MainWindowUI::OnQueueItemManagerItemUpdated(QM::QueueItem item)
|
{
|
}
|
|
void MainWindowUI::Generate(wxEvtHandler *eventHandler, QM::QueueItem myItem)
|
{
|
sd_gui_utils::VoidHolder *vparams = new sd_gui_utils::VoidHolder;
|
vparams->p1 = (void *)this->GetEventHandler();
|
vparams->p2 = (void *)&myItem;
|
|
sd_set_progress_callback(MainWindowUI::HandleSDProgress, (void *)vparams);
|
|
if (!this->modelLoaded)
|
{
|
this->sd_ctx = this->LoadModelv2(eventHandler, myItem);
|
this->currentModel = myItem.params.model_path;
|
this->currentVaeModel = myItem.params.vae_path;
|
}
|
else
|
{
|
if (myItem.params.model_path != this->currentModel || this->currentVaeModel != myItem.params.vae_path)
|
{
|
free_sd_ctx(this->sd_ctx);
|
this->sd_ctx = this->LoadModelv2(eventHandler, myItem);
|
}
|
}
|
if (!this->modelLoaded || this->sd_ctx == nullptr)
|
{
|
wxThreadEvent *f = new wxThreadEvent();
|
f->SetString("GENERATION_ERROR:Model load failed...");
|
f->SetPayload(myItem);
|
wxQueueEvent(eventHandler, f);
|
return;
|
}
|
|
auto start = std::chrono::system_clock::now();
|
|
wxThreadEvent *e = new wxThreadEvent();
|
e->SetString(wxString::Format("GENERATION_START:%s", this->sd_params->model_path));
|
e->SetPayload(myItem);
|
wxQueueEvent(eventHandler, e);
|
|
sd_image_t *control_image = NULL;
|
sd_image_t *results;
|
|
// std::lock_guard<std::mutex> guard(this->sdMutex);
|
results = txt2img(this->sd_ctx,
|
myItem.params.prompt.c_str(),
|
myItem.params.negative_prompt.c_str(),
|
myItem.params.clip_skip,
|
myItem.params.cfg_scale,
|
myItem.params.width,
|
myItem.params.height,
|
myItem.params.sample_method,
|
myItem.params.sample_steps,
|
myItem.params.seed,
|
myItem.params.batch_count,
|
control_image,
|
myItem.params.control_strength);
|
|
if (results == NULL)
|
{
|
wxThreadEvent *f = new wxThreadEvent();
|
f->SetString("GENERATION_ERROR:Something wrong happened at image generation...");
|
f->SetPayload(myItem);
|
wxQueueEvent(eventHandler, f);
|
return;
|
}
|
if (!std::filesystem::exists(this->cfg->output))
|
{
|
std::filesystem::create_directories(this->cfg->output);
|
}
|
/* save image(s) */
|
|
const auto p1 = std::chrono::system_clock::now();
|
auto ctime = std::chrono::duration_cast<std::chrono::seconds>(p1.time_since_epoch()).count();
|
|
for (int i = 0; i < this->sd_params->batch_count; i++)
|
{
|
if (results[i].data == NULL)
|
{
|
continue;
|
}
|
|
// handle data??
|
wxImage *img = new wxImage(results[i].width, results[i].height, results[i].data);
|
std::string filename = this->cfg->output;
|
std::string extension = ".png";
|
|
if (this->sd_params->batch_count > 1)
|
{
|
filename = filename + wxFileName::GetPathSeparator() + std::to_string(ctime) + "_" + std::to_string(i) + extension;
|
}
|
else
|
{
|
filename = filename + wxFileName::GetPathSeparator() + std::to_string(ctime) + extension;
|
}
|
if (!img->SaveFile(filename))
|
{
|
wxThreadEvent *g = new wxThreadEvent();
|
g->SetString(wxString::Format("GENERATION_ERROR:Failed to save image into %s", filename));
|
g->SetPayload(myItem);
|
wxQueueEvent(eventHandler, g);
|
}
|
else
|
{
|
myItem.images.emplace_back(filename);
|
}
|
// handle data??
|
}
|
|
auto end = std::chrono::system_clock::now();
|
std::chrono::duration<double> elapsed_seconds = end - start;
|
|
// send to notify the user...
|
wxThreadEvent *h = new wxThreadEvent();
|
auto msg = fmt::format("MESSAGE:Image generation done in {}s. Saved into {}", elapsed_seconds.count(), this->cfg->output);
|
h->SetString(wxString(msg.c_str()));
|
// h->SetPayload(myItem);
|
wxQueueEvent(eventHandler, h);
|
|
wxThreadEvent *i = new wxThreadEvent();
|
i->SetString(wxString::Format("GENERATION_DONE:ok"));
|
i->SetPayload(results);
|
wxQueueEvent(eventHandler, i);
|
|
// send to the queue manager
|
wxThreadEvent *j = new wxThreadEvent();
|
j->SetString(wxString::Format("QUEUE:%d", QM::QueueEvents::ITEM_FINISHED));
|
j->SetPayload(myItem);
|
wxQueueEvent(eventHandler, j);
|
}
|
|
void MainWindowUI::initConfig()
|
{
|
wxString datapath = wxStandardPaths::Get().GetUserDataDir() + wxFileName::GetPathSeparator() + "sd_ui_data" + wxFileName::GetPathSeparator();
|
wxString imagespath = wxStandardPaths::Get().GetDocumentsDir() + wxFileName::GetPathSeparator() + "sd_ui_output" + wxFileName::GetPathSeparator();
|
|
wxString model_path = datapath;
|
model_path.append("checkpoints");
|
|
wxString vae_path = datapath;
|
vae_path.append("vae");
|
|
wxString lora_path = datapath;
|
lora_path.append("lora");
|
|
wxString embedding_path = datapath;
|
embedding_path.append("embedding");
|
|
wxString presets_path = datapath;
|
presets_path.append("presets");
|
|
wxString jobs_path = datapath;
|
jobs_path.append("queue_jobs");
|
|
this->cfg->lora = this->fileConfig->Read("/paths/lora", lora_path).ToStdString();
|
this->cfg->model = this->fileConfig->Read("/paths/model", model_path).ToStdString();
|
this->cfg->vae = this->fileConfig->Read("/paths/vae", vae_path).ToStdString();
|
this->cfg->embedding = this->fileConfig->Read("/paths/embedding", embedding_path).ToStdString();
|
this->cfg->presets = this->fileConfig->Read("/paths/presets", presets_path).ToStdString();
|
|
this->cfg->jobs = this->fileConfig->Read("/paths/presets", jobs_path).ToStdString();
|
|
this->cfg->output = this->fileConfig->Read("/paths/output", imagespath).ToStdString();
|
this->cfg->keep_model_in_memory = this->fileConfig->Read("/keep_model_in_memory", this->cfg->keep_model_in_memory);
|
this->cfg->save_all_image = this->fileConfig->Read("/save_all_image", this->cfg->save_all_image);
|
|
// populate data from sd_params as default...
|
|
if (!this->modelLoaded)
|
{
|
this->m_cfg->SetValue(static_cast<double>(this->sd_params->cfg_scale));
|
this->m_seed->SetValue(static_cast<int>(this->sd_params->seed));
|
this->m_clip_skip->SetValue(this->sd_params->clip_skip);
|
this->m_steps->SetValue(this->sd_params->sample_steps);
|
this->m_width->SetValue(this->sd_params->width);
|
this->m_height->SetValue(this->sd_params->height);
|
this->m_batch_count->SetValue(this->sd_params->batch_count);
|
}
|
}
|
|
void MainWindowUI::OnCloseSettings(wxCloseEvent &event)
|
{
|
this->initConfig();
|
this->settingsWindow->Destroy();
|
}
|
|
void MainWindowUI::OnPopupClick(wxCommandEvent &evt)
|
{
|
void *data = static_cast<wxMenu *>(evt.GetEventObject())->GetClientData();
|
}
|
|
void MainWindowUI::HandleSDLog(sd_log_level_t level, const char *text, void *data)
|
{
|
if (level == sd_log_level_t::SD_LOG_INFO || level == sd_log_level_t::SD_LOG_ERROR)
|
{
|
auto *eventHandler = (wxEvtHandler *)data;
|
wxThreadEvent *e = new wxThreadEvent();
|
e->SetString(wxString::Format("SD_MESSAGE:%s", text));
|
e->SetPayload(level);
|
wxQueueEvent(eventHandler, e);
|
}
|
}
|
|
void MainWindowUI::OnQueueItemManagerItemStatusChanged(QM::QueueItem item)
|
{
|
auto store = this->m_joblist->GetStore();
|
|
int lastCol = this->m_joblist->GetColumnCount() - 1;
|
|
for (unsigned int i = 0; i < store->GetItemCount(); i++)
|
{
|
auto _item = store->GetItem(i);
|
auto _item_data = store->GetItemData(_item);
|
auto *_qitem = reinterpret_cast<QM::QueueItem *>(_item_data);
|
if (_qitem->id == item.id)
|
{
|
store->SetValueByRow(wxVariant(QM::QueueStatus_str[item.status]), i, lastCol);
|
this->m_joblist->Refresh();
|
break;
|
}
|
}
|
}
|
|
void MainWindowUI::loadModelList()
|
{
|
this->m_sampler->Clear();
|
for (auto sampler : sd_gui_utils::sample_method_str)
|
{
|
int _u = this->m_sampler->Append(sampler);
|
|
if (sampler == sd_gui_utils::sample_method_str[this->sd_params->sample_method])
|
{
|
this->m_sampler->Select(_u);
|
}
|
}
|
|
this->LoadFileList(sd_gui_utils::DirTypes::CHECKPOINT);
|
|
for (auto model : this->ModelFiles)
|
{
|
// auto size = sd_gui_utils::HumanReadable{std::filesystem::file_size(model.second)};
|
uintmax_t size = std::filesystem::file_size(model.second);
|
auto humanSize = sd_gui_utils::humanReadableFileSize(size);
|
auto hs = wxString::Format("%.1f %s", humanSize.first, humanSize.second);
|
wxVector<wxVariant> data;
|
data.push_back(wxVariant(model.first));
|
data.push_back(hs);
|
data.push_back("--");
|
this->m_data_model_list->AppendItem(data);
|
}
|
this->m_data_model_list->Refresh();
|
}
|
|
void MainWindowUI::StartGeneration(QM::QueueItem myJob)
|
{
|
|
// here starts the trhead
|
// this->threads.push_back(std::thread(std::bind(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob)));
|
|
// this->threads.emplace_back(std::thread(std::bind(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob)));
|
// this->threads.emplace_back(std::thread(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob));
|
// std::thread j(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob);
|
// std::thread *p(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob);
|
std::thread *p = new std::thread(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob);
|
this->threads.emplace_back(p);
|
}
|
|
void MainWindowUI::HandleSDProgress(int step, int steps, float time, void *data)
|
{
|
sd_gui_utils::VoidHolder *objs = (sd_gui_utils::VoidHolder *)data;
|
wxEvtHandler *eventHandler = (wxEvtHandler *)objs->p1;
|
QM::QueueItem *myItem = (QM::QueueItem *)objs->p2;
|
myItem->step = step;
|
myItem->steps = steps;
|
myItem->time = time;
|
/*
|
format it/s
|
time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s",
|
progress.c_str(), step, steps,
|
time > 1.0f || time == 0 ? time : (1.0f / time)
|
*/
|
|
wxThreadEvent *e = new wxThreadEvent();
|
e->SetString(wxString::Format("GENERATION_PROGRESS:%d/%d", step, steps));
|
|
e->SetPayload(myItem);
|
wxQueueEvent(eventHandler, e);
|
}
|
|
void MainWindowUI::OnThreadMessage(wxThreadEvent &e)
|
{
|
if (e.GetSkipped() == false)
|
{
|
e.Skip();
|
}
|
auto msg = e.GetString().ToStdString();
|
|
std::string token = msg.substr(0, msg.find(":"));
|
std::string content = msg.substr(msg.find(":") + 1);
|
|
// this->logs->AppendText(fmt::format("Got thread message: {}\n", e.GetString().ToStdString()));
|
if (token == "QUEUE")
|
{
|
// only numbers here...
|
QM::QueueEvents event = (QM::QueueEvents)std::stoi(content);
|
// only handle the QUEUE messages, what this class generate
|
// alway QM::EueueItem the payload, with the new data
|
QM::QueueItem payload;
|
payload = e.GetPayload<QM::QueueItem>();
|
switch (event)
|
{
|
// new item added
|
case QM::QueueEvents::ITEM_ADDED:
|
this->OnQueueItemManagerItemAdded(payload);
|
break;
|
// item status changed
|
case QM::QueueEvents::ITEM_STATUS_CHANGED:
|
this->OnQueueItemManagerItemStatusChanged(payload);
|
break;
|
// item updated... ? ? ?
|
case QM::QueueEvents::ITEM_UPDATED:
|
this->OnQueueItemManagerItemUpdated(payload);
|
break;
|
case QM::QueueEvents::ITEM_START:
|
this->StartGeneration(payload);
|
break;
|
|
default:
|
break;
|
}
|
}
|
if (token == "MODEL_LOAD_DONE")
|
{
|
// this->m_generate->Enable();
|
// this->m_model->Enable();
|
// this->m_vae->Enable();
|
// this->m_refresh->Enable();
|
|
// this->logs->AppendText(fmt::format("Model loaded: {}\n", content));
|
this->modelLoaded = true;
|
// std::lock_guard<std::mutex> guard(this->sdMutex);
|
// this->sd_ctx = e.GetPayload<sd_ctx_t *>();
|
if (!this->IsShownOnScreen())
|
{
|
this->notification->SetFlags(wxICON_INFORMATION);
|
this->notification->SetTitle("SD Gui");
|
this->notification->SetMessage(content);
|
this->notification->Show(5000);
|
}
|
}
|
if (token == "MODEL_LOAD_START")
|
{
|
// this->m_generate->Disable();
|
// this->m_model->Disable();
|
// this->m_vae->Disable();
|
// this->m_refresh->Disable();
|
this->logs->AppendText(fmt::format("Model load start: {}\n", content));
|
}
|
if (token == "MODEL_LOAD_ERROR")
|
{
|
// this->m_generate->Disable();
|
// this->m_model->Enable();
|
// this->m_vae->Disable();
|
// this->m_refresh->Enable();
|
this->logs->AppendText(fmt::format("Model load error: {}\n", content));
|
this->modelLoaded = false;
|
if (!this->IsShownOnScreen())
|
{
|
this->notification->SetFlags(wxICON_ERROR);
|
this->notification->SetTitle("SD Gui - error");
|
this->notification->SetMessage(content);
|
this->notification->Show(5000);
|
}
|
}
|
|
if (token == "GENERATION_START")
|
{
|
auto myjob = e.GetPayload<QM::QueueItem>();
|
|
// this->m_generate->Disable();
|
// this->m_model->Disable();
|
// this->m_vae->Disable();
|
// this->m_refresh->Disable();
|
this->logs->AppendText(fmt::format("Diffusion started. Seed: {} Batch: {} {}x{}px Cfg: {} Steps: {}\n",
|
myjob.params.seed,
|
myjob.params.batch_count,
|
myjob.params.width,
|
myjob.params.height,
|
myjob.params.cfg_scale,
|
myjob.params.sample_steps));
|
}
|
// in the original SD.cpp the progress callback is not implemented... :(
|
if (token == "GENERATION_PROGRESS")
|
{
|
QM::QueueItem *myjob = e.GetPayload<QM::QueueItem *>();
|
// update column
|
auto store = this->m_joblist->GetStore();
|
// -1 the last (status)
|
// -2 ... (speed)
|
// -3 ... (progressbar)
|
|
// it/s format
|
/*
|
time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s",
|
progress.c_str(), step, steps,
|
time > 1.0f || time == 0 ? time : (1.0f / time)
|
*/
|
wxString speed = wxString::Format(myjob->time > 1.0f ? "%.2fs/it" : "%.2fit/s", myjob->time > 1.0f || myjob->time == 0 ? myjob->time : (1.0f / myjob->time));
|
int progressCol = this->m_joblist->GetColumnCount() - 3;
|
int speedCol = this->m_joblist->GetColumnCount() - 2;
|
float current_progress = 100.f * (static_cast<float>(myjob->step) / static_cast<float>(myjob->steps));
|
if (current_progress < 2.f)
|
{
|
return;
|
}
|
|
for (unsigned int i = 0; i < store->GetItemCount(); i++)
|
{
|
auto _item = store->GetItem(i);
|
auto _item_data = store->GetItemData(_item);
|
auto *_qitem = reinterpret_cast<QM::QueueItem *>(_item_data);
|
if (_qitem->id == myjob->id)
|
{
|
store->SetValueByRow(static_cast<int>(current_progress), i, progressCol);
|
store->SetValueByRow(speed, i, speedCol);
|
this->m_joblist->Refresh();
|
break;
|
}
|
}
|
|
return;
|
for (auto it = this->JobTableItems.begin(); it != this->JobTableItems.end(); ++it)
|
{
|
if (it->second->id == myjob->id)
|
{
|
// store->SetValueByRow(wxVariant(QM::QueueStatus_str[item.status]), it->first, progressCol);
|
store->SetValueByRow(static_cast<int>(current_progress), it->first, progressCol);
|
store->SetValueByRow(speed, it->first, speedCol);
|
this->m_joblist->Refresh();
|
break;
|
}
|
}
|
// update column
|
}
|
if (token == "GENERATION_DONE")
|
{
|
// this->m_generate->Enable();
|
// this->m_model->Enable();
|
// this->m_vae->Enable();
|
// this->m_refresh->Enable();
|
// sd_image_t *results = e.GetPayload<sd_image_t *>();
|
// show images in new window...
|
/* for (int i = 0; i < this->sd_params->batch_count; i++)
|
{
|
MainWindowImageViewer *imgWindow = new MainWindowImageViewer(this);
|
// wxBitmap *img = new wxBitmap(results[i].data, (int)results[i].width, (int)results[i].height, (int)results[i].channel);
|
wxImage img(results[i].width, results[i].height, results[i].data);
|
|
wxBitmapBundle wxBmapB(img);
|
imgWindow->m_bitmap->SetBitmap(wxBmapB);
|
imgWindow->m_bitmap->SetSize(results[i].width, results[i].height);
|
imgWindow->SetSize(results[i].width + 200, results[i].height);
|
|
std::string details = fmt::format("Prompt:\n\n{}\n\nNegative prompt: \n\n{}\n\nSeed: {} \nCfg scale: {}\nClip skip: {}\nSampler: {}\nSteps: {}\nWidth: {} Height: {}",
|
this->sd_params->prompt, this->sd_params->negative_prompt,
|
this->sd_params->seed + i, this->sd_params->cfg_scale,
|
this->sd_params->clip_skip, sd_gui_utils::sample_method_str[this->sd_params->sample_method], this->sd_params->sample_steps,
|
results[i].width, results[i].height);
|
imgWindow->m_textCtrl4->AppendText(wxString(details));
|
imgWindow->Show();
|
|
// imgWindow->m_bitmap->SetBitmap(img);
|
/// imgWindow->m_bitmap->Set
|
}*/
|
}
|
if (token == "GENERATION_ERROR")
|
{
|
this->logs->AppendText(fmt::format("Generation error: {}\n", content));
|
if (!this->IsShownOnScreen())
|
{
|
this->notification->SetFlags(wxICON_ERROR);
|
this->notification->SetTitle("SD Gui - error");
|
this->notification->SetMessage(content);
|
this->notification->Show(5000);
|
}
|
}
|
if (token == "SD_MESSAGE")
|
{
|
if (content.length() < 1)
|
{
|
return;
|
}
|
this->logs->AppendText(fmt::format("{}", content));
|
}
|
if (token == "MESSAGE")
|
{
|
this->logs->AppendText(fmt::format("{}\n", content));
|
if (!this->IsShownOnScreen())
|
{
|
this->notification->SetFlags(wxICON_INFORMATION);
|
this->notification->SetTitle("SD Gui");
|
this->notification->SetMessage(content);
|
this->notification->Show(5000);
|
}
|
}
|
}
|
|
sd_ctx_t *MainWindowUI::LoadModelv2(wxEvtHandler *eventHandler, QM::QueueItem myItem)
|
{
|
wxThreadEvent *e = new wxThreadEvent();
|
e->SetString(wxString::Format("MODEL_LOAD_START:%s", myItem.params.model_path));
|
e->SetPayload(myItem);
|
wxQueueEvent(eventHandler, e);
|
|
// std::lock_guard<std::mutex> guard(this->sdMutex);
|
sd_ctx_t *sd_ctx_ = new_sd_ctx(
|
myItem.params.model_path.c_str(),
|
myItem.params.vae_path.c_str(),
|
myItem.params.taesd_path.c_str(),
|
myItem.params.controlnet_path.c_str(),
|
myItem.params.lora_model_dir.c_str(),
|
myItem.params.embeddings_path.c_str(),
|
false, myItem.params.vae_tiling, false,
|
myItem.params.n_threads,
|
myItem.params.wtype,
|
myItem.params.rng_type,
|
myItem.params.schedule, false);
|
|
if (sd_ctx_ == NULL)
|
{
|
wxThreadEvent *c = new wxThreadEvent();
|
c->SetString(wxString::Format("MODEL_LOAD_ERROR:%s", myItem.params.model_path));
|
c->SetPayload(myItem);
|
wxQueueEvent(eventHandler, c);
|
this->modelLoaded = false;
|
return nullptr;
|
}
|
else
|
{
|
wxThreadEvent *c = new wxThreadEvent();
|
c->SetString(wxString::Format("MODEL_LOAD_DONE:%s", myItem.params.model_path));
|
wxQueueEvent(eventHandler, c);
|
this->modelLoaded = true;
|
this->currentModel = myItem.params.model_path;
|
this->currentVaeModel = myItem.params.vae_path;
|
}
|
return sd_ctx_;
|
}
|
|
void MainWindowUI::LoadPresets()
|
{
|
this->LoadFileList(sd_gui_utils::DirTypes::PRESETS);
|
}
|