wxWidgets based Stable Diffusion C++ GUi
Ferenc Szontágh
2024-02-04 5e5b8dcd5ce488ec4fdcb5ff07272830711566de
ui/MainWindowUI.cpp
@@ -105,6 +105,11 @@
    }
}
void MainWindowUI::onRandomGenerateButton(wxCommandEvent &event)
{
    this->m_seed->SetValue(sd_gui_utils::generateRandomInt(100000000, 999999999));
}
void MainWindowUI::onResolutionSwap(wxCommandEvent &event)
{
    auto oldW = this->m_width->GetValue();
@@ -137,10 +142,27 @@
void MainWindowUI::onContextMenu(wxDataViewEvent &event)
{
    auto *source = (wxDataViewListCtrl *)event.GetEventObject();
    wxMenu menu;
    menu.Append(0, "Calculate Hash");
    menu.Append(1, "Download info from CivitAi.com");
    menu.SetClientData((void *)source);
    if (source == this->m_joblist)
    {
        menu.Append(1, "Copy and restart");
        menu.Append(2, "Copy paramters to text2img");
        menu.Append(3, "Copy paramters to img2img");
        menu.Append(4, "Details...");
    }
    if (source == this->m_data_model_list)
    {
        menu.Append(1, "Calculate Hash");
        menu.Append(2, "Download info from CivitAi.com");
    }
    menu.Connect(wxEVT_COMMAND_MENU_SELECTED, wxCommandEventHandler(MainWindowUI::OnPopupClick), NULL, this);
    PopupMenu(&menu);
}
@@ -280,99 +302,178 @@
    }
}
void MainWindowUI::HandleSDLog(sd_log_level_t level, const char *text, void *data)
void MainWindowUI::LoadFileList(sd_gui_utils::DirTypes type)
{
    if (level == sd_log_level_t::SD_LOG_INFO || level == sd_log_level_t::SD_LOG_ERROR)
    std::string basepath;
    switch (type)
    {
        auto *eventHandler = (wxEvtHandler *)data;
        wxThreadEvent *e = new wxThreadEvent();
        e->SetString(wxString::Format("SD_MESSAGE:%s", text));
        e->SetPayload(level);
        wxQueueEvent(eventHandler, e);
    case sd_gui_utils::DirTypes::VAE:
        this->VaeFiles.clear();
        this->m_vae->Clear();
        this->m_vae->Append("-none-");
        this->m_vae->Select(0);
        basepath = this->cfg->vae;
        break;
    case sd_gui_utils::DirTypes::LORA:
        basepath = this->cfg->lora;
        break;
    case sd_gui_utils::DirTypes::CHECKPOINT:
        this->ModelFiles.clear();
        this->m_model->Clear();
        this->m_model->Append("-none-");
        this->m_model->Select(0);
        basepath = this->cfg->model;
        break;
    case sd_gui_utils::DirTypes::PRESETS:
        this->Presets.clear();
        this->m_preset_list->Clear();
        this->m_preset_list->Append("-none-");
        this->m_preset_list->Select(0);
        basepath = this->cfg->presets;
        break;
    }
    if (!std::filesystem::exists(basepath))
    {
        std::filesystem::create_directories(basepath);
    }
    int i = 0;
    for (auto const &dir_entry : std::filesystem::recursive_directory_iterator(basepath))
    {
        if (!dir_entry.exists() || !dir_entry.is_regular_file() || !dir_entry.path().has_extension())
        {
            continue;
        }
        std::filesystem::path path = dir_entry.path();
        std::string ext = path.extension().string();
        if (type == sd_gui_utils::DirTypes::CHECKPOINT || type == sd_gui_utils::DirTypes::VAE)
        {
            if (ext != ".safetensors" && ext != ".ckpt")
            {
                continue;
            }
        }
        if (type == sd_gui_utils::DirTypes::PRESETS)
        {
            if (ext != ".json")
            {
                continue;
            }
        }
        std::string name = path.filename().replace_extension("").string();
        // prepend the subdirectory to the modelname
        // // wxFileName::GetPathSeparator()
        auto path_name = path.string();
        sd_gui_utils::replace(path_name, basepath, "");
        sd_gui_utils::replace(path_name, "//", "");
        sd_gui_utils::replace(path_name, "\\\\", "");
        sd_gui_utils::replace(path_name, ext, "");
        name = path_name.substr(1);
        if (type == sd_gui_utils::CHECKPOINT)
        {
            this->m_model->Append(name);
            this->ModelFiles.emplace(name, dir_entry.path().string());
        }
        if (type == sd_gui_utils::VAE)
        {
            this->m_vae->Append(name);
            this->VaeFiles.emplace(name, dir_entry.path().string());
        }
        if (type == sd_gui_utils::PRESETS)
        {
            sd_gui_utils::generator_preset preset;
            std::ifstream f(path.string());
            try
            {
                nlohmann::json data = nlohmann::json::parse(f);
                preset = data;
                preset.path = path.string();
                this->m_preset_list->Append(preset.name);
                this->Presets.emplace(preset.name, preset);
            }
            catch (const std::exception &e)
            {
                std::remove(path.string().c_str());
                std::cerr << e.what() << '\n';
            }
        }
    }
    if (type == sd_gui_utils::CHECKPOINT)
    {
        this->logs->AppendText(fmt::format("Loaded checkpoints: {}\n", this->ModelFiles.size()));
    }
    if (type == sd_gui_utils::VAE)
    {
        this->logs->AppendText(fmt::format("Loaded vaes: {}\n", this->VaeFiles.size()));
    }
    if (type == sd_gui_utils::PRESETS)
    {
        this->logs->AppendText(fmt::format("Loaded presets: {}\n", this->Presets.size()));
        if (this->Presets.size() > 0)
        {
            this->m_preset_list->Enable();
        }
    }
}
void MainWindowUI::OnQueueItemManagerItemStatusChanged(QM::QueueItem item)
void MainWindowUI::OnQueueItemManagerItemAdded(QM::QueueItem item)
{
    wxVector<wxVariant> data;
    auto created_at = sd_gui_utils::formatUnixTimestampToDate(item.created_at);
    data.push_back(wxVariant(std::to_string(item.id)));
    data.push_back(wxVariant(created_at));
    data.push_back(wxVariant(item.model));
    data.push_back(wxVariant(sd_gui_utils::sample_method_str[(int)item.params.sample_method]));
    data.push_back(wxVariant(std::to_string(item.params.seed)));
    data.push_back(item.status == QM::QueueStatus::DONE ? 100 : 1); // progressbar
    data.push_back(wxString("-.--it/s"));                           // speed
    data.push_back(wxVariant(QM::QueueStatus_str[item.status]));    // status
    auto store = this->m_joblist->GetStore();
    int lastCol = this->m_joblist->GetColumnCount() - 1;
    QM::QueueItem *nItem = new QM::QueueItem(item);
    for (unsigned int i = 0; i < store->GetItemCount(); i++)
    this->JobTableItems[item.id] = nItem;
    //  store->AppendItem(data, wxUIntPtr(this->JobTableItems[item.id]));
    store->PrependItem(data, wxUIntPtr(this->JobTableItems[item.id]));
}
MainWindowUI::~MainWindowUI()
{
    // this->Hide();
    /* for (int i = 0; i < this->threads.size(); i++)
       {
           if (this->threads.at(i).joinable())
           {
               this->threads.at(i).join();
           }
       }*/
    for (auto &t : this->threads)
    {
        auto _item = store->GetItem(i);
        auto _item_data = store->GetItemData(_item);
        auto *_qitem = reinterpret_cast<QM::QueueItem *>(_item_data);
        if (_qitem->id == item.id)
        {
            store->SetValueByRow(wxVariant(QM::QueueStatus_str[item.status]), i, lastCol);
            this->m_joblist->Refresh();
            break;
        }
        t->join();
    }
}
void MainWindowUI::loadModelList()
void MainWindowUI::loadVaeList()
{
    this->m_sampler->Clear();
    for (auto sampler : sd_gui_utils::sample_method_str)
    if (!std::filesystem::exists(this->cfg->vae))
    {
        int _u = this->m_sampler->Append(sampler);
        if (sampler == sd_gui_utils::sample_method_str[this->sd_params->sample_method])
        {
            this->m_sampler->Select(_u);
        }
        std::filesystem::create_directories(this->cfg->vae);
    }
    this->LoadFileList(sd_gui_utils::DirTypes::CHECKPOINT);
    for (auto model : this->ModelFiles)
    {
        // auto size = sd_gui_utils::HumanReadable{std::filesystem::file_size(model.second)};
        uintmax_t size = std::filesystem::file_size(model.second);
        auto humanSize = sd_gui_utils::humanReadableFileSize(size);
        auto hs = wxString::Format("%.1f %s", humanSize.first, humanSize.second);
        wxVector<wxVariant> data;
        data.push_back(wxVariant(model.first));
        data.push_back(hs);
        data.push_back("--");
        this->m_data_model_list->AppendItem(data);
    }
    this->m_data_model_list->Refresh();
    this->LoadFileList(sd_gui_utils::DirTypes::VAE);
}
void MainWindowUI::StartGeneration(QM::QueueItem myJob)
void MainWindowUI::OnQueueItemManagerItemUpdated(QM::QueueItem item)
{
    // here starts the trhead
    // this->threads.push_back(std::thread(std::bind(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob)));
    // this->threads.emplace_back(std::thread(std::bind(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob)));
    // this->threads.emplace_back(std::thread(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob));
    std::thread(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob);
}
void MainWindowUI::HandleSDProgress(int step, int steps, float time, void *data)
{
    sd_gui_utils::VoidHolder *objs = (sd_gui_utils::VoidHolder *)data;
    wxEvtHandler *eventHandler = (wxEvtHandler *)objs->p1;
    QM::QueueItem *myItem = (QM::QueueItem *)objs->p2;
    myItem->step = step;
    myItem->steps = steps;
    myItem->time = time;
    /*
        format it/s
        time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s",
               progress.c_str(), step, steps,
               time > 1.0f || time == 0 ? time : (1.0f / time)
    */
    wxThreadEvent *e = new wxThreadEvent();
    e->SetString(wxString::Format("GENERATION_PROGRESS:%d/%d", step, steps));
    e->SetPayload(myItem);
    wxQueueEvent(eventHandler, e);
}
void MainWindowUI::Generate(wxEvtHandler *eventHandler, QM::QueueItem myItem)
@@ -556,6 +657,109 @@
{
    this->initConfig();
    this->settingsWindow->Destroy();
}
void MainWindowUI::OnPopupClick(wxCommandEvent &evt)
{
    void *data = static_cast<wxMenu *>(evt.GetEventObject())->GetClientData();
}
void MainWindowUI::HandleSDLog(sd_log_level_t level, const char *text, void *data)
{
    if (level == sd_log_level_t::SD_LOG_INFO || level == sd_log_level_t::SD_LOG_ERROR)
    {
        auto *eventHandler = (wxEvtHandler *)data;
        wxThreadEvent *e = new wxThreadEvent();
        e->SetString(wxString::Format("SD_MESSAGE:%s", text));
        e->SetPayload(level);
        wxQueueEvent(eventHandler, e);
    }
}
void MainWindowUI::OnQueueItemManagerItemStatusChanged(QM::QueueItem item)
{
    auto store = this->m_joblist->GetStore();
    int lastCol = this->m_joblist->GetColumnCount() - 1;
    for (unsigned int i = 0; i < store->GetItemCount(); i++)
    {
        auto _item = store->GetItem(i);
        auto _item_data = store->GetItemData(_item);
        auto *_qitem = reinterpret_cast<QM::QueueItem *>(_item_data);
        if (_qitem->id == item.id)
        {
            store->SetValueByRow(wxVariant(QM::QueueStatus_str[item.status]), i, lastCol);
            this->m_joblist->Refresh();
            break;
        }
    }
}
void MainWindowUI::loadModelList()
{
    this->m_sampler->Clear();
    for (auto sampler : sd_gui_utils::sample_method_str)
    {
        int _u = this->m_sampler->Append(sampler);
        if (sampler == sd_gui_utils::sample_method_str[this->sd_params->sample_method])
        {
            this->m_sampler->Select(_u);
        }
    }
    this->LoadFileList(sd_gui_utils::DirTypes::CHECKPOINT);
    for (auto model : this->ModelFiles)
    {
        // auto size = sd_gui_utils::HumanReadable{std::filesystem::file_size(model.second)};
        uintmax_t size = std::filesystem::file_size(model.second);
        auto humanSize = sd_gui_utils::humanReadableFileSize(size);
        auto hs = wxString::Format("%.1f %s", humanSize.first, humanSize.second);
        wxVector<wxVariant> data;
        data.push_back(wxVariant(model.first));
        data.push_back(hs);
        data.push_back("--");
        this->m_data_model_list->AppendItem(data);
    }
    this->m_data_model_list->Refresh();
}
void MainWindowUI::StartGeneration(QM::QueueItem myJob)
{
    // here starts the trhead
    // this->threads.push_back(std::thread(std::bind(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob)));
    // this->threads.emplace_back(std::thread(std::bind(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob)));
    // this->threads.emplace_back(std::thread(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob));
    // std::thread j(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob);
    // std::thread *p(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob);
    std::thread *p = new std::thread(&MainWindowUI::Generate, this, this->GetEventHandler(), myJob);
    this->threads.emplace_back(p);
}
void MainWindowUI::HandleSDProgress(int step, int steps, float time, void *data)
{
    sd_gui_utils::VoidHolder *objs = (sd_gui_utils::VoidHolder *)data;
    wxEvtHandler *eventHandler = (wxEvtHandler *)objs->p1;
    QM::QueueItem *myItem = (QM::QueueItem *)objs->p2;
    myItem->step = step;
    myItem->steps = steps;
    myItem->time = time;
    /*
        format it/s
        time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s",
               progress.c_str(), step, steps,
               time > 1.0f || time == 0 ? time : (1.0f / time)
    */
    wxThreadEvent *e = new wxThreadEvent();
    e->SetString(wxString::Format("GENERATION_PROGRESS:%d/%d", step, steps));
    e->SetPayload(myItem);
    wxQueueEvent(eventHandler, e);
}
void MainWindowUI::OnThreadMessage(wxThreadEvent &e)
@@ -820,174 +1024,4 @@
void MainWindowUI::LoadPresets()
{
    this->LoadFileList(sd_gui_utils::DirTypes::PRESETS);
}
void MainWindowUI::OnQueueItemManagerItemUpdated(QM::QueueItem item)
{
}
void MainWindowUI::loadVaeList()
{
    if (!std::filesystem::exists(this->cfg->vae))
    {
        std::filesystem::create_directories(this->cfg->vae);
    }
    this->LoadFileList(sd_gui_utils::DirTypes::VAE);
}
MainWindowUI::~MainWindowUI()
{
    // this->Hide();
    for (int i = 0; i < this->threads.size(); i++)
    {
        if (this->threads.at(i).joinable())
        {
            this->threads.at(i).join();
        }
    }
}
void MainWindowUI::OnQueueItemManagerItemAdded(QM::QueueItem item)
{
    wxVector<wxVariant> data;
    auto created_at = sd_gui_utils::formatUnixTimestampToDate(item.created_at);
    data.push_back(wxVariant(std::to_string(item.id)));
    data.push_back(wxVariant(created_at));
    data.push_back(wxVariant(item.model));
    data.push_back(wxVariant(sd_gui_utils::sample_method_str[(int)item.params.sample_method]));
    data.push_back(wxVariant(std::to_string(item.params.seed)));
    data.push_back(item.status == QM::QueueStatus::DONE ? 100 : 1); // progressbar
    data.push_back(wxString("-.--it/s"));                           // speed
    data.push_back(wxVariant(QM::QueueStatus_str[item.status]));    // status
    auto store = this->m_joblist->GetStore();
    QM::QueueItem *nItem = new QM::QueueItem(item);
    this->JobTableItems[item.id] = nItem;
    //  store->AppendItem(data, wxUIntPtr(this->JobTableItems[item.id]));
    store->PrependItem(data, wxUIntPtr(this->JobTableItems[item.id]));
}
void MainWindowUI::LoadFileList(sd_gui_utils::DirTypes type)
{
    std::string basepath;
    switch (type)
    {
    case sd_gui_utils::DirTypes::VAE:
        this->VaeFiles.clear();
        this->m_vae->Clear();
        this->m_vae->Append("-none-");
        this->m_vae->Select(0);
        basepath = this->cfg->vae;
        break;
    case sd_gui_utils::DirTypes::LORA:
        basepath = this->cfg->lora;
        break;
    case sd_gui_utils::DirTypes::CHECKPOINT:
        this->ModelFiles.clear();
        this->m_model->Clear();
        this->m_model->Append("-none-");
        this->m_model->Select(0);
        basepath = this->cfg->model;
        break;
    case sd_gui_utils::DirTypes::PRESETS:
        this->Presets.clear();
        this->m_preset_list->Clear();
        this->m_preset_list->Append("-none-");
        this->m_preset_list->Select(0);
        basepath = this->cfg->presets;
        break;
    }
    if (!std::filesystem::exists(basepath))
    {
        std::filesystem::create_directories(basepath);
    }
    int i = 0;
    for (auto const &dir_entry : std::filesystem::recursive_directory_iterator(basepath))
    {
        if (!dir_entry.exists() || !dir_entry.is_regular_file() || !dir_entry.path().has_extension())
        {
            continue;
        }
        std::filesystem::path path = dir_entry.path();
        std::string ext = path.extension().string();
        if (type == sd_gui_utils::DirTypes::CHECKPOINT || type == sd_gui_utils::DirTypes::VAE)
        {
            if (ext != ".safetensors" && ext != ".ckpt")
            {
                continue;
            }
        }
        if (type == sd_gui_utils::DirTypes::PRESETS)
        {
            if (ext != ".json")
            {
                continue;
            }
        }
        std::string name = path.filename().replace_extension("").string();
        // prepend the subdirectory to the modelname
        // // wxFileName::GetPathSeparator()
        auto path_name = path.string();
        sd_gui_utils::replace(path_name, basepath, "");
        sd_gui_utils::replace(path_name, "//", "");
        sd_gui_utils::replace(path_name, "\\\\", "");
        sd_gui_utils::replace(path_name, ext, "");
        name = path_name.substr(1);
        if (type == sd_gui_utils::CHECKPOINT)
        {
            this->m_model->Append(name);
            this->ModelFiles.emplace(name, dir_entry.path().string());
        }
        if (type == sd_gui_utils::VAE)
        {
            this->m_vae->Append(name);
            this->VaeFiles.emplace(name, dir_entry.path().string());
        }
        if (type == sd_gui_utils::PRESETS)
        {
            sd_gui_utils::generator_preset preset;
            std::ifstream f(path.string());
            try
            {
                nlohmann::json data = nlohmann::json::parse(f);
                preset = data;
                preset.path = path.string();
                this->m_preset_list->Append(preset.name);
                this->Presets.emplace(preset.name, preset);
            }
            catch (const std::exception &e)
            {
                std::remove(path.string().c_str());
                std::cerr << e.what() << '\n';
            }
        }
    }
    if (type == sd_gui_utils::CHECKPOINT)
    {
        this->logs->AppendText(fmt::format("Loaded checkpoints: {}\n", this->ModelFiles.size()));
    }
    if (type == sd_gui_utils::VAE)
    {
        this->logs->AppendText(fmt::format("Loaded vaes: {}\n", this->VaeFiles.size()));
    }
    if (type == sd_gui_utils::PRESETS)
    {
        this->logs->AppendText(fmt::format("Loaded presets: {}\n", this->Presets.size()));
        if (this->Presets.size() > 0)
        {
            this->m_preset_list->Enable();
        }
    }
}