wxWidgets based Stable Diffusion C++ GUi
Ferenc Szontágh
2024-02-03 2088e1b7aa6419dec58800bc1d0cb24f5808affe
ui/MainWindowUI.cpp
@@ -6,6 +6,7 @@
    this->ini_path = wxStandardPaths::Get().GetUserConfigDir() + wxFileName::GetPathSeparator() + "sd.ui.config.ini";
    this->sd_params = new sd_gui_utils::SDParams;
    this->notification = new wxNotificationMessage();
    this->JobTableItems = new std::map<int, QM::QueueItem>();
    this->notification->SetParent(this);
@@ -15,7 +16,9 @@
    this->m_joblist->AppendTextColumn("Id");
    this->m_joblist->AppendTextColumn("Created");
    this->m_joblist->AppendTextColumn("Update");
    this->m_joblist->AppendTextColumn("Model");
    this->m_joblist->AppendTextColumn("Sampler");
    this->m_joblist->AppendTextColumn("Seed");
    this->m_joblist->AppendTextColumn("Status");
    this->SetTitle(this->GetTitle() + SD_GUI_VERSION);
@@ -97,12 +100,6 @@
        // add the selected vae
        this->sd_params->vae_path = this->VaeFiles.at(selection.ToStdString());
    }
    // just select the vae file and add the live paramaeters, do not load the model, just add job to the queu on press the queue button
    /*  if (this->m_generate->IsEnabled() == true)
    {
    wxPostEvent(this->m_model, event);
    }
    */
}
void MainWindowUI::onSamplerSelect(wxCommandEvent &event)
@@ -146,10 +143,28 @@
void MainWindowUI::onGenerate(wxCommandEvent &event)
{
    // this->StartGeneration();
    //  dont start the generation, just add the job to the queue, to handle it the queue manager
    //  so it will possible to add more job afterall
    this->qmanager->AddItem(this->sd_params);
    // prepare params
    this->sd_params->model_path = this->ModelFiles.at(this->m_model->GetStringSelection().ToStdString());
    this->sd_params->lora_model_dir = this->cfg->lora;
    this->sd_params->embeddings_path = this->cfg->embedding;
    this->sd_params->prompt = this->m_prompt->GetValue().ToStdString();
    this->sd_params->negative_prompt = this->m_neg_prompt->GetValue().ToStdString();
    this->sd_params->cfg_scale = static_cast<float>(this->m_cfg->GetValue());
    this->sd_params->seed = this->m_seed->GetValue();
    this->sd_params->clip_skip = this->m_clip_skip->GetValue();
    this->sd_params->sample_steps = this->m_steps->GetValue();
    this->sd_params->sample_method = (sample_method_t)this->m_sampler->GetCurrentSelection();
    this->sd_params->batch_count = this->m_batch_count->GetValue();
    this->sd_params->width = this->m_width->GetValue();
    this->sd_params->height = this->m_height->GetValue();
    // add the queue item
    auto id = this->qmanager->AddItem(this->sd_params);
}
void MainWindowUI::onSavePreset(wxCommandEvent &event)
@@ -243,29 +258,30 @@
        this->LoadPresets();
    }
}
/*
    this->m_joblist->AppendTextColumn("Id");
    this->m_joblist->AppendTextColumn("Created");
    this->m_joblist->AppendTextColumn("Model");
    this->m_joblist->AppendTextColumn("Sampler");
    this->m_joblist->AppendTextColumn("Seed");
    this->m_joblist->AppendTextColumn("Status");
*/
void MainWindowUI::OnQueueItemManagerItemAdded(QM::QueueItem item)
{
    wxVector<wxVariant> data;
    data.push_back(wxVariant(item.id));
    data.push_back(wxVariant(item.created_at));
    data.push_back(wxVariant(item.updated_at));
    data.push_back(wxVariant(std::to_string(item.id)));
    data.push_back(wxVariant(std::to_string(item.created_at)));
    data.push_back(wxVariant(item.params.model_path));
    data.push_back(wxVariant(sd_gui_utils::sample_method_str[(int)item.params.sample_method]));
    data.push_back(wxVariant(std::to_string(item.params.seed)));
    data.push_back(wxVariant(QM::QueueStatus_str[item.status]));
    this->m_joblist->AppendItem(data);
    /*
    for (auto model : this->ModelFiles)
    {
        // auto size = sd_gui_utils::HumanReadable{std::filesystem::file_size(model.second)};
        auto size = std::filesystem::file_size(model.second);
        wxVector<wxVariant> data;
        data.push_back(wxVariant(model.first));
        data.push_back(wxVariant(std::to_string(size)));
        this->m_data_model_list->AppendItem(data);
    }
    this->m_data_model_list->Refresh();*/
    auto store = this->m_joblist->GetStore();
    (*this->JobTableItems)[store->GetCount() - 1] = item;
}
void MainWindowUI::LoadFileList(sd_gui_utils::DirTypes type)
@@ -396,22 +412,22 @@
void MainWindowUI::OnThreadMessage(wxThreadEvent &e)
{
    e.Skip();
    auto msg = e.GetString().ToStdString();
    std::string token = msg.substr(0, msg.find(":"));
    std::string content = msg.substr(msg.find(":") + 1);
    //    this->logs->AppendText(fmt::format("Got thread message: {}\n", e.GetString().ToStdString()));
    // this->logs->AppendText(fmt::format("Got thread message: {}\n", e.GetString().ToStdString()));
    if (token == "QUEUE")
    {
        m_statusBar166->SetStatusText("got QUEUE cmd: " + msg);
        // only numbers here...
        QM::QueueEvents event = (QM::QueueEvents)std::stoi(content);
        // only handle the QUEUE messages, what this class generate
        // alway QM::EueueItem the payload, with the new data
        auto payload = e.GetPayload<QM::QueueItem>();
        QM::QueueItem payload;
        payload = e.GetPayload<QM::QueueItem>();
        switch (event)
        {
            // new item added
@@ -420,11 +436,14 @@
            break;
            // item status changed
        case QM::QueueEvents::ITEM_STATUS_CHANGED:
            this->OnQueueItemManagerItemAdded(payload);
            this->OnQueueItemManagerItemStatusChanged(payload);
            break;
            // item updated... ? ? ?
        case QM::QueueEvents::ITEM_UPDATED:
            this->OnQueueItemManagerItemAdded(payload);
            this->OnQueueItemManagerItemUpdated(payload);
            break;
        case QM::QueueEvents::ITEM_START:
            this->StartGeneration(payload);
            break;
        default:
@@ -433,12 +452,13 @@
    }
    if (token == "MODEL_LOAD_DONE")
    {
        this->m_generate->Enable();
        this->m_model->Enable();
        this->m_vae->Enable();
        this->m_refresh->Enable();
        // this->m_generate->Enable();
        // this->m_model->Enable();
        // this->m_vae->Enable();
        // this->m_refresh->Enable();
        this->logs->AppendText(fmt::format("Model loaded: {}\n", content));
        // this->logs->AppendText(fmt::format("Model loaded: {}\n", content));
        this->modelLoaded = true;
        this->sd_ctx = e.GetPayload<sd_ctx_t *>();
        if (!this->IsShownOnScreen())
        {
@@ -450,19 +470,20 @@
    }
    if (token == "MODEL_LOAD_START")
    {
        this->m_generate->Disable();
        this->m_model->Disable();
        this->m_vae->Disable();
        this->m_refresh->Disable();
        // this->m_generate->Disable();
        // this->m_model->Disable();
        // this->m_vae->Disable();
        // this->m_refresh->Disable();
        this->logs->AppendText(fmt::format("Model load start: {}\n", content));
    }
    if (token == "MODEL_LOAD_ERROR")
    {
        this->m_generate->Disable();
        this->m_model->Enable();
        this->m_vae->Disable();
        this->m_refresh->Enable();
        // this->m_generate->Disable();
        // this->m_model->Enable();
        // this->m_vae->Disable();
        // this->m_refresh->Enable();
        this->logs->AppendText(fmt::format("Model load error: {}\n", content));
        this->modelLoaded = false;
        if (!this->IsShownOnScreen())
        {
            this->notification->SetFlags(wxICON_ERROR);
@@ -474,62 +495,62 @@
    if (token == "GENERATION_START")
    {
        sd_gui_utils::SDParams *params = e.GetPayload<sd_gui_utils::SDParams *>();
        auto myjob = e.GetPayload<QM::QueueItem>();
        this->m_generate->Disable();
        this->m_model->Disable();
        this->m_vae->Disable();
        this->m_refresh->Disable();
        this->logs->AppendText(fmt::format("Difusion started. Seed: {} Batch: {} {}x{}px Cfg: {} Steps: {}\n",
                                           params->seed,
                                           params->batch_count,
                                           params->width,
                                           params->height,
                                           params->cfg_scale,
                                           params->sample_steps));
        // this->m_generate->Disable();
        // this->m_model->Disable();
        // this->m_vae->Disable();
        // this->m_refresh->Disable();
        this->logs->AppendText(fmt::format("Diffusion started. Seed: {} Batch: {} {}x{}px Cfg: {} Steps: {}\n",
                                           myjob.params.seed,
                                           myjob.params.batch_count,
                                           myjob.params.width,
                                           myjob.params.height,
                                           myjob.params.cfg_scale,
                                           myjob.params.sample_steps));
    }
    // never, not implemented in sd.cpp
    if (token == "GENERATION_PROGRESS")
    {
        this->m_generate->Disable();
        this->logs->AppendText(fmt::format("Generation progress: {}\n", content));
        // this->m_generate->Disable();
        // this->logs->AppendText(fmt::format("Generation progress: {}\n", content));
    }
    if (token == "GENERATION_DONE")
    {
        this->m_generate->Enable();
        this->m_model->Enable();
        this->m_vae->Enable();
        this->m_refresh->Enable();
        sd_image_t *results = e.GetPayload<sd_image_t *>();
        // this->m_generate->Enable();
        // this->m_model->Enable();
        // this->m_vae->Enable();
        // this->m_refresh->Enable();
        // sd_image_t *results = e.GetPayload<sd_image_t *>();
        // show images in new window...
        for (int i = 0; i < this->sd_params->batch_count; i++)
        {
            MainWindowImageViewer *imgWindow = new MainWindowImageViewer(this);
            // wxBitmap *img = new wxBitmap(results[i].data, (int)results[i].width, (int)results[i].height, (int)results[i].channel);
            wxImage img(results[i].width, results[i].height, results[i].data);
        /* for (int i = 0; i < this->sd_params->batch_count; i++)
         {
             MainWindowImageViewer *imgWindow = new MainWindowImageViewer(this);
             // wxBitmap *img = new wxBitmap(results[i].data, (int)results[i].width, (int)results[i].height, (int)results[i].channel);
             wxImage img(results[i].width, results[i].height, results[i].data);
            wxBitmapBundle wxBmapB(img);
            imgWindow->m_bitmap->SetBitmap(wxBmapB);
            imgWindow->m_bitmap->SetSize(results[i].width, results[i].height);
            imgWindow->SetSize(results[i].width + 200, results[i].height);
             wxBitmapBundle wxBmapB(img);
             imgWindow->m_bitmap->SetBitmap(wxBmapB);
             imgWindow->m_bitmap->SetSize(results[i].width, results[i].height);
             imgWindow->SetSize(results[i].width + 200, results[i].height);
            std::string details = fmt::format("Prompt:\n\n{}\n\nNegative prompt: \n\n{}\n\nSeed: {} \nCfg scale: {}\nClip skip: {}\nSampler: {}\nSteps: {}\nWidth: {} Height: {}",
                                              this->sd_params->prompt, this->sd_params->negative_prompt,
                                              this->sd_params->seed + i, this->sd_params->cfg_scale,
                                              this->sd_params->clip_skip, sd_gui_utils::sample_method_str[this->sd_params->sample_method], this->sd_params->sample_steps,
                                              results[i].width, results[i].height);
            imgWindow->m_textCtrl4->AppendText(wxString(details));
            imgWindow->Show();
             std::string details = fmt::format("Prompt:\n\n{}\n\nNegative prompt: \n\n{}\n\nSeed: {} \nCfg scale: {}\nClip skip: {}\nSampler: {}\nSteps: {}\nWidth: {} Height: {}",
                                               this->sd_params->prompt, this->sd_params->negative_prompt,
                                               this->sd_params->seed + i, this->sd_params->cfg_scale,
                                               this->sd_params->clip_skip, sd_gui_utils::sample_method_str[this->sd_params->sample_method], this->sd_params->sample_steps,
                                               results[i].width, results[i].height);
             imgWindow->m_textCtrl4->AppendText(wxString(details));
             imgWindow->Show();
            // imgWindow->m_bitmap->SetBitmap(img);
            /// imgWindow->m_bitmap->Set
        }
             // imgWindow->m_bitmap->SetBitmap(img);
             /// imgWindow->m_bitmap->Set
         }*/
    }
    if (token == "GENERATION_ERROR")
    {
        this->m_generate->Enable();
        this->m_model->Enable();
        this->m_vae->Enable();
        // this->m_generate->Enable();
        // this->m_model->Enable();
        // this->m_vae->Enable();
        this->logs->AppendText(fmt::format("Generation error: {}\n", content));
        if (!this->IsShownOnScreen())
        {
@@ -560,10 +581,11 @@
    }
}
void MainWindowUI::LoadModel(wxEvtHandler *eventHandler)
void MainWindowUI::LoadModel(wxEvtHandler *eventHandler, QM::QueueItem myItem)
{
    wxThreadEvent *e = new wxThreadEvent();
    e->SetString(wxString::Format("MODEL_LOAD_START:%s", this->sd_params->model_path));
    e->SetPayload(myItem);
    wxQueueEvent(eventHandler, e);
    /*this->sd_ctx = new_sd_ctx(this->sd_params->model_path.c_str(),
@@ -580,22 +602,25 @@
                              this->sd_params->rng_type,
                              this->sd_params->schedule,
                              this->sd_params->control_net_cpu);*/
    this->sd_ctx = new_sd_ctx(this->sd_params->model_path.c_str(), this->sd_params->vae_path.c_str(), this->sd_params->taesd_path.c_str(), this->sd_params->lora_model_dir.c_str(), true, false, false, this->sd_params->n_threads, this->sd_params->wtype, this->sd_params->rng_type, this->sd_params->schedule);
    // this->sd_ctx = new_sd_ctx(this->sd_params->model_path.c_str(), this->sd_params->vae_path.c_str(), this->sd_params->taesd_path.c_str(), this->sd_params->lora_model_dir.c_str(), true, false, false, this->sd_params->n_threads, this->sd_params->wtype, this->sd_params->rng_type, this->sd_params->schedule);
    this->sd_ctx = new_sd_ctx(myItem.params.model_path.c_str(), myItem.params.vae_path.c_str(), myItem.params.taesd_path.c_str(), myItem.params.lora_model_dir.c_str(), true, false, false, myItem.params.n_threads, myItem.params.wtype, myItem.params.rng_type, myItem.params.schedule);
    if (this->sd_ctx == NULL)
    {
        wxThreadEvent *c = new wxThreadEvent();
        c->SetString(wxString::Format("MODEL_LOAD_ERROR:%s", this->sd_params->model_path));
        c->SetPayload(myItem);
        wxQueueEvent(eventHandler, c);
        this->modelLoaded = false;
        return;
    }
    else
    {
        wxThreadEvent *c = new wxThreadEvent();
        c->SetString(wxString::Format("MODEL_LOAD_DONE:%s", this->sd_params->model_path));
        // c->SetEventObject(this->sd_ctx);
        c->SetPayload(this->sd_ctx);
        wxQueueEvent(eventHandler, c);
        this->modelLoaded = true;
        return;
    }
@@ -639,7 +664,7 @@
    // populate data from sd_params as default...
    if (!this->m_generate->IsEnabled())
    if (!this->modelLoaded)
    {
        this->m_cfg->SetValue(static_cast<double>(this->sd_params->cfg_scale));
        this->m_seed->SetValue(static_cast<int>(this->sd_params->seed));
@@ -649,38 +674,11 @@
        this->m_height->SetValue(this->sd_params->height);
        this->m_batch_count->SetValue(this->sd_params->batch_count);
    }
    // hide unusable configs...
    /*
    if (SD_CPP_VERSION == "c6071fa") {
            // .. nope, configs in another window...
    }*/
}
void MainWindowUI::StartGeneration()
void MainWindowUI::StartGeneration(QM::QueueItem myJob)
{
    // prepare params
    this->sd_params->model_path = this->ModelFiles.at(this->m_model->GetStringSelection().ToStdString());
    this->sd_params->lora_model_dir = this->cfg->lora;
    this->sd_params->embeddings_path = this->cfg->embedding;
    this->sd_params->prompt = this->m_prompt->GetValue().ToStdString();
    this->sd_params->negative_prompt = this->m_neg_prompt->GetValue().ToStdString();
    this->sd_params->cfg_scale = static_cast<float>(this->m_cfg->GetValue());
    this->sd_params->seed = this->m_seed->GetValue();
    this->sd_params->clip_skip = this->m_clip_skip->GetValue();
    this->sd_params->sample_steps = this->m_steps->GetValue();
    /* sample method */
    this->sd_params->sample_method = (sample_method_t)this->m_sampler->GetCurrentSelection();
    /* sample method */
    this->sd_params->batch_count = this->m_batch_count->GetValue();
    this->sd_params->width = this->m_width->GetValue();
    this->sd_params->height = this->m_height->GetValue();
    this->threads.push_back(std::thread(std::bind(&MainWindowUI::Generate, this, this->GetEventHandler())));
    this->threads.push_back(std::thread(std::bind(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob)));
}
void MainWindowUI::OnCloseSettings(wxCloseEvent &event)
@@ -717,31 +715,46 @@
    this->m_data_model_list->Refresh();
}
void MainWindowUI::Generate(wxEvtHandler *eventHandler)
void MainWindowUI::Generate(wxEvtHandler *eventHandler, QM::QueueItem myItem)
{
    // calculate time
    // @brief model loading is done in the same thread which generating stuffs... no need new thread
    // @brief to load the model if no model loaded, or the loaded model is different from the job's model...
    // @brief if all second job have different model, it will be slower than same model on all jobs
    // TODO: abort job if model can not be loaded...
    if (!this->modelLoaded)
    {
        this->LoadModel(eventHandler, myItem);
        this->currentModel = myItem.params.model_path;
    }
    else
    {
        if (myItem.params.model_path != this->currentModel)
        {
            free_sd_ctx(this->sd_ctx);
            this->LoadModel(eventHandler, myItem);
        }
    }
    if (!this->modelLoaded)
    {
        wxThreadEvent *f = new wxThreadEvent();
        f->SetString("GENERATION_ERROR:Model load failed...");
        f->SetPayload(myItem);
        wxQueueEvent(eventHandler, f);
        return;
    }
    auto start = std::chrono::system_clock::now();
    wxThreadEvent *e = new wxThreadEvent();
    e->SetString(wxString::Format("GENERATION_START:%s", this->sd_params->model_path));
    e->SetPayload(this->sd_params);
    e->SetPayload(myItem);
    wxQueueEvent(eventHandler, e);
    sd_image_t *control_image = NULL;
    sd_image_t *results;
    /*results = txt2img(sd_ctx,
                      this->sd_params->prompt.c_str(),
                      this->sd_params->negative_prompt.c_str(),
                      this->sd_params->clip_skip,
                      this->sd_params->cfg_scale,
                      this->sd_params->width,
                      this->sd_params->height,
                      this->sd_params->sample_method,
                      this->sd_params->sample_steps,
                      this->sd_params->seed,
                      this->sd_params->batch_count,
                      control_image,
                      this->sd_params->control_strength);*/
    results = txt2img(this->sd_ctx,
                      this->sd_params->prompt.c_str(), this->sd_params->negative_prompt.c_str(), this->sd_params->clip_skip, this->sd_params->cfg_scale, this->sd_params->width, this->sd_params->height, this->sd_params->sample_method, this->sd_params->sample_steps, this->sd_params->seed, this->sd_params->batch_count);
@@ -749,6 +762,7 @@
    {
        wxThreadEvent *f = new wxThreadEvent();
        f->SetString("GENERATION_ERROR:Something wrong happened at image generation...");
        f->SetPayload(myItem);
        wxQueueEvent(eventHandler, f);
        return;
    }
@@ -785,7 +799,12 @@
        {
            wxThreadEvent *g = new wxThreadEvent();
            g->SetString(wxString::Format("GENERATION_ERROR:Failed to save image into %s", filename));
            g->SetPayload(myItem);
            wxQueueEvent(eventHandler, g);
        }
        else
        {
            myItem.images.emplace_back(filename);
        }
        // handle data??
@@ -798,18 +817,29 @@
    wxThreadEvent *h = new wxThreadEvent();
    auto msg = fmt::format("MESSAGE:Image generation done in {}s. Saved into {}", elapsed_seconds.count(), this->cfg->output);
    h->SetString(wxString(msg.c_str()));
    h->SetPayload(myItem);
    wxQueueEvent(eventHandler, h);
    // send to reset the buttons
    wxThreadEvent *i = new wxThreadEvent();
    i->SetString(wxString::Format("GENERATION_DONE:ok"));
    i->SetPayload(results);
    wxQueueEvent(eventHandler, i);
    // send to the queue manager
    wxThreadEvent *j = new wxThreadEvent();
    j->SetString(wxString::Format("QUEUE:%d", QM::QueueEvents::ITEM_FINISHED));
    j->SetPayload(myItem);
    wxQueueEvent(eventHandler, j);
    return;
}
void MainWindowUI::HandleSDLog(sd_log_level_t level, const char *text, void *data)
{
    if (level != sd_log_level_t::SD_LOG_INFO)
    {
        return;
    }
    // wxEvtHandler *eventHandler = (wxEvtHandler *)data;
    MainWindowUI *ui = (MainWindowUI *)data;
    wxEvtHandler *eventHandler = ui->GetEventHandler();
@@ -821,12 +851,37 @@
void MainWindowUI::OnQueueItemManagerItemStatusChanged(QM::QueueItem item)
{
    /*
    this->m_joblist->AppendTextColumn("Id");
    this->m_joblist->AppendTextColumn("Created");
    this->m_joblist->AppendTextColumn("Model");
    this->m_joblist->AppendTextColumn("Sampler");
    this->m_joblist->AppendTextColumn("Seed");
    this->m_joblist->AppendTextColumn("Status");
    */
    // TODO: how to update an item in the table...
    // auto item = this->m_joblist->RowToItem();
    auto store = this->m_joblist->GetStore();
    for (auto it = this->JobTableItems->begin(); it != this->JobTableItems->end(); ++it)
    {
        if (it->second.id == item.id)
        {
            // auto store = *this->m_joblist()->GetStore();
            int lastCol = this->m_joblist->GetColumnCount();
            // always update the last col, so the last col is always need to be the status
            store->SetValueByRow(wxVariant(QM::QueueStatus_str[item.status]), it->first, lastCol - 1);
            this->m_joblist->Refresh();
            break;
        }
    }
}
MainWindowUI::~MainWindowUI()
{
    // clean up things...
    if (this->m_generate->IsEnabled())
    if (this->modelLoaded)
    {
        free_sd_ctx(this->sd_ctx);
    }
@@ -861,7 +916,7 @@
        free_sd_ctx(this->sd_ctx);
        return;
    }
    if (this->m_generate->IsEnabled())
    if (this->modelLoaded)
    {
        free_sd_ctx(this->sd_ctx);
    }
@@ -871,5 +926,5 @@
    this->m_refresh->Disable();
    this->sd_params->model_path = this->ModelFiles.at(selection.ToStdString());
    this->sd_params->lora_model_dir = this->fileConfig->Read("/paths/lora", "").ToStdString();
    this->threads.push_back(std::thread(std::bind(&MainWindowUI::LoadModel, this, this->GetEventHandler())));
    // this->threads.push_back(std::thread(std::bind(&MainWindowUI::LoadModel, this, this->GetEventHandler())));
}