From 5a7568be177b7557b811c1b50edc3479e84a5a20 Mon Sep 17 00:00:00 2001
From: fszontagh <szf@fsociety.hu>
Date: Sat, 24 Feb 2024 14:28:39 +0000
Subject: [PATCH] lora n_dims

---
 lora.hpp |   11 ++++++++++-
 1 files changed, 10 insertions(+), 1 deletions(-)

diff --git a/lora.hpp b/lora.hpp
index 66477f1..e2c4612 100644
--- a/lora.hpp
+++ b/lora.hpp
@@ -29,6 +29,15 @@
         return LORA_GRAPH_SIZE;
     }
 
+    static inline int ggml_n_dims_t(const struct TensorStorage tensor) {
+        for (int i = GGML_MAX_DIMS - 1; i >= 1; --i) {
+            if (tensor.ne[i] > 1) {
+                return i + 1;
+            }
+        }
+        return 1;
+    }
+    
     size_t get_params_mem_size() {
         return model_loader.get_params_mem_size(NULL);
     }
@@ -47,7 +56,7 @@
         auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
             const std::string& name = tensor_storage.name;
 
-            struct ggml_tensor* real = ggml_new_tensor(params_ctx, tensor_storage.type, tensor_storage.n_dims, tensor_storage.ne);
+            struct ggml_tensor* real = ggml_new_tensor(params_ctx, tensor_storage.type, this->ggml_n_dims_t(tensor_storage), tensor_storage.ne);
             ggml_allocr_alloc(alloc, real);
 
             *dst_tensor = real;

--
Gitblit v1.9.3