From d565832e9b53467b2ae4924432f8f974c5778781 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Thu, 22 Feb 2024 01:05:05 +0100 Subject: [PATCH 1/5] Add parameterized query support --- pypika/terms.py | 151 ++++++++++++++++++++++++++++----- pypika/tests/test_parameter.py | 104 +++++++++++++++++++++++ 2 files changed, 232 insertions(+), 23 deletions(-) diff --git a/pypika/terms.py b/pypika/terms.py index ce7aed65..79d8c7ba 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -3,7 +3,7 @@ import uuid from datetime import date from enum import Enum -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order from pypika.utils import ( @@ -288,57 +288,116 @@ def get_sql(self, **kwargs: Any) -> str: raise NotImplementedError() +def idx_placeholder_gen(idx: int) -> str: + return str(idx + 1) + + +def named_placeholder_gen(idx: int) -> str: + return f'param{idx + 1}' + + class Parameter(Term): is_aggregate = None - def __init__(self, placeholder: Union[str, int]) -> None: + def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = idx_placeholder_gen) -> None: super().__init__() - self.placeholder = placeholder + self._placeholder = placeholder + + @property + def placeholder(self): + if callable(self._placeholder): + return self._placeholder(None) + + return self._placeholder def get_sql(self, **kwargs: Any) -> str: return str(self.placeholder) + def update_parameters(self, param_key: Any, param_value: Any, **kwargs): + pass -class QmarkParameter(Parameter): - """Question mark style, e.g. ...WHERE name=?""" + def get_param_key(self, placeholder: Any, **kwargs): + return placeholder - def __init__(self) -> None: - pass - def get_sql(self, **kwargs: Any) -> str: - return "?" +class ListParameter(Parameter): + def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = idx_placeholder_gen) -> None: + super().__init__() + self._placeholder = placeholder + self._parameters = list() + @property + def placeholder(self): + if callable(self._placeholder): + return self._placeholder(len(self._parameters)) -class NumericParameter(Parameter): - """Numeric, positional style, e.g. ...WHERE name=:1""" + return self._placeholder - def get_sql(self, **kwargs: Any) -> str: - return ":{placeholder}".format(placeholder=self.placeholder) + def get_parameters(self, **kwargs): + return self._parameters + def update_parameters(self, value: Any, **kwargs): + self._parameters.append(value) -class NamedParameter(Parameter): - """Named style, e.g. ...WHERE name=:name""" + +class DictParameter(Parameter): + def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = named_placeholder_gen) -> None: + super().__init__() + self._placeholder = placeholder + self._parameters = dict() + + @property + def placeholder(self): + if callable(self._placeholder): + return self._placeholder(len(self._parameters)) + + return self._placeholder + + def get_parameters(self, **kwargs): + return self._parameters + + def get_param_key(self, placeholder: Any, **kwargs): + return placeholder[1:] + + def update_parameters(self, param_key: Any, value: Any, **kwargs): + self._parameters[param_key] = value + + +class QmarkParameter(ListParameter): + def get_sql(self, **kwargs): + return '?' + + +class NumericParameter(ListParameter): + """Numeric, positional style, e.g. ...WHERE name=:1""" def get_sql(self, **kwargs: Any) -> str: return ":{placeholder}".format(placeholder=self.placeholder) -class FormatParameter(Parameter): +class FormatParameter(ListParameter): """ANSI C printf format codes, e.g. ...WHERE name=%s""" - def __init__(self) -> None: - pass - def get_sql(self, **kwargs: Any) -> str: return "%s" -class PyformatParameter(Parameter): +class NamedParameter(DictParameter): + """Named style, e.g. ...WHERE name=:name""" + + def get_sql(self, **kwargs: Any) -> str: + return ":{placeholder}".format(placeholder=self.placeholder) + + +class PyformatParameter(DictParameter): """Python extended format codes, e.g. ...WHERE name=%(name)s""" def get_sql(self, **kwargs: Any) -> str: return "%({placeholder})s".format(placeholder=self.placeholder) + def get_param_key(self, placeholder: Any, **kwargs): + return placeholder[2:-2] + class Negative(Term): def __init__(self, term: Term) -> None: @@ -385,9 +444,55 @@ def get_formatted_value(cls, value: Any, **kwargs): return "null" return str(value) - def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", **kwargs: Any) -> str: - sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs) - return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) + def get_sql( + self, + quote_char: Optional[str] = None, + secondary_quote_char: str = "'", + parameter: Parameter = None, + **kwargs: Any, + ) -> str: + if parameter is None: + sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs) + return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) + else: + # Don't stringify numbers when using a parameter + if isinstance(self.value, (int, float)): + value_sql = self.value + else: + value_sql = self.get_value_sql(quote_char=quote_char, **kwargs) + param_sql = parameter.get_sql(**kwargs) + param_key = parameter.get_param_key(placeholder=param_sql) + parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs) + + return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs) + + +class ParameterValueWrapper(ValueWrapper): + def __init__(self, parameter: Parameter, value: Any, alias: Optional[str] = None) -> None: + super().__init__(value, alias) + self._parameter = parameter + + def get_sql( + self, + quote_char: Optional[str] = None, + secondary_quote_char: str = "'", + parameter: Parameter = None, + **kwargs: Any, + ) -> str: + if parameter is None: + sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs) + return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) + else: + # Don't stringify numbers when using a parameter + if isinstance(self.value, (int, float)): + value_sql = self.value + else: + value_sql = self.get_value_sql(quote_char=quote_char, **kwargs) + param_sql = self._parameter.get_sql(**kwargs) + param_key = self._parameter.get_param_key(placeholder=param_sql) + parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs) + + return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs) class JSON(Term): diff --git a/pypika/tests/test_parameter.py b/pypika/tests/test_parameter.py index e19666a0..b6a5b576 100644 --- a/pypika/tests/test_parameter.py +++ b/pypika/tests/test_parameter.py @@ -1,4 +1,5 @@ import unittest +from datetime import date from pypika import ( FormatParameter, @@ -10,6 +11,7 @@ Query, Tables, ) +from pypika.terms import ListParameter, ParameterValueWrapper class ParametrizedTests(unittest.TestCase): @@ -92,3 +94,105 @@ def test_format_parameter(self): def test_pyformat_parameter(self): self.assertEqual('%(buz)s', PyformatParameter('buz').get_sql()) + + +class ParametrizedTestsWithValues(unittest.TestCase): + table_abc, table_efg = Tables("abc", "efg") + + def test_param_insert(self): + q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, 'foo') + + parameter = QmarkParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (?,?,?)', sql) + self.assertEqual([1, 2.2, 'foo'], parameter.get_parameters()) + + def test_param_select_join(self): + q = ( + Query.from_(self.table_abc) + .select("*") + .where(self.table_abc.category == 'foobar') + .join(self.table_efg) + .on(self.table_abc.id == self.table_efg.abc_id) + .where(self.table_efg.date >= date(2024, 2, 22)) + .limit(10) + ) + + parameter = FormatParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'SELECT * FROM "abc" JOIN "efg" ON "abc"."id"="efg"."abc_id" WHERE "abc"."category"=%s AND "efg"."date">=%s LIMIT 10', + sql, + ) + self.assertEqual(['foobar', '2024-02-22'], parameter.get_parameters()) + + def test_param_select_subquery(self): + q = ( + Query.from_(self.table_abc) + .select("*") + .where(self.table_abc.category == 'foobar') + .where( + self.table_abc.id.isin( + Query.from_(self.table_efg) + .select(self.table_efg.abc_id) + .where(self.table_efg.date >= date(2024, 2, 22)) + ) + ) + .limit(10) + ) + + parameter = ListParameter(placeholder=lambda idx: f'&{idx+1}') + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'SELECT * FROM "abc" WHERE "category"=&1 AND "id" IN (SELECT "abc_id" FROM "efg" WHERE "date">=&2) LIMIT 10', + sql, + ) + self.assertEqual(['foobar', '2024-02-22'], parameter.get_parameters()) + + def test_join(self): + subquery = ( + Query.from_(self.table_efg) + .select(self.table_efg.fiz, self.table_efg.buz) + .where(self.table_efg.buz == 'buz') + ) + + q = ( + Query.from_(self.table_abc) + .join(subquery) + .on(self.table_abc.bar == subquery.buz) + .select(self.table_abc.foo, subquery.fiz) + .where(self.table_abc.bar == 'bar') + ) + + parameter = NamedParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:param1)' + ' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:param2', + sql, + ) + self.assertEqual({'param1': 'buz', 'param2': 'bar'}, parameter.get_parameters()) + + def test_join_with_parameter_value_wrapper(self): + subquery = ( + Query.from_(self.table_efg) + .select(self.table_efg.fiz, self.table_efg.buz) + .where(self.table_efg.buz == ParameterValueWrapper(Parameter(':buz'), 'buz')) + ) + + q = ( + Query.from_(self.table_abc) + .join(subquery) + .on(self.table_abc.bar == subquery.buz) + .select(self.table_abc.foo, subquery.fiz) + .where(self.table_abc.bar == ParameterValueWrapper(NamedParameter('bar'), 'bar')) + ) + + parameter = NamedParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:buz)' + ' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:bar', + sql, + ) + self.assertEqual({':buz': 'buz', 'bar': 'bar'}, parameter.get_parameters()) From b82d8e2dd7ff125e923ed133563d38f306957cf0 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Thu, 22 Feb 2024 14:52:31 +0100 Subject: [PATCH 2/5] Revert base Parameter constructor back to it's original signature --- pypika/terms.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/pypika/terms.py b/pypika/terms.py index 79d8c7ba..d241a770 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -299,15 +299,12 @@ def named_placeholder_gen(idx: int) -> str: class Parameter(Term): is_aggregate = None - def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = idx_placeholder_gen) -> None: + def __init__(self, placeholder: Union[str, int]) -> None: super().__init__() self._placeholder = placeholder @property def placeholder(self): - if callable(self._placeholder): - return self._placeholder(None) - return self._placeholder def get_sql(self, **kwargs: Any) -> str: @@ -322,8 +319,7 @@ def get_param_key(self, placeholder: Any, **kwargs): class ListParameter(Parameter): def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = idx_placeholder_gen) -> None: - super().__init__() - self._placeholder = placeholder + super().__init__(placeholder=placeholder) self._parameters = list() @property @@ -342,8 +338,7 @@ def update_parameters(self, value: Any, **kwargs): class DictParameter(Parameter): def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = named_placeholder_gen) -> None: - super().__init__() - self._placeholder = placeholder + super().__init__(placeholder=placeholder) self._parameters = dict() @property From a5c981c738abff47475edd7d597910af69fb3210 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Wed, 13 Mar 2024 22:37:19 +0100 Subject: [PATCH 3/5] Fix a few typehints and make code more DRY --- pypika/terms.py | 73 +++++++++++++++++++++++++------------------------ 1 file changed, 38 insertions(+), 35 deletions(-) diff --git a/pypika/terms.py b/pypika/terms.py index d241a770..0a077b57 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -3,7 +3,21 @@ import uuid from datetime import date from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, +) from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order from pypika.utils import ( @@ -323,11 +337,11 @@ def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = idx_plac self._parameters = list() @property - def placeholder(self): + def placeholder(self) -> str: if callable(self._placeholder): return self._placeholder(len(self._parameters)) - return self._placeholder + return str(self._placeholder) def get_parameters(self, **kwargs): return self._parameters @@ -342,11 +356,11 @@ def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = named_pl self._parameters = dict() @property - def placeholder(self): + def placeholder(self) -> str: if callable(self._placeholder): return self._placeholder(len(self._parameters)) - return self._placeholder + return str(self._placeholder) def get_parameters(self, **kwargs): return self._parameters @@ -439,6 +453,12 @@ def get_formatted_value(cls, value: Any, **kwargs): return "null" return str(value) + def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]: + param_sql = parameter.get_sql(**kwargs) + param_key = parameter.get_param_key(placeholder=param_sql) + + return param_sql, param_key + def get_sql( self, quote_char: Optional[str] = None, @@ -449,17 +469,16 @@ def get_sql( if parameter is None: sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs) return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) + + # Don't stringify numbers when using a parameter + if isinstance(self.value, (int, float)): + value_sql = self.value else: - # Don't stringify numbers when using a parameter - if isinstance(self.value, (int, float)): - value_sql = self.value - else: - value_sql = self.get_value_sql(quote_char=quote_char, **kwargs) - param_sql = parameter.get_sql(**kwargs) - param_key = parameter.get_param_key(placeholder=param_sql) - parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs) + value_sql = self.get_value_sql(quote_char=quote_char, **kwargs) + param_sql, param_key = self._get_param_data(parameter, **kwargs) + parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs) - return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs) + return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs) class ParameterValueWrapper(ValueWrapper): @@ -467,27 +486,11 @@ def __init__(self, parameter: Parameter, value: Any, alias: Optional[str] = None super().__init__(value, alias) self._parameter = parameter - def get_sql( - self, - quote_char: Optional[str] = None, - secondary_quote_char: str = "'", - parameter: Parameter = None, - **kwargs: Any, - ) -> str: - if parameter is None: - sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs) - return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) - else: - # Don't stringify numbers when using a parameter - if isinstance(self.value, (int, float)): - value_sql = self.value - else: - value_sql = self.get_value_sql(quote_char=quote_char, **kwargs) - param_sql = self._parameter.get_sql(**kwargs) - param_key = self._parameter.get_param_key(placeholder=param_sql) - parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs) - - return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs) + def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]: + param_sql = self._parameter.get_sql(**kwargs) + param_key = self._parameter.get_param_key(placeholder=param_sql) + + return param_sql, param_key class JSON(Term): From e3244366e219778b4a1803b4ad1e90955ae5aea1 Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Fri, 27 Sep 2024 14:34:55 +0200 Subject: [PATCH 4/5] add test for PyformatParameter --- pypika/tests/test_parameter.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pypika/tests/test_parameter.py b/pypika/tests/test_parameter.py index b6a5b576..c11e9afc 100644 --- a/pypika/tests/test_parameter.py +++ b/pypika/tests/test_parameter.py @@ -196,3 +196,11 @@ def test_join_with_parameter_value_wrapper(self): sql, ) self.assertEqual({':buz': 'buz', 'bar': 'bar'}, parameter.get_parameters()) + + def test_pyformat_parameter(self): + q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, 'foo') + + parameter = PyformatParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (%(param1)s,%(param2)s,%(param3)s)', sql) + self.assertEqual({"param1": 1, "param2": 2.2, "param3": "foo"}, parameter.get_parameters()) From 9ba124180598065b2ca6f48424253ca950091c97 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Fri, 11 Oct 2024 10:15:34 +0200 Subject: [PATCH 5/5] fix linting issues --- pypika/terms.py | 1 + pypika/tests/test_terms.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pypika/terms.py b/pypika/terms.py index 0a077b57..a277e1a5 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -654,6 +654,7 @@ def __init__( if isinstance(table, str): # avoid circular import at load time from pypika.queries import Table + table = Table(table) self.table = table diff --git a/pypika/tests/test_terms.py b/pypika/tests/test_terms.py index 607c4c01..4c7590df 100644 --- a/pypika/tests/test_terms.py +++ b/pypika/tests/test_terms.py @@ -20,7 +20,7 @@ def test_init_with_str_table(self): test_table_name = "test_table" field = Field(name="name", table=test_table_name) self.assertEqual(field.table, Table(name=test_table_name)) - + class FieldHashingTests(TestCase): def test_tabled_eq_fields_equally_hashed(self):