From a3c860ee7a791757c223180b041764bb90870893 Mon Sep 17 00:00:00 2001
From: NI1993 <60190218+NI1993@users.noreply.github.com>
Date: Tue, 20 Aug 2024 13:20:53 +0300
Subject: [PATCH] Optionally allow keeping original column names in
 fetch_dataframe method

---
 redshift_connector/cursor.py | 18 +++++++++++++++---
 1 file changed, 15 insertions(+), 3 deletions(-)

diff --git a/redshift_connector/cursor.py b/redshift_connector/cursor.py
index 03e7af4..c89b0e1 100644
--- a/redshift_connector/cursor.py
+++ b/redshift_connector/cursor.py
@@ -504,13 +504,18 @@ def __next__(self: "Cursor") -> typing.List:
             else:
                 raise StopIteration()
 
-    def fetch_dataframe(self: "Cursor", num: typing.Optional[int] = None) -> "pandas.DataFrame":
+    def fetch_dataframe(
+        self: "Cursor",
+        num: typing.Optional[int] = None,
+        lowercase_column_names: bool = True,
+    ) -> "pandas.DataFrame":
         """
         Fetches a user defined number of rows of a query result as a :class:`pandas.DataFrame`.
 
         Parameters
         ----------
         num : Optional[int] The number of rows to retrieve. If unspecified, all rows will be retrieved
+        lowercase_column_names : bool If set to True, column names in returend dataframe will be in lower case. Else, original column names are returned.
 
         Returns
         -------
@@ -519,11 +524,18 @@ def fetch_dataframe(self: "Cursor", num: typing.Optional[int] = None) -> "pandas
         try:
             import pandas
         except ModuleNotFoundError:
-            raise ModuleNotFoundError(MISSING_MODULE_ERROR_MSG.format(module="pandas"))
+            raise ModuleNotFoundError(MISSING_MODULE_ERROR_MSG.format(module="pandas"))        
+
+        def _handle_name(name: typing.Union[str, bytes]) -> typing.Union[str, bytes]:
+            """Inner helper to optionally handle column name"""
+            output = name
+            if lowercase_column_names:
+                output = name.lower()
+            return output
 
         columns: typing.Optional[typing.List[typing.Union[str, bytes]]] = None
         try:
-            columns = [column[0].lower() for column in self.description]
+            columns = [_handle_name(column[0]) for column in self.description]
         except:
             warn("No row description was found. pandas dataframe will be missing column labels.", stacklevel=2)