From bce9758e32d0c3029a26efb486a342ff05f6df72 Mon Sep 17 00:00:00 2001
From: znone <glyc@sina.com.cn>
Date: Sat, 27 Feb 2021 12:57:25 +0000
Subject: [PATCH] PostgreSQL: support array type

---
 include/qtl_postgres.hpp |  109 ++++++++++++++++++++++++++++++++++--
 README_CN.md             |    6 +
 README.md                |    6 +
 3 files changed, 110 insertions(+), 11 deletions(-)

diff --git a/README.md b/README.md
index b235fd1..a495c81 100644
--- a/README.md
+++ b/README.md
@@ -411,13 +411,14 @@
 | smallint | int16_t |
 | bigint | int64_t |
 | real | float |
-| DOUBLE | double |
+| double | double |
 | text | const char*<br>std::string |
 | bytea | qtl::const_blob_data<br>std::vector<uint8_t> |
 | oid | qtl::postgres::large_object |
 | date | qtl::postgres::date |
 | timestamp | qtl::postgres::timestamp |
 | interval | qtl::postgres::interval |
+| array | std::vector |
 
 ### PostgreSQL field data binding
 
@@ -428,13 +429,14 @@
 | smallint | int16_t |
 | bigint | int64_t |
 | real | float |
-| DOUBLE | double |
+| double | double |
 | text | char[N]<br>std::array&lt;char, N&gt;<br>std::string |
 | bytea | qtl::const_blob_data<br>qtl::blob_data<br>std::vector<uint8_t> |
 | oid | qtl::postgres::large_object |
 | date | qtl::postgres::date |
 | timestamp | qtl::postgres::timestamp |
 | interval | qtl::postgres::interval |
+| array | std::vector |
 
 ### C ++ classes related to PostgreSQL
 - qtl::postgres::database
diff --git a/README_CN.md b/README_CN.md
index c44f017..2bfe98c 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -411,13 +411,14 @@
 | smallint | int16_t |
 | bigint | int64_t |
 | real | float |
-| DOUBLE | double |
+| double | double |
 | text | const char*<br>std::string |
 | bytea | qtl::const_blob_data<br>std::vector<uint8_t> |
 | oid | qtl::postgres::large_object |
 | date | qtl::postgres::date |
 | timestamp | qtl::postgres::timestamp |
 | interval | qtl::postgres::interval |
+| array | std::vector |
 
 ### PostgreSQL的字段数据绑定
 
@@ -428,13 +429,14 @@
 | smallint | int16_t |
 | bigint | int64_t |
 | real | float |
-| DOUBLE | double |
+| double | double |
 | text | char[N]<br>std::array&lt;char, N&gt;<br>std::string |
 | bytea | qtl::const_blob_data<br>qtl::blob_data<br>std::vector<uint8_t> |
 | oid | qtl::postgres::large_object |
 | date | qtl::postgres::date |
 | timestamp | qtl::postgres::timestamp |
 | interval | qtl::postgres::interval |
+| array | std::vector |
 
 ### ODBC相关的C++类
 - qtl::postgres::database
diff --git a/include/qtl_postgres.hpp b/include/qtl_postgres.hpp
index 47be380..cbb780a 100644
--- a/include/qtl_postgres.hpp
+++ b/include/qtl_postgres.hpp
@@ -98,8 +98,8 @@
 		return ntoh(static_cast<uint64_t>(v));
 	}
 
-	template<typename T>
-	inline T& ntoh_inplace(typename std::enable_if<std::is_integral<T>::value && !std::is_const<T>::value, T>::type& v)
+	template<typename T, typename = typename std::enable_if<std::is_integral<T>::value && !std::is_const<T>::value>::type>
+	inline T& ntoh_inplace(T& v)
 	{
 		v = ntoh(v);
 		return v;
@@ -134,8 +134,8 @@
 		return hton(static_cast<uint64_t>(v));
 	}
 
-	template<typename T>
-	inline T& hton_inplace(typename std::enable_if<std::is_integral<T>::value && !std::is_const<T>::value>::type& v)
+	template<typename T, typename = typename std::enable_if<std::is_integral<T>::value && !std::is_const<T>::value>::type>
+	inline T& hton_inplace(T& v)
 	{
 		v = hton(v);
 		return v;
@@ -684,6 +684,17 @@
 	int m_fd;
 };
 
+struct array_header
+{
+	int32_t ndim;
+	int32_t flags;
+	int32_t elemtype;
+	struct dimension {
+		int32_t length;
+		int32_t lower_bound;
+	} dims[1];
+};
+
 /*
 	template<typename T>
 	struct oid_traits
@@ -901,6 +912,76 @@
 	{
 		return object_traits<int32_t>::data(v.oid(), data);
 	}
+};
+
+template<typename T, Oid id> 
+struct array_traits : public base_object_traits<std::vector<T>, id>
+{
+	typedef typename base_object_traits<std::vector<T>, id>::value_type value_type;
+	static value_type get(const char* data, size_t n)
+	{
+		array_header header = *reinterpret_cast<const array_header*>(data);
+		detail::ntoh_inplace(header.ndim);
+		detail::ntoh_inplace(header.flags);
+		detail::ntoh_inplace(header.elemtype);
+		detail::ntoh_inplace(header.dims[0].length);
+		detail::ntoh_inplace(header.dims[0].lower_bound);
+		if (header.ndim != 1 || !object_traits<T>::is_match(header.elemtype))
+			throw std::bad_cast();
+
+		std::vector<T> result;
+		data += sizeof(array_header);
+		result.reserve(header.dims[0].length);
+
+		for (int32_t i = 0; i != header.dims[0].length; i++)
+		{
+			int32_t size = detail::ntoh(*reinterpret_cast<const int32_t*>(data));
+			const char* elem_data = data + sizeof(int32_t);
+			result.push_back(object_traits<T>::get(elem_data, size));
+			data = elem_data + size;
+		}
+		return result;
+	}
+	static std::pair<const char*, size_t> data(const std::vector<T>& v, std::vector<char>& data)
+	{
+		assert(v.size() <= INT32_MAX);
+		data.resize(sizeof(array_header));
+		array_header* header = reinterpret_cast<array_header*>(data.data());
+		header->ndim = detail::hton(1);
+		header->flags = detail::hton(0);
+		header->elemtype = detail::hton(object_traits<T>::type);
+		header->dims[0].length = detail::hton(static_cast<int32_t>(v.size()));
+		header->dims[0].lower_bound = detail::hton(1);
+		std::vector<char> temp;
+		for (const T& e : v)
+		{
+			std::pair<const char*, size_t> buffer = object_traits<T>::data(e, temp);
+			int32_t size = detail::hton(static_cast<int32_t>(buffer.second));
+			data.insert(data.end(), reinterpret_cast<char*>(&size), reinterpret_cast<char*>(&size) + sizeof(int32_t));
+			data.insert(data.end(), buffer.first, buffer.first + buffer.second);
+		}
+		return std::make_pair(reinterpret_cast<const char*>(data.data()), data.size());
+	}
+};
+
+template<> class object_traits<std::vector<int16_t>> : public array_traits<int16_t, INT2ARRAYOID>
+{
+};
+
+template<> class object_traits<std::vector<int32_t>> : public array_traits<int32_t, INT4ARRAYOID>
+{
+};
+
+template<> class object_traits<std::vector<float>> : public array_traits<float, FLOAT4ARRAYOID>
+{
+};
+
+template<> class object_traits<std::vector<std::string>> : public array_traits<std::string, TEXTARRAYOID>
+{
+};
+
+template<> class object_traits<std::vector<large_object>> : public array_traits<large_object, OIDARRAYOID>
+{
 };
 
 struct binder
@@ -1174,7 +1255,10 @@
 	template<class Type>
 	void bind_field(size_t index, Type&& value)
 	{
-		value = m_binders[index].get<typename std::remove_const<Type>::type>();
+		if (m_res.is_null(0, static_cast<int>(index)))
+			value = Type();
+		else
+			value = m_binders[index].get<typename std::remove_const<Type>::type>();
 	}
 
 	void bind_field(size_t index, char* value, size_t length)
@@ -1208,11 +1292,22 @@
 
 	void bind_field(size_t index, large_object&& value)
 	{
-		value = m_binders[index].get<large_object>(m_conn);
+		if (m_res.is_null(0, static_cast<int>(index)))
+			value.close();
+		else
+			value = m_binders[index].get<large_object>(m_conn);
 	}
 	void bind_field(size_t index, blob_data&& value)
 	{
-		m_binders[index].get(value);
+		if (m_res.is_null(0, static_cast<int>(index)))
+		{
+			value.data = nullptr;
+			value.size = 0;
+		}
+		else
+		{
+			m_binders[index].get(value);
+		}
 	}
 
 protected:

--
Gitblit v1.9.3