From 14337cf5b302c5385f0ae1393caf6df7e83fc539 Mon Sep 17 00:00:00 2001
From: znone <glyc@sina.com.cn>
Date: Sat, 07 Dec 2019 06:52:19 +0000
Subject: [PATCH] 1. 允许绑定字段到std::optional和std::any 2. 增加函数bind_fields可以一次绑定到多个字段 3. 查询函数返回数据库对象自身,以支持链式调用

---
 include/qtl_odbc.hpp |  197 +++++++++++++++++++++++++++++++++++++++++++-----
 1 files changed, 174 insertions(+), 23 deletions(-)

diff --git a/include/qtl_odbc.hpp b/include/qtl_odbc.hpp
index 94ed746..c42d812 100644
--- a/include/qtl_odbc.hpp
+++ b/include/qtl_odbc.hpp
@@ -16,6 +16,11 @@
 #include <sys/time.h>
 #endif //_WIN32
 
+#if (ODBCVER >= 0x0380) && (WIN32_WINNT >= 0x0602)
+#define QTL_ODBC_ENABLE_ASYNC_MODE 1
+#endif //ODBC 3.80 && Windows
+
+
 #include "qtl_common.hpp"
 
 namespace qtl
@@ -34,7 +39,7 @@
 	error(const object<Type>& h, SQLINTEGER code);
 	error(SQLINTEGER code, const char* msg) : m_errno(code), m_errmsg(msg) { }
 	SQLINTEGER code() const { return m_errno; }
-	operator bool() { return m_errno!=SQL_SUCCESS || m_errno!=SQL_SUCCESS_WITH_INFO; }
+	operator bool() const { return m_errno!=SQL_SUCCESS || m_errno!=SQL_SUCCESS_WITH_INFO; }
 	virtual const char* what() const throw() override { return m_errmsg.data(); }
 private:
 	SQLINTEGER m_errno;
@@ -149,6 +154,13 @@
 		verify_error(SQLSetEnvAttr(m_handle, SQL_ATTR_ODBC_VERSION, (SQLPOINTER)SQL_OV_ODBC3, 0));
 	}
 	environment(environment&& src) : object(std::forward<environment>(src)) { }
+
+	int32_t version() const
+	{
+		int32_t ver = 0;
+		verify_error(SQLGetEnvAttr(m_handle, SQL_ATTR_ODBC_VERSION, &ver, sizeof(DWORD), NULL));
+		return ver;
+	}
 };
 
 class statement final : public object<SQL_HANDLE_STMT>
@@ -445,7 +457,7 @@
 	void bind_field(SQLUSMALLINT index, qtl::bind_string_helper<T>&& v)
 	{
 		SQLLEN length=0;
-		SQLColAttribute(m_handle, index+1, SQL_DESC_LENGTH, NULL, 0, NULL, &length);
+		verify_error(SQLColAttribute(m_handle, index+1, SQL_DESC_LENGTH, NULL, 0, NULL, &length));
 		typename qtl::bind_string_helper<T>::char_type* data=v.alloc(length);
 		bind_field(index, data, length+1);
 		m_params[index].m_after_fetch=[v](const param_data& p) mutable {
@@ -488,7 +500,7 @@
 		};
 	}
 
-	void bind_field(size_t index, blobbuf&& value)
+	void bind_field(SQLUSMALLINT index, blobbuf&& value)
 	{
 		m_params[index].m_data = nullptr;
 		m_params[index].m_size = 0;
@@ -498,7 +510,7 @@
 	}
 
 	template<typename Type>
-	void bind_field(size_t index, indicator<Type>&& value)
+	void bind_field(SQLUSMALLINT index, indicator<Type>&& value)
 	{
 		qtl::bind_field(*this, index, value.data);
 		param_data& param=m_params[index];
@@ -521,6 +533,145 @@
 		};
 	}
 
+#ifdef _QTL_ENABLE_CPP17
+
+	template<typename Type>
+	void bind_field(SQLUSMALLINT index, std::optional<Type>&& value)
+	{
+		qtl::bind_field(*this, index, *value);
+		param_data& param = m_params[index];
+		auto fetch_fun = param.m_after_fetch;
+		param.m_after_fetch = [fetch_fun, &value](const param_data& p) {
+			if (fetch_fun) fetch_fun(p);
+			if (p.m_indicator == SQL_NULL_DATA)
+				value.reset();
+		};
+	}
+
+	void bind_field(SQLUSMALLINT index, std::any&& value)
+	{
+		SQLLEN type = 0, isUnsigned=SQL_FALSE;
+		verify_error(SQLColAttribute(m_handle, index + 1, SQL_DESC_TYPE, NULL, 0, NULL, &type));
+		verify_error(SQLColAttribute(m_handle, index + 1, SQL_DESC_UNSIGNED, NULL, 0, NULL, &isUnsigned));
+		switch (type)
+		{
+		case SQL_BIT:
+			value.emplace<bool>();
+			bind_field(index, std::forward<bool>(std::any_cast<bool&>(value)));
+			break;
+		case SQL_TINYINT:
+			if (isUnsigned)
+			{
+				value.emplace<uint8_t>();
+				bind_field(index, std::forward<uint8_t>(std::any_cast<uint8_t&>(value)));
+			}
+			else
+			{
+				value.emplace<int8_t>();
+				bind_field(index, std::forward<int8_t>(std::any_cast<int8_t&>(value)));
+			}
+			break;
+		case SQL_SMALLINT:
+			if (isUnsigned)
+			{
+				value.emplace<uint16_t>();
+				bind_field(index, std::forward<uint16_t>(std::any_cast<uint16_t&>(value)));
+			}
+			else
+			{
+				value.emplace<int16_t>();
+				bind_field(index, std::forward<int16_t>(std::any_cast<int16_t&>(value)));
+			}
+			break;
+		case SQL_INTEGER:
+			if (isUnsigned)
+			{
+				value.emplace<uint32_t>();
+				bind_field(index, std::forward<uint32_t>(std::any_cast<uint32_t&>(value)));
+			}
+			else
+			{
+				value.emplace<int32_t>();
+				bind_field(index, std::forward<int32_t>(std::any_cast<int32_t&>(value)));
+			}
+			break;
+		case SQL_BIGINT:
+			if (isUnsigned)
+			{
+				value.emplace<uint64_t>();
+				bind_field(index, std::forward<uint64_t>(std::any_cast<uint64_t&>(value)));
+			}
+			else
+			{
+				value.emplace<int64_t>();
+				bind_field(index, std::forward<int64_t>(std::any_cast<int64_t&>(value)));
+			}
+			break;
+		case SQL_FLOAT:
+			value.emplace<float>();
+			bind_field(index, std::forward<float>(std::any_cast<float&>(value)));
+			break;
+		case SQL_DOUBLE:
+			value.emplace<double>();
+			bind_field(index, std::forward<double>(std::any_cast<double&>(value)));
+			break;
+		case SQL_NUMERIC:
+			value.emplace<SQL_NUMERIC_STRUCT>();
+			bind_field(index, std::forward<SQL_NUMERIC_STRUCT>(std::any_cast<SQL_NUMERIC_STRUCT&>(value)));
+			break;
+		case SQL_TIME:
+			value.emplace<SQL_TIME_STRUCT>();
+			bind_field(index, std::forward<SQL_TIME_STRUCT>(std::any_cast<SQL_TIME_STRUCT&>(value)));
+			break;
+		case SQL_DATE:
+			value.emplace<SQL_DATE_STRUCT>();
+			bind_field(index, std::forward<SQL_DATE_STRUCT>(std::any_cast<SQL_DATE_STRUCT&>(value)));
+			break;
+		case SQL_TIMESTAMP:
+			value.emplace<SQL_TIMESTAMP_STRUCT>();
+			bind_field(index, std::forward<SQL_TIMESTAMP_STRUCT>(std::any_cast<SQL_TIMESTAMP_STRUCT&>(value)));
+			break;
+		case SQL_INTERVAL_MONTH:
+		case SQL_INTERVAL_YEAR:
+		case SQL_INTERVAL_YEAR_TO_MONTH:
+		case SQL_INTERVAL_DAY:
+		case SQL_INTERVAL_HOUR:
+		case SQL_INTERVAL_MINUTE:
+		case SQL_INTERVAL_SECOND:
+		case SQL_INTERVAL_DAY_TO_HOUR:
+		case SQL_INTERVAL_DAY_TO_MINUTE:
+		case SQL_INTERVAL_DAY_TO_SECOND:
+		case SQL_INTERVAL_HOUR_TO_MINUTE:
+		case SQL_INTERVAL_HOUR_TO_SECOND:
+		case SQL_INTERVAL_MINUTE_TO_SECOND:
+			value.emplace<SQL_INTERVAL_STRUCT>();
+			bind_field(index, std::forward<SQL_INTERVAL_STRUCT>(std::any_cast<SQL_INTERVAL_STRUCT&>(value)));
+			break;
+		case SQL_CHAR:
+			value.emplace<std::string>();
+			bind_field(index, qtl::bind_string(std::any_cast<std::string&>(value)));
+			break;
+		case SQL_GUID:
+			value.emplace<SQLGUID>();
+			bind_field(index, std::forward<SQLGUID>(std::any_cast<SQLGUID&>(value)));
+			break;
+		case SQL_BINARY:
+			value.emplace<blobbuf>();
+			bind_field(index, std::forward<blobbuf>(std::any_cast<blobbuf&>(value)));
+			break;
+		default:
+			throw odbc::error(*this, SQL_ERROR);
+		}
+		param_data& param = m_params[index];
+		auto fetch_fun = param.m_after_fetch;
+		param.m_after_fetch = [fetch_fun, &value](const param_data& p) {
+			if (fetch_fun) fetch_fun(p);
+			if (p.m_indicator == SQL_NULL_DATA)
+				value.reset();
+		};
+	}
+
+#endif // C++17
 	template<typename Types>
 	void execute(const Types& params)
 	{
@@ -720,9 +871,9 @@
 	void open(const char* server_name, size_t server_name_length, 
 		const char* user_name, size_t user_name_length, const char* password, size_t password_length)
 	{
-		if(m_opened) close();
+		if (m_opened) close();
 		verify_error(SQLConnectA(m_handle, (SQLCHAR*)server_name, server_name_length, (SQLCHAR*)user_name, user_name_length, (SQLCHAR*)password, password_length));
-		m_opened=true;
+		m_opened = true;
 	}
 	void open(const char* server_name, const char* user_name, const char* password)
 	{
@@ -732,21 +883,21 @@
 	{
 		open(server_name.data(), server_name.size(), user_name.data(), user_name.size(), password.data(), password.size());
 	}
-	void open(const char* input_text, size_t text_length=SQL_NTS, SQLSMALLINT driver_completion=SQL_DRIVER_NOPROMPT, SQLHWND hwnd=NULL)
+	void open(const char* input_text, size_t text_length = SQL_NTS, SQLSMALLINT driver_completion = SQL_DRIVER_NOPROMPT, SQLHWND hwnd = NULL)
 	{
 		m_connection.resize(512);
 		SQLSMALLINT out_len;
-		if(m_opened) close();
-		verify_error(SQLDriverConnectA(m_handle, hwnd, (SQLCHAR*)input_text, (SQLSMALLINT)text_length, 
+		if (m_opened) close();
+		verify_error(SQLDriverConnectA(m_handle, hwnd, (SQLCHAR*)input_text, (SQLSMALLINT)text_length,
 			(SQLCHAR*)m_connection.data(), (SQLSMALLINT)m_connection.size(), &out_len, driver_completion));
 		m_connection.resize(out_len);
-		m_opened=true;
+		m_opened = true;
 	}
-	void open(const std::string& input_text, SQLSMALLINT driver_completion=SQL_DRIVER_NOPROMPT, SQLHWND hwnd=NULL)
+	void open(const std::string& input_text, SQLSMALLINT driver_completion = SQL_DRIVER_NOPROMPT, SQLHWND hwnd = NULL)
 	{
 		open(input_text.data(), input_text.size(), driver_completion, hwnd);
 	}
-	void open(SQLHWND hwnd, SQLSMALLINT driver_completion=SQL_DRIVER_COMPLETE)
+	void open(SQLHWND hwnd, SQLSMALLINT driver_completion = SQL_DRIVER_COMPLETE)
 	{
 		open("", SQL_NTS, driver_completion, hwnd);
 	}
@@ -755,27 +906,27 @@
 	template<typename InputPred>
 	void open(const char* connection_text, size_t text_length, InputPred&& pred)
 	{
-		SQLSMALLINT length=0;
-		SQLINTEGER ret=SQL_SUCCESS;
+		SQLSMALLINT length = 0;
+		SQLINTEGER ret = SQL_SUCCESS;
 		std::string input_text;
-		if(m_opened) close();
-		if(text_length==SQL_NTS)
-			input_text=connection_text;
+		if (m_opened) close();
+		if (text_length == SQL_NTS)
+			input_text = connection_text;
 		else
 			input_text.assign(connection_text, text_length);
 		m_connection.resize(1024);
-		while( (ret=SQLBrowseConnectA(m_handle, (SQLCHAR*)input_text.data(), SQL_NTS, 
+		while ((ret = SQLBrowseConnectA(m_handle, (SQLCHAR*)input_text.data(), SQL_NTS,
 			(SQLCHAR*)m_connection.data(), m_connection.size(), &length)) == SQL_NEED_DATA)
 		{
 			connection_parameters parameters;
 			parse_browse_string(m_connection.data(), length, parameters);
-			if(!pred(parameters))
+			if (!pred(parameters))
 				throw error(SQL_NEED_DATA, "User cancel operation.");
-			input_text=create_connection_text(parameters);
+			input_text = create_connection_text(parameters);
 		}
-		if(ret==SQL_ERROR || ret==SQL_SUCCESS_WITH_INFO)
+		if (ret == SQL_ERROR || ret == SQL_SUCCESS_WITH_INFO)
 			verify_error(ret);
-		m_opened=true;
+		m_opened = true;
 	}
 	template<typename InputPred>
 	void open(const char* connection_text, InputPred&& pred)
@@ -863,7 +1014,7 @@
 	{
 		SQLINTEGER value;
 		get_attribute(SQL_ATTR_CONNECTION_DEAD, value);
-		return value==SQL_CD_FALSE;
+		return value == SQL_CD_FALSE;
 	}
 
 	statement open_command(const char* query_text, size_t text_length)

--
Gitblit v1.9.3