wxWidgets based Stable Diffusion C++ GUi
Ferenc Szontágh
2024-02-04 1683c8a090c3efc51c43107d5ede0dcd5d506e3b
ui/MainWindowUI.cpp
@@ -6,20 +6,23 @@
    this->ini_path = wxStandardPaths::Get().GetUserConfigDir() + wxFileName::GetPathSeparator() + "sd.ui.config.ini";
    this->sd_params = new sd_gui_utils::SDParams;
    this->notification = new wxNotificationMessage();
    this->JobTableItems = new std::map<int, QM::QueueItem>();
    // this->JobTableItems = new std::map<int, QM::QueueItem>();
    this->notification->SetParent(this);
    // prepare data list views
    this->m_data_model_list->AppendTextColumn("Name", wxDATAVIEW_CELL_INERT, 200);
    this->m_data_model_list->AppendTextColumn("Size");
    this->m_data_model_list->AppendTextColumn("Hash");
    this->m_joblist->AppendTextColumn("Id");
    this->m_joblist->AppendTextColumn("Created");
    this->m_joblist->AppendTextColumn("Created at");
    this->m_joblist->AppendTextColumn("Model");
    this->m_joblist->AppendTextColumn("Sampler");
    this->m_joblist->AppendTextColumn("Seed");
    this->m_joblist->AppendTextColumn("Status");
    this->m_joblist->AppendProgressColumn("Progress"); // progressbar
    this->m_joblist->AppendTextColumn("Speed");        // speed
    this->m_joblist->AppendTextColumn("Status");       // status
    this->SetTitle(this->GetTitle() + SD_GUI_VERSION);
@@ -30,7 +33,7 @@
    this->qmanager = new QM::QueueManager(this->GetEventHandler(), this->cfg->jobs);
    // set SD logger
    sd_set_log_callback(MainWindowUI::HandleSDLog, (void *)this);
    sd_set_log_callback(MainWindowUI::HandleSDLog, (void *)this->GetEventHandler());
    // load
    this->LoadPresets();
@@ -102,11 +105,6 @@
    }
}
void MainWindowUI::onSamplerSelect(wxCommandEvent &event)
{
    this->sd_params->sample_method = (sample_method_t)this->m_sampler->GetSelection();
}
void MainWindowUI::onResolutionSwap(wxCommandEvent &event)
{
    auto oldW = this->m_width->GetValue();
@@ -136,6 +134,16 @@
    // TODO: Implement onJoblistItemActivated
}
void MainWindowUI::onContextMenu(wxDataViewEvent &event)
{
    wxMenu menu;
    menu.Append(0, "Calculate Hash");
    menu.Append(1, "Download info from CivitAi.com");
    PopupMenu(&menu);
}
void MainWindowUI::onJoblistSelectionChanged(wxDataViewEvent &event)
{
    // TODO: Implement onJoblistSelectionChanged
@@ -162,9 +170,22 @@
    this->sd_params->width = this->m_width->GetValue();
    this->sd_params->height = this->m_height->GetValue();
    QM::QueueItem item;
    item.params = *this->sd_params;
    item.model = this->m_model->GetStringSelection().ToStdString();
    if (item.params.seed == -1)
    {
        item.params.seed = sd_gui_utils::generateRandomInt(100000000, 999999999);
    }
    // add the queue item
    auto id = this->qmanager->AddItem(this->sd_params);
    auto id = this->qmanager->AddItem(item);
}
void MainWindowUI::onSamplerSelect(wxCommandEvent &event)
{
    this->sd_params->sample_method = (sample_method_t)this->m_sampler->GetSelection();
}
void MainWindowUI::onSavePreset(wxCommandEvent &event)
@@ -258,30 +279,596 @@
        this->LoadPresets();
    }
}
/*
    this->m_joblist->AppendTextColumn("Id");
    this->m_joblist->AppendTextColumn("Created");
    this->m_joblist->AppendTextColumn("Model");
    this->m_joblist->AppendTextColumn("Sampler");
    this->m_joblist->AppendTextColumn("Seed");
    this->m_joblist->AppendTextColumn("Status");
*/
void MainWindowUI::HandleSDLog(sd_log_level_t level, const char *text, void *data)
{
    if (level == sd_log_level_t::SD_LOG_INFO || level == sd_log_level_t::SD_LOG_ERROR)
    {
        auto *eventHandler = (wxEvtHandler *)data;
        wxThreadEvent *e = new wxThreadEvent();
        e->SetString(wxString::Format("SD_MESSAGE:%s", text));
        e->SetPayload(level);
        wxQueueEvent(eventHandler, e);
    }
}
void MainWindowUI::OnQueueItemManagerItemStatusChanged(QM::QueueItem item)
{
    auto store = this->m_joblist->GetStore();
    int lastCol = this->m_joblist->GetColumnCount() - 1;
    for (unsigned int i = 0; i < store->GetItemCount(); i++)
    {
        auto _item = store->GetItem(i);
        auto _item_data = store->GetItemData(_item);
        auto *_qitem = reinterpret_cast<QM::QueueItem *>(_item_data);
        if (_qitem->id == item.id)
        {
            store->SetValueByRow(wxVariant(QM::QueueStatus_str[item.status]), i, lastCol);
            this->m_joblist->Refresh();
            break;
        }
    }
}
void MainWindowUI::loadModelList()
{
    this->m_sampler->Clear();
    for (auto sampler : sd_gui_utils::sample_method_str)
    {
        int _u = this->m_sampler->Append(sampler);
        if (sampler == sd_gui_utils::sample_method_str[this->sd_params->sample_method])
        {
            this->m_sampler->Select(_u);
        }
    }
    this->LoadFileList(sd_gui_utils::DirTypes::CHECKPOINT);
    for (auto model : this->ModelFiles)
    {
        // auto size = sd_gui_utils::HumanReadable{std::filesystem::file_size(model.second)};
        uintmax_t size = std::filesystem::file_size(model.second);
        auto humanSize = sd_gui_utils::humanReadableFileSize(size);
        auto hs = wxString::Format("%.1f %s", humanSize.first, humanSize.second);
        wxVector<wxVariant> data;
        data.push_back(wxVariant(model.first));
        data.push_back(hs);
        data.push_back("--");
        this->m_data_model_list->AppendItem(data);
    }
    this->m_data_model_list->Refresh();
}
void MainWindowUI::StartGeneration(QM::QueueItem myJob)
{
    // here starts the trhead
    // this->threads.push_back(std::thread(std::bind(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob)));
    // this->threads.emplace_back(std::thread(std::bind(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob)));
    // this->threads.emplace_back(std::thread(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob));
    std::thread(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob);
}
void MainWindowUI::HandleSDProgress(int step, int steps, float time, void *data)
{
    sd_gui_utils::VoidHolder *objs = (sd_gui_utils::VoidHolder *)data;
    wxEvtHandler *eventHandler = (wxEvtHandler *)objs->p1;
    QM::QueueItem *myItem = (QM::QueueItem *)objs->p2;
    myItem->step = step;
    myItem->steps = steps;
    myItem->time = time;
    /*
        format it/s
        time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s",
               progress.c_str(), step, steps,
               time > 1.0f || time == 0 ? time : (1.0f / time)
    */
    wxThreadEvent *e = new wxThreadEvent();
    e->SetString(wxString::Format("GENERATION_PROGRESS:%d/%d", step, steps));
    e->SetPayload(myItem);
    wxQueueEvent(eventHandler, e);
}
void MainWindowUI::Generate(wxEvtHandler *eventHandler, QM::QueueItem myItem)
{
    sd_gui_utils::VoidHolder *vparams = new sd_gui_utils::VoidHolder;
    vparams->p1 = (void *)this->GetEventHandler();
    vparams->p2 = (void *)&myItem;
    sd_set_progress_callback(MainWindowUI::HandleSDProgress, (void *)vparams);
    if (!this->modelLoaded)
    {
        this->sd_ctx = this->LoadModelv2(eventHandler, myItem);
        this->currentModel = myItem.params.model_path;
    }
    else
    {
        if (myItem.params.model_path != this->currentModel)
        {
            free_sd_ctx(this->sd_ctx);
            this->sd_ctx = this->LoadModelv2(eventHandler, myItem);
        }
    }
    if (!this->modelLoaded || this->sd_ctx == nullptr)
    {
        wxThreadEvent *f = new wxThreadEvent();
        f->SetString("GENERATION_ERROR:Model load failed...");
        f->SetPayload(myItem);
        wxQueueEvent(eventHandler, f);
        return;
    }
    auto start = std::chrono::system_clock::now();
    wxThreadEvent *e = new wxThreadEvent();
    e->SetString(wxString::Format("GENERATION_START:%s", this->sd_params->model_path));
    e->SetPayload(myItem);
    wxQueueEvent(eventHandler, e);
    sd_image_t *control_image = NULL;
    sd_image_t *results;
    // std::lock_guard<std::mutex> guard(this->sdMutex);
    results = txt2img(this->sd_ctx,
                      myItem.params.prompt.c_str(),
                      myItem.params.negative_prompt.c_str(),
                      myItem.params.clip_skip,
                      myItem.params.cfg_scale,
                      myItem.params.width,
                      myItem.params.height,
                      myItem.params.sample_method,
                      myItem.params.sample_steps,
                      myItem.params.seed,
                      myItem.params.batch_count,
                      control_image,
                      myItem.params.control_strength);
    if (results == NULL)
    {
        wxThreadEvent *f = new wxThreadEvent();
        f->SetString("GENERATION_ERROR:Something wrong happened at image generation...");
        f->SetPayload(myItem);
        wxQueueEvent(eventHandler, f);
        return;
    }
    if (!std::filesystem::exists(this->cfg->output))
    {
        std::filesystem::create_directories(this->cfg->output);
    }
    /* save image(s) */
    const auto p1 = std::chrono::system_clock::now();
    auto ctime = std::chrono::duration_cast<std::chrono::seconds>(p1.time_since_epoch()).count();
    for (int i = 0; i < this->sd_params->batch_count; i++)
    {
        if (results[i].data == NULL)
        {
            continue;
        }
        // handle data??
        wxImage *img = new wxImage(results[i].width, results[i].height, results[i].data);
        std::string filename = this->cfg->output;
        std::string extension = ".png";
        if (this->sd_params->batch_count > 1)
        {
            filename = filename + wxFileName::GetPathSeparator() + std::to_string(ctime) + "_" + std::to_string(i) + extension;
        }
        else
        {
            filename = filename + wxFileName::GetPathSeparator() + std::to_string(ctime) + extension;
        }
        if (!img->SaveFile(filename))
        {
            wxThreadEvent *g = new wxThreadEvent();
            g->SetString(wxString::Format("GENERATION_ERROR:Failed to save image into %s", filename));
            g->SetPayload(myItem);
            wxQueueEvent(eventHandler, g);
        }
        else
        {
            myItem.images.emplace_back(filename);
        }
        // handle data??
    }
    auto end = std::chrono::system_clock::now();
    std::chrono::duration<double> elapsed_seconds = end - start;
    // send to notify the user...
    wxThreadEvent *h = new wxThreadEvent();
    auto msg = fmt::format("MESSAGE:Image generation done in {}s. Saved into {}", elapsed_seconds.count(), this->cfg->output);
    h->SetString(wxString(msg.c_str()));
    // h->SetPayload(myItem);
    wxQueueEvent(eventHandler, h);
    wxThreadEvent *i = new wxThreadEvent();
    i->SetString(wxString::Format("GENERATION_DONE:ok"));
    i->SetPayload(results);
    wxQueueEvent(eventHandler, i);
    // send to the queue manager
    wxThreadEvent *j = new wxThreadEvent();
    j->SetString(wxString::Format("QUEUE:%d", QM::QueueEvents::ITEM_FINISHED));
    j->SetPayload(myItem);
    wxQueueEvent(eventHandler, j);
}
void MainWindowUI::initConfig()
{
    wxString datapath = wxStandardPaths::Get().GetUserDataDir() + wxFileName::GetPathSeparator() + "sd_ui_data" + wxFileName::GetPathSeparator();
    wxString imagespath = wxStandardPaths::Get().GetDocumentsDir() + wxFileName::GetPathSeparator() + "sd_ui_output" + wxFileName::GetPathSeparator();
    wxString model_path = datapath;
    model_path.append("checkpoints");
    wxString vae_path = datapath;
    vae_path.append("vae");
    wxString lora_path = datapath;
    lora_path.append("lora");
    wxString embedding_path = datapath;
    embedding_path.append("embedding");
    wxString presets_path = datapath;
    presets_path.append("presets");
    wxString jobs_path = datapath;
    jobs_path.append("queue_jobs");
    this->cfg->lora = this->fileConfig->Read("/paths/lora", lora_path).ToStdString();
    this->cfg->model = this->fileConfig->Read("/paths/model", model_path).ToStdString();
    this->cfg->vae = this->fileConfig->Read("/paths/vae", vae_path).ToStdString();
    this->cfg->embedding = this->fileConfig->Read("/paths/embedding", embedding_path).ToStdString();
    this->cfg->presets = this->fileConfig->Read("/paths/presets", presets_path).ToStdString();
    this->cfg->jobs = this->fileConfig->Read("/paths/presets", jobs_path).ToStdString();
    this->cfg->output = this->fileConfig->Read("/paths/output", imagespath).ToStdString();
    this->cfg->keep_model_in_memory = this->fileConfig->Read("/keep_model_in_memory", this->cfg->keep_model_in_memory);
    this->cfg->save_all_image = this->fileConfig->Read("/save_all_image", this->cfg->save_all_image);
    // populate data from sd_params as default...
    if (!this->modelLoaded)
    {
        this->m_cfg->SetValue(static_cast<double>(this->sd_params->cfg_scale));
        this->m_seed->SetValue(static_cast<int>(this->sd_params->seed));
        this->m_clip_skip->SetValue(this->sd_params->clip_skip);
        this->m_steps->SetValue(this->sd_params->sample_steps);
        this->m_width->SetValue(this->sd_params->width);
        this->m_height->SetValue(this->sd_params->height);
        this->m_batch_count->SetValue(this->sd_params->batch_count);
    }
}
void MainWindowUI::OnCloseSettings(wxCloseEvent &event)
{
    this->initConfig();
    this->settingsWindow->Destroy();
}
void MainWindowUI::OnThreadMessage(wxThreadEvent &e)
{
    if (e.GetSkipped() == false)
    {
        e.Skip();
    }
    auto msg = e.GetString().ToStdString();
    std::string token = msg.substr(0, msg.find(":"));
    std::string content = msg.substr(msg.find(":") + 1);
    // this->logs->AppendText(fmt::format("Got thread message: {}\n", e.GetString().ToStdString()));
    if (token == "QUEUE")
    {
        // only numbers here...
        QM::QueueEvents event = (QM::QueueEvents)std::stoi(content);
        // only handle the QUEUE messages, what this class generate
        // alway QM::EueueItem the payload, with the new data
        QM::QueueItem payload;
        payload = e.GetPayload<QM::QueueItem>();
        switch (event)
        {
            // new item added
        case QM::QueueEvents::ITEM_ADDED:
            this->OnQueueItemManagerItemAdded(payload);
            break;
            // item status changed
        case QM::QueueEvents::ITEM_STATUS_CHANGED:
            this->OnQueueItemManagerItemStatusChanged(payload);
            break;
            // item updated... ? ? ?
        case QM::QueueEvents::ITEM_UPDATED:
            this->OnQueueItemManagerItemUpdated(payload);
            break;
        case QM::QueueEvents::ITEM_START:
            this->StartGeneration(payload);
            break;
        default:
            break;
        }
    }
    if (token == "MODEL_LOAD_DONE")
    {
        // this->m_generate->Enable();
        // this->m_model->Enable();
        // this->m_vae->Enable();
        // this->m_refresh->Enable();
        // this->logs->AppendText(fmt::format("Model loaded: {}\n", content));
        this->modelLoaded = true;
        // std::lock_guard<std::mutex> guard(this->sdMutex);
        // this->sd_ctx = e.GetPayload<sd_ctx_t *>();
        if (!this->IsShownOnScreen())
        {
            this->notification->SetFlags(wxICON_INFORMATION);
            this->notification->SetTitle("SD Gui");
            this->notification->SetMessage(content);
            this->notification->Show(5000);
        }
    }
    if (token == "MODEL_LOAD_START")
    {
        // this->m_generate->Disable();
        // this->m_model->Disable();
        // this->m_vae->Disable();
        // this->m_refresh->Disable();
        this->logs->AppendText(fmt::format("Model load start: {}\n", content));
    }
    if (token == "MODEL_LOAD_ERROR")
    {
        // this->m_generate->Disable();
        // this->m_model->Enable();
        // this->m_vae->Disable();
        // this->m_refresh->Enable();
        this->logs->AppendText(fmt::format("Model load error: {}\n", content));
        this->modelLoaded = false;
        if (!this->IsShownOnScreen())
        {
            this->notification->SetFlags(wxICON_ERROR);
            this->notification->SetTitle("SD Gui - error");
            this->notification->SetMessage(content);
            this->notification->Show(5000);
        }
    }
    if (token == "GENERATION_START")
    {
        auto myjob = e.GetPayload<QM::QueueItem>();
        // this->m_generate->Disable();
        // this->m_model->Disable();
        // this->m_vae->Disable();
        // this->m_refresh->Disable();
        this->logs->AppendText(fmt::format("Diffusion started. Seed: {} Batch: {} {}x{}px Cfg: {} Steps: {}\n",
                                           myjob.params.seed,
                                           myjob.params.batch_count,
                                           myjob.params.width,
                                           myjob.params.height,
                                           myjob.params.cfg_scale,
                                           myjob.params.sample_steps));
    }
    // in the original SD.cpp the progress callback is not implemented... :(
    if (token == "GENERATION_PROGRESS")
    {
        QM::QueueItem *myjob = e.GetPayload<QM::QueueItem *>();
        // update column
        auto store = this->m_joblist->GetStore();
        // -1 the last (status)
        // -2 ...      (speed)
        // -3 ...      (progressbar)
        // it/s format
        /*
        time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s",
               progress.c_str(), step, steps,
               time > 1.0f || time == 0 ? time : (1.0f / time)
        */
        wxString speed = wxString::Format(myjob->time > 1.0f ? "%.2fs/it" : "%.2fit/s", myjob->time > 1.0f || myjob->time == 0 ? myjob->time : (1.0f / myjob->time));
        int progressCol = this->m_joblist->GetColumnCount() - 3;
        int speedCol = this->m_joblist->GetColumnCount() - 2;
        float current_progress = 100.f * (static_cast<float>(myjob->step) / static_cast<float>(myjob->steps));
        if (current_progress < 2.f)
        {
            return;
        }
        for (unsigned int i = 0; i < store->GetItemCount(); i++)
        {
            auto _item = store->GetItem(i);
            auto _item_data = store->GetItemData(_item);
            auto *_qitem = reinterpret_cast<QM::QueueItem *>(_item_data);
            if (_qitem->id == myjob->id)
            {
                store->SetValueByRow(static_cast<int>(current_progress), i, progressCol);
                store->SetValueByRow(speed, i, speedCol);
                this->m_joblist->Refresh();
                break;
            }
        }
        return;
        for (auto it = this->JobTableItems.begin(); it != this->JobTableItems.end(); ++it)
        {
            if (it->second->id == myjob->id)
            {
                // store->SetValueByRow(wxVariant(QM::QueueStatus_str[item.status]), it->first, progressCol);
                store->SetValueByRow(static_cast<int>(current_progress), it->first, progressCol);
                store->SetValueByRow(speed, it->first, speedCol);
                this->m_joblist->Refresh();
                break;
            }
        }
        // update column
    }
    if (token == "GENERATION_DONE")
    {
        // this->m_generate->Enable();
        // this->m_model->Enable();
        // this->m_vae->Enable();
        // this->m_refresh->Enable();
        // sd_image_t *results = e.GetPayload<sd_image_t *>();
        // show images in new window...
        /* for (int i = 0; i < this->sd_params->batch_count; i++)
         {
             MainWindowImageViewer *imgWindow = new MainWindowImageViewer(this);
             // wxBitmap *img = new wxBitmap(results[i].data, (int)results[i].width, (int)results[i].height, (int)results[i].channel);
             wxImage img(results[i].width, results[i].height, results[i].data);
             wxBitmapBundle wxBmapB(img);
             imgWindow->m_bitmap->SetBitmap(wxBmapB);
             imgWindow->m_bitmap->SetSize(results[i].width, results[i].height);
             imgWindow->SetSize(results[i].width + 200, results[i].height);
             std::string details = fmt::format("Prompt:\n\n{}\n\nNegative prompt: \n\n{}\n\nSeed: {} \nCfg scale: {}\nClip skip: {}\nSampler: {}\nSteps: {}\nWidth: {} Height: {}",
                                               this->sd_params->prompt, this->sd_params->negative_prompt,
                                               this->sd_params->seed + i, this->sd_params->cfg_scale,
                                               this->sd_params->clip_skip, sd_gui_utils::sample_method_str[this->sd_params->sample_method], this->sd_params->sample_steps,
                                               results[i].width, results[i].height);
             imgWindow->m_textCtrl4->AppendText(wxString(details));
             imgWindow->Show();
             // imgWindow->m_bitmap->SetBitmap(img);
             /// imgWindow->m_bitmap->Set
         }*/
    }
    if (token == "GENERATION_ERROR")
    {
        this->logs->AppendText(fmt::format("Generation error: {}\n", content));
        if (!this->IsShownOnScreen())
        {
            this->notification->SetFlags(wxICON_ERROR);
            this->notification->SetTitle("SD Gui - error");
            this->notification->SetMessage(content);
            this->notification->Show(5000);
        }
    }
    if (token == "SD_MESSAGE")
    {
        if (content.length() < 1)
        {
            return;
        }
        this->logs->AppendText(fmt::format("{}", content));
    }
    if (token == "MESSAGE")
    {
        this->logs->AppendText(fmt::format("{}\n", content));
        if (!this->IsShownOnScreen())
        {
            this->notification->SetFlags(wxICON_INFORMATION);
            this->notification->SetTitle("SD Gui");
            this->notification->SetMessage(content);
            this->notification->Show(5000);
        }
    }
}
sd_ctx_t *MainWindowUI::LoadModelv2(wxEvtHandler *eventHandler, QM::QueueItem myItem)
{
    wxThreadEvent *e = new wxThreadEvent();
    e->SetString(wxString::Format("MODEL_LOAD_START:%s", myItem.params.model_path));
    e->SetPayload(myItem);
    wxQueueEvent(eventHandler, e);
    // std::lock_guard<std::mutex> guard(this->sdMutex);
    sd_ctx_t *sd_ctx_ = new_sd_ctx(
        myItem.params.model_path.c_str(),
        myItem.params.vae_path.c_str(),
        myItem.params.taesd_path.c_str(),
        myItem.params.controlnet_path.c_str(),
        myItem.params.lora_model_dir.c_str(),
        myItem.params.embeddings_path.c_str(),
        true, false, false,
        myItem.params.n_threads,
        myItem.params.wtype,
        myItem.params.rng_type,
        myItem.params.schedule, false);
    if (sd_ctx_ == NULL)
    {
        wxThreadEvent *c = new wxThreadEvent();
        c->SetString(wxString::Format("MODEL_LOAD_ERROR:%s", myItem.params.model_path));
        c->SetPayload(myItem);
        wxQueueEvent(eventHandler, c);
        this->modelLoaded = false;
        return nullptr;
    }
    else
    {
        wxThreadEvent *c = new wxThreadEvent();
        c->SetString(wxString::Format("MODEL_LOAD_DONE:%s", myItem.params.model_path));
        wxQueueEvent(eventHandler, c);
        this->modelLoaded = true;
        this->currentModel = myItem.params.model_path;
    }
    return sd_ctx_;
}
void MainWindowUI::LoadPresets()
{
    this->LoadFileList(sd_gui_utils::DirTypes::PRESETS);
}
void MainWindowUI::OnQueueItemManagerItemUpdated(QM::QueueItem item)
{
}
void MainWindowUI::loadVaeList()
{
    if (!std::filesystem::exists(this->cfg->vae))
    {
        std::filesystem::create_directories(this->cfg->vae);
    }
    this->LoadFileList(sd_gui_utils::DirTypes::VAE);
}
MainWindowUI::~MainWindowUI()
{
    // this->Hide();
    for (int i = 0; i < this->threads.size(); i++)
    {
        if (this->threads.at(i).joinable())
        {
            this->threads.at(i).join();
        }
    }
}
void MainWindowUI::OnQueueItemManagerItemAdded(QM::QueueItem item)
{
    wxVector<wxVariant> data;
    auto created_at = sd_gui_utils::formatUnixTimestampToDate(item.created_at);
    data.push_back(wxVariant(std::to_string(item.id)));
    data.push_back(wxVariant(std::to_string(item.created_at)));
    data.push_back(wxVariant(item.params.model_path));
    data.push_back(wxVariant(created_at));
    data.push_back(wxVariant(item.model));
    data.push_back(wxVariant(sd_gui_utils::sample_method_str[(int)item.params.sample_method]));
    data.push_back(wxVariant(std::to_string(item.params.seed)));
    data.push_back(wxVariant(QM::QueueStatus_str[item.status]));
    this->m_joblist->AppendItem(data);
    data.push_back(item.status == QM::QueueStatus::DONE ? 100 : 1); // progressbar
    data.push_back(wxString("-.--it/s"));                           // speed
    data.push_back(wxVariant(QM::QueueStatus_str[item.status]));    // status
    auto store = this->m_joblist->GetStore();
    (*this->JobTableItems)[store->GetCount() - 1] = item;
    QM::QueueItem *nItem = new QM::QueueItem(item);
    this->JobTableItems[item.id] = nItem;
    //  store->AppendItem(data, wxUIntPtr(this->JobTableItems[item.id]));
    store->PrependItem(data, wxUIntPtr(this->JobTableItems[item.id]));
}
void MainWindowUI::LoadFileList(sd_gui_utils::DirTypes type)
@@ -334,7 +921,7 @@
        if (type == sd_gui_utils::DirTypes::CHECKPOINT || type == sd_gui_utils::DirTypes::VAE)
        {
            if (ext != ".safetensors" && ext != ".cptk")
            if (ext != ".safetensors" && ext != ".ckpt")
            {
                continue;
            }
@@ -403,528 +990,4 @@
            this->m_preset_list->Enable();
        }
    }
}
void MainWindowUI::LoadPresets()
{
    this->LoadFileList(sd_gui_utils::DirTypes::PRESETS);
}
void MainWindowUI::OnThreadMessage(wxThreadEvent &e)
{
    e.Skip();
    auto msg = e.GetString().ToStdString();
    std::string token = msg.substr(0, msg.find(":"));
    std::string content = msg.substr(msg.find(":") + 1);
    // this->logs->AppendText(fmt::format("Got thread message: {}\n", e.GetString().ToStdString()));
    if (token == "QUEUE")
    {
        // only numbers here...
        QM::QueueEvents event = (QM::QueueEvents)std::stoi(content);
        // only handle the QUEUE messages, what this class generate
        // alway QM::EueueItem the payload, with the new data
        QM::QueueItem payload;
        payload = e.GetPayload<QM::QueueItem>();
        switch (event)
        {
            // new item added
        case QM::QueueEvents::ITEM_ADDED:
            this->OnQueueItemManagerItemAdded(payload);
            break;
            // item status changed
        case QM::QueueEvents::ITEM_STATUS_CHANGED:
            this->OnQueueItemManagerItemStatusChanged(payload);
            break;
            // item updated... ? ? ?
        case QM::QueueEvents::ITEM_UPDATED:
            this->OnQueueItemManagerItemUpdated(payload);
            break;
        case QM::QueueEvents::ITEM_START:
            this->StartGeneration(payload);
            break;
        default:
            break;
        }
    }
    if (token == "MODEL_LOAD_DONE")
    {
        // this->m_generate->Enable();
        // this->m_model->Enable();
        // this->m_vae->Enable();
        // this->m_refresh->Enable();
        // this->logs->AppendText(fmt::format("Model loaded: {}\n", content));
        this->modelLoaded = true;
        this->sd_ctx = e.GetPayload<sd_ctx_t *>();
        if (!this->IsShownOnScreen())
        {
            this->notification->SetFlags(wxICON_INFORMATION);
            this->notification->SetTitle("SD Gui");
            this->notification->SetMessage(content);
            this->notification->Show(5000);
        }
    }
    if (token == "MODEL_LOAD_START")
    {
        // this->m_generate->Disable();
        // this->m_model->Disable();
        // this->m_vae->Disable();
        // this->m_refresh->Disable();
        this->logs->AppendText(fmt::format("Model load start: {}\n", content));
    }
    if (token == "MODEL_LOAD_ERROR")
    {
        // this->m_generate->Disable();
        // this->m_model->Enable();
        // this->m_vae->Disable();
        // this->m_refresh->Enable();
        this->logs->AppendText(fmt::format("Model load error: {}\n", content));
        this->modelLoaded = false;
        if (!this->IsShownOnScreen())
        {
            this->notification->SetFlags(wxICON_ERROR);
            this->notification->SetTitle("SD Gui - error");
            this->notification->SetMessage(content);
            this->notification->Show(5000);
        }
    }
    if (token == "GENERATION_START")
    {
        auto myjob = e.GetPayload<QM::QueueItem>();
        // this->m_generate->Disable();
        // this->m_model->Disable();
        // this->m_vae->Disable();
        // this->m_refresh->Disable();
        this->logs->AppendText(fmt::format("Diffusion started. Seed: {} Batch: {} {}x{}px Cfg: {} Steps: {}\n",
                                           myjob.params.seed,
                                           myjob.params.batch_count,
                                           myjob.params.width,
                                           myjob.params.height,
                                           myjob.params.cfg_scale,
                                           myjob.params.sample_steps));
    }
    // never, not implemented in sd.cpp
    if (token == "GENERATION_PROGRESS")
    {
        // this->m_generate->Disable();
        // this->logs->AppendText(fmt::format("Generation progress: {}\n", content));
    }
    if (token == "GENERATION_DONE")
    {
        // this->m_generate->Enable();
        // this->m_model->Enable();
        // this->m_vae->Enable();
        // this->m_refresh->Enable();
        // sd_image_t *results = e.GetPayload<sd_image_t *>();
        // show images in new window...
        /* for (int i = 0; i < this->sd_params->batch_count; i++)
         {
             MainWindowImageViewer *imgWindow = new MainWindowImageViewer(this);
             // wxBitmap *img = new wxBitmap(results[i].data, (int)results[i].width, (int)results[i].height, (int)results[i].channel);
             wxImage img(results[i].width, results[i].height, results[i].data);
             wxBitmapBundle wxBmapB(img);
             imgWindow->m_bitmap->SetBitmap(wxBmapB);
             imgWindow->m_bitmap->SetSize(results[i].width, results[i].height);
             imgWindow->SetSize(results[i].width + 200, results[i].height);
             std::string details = fmt::format("Prompt:\n\n{}\n\nNegative prompt: \n\n{}\n\nSeed: {} \nCfg scale: {}\nClip skip: {}\nSampler: {}\nSteps: {}\nWidth: {} Height: {}",
                                               this->sd_params->prompt, this->sd_params->negative_prompt,
                                               this->sd_params->seed + i, this->sd_params->cfg_scale,
                                               this->sd_params->clip_skip, sd_gui_utils::sample_method_str[this->sd_params->sample_method], this->sd_params->sample_steps,
                                               results[i].width, results[i].height);
             imgWindow->m_textCtrl4->AppendText(wxString(details));
             imgWindow->Show();
             // imgWindow->m_bitmap->SetBitmap(img);
             /// imgWindow->m_bitmap->Set
         }*/
    }
    if (token == "GENERATION_ERROR")
    {
        // this->m_generate->Enable();
        // this->m_model->Enable();
        // this->m_vae->Enable();
        this->logs->AppendText(fmt::format("Generation error: {}\n", content));
        if (!this->IsShownOnScreen())
        {
            this->notification->SetFlags(wxICON_ERROR);
            this->notification->SetTitle("SD Gui - error");
            this->notification->SetMessage(content);
            this->notification->Show(5000);
        }
    }
    if (token == "SD_MESSAGE")
    {
        if (content.length() < 1)
        {
            return;
        }
        this->logs->AppendText(fmt::format("{}", content));
    }
    if (token == "MESSAGE")
    {
        this->logs->AppendText(fmt::format("{}\n", content));
        if (!this->IsShownOnScreen())
        {
            this->notification->SetFlags(wxICON_INFORMATION);
            this->notification->SetTitle("SD Gui");
            this->notification->SetMessage(content);
            this->notification->Show(5000);
        }
    }
}
void MainWindowUI::LoadModel(wxEvtHandler *eventHandler, QM::QueueItem myItem)
{
    wxThreadEvent *e = new wxThreadEvent();
    e->SetString(wxString::Format("MODEL_LOAD_START:%s", this->sd_params->model_path));
    e->SetPayload(myItem);
    wxQueueEvent(eventHandler, e);
    /*this->sd_ctx = new_sd_ctx(this->sd_params->model_path.c_str(),
                              this->sd_params->vae_path.c_str(),
                              this->sd_params->taesd_path.c_str(),
                              this->sd_params->controlnet_path.c_str(),
                              std::string(this->sd_params->lora_model_dir + "\\").c_str(),
                              this->sd_params->embeddings_path.c_str(),
                              true, // vae decode only
                              true ,
                              !this->cfg->keep_model_in_memory,
                              this->sd_params->n_threads,
                              this->sd_params->wtype ,
                              this->sd_params->rng_type,
                              this->sd_params->schedule,
                              this->sd_params->control_net_cpu);*/
    // this->sd_ctx = new_sd_ctx(this->sd_params->model_path.c_str(), this->sd_params->vae_path.c_str(), this->sd_params->taesd_path.c_str(), this->sd_params->lora_model_dir.c_str(), true, false, false, this->sd_params->n_threads, this->sd_params->wtype, this->sd_params->rng_type, this->sd_params->schedule);
    this->sd_ctx = new_sd_ctx(myItem.params.model_path.c_str(), myItem.params.vae_path.c_str(), myItem.params.taesd_path.c_str(), myItem.params.lora_model_dir.c_str(), true, false, false, myItem.params.n_threads, myItem.params.wtype, myItem.params.rng_type, myItem.params.schedule);
    if (this->sd_ctx == NULL)
    {
        wxThreadEvent *c = new wxThreadEvent();
        c->SetString(wxString::Format("MODEL_LOAD_ERROR:%s", this->sd_params->model_path));
        c->SetPayload(myItem);
        wxQueueEvent(eventHandler, c);
        this->modelLoaded = false;
        return;
    }
    else
    {
        wxThreadEvent *c = new wxThreadEvent();
        c->SetString(wxString::Format("MODEL_LOAD_DONE:%s", this->sd_params->model_path));
        c->SetPayload(this->sd_ctx);
        wxQueueEvent(eventHandler, c);
        this->modelLoaded = true;
        return;
    }
    return;
}
void MainWindowUI::initConfig()
{
    wxString datapath = wxStandardPaths::Get().GetUserDataDir() + wxFileName::GetPathSeparator() + "sd_ui_data" + wxFileName::GetPathSeparator();
    wxString imagespath = wxStandardPaths::Get().GetDocumentsDir() + wxFileName::GetPathSeparator() + "sd_ui_output" + wxFileName::GetPathSeparator();
    wxString model_path = datapath;
    model_path.append("checkpoints");
    wxString vae_path = datapath;
    vae_path.append("vae");
    wxString lora_path = datapath;
    lora_path.append("lora");
    wxString embedding_path = datapath;
    embedding_path.append("embedding");
    wxString presets_path = datapath;
    presets_path.append("presets");
    wxString jobs_path = datapath;
    jobs_path.append("queue_jobs");
    this->cfg->lora = this->fileConfig->Read("/paths/lora", lora_path).ToStdString();
    this->cfg->model = this->fileConfig->Read("/paths/model", model_path).ToStdString();
    this->cfg->vae = this->fileConfig->Read("/paths/vae", vae_path).ToStdString();
    this->cfg->embedding = this->fileConfig->Read("/paths/embedding", embedding_path).ToStdString();
    this->cfg->presets = this->fileConfig->Read("/paths/presets", presets_path).ToStdString();
    this->cfg->jobs = this->fileConfig->Read("/paths/presets", jobs_path).ToStdString();
    this->cfg->output = this->fileConfig->Read("/paths/output", imagespath).ToStdString();
    this->cfg->keep_model_in_memory = this->fileConfig->Read("/keep_model_in_memory", this->cfg->keep_model_in_memory);
    this->cfg->save_all_image = this->fileConfig->Read("/save_all_image", this->cfg->save_all_image);
    // populate data from sd_params as default...
    if (!this->modelLoaded)
    {
        this->m_cfg->SetValue(static_cast<double>(this->sd_params->cfg_scale));
        this->m_seed->SetValue(static_cast<int>(this->sd_params->seed));
        this->m_clip_skip->SetValue(this->sd_params->clip_skip);
        this->m_steps->SetValue(this->sd_params->sample_steps);
        this->m_width->SetValue(this->sd_params->width);
        this->m_height->SetValue(this->sd_params->height);
        this->m_batch_count->SetValue(this->sd_params->batch_count);
    }
}
void MainWindowUI::StartGeneration(QM::QueueItem myJob)
{
    this->threads.push_back(std::thread(std::bind(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob)));
}
void MainWindowUI::OnCloseSettings(wxCloseEvent &event)
{
    this->initConfig();
    this->settingsWindow->Destroy();
}
void MainWindowUI::loadModelList()
{
    this->m_sampler->Clear();
    for (auto sampler : sd_gui_utils::sample_method_str)
    {
        int _u = this->m_sampler->Append(sampler);
        if (sampler == sd_gui_utils::sample_method_str[this->sd_params->sample_method])
        {
            this->m_sampler->Select(_u);
        }
    }
    this->LoadFileList(sd_gui_utils::DirTypes::CHECKPOINT);
    for (auto model : this->ModelFiles)
    {
        // auto size = sd_gui_utils::HumanReadable{std::filesystem::file_size(model.second)};
        auto size = std::filesystem::file_size(model.second);
        wxVector<wxVariant> data;
        data.push_back(wxVariant(model.first));
        data.push_back(wxVariant(std::to_string(size)));
        this->m_data_model_list->AppendItem(data);
    }
    this->m_data_model_list->Refresh();
}
void MainWindowUI::Generate(wxEvtHandler *eventHandler, QM::QueueItem myItem)
{
    // @brief model loading is done in the same thread which generating stuffs... no need new thread
    // @brief to load the model if no model loaded, or the loaded model is different from the job's model...
    // @brief if all second job have different model, it will be slower than same model on all jobs
    // TODO: abort job if model can not be loaded...
    if (!this->modelLoaded)
    {
        this->LoadModel(eventHandler, myItem);
        this->currentModel = myItem.params.model_path;
    }
    else
    {
        if (myItem.params.model_path != this->currentModel)
        {
            free_sd_ctx(this->sd_ctx);
            this->LoadModel(eventHandler, myItem);
        }
    }
    if (!this->modelLoaded)
    {
        wxThreadEvent *f = new wxThreadEvent();
        f->SetString("GENERATION_ERROR:Model load failed...");
        f->SetPayload(myItem);
        wxQueueEvent(eventHandler, f);
        return;
    }
    auto start = std::chrono::system_clock::now();
    wxThreadEvent *e = new wxThreadEvent();
    e->SetString(wxString::Format("GENERATION_START:%s", this->sd_params->model_path));
    e->SetPayload(myItem);
    wxQueueEvent(eventHandler, e);
    sd_image_t *control_image = NULL;
    sd_image_t *results;
    results = txt2img(this->sd_ctx,
                      this->sd_params->prompt.c_str(), this->sd_params->negative_prompt.c_str(), this->sd_params->clip_skip, this->sd_params->cfg_scale, this->sd_params->width, this->sd_params->height, this->sd_params->sample_method, this->sd_params->sample_steps, this->sd_params->seed, this->sd_params->batch_count);
    if (results == NULL)
    {
        wxThreadEvent *f = new wxThreadEvent();
        f->SetString("GENERATION_ERROR:Something wrong happened at image generation...");
        f->SetPayload(myItem);
        wxQueueEvent(eventHandler, f);
        return;
    }
    if (!std::filesystem::exists(this->cfg->output))
    {
        std::filesystem::create_directories(this->cfg->output);
    }
    /* save image(s) */
    const auto p1 = std::chrono::system_clock::now();
    auto ctime = std::chrono::duration_cast<std::chrono::seconds>(p1.time_since_epoch()).count();
    for (int i = 0; i < this->sd_params->batch_count; i++)
    {
        if (results[i].data == NULL)
        {
            continue;
        }
        // handle data??
        wxImage *img = new wxImage(results[i].width, results[i].height, results[i].data);
        std::string filename = this->cfg->output;
        std::string extension = ".png";
        if (this->sd_params->batch_count > 1)
        {
            filename = filename + wxFileName::GetPathSeparator() + std::to_string(ctime) + "_" + std::to_string(i) + extension;
        }
        else
        {
            filename = filename + wxFileName::GetPathSeparator() + std::to_string(ctime) + extension;
        }
        if (!img->SaveFile(filename))
        {
            wxThreadEvent *g = new wxThreadEvent();
            g->SetString(wxString::Format("GENERATION_ERROR:Failed to save image into %s", filename));
            g->SetPayload(myItem);
            wxQueueEvent(eventHandler, g);
        }
        else
        {
            myItem.images.emplace_back(filename);
        }
        // handle data??
    }
    auto end = std::chrono::system_clock::now();
    std::chrono::duration<double> elapsed_seconds = end - start;
    // send to notify the user...
    wxThreadEvent *h = new wxThreadEvent();
    auto msg = fmt::format("MESSAGE:Image generation done in {}s. Saved into {}", elapsed_seconds.count(), this->cfg->output);
    h->SetString(wxString(msg.c_str()));
    h->SetPayload(myItem);
    wxQueueEvent(eventHandler, h);
    wxThreadEvent *i = new wxThreadEvent();
    i->SetString(wxString::Format("GENERATION_DONE:ok"));
    i->SetPayload(results);
    wxQueueEvent(eventHandler, i);
    // send to the queue manager
    wxThreadEvent *j = new wxThreadEvent();
    j->SetString(wxString::Format("QUEUE:%d", QM::QueueEvents::ITEM_FINISHED));
    j->SetPayload(myItem);
    wxQueueEvent(eventHandler, j);
    return;
}
void MainWindowUI::HandleSDLog(sd_log_level_t level, const char *text, void *data)
{
    if (level != sd_log_level_t::SD_LOG_INFO)
    {
        return;
    }
    // wxEvtHandler *eventHandler = (wxEvtHandler *)data;
    MainWindowUI *ui = (MainWindowUI *)data;
    wxEvtHandler *eventHandler = ui->GetEventHandler();
    wxThreadEvent *e = new wxThreadEvent();
    e->SetString(wxString::Format("SD_MESSAGE:%s", text));
    e->SetPayload(level);
    wxQueueEvent(eventHandler, e);
}
void MainWindowUI::OnQueueItemManagerItemStatusChanged(QM::QueueItem item)
{
    /*
    this->m_joblist->AppendTextColumn("Id");
    this->m_joblist->AppendTextColumn("Created");
    this->m_joblist->AppendTextColumn("Model");
    this->m_joblist->AppendTextColumn("Sampler");
    this->m_joblist->AppendTextColumn("Seed");
    this->m_joblist->AppendTextColumn("Status");
    */
    // TODO: how to update an item in the table...
    // auto item = this->m_joblist->RowToItem();
    auto store = this->m_joblist->GetStore();
    for (auto it = this->JobTableItems->begin(); it != this->JobTableItems->end(); ++it)
    {
        if (it->second.id == item.id)
        {
            // auto store = *this->m_joblist()->GetStore();
            int lastCol = this->m_joblist->GetColumnCount();
            // always update the last col, so the last col is always need to be the status
            store->SetValueByRow(wxVariant(QM::QueueStatus_str[item.status]), it->first, lastCol - 1);
            this->m_joblist->Refresh();
            break;
        }
    }
}
MainWindowUI::~MainWindowUI()
{
    // clean up things...
    if (this->modelLoaded)
    {
        free_sd_ctx(this->sd_ctx);
    }
    for (int i = 0; i < this->threads.size(); i++)
    {
        this->threads.at(i).join();
    }
    this->threads.clear();
    this->Destroy();
    exit(0);
}
void MainWindowUI::loadVaeList()
{
    if (!std::filesystem::exists(this->cfg->vae))
    {
        std::filesystem::create_directories(this->cfg->vae);
    }
    this->LoadFileList(sd_gui_utils::DirTypes::VAE);
}
void MainWindowUI::OnQueueItemManagerItemUpdated(QM::QueueItem item)
{
}
void MainWindowUI::StartLoadModel()
{
    // prepare
    auto selection = this->m_model->GetStringSelection();
    if (selection == "-none-")
    {
        free_sd_ctx(this->sd_ctx);
        return;
    }
    if (this->modelLoaded)
    {
        free_sd_ctx(this->sd_ctx);
    }
    // disable ui
    this->m_model->Disable();
    this->m_vae->Disable();
    this->m_refresh->Disable();
    this->sd_params->model_path = this->ModelFiles.at(selection.ToStdString());
    this->sd_params->lora_model_dir = this->fileConfig->Read("/paths/lora", "").ToStdString();
    // this->threads.push_back(std::thread(std::bind(&MainWindowUI::LoadModel, this, this->GetEventHandler())));
}