diff --git a/src/base/0.0.0/end-no-respawn-fields.py b/src/base/0.0.0/end-no-respawn-fields.py index f4637dfcb..6d39576a5 100644 --- a/src/base/0.0.0/end-no-respawn-fields.py +++ b/src/base/0.0.0/end-no-respawn-fields.py @@ -20,7 +20,7 @@ def migrate(cr, version): """ ) execute_values( - cr._obj, + getattr(cr, '_obj__', None) or cr._obj, "INSERT INTO no_respawn(model, field) VALUES %s", # fmt:off [ diff --git a/src/base/tests/test_util.py b/src/base/tests/test_util.py index c56c26add..be206466a 100644 --- a/src/base/tests/test_util.py +++ b/src/base/tests/test_util.py @@ -28,6 +28,7 @@ _model_of_path, ) from odoo.addons.base.maintenance.migrations.util.exceptions import MigrationError +from odoo.addons.base.maintenance.migrations.util.pg import cursor_get_connection USE_ORM_DOMAIN = util.misc.version_gte("saas~18.2") NOTNOT = () if USE_ORM_DOMAIN else ("!", "!") @@ -842,7 +843,7 @@ def test_create_column_with_fk(self): def test_ColumnList(self): cr = self.env.cr - s = lambda c: c.as_string(cr._cnx) + s = lambda c: c.as_string(cursor_get_connection(cr)) columns = util.ColumnList(["a", "A"], ['"a"', '"A"']) self.assertEqual(len(columns), 2) diff --git a/src/testing.py b/src/testing.py index 21794c9c7..ae17c3f89 100644 --- a/src/testing.py +++ b/src/testing.py @@ -28,6 +28,7 @@ from . import util from .util import json +from .util.pg import cursor_get_connection _logger = logging.getLogger(__name__) @@ -142,7 +143,7 @@ def _set_value(self, key, value): ON CONFLICT (key) DO UPDATE SET value=EXCLUDED.value """.format(DATA_TABLE) self._data_table_cr.execute(query, (key, value)) - self._data_table_cr._cnx.commit() + cursor_get_connection(self._data_table_cr).commit() def _get_value(self, key): self._init_db() @@ -169,7 +170,7 @@ def _init_db(self): value JSONB NOT NULL )""".format(DATA_TABLE) self._data_table_cr.execute(query) - self._data_table_cr._cnx.commit() + cursor_get_connection(self._data_table_cr).commit() UpgradeCommon.__initialized = True def _setup_registry(self): @@ -355,7 +356,7 @@ def setUp(self): def commit(self): if self.dbname == config["log_db"].split("/")[-1]: - self._cnx.commit() + cursor_get_connection(self).commit() else: raise RuntimeError("Commit is forbidden in integrity cases") diff --git a/src/util/fields.py b/src/util/fields.py index d2a15ae21..93857139a 100644 --- a/src/util/fields.py +++ b/src/util/fields.py @@ -53,6 +53,7 @@ def make_index_name(table_name, column_name): alter_column_type, column_exists, column_type, + cursor_get_connection, explode_execute, explode_query_range, format_query, @@ -898,7 +899,7 @@ def convert_binary_field_to_attachment(cr, model, field, encoded=True, name_fiel [count] = cr.fetchone() A = env(cr)["ir.attachment"] - iter_cur = cr._cnx.cursor("fetch_binary") + iter_cur = cursor_get_connection(cr).cursor("fetch_binary") iter_cur.itersize = 1 iter_cur.execute( format_query( @@ -1484,7 +1485,7 @@ def update_server_actions_fields(cr, src_model, dst_model=None, fields_mapping=N ) else: psycopg2.extras.execute_values( - cr._obj, + getattr(cr, '_obj__') or cr._obj, """ WITH field_ids AS ( SELECT mf1.id as old_field_id, mf2.id as new_field_id diff --git a/src/util/inconsistencies.py b/src/util/inconsistencies.py index 2387d4bcc..c0d425cee 100644 --- a/src/util/inconsistencies.py +++ b/src/util/inconsistencies.py @@ -9,7 +9,7 @@ from .helpers import _validate_model, table_of_model from .misc import Sentinel, chunks, str2bool -from .pg import format_query, get_value_or_en_translation, target_of +from .pg import cursor_get_connection, format_query, get_value_or_en_translation, target_of from .report import add_to_migration_reports, get_anchor_link_to_record, html_escape _logger = logging.getLogger(__name__) @@ -228,7 +228,7 @@ def verify_uoms( _validate_model(model) table = table_of_model(cr, model) - q = lambda s: quote_ident(s, cr._cnx) + q = lambda s: quote_ident(s, cursor_get_connection(cr)) if include_archived_products is FROM_ENV: include_archived_products = INCLUDE_ARCHIVED_PRODUCTS @@ -404,7 +404,7 @@ def verify_products( table = table_of_model(cr, model) foreign_table = table_of_model(cr, foreign_model) - q = lambda s: quote_ident(s, cr._cnx) + q = lambda s: quote_ident(s, cursor_get_connection(cr)) if include_archived_products is FROM_ENV: include_archived_products = INCLUDE_ARCHIVED_PRODUCTS diff --git a/src/util/pg.py b/src/util/pg.py index ef9401d90..982f75110 100644 --- a/src/util/pg.py +++ b/src/util/pg.py @@ -217,7 +217,7 @@ def wrap(arg): args = tuple(wrap(a) for a in args) kwargs = {k: wrap(v) for k, v in kwargs.items()} - return SQLStr(sql.SQL(query).format(*args, **kwargs).as_string(cr._cnx)) + return SQLStr(sql.SQL(query).format(*args, **kwargs).as_string(cursor_get_connection(cr))) def explode_query(cr, query, alias=None, num_buckets=8, prefix=None): @@ -557,7 +557,7 @@ def create_column(cr, table, column, definition, **kwargs): fk = ( sql.SQL("REFERENCES {}(id) ON DELETE {}") .format(sql.Identifier(fk_table), sql.SQL(on_delete_action)) - .as_string(cr._cnx) + .as_string(cursor_get_connection(cr)) ) elif on_delete_action is not no_def: raise ValueError("`on_delete_action` argument can only be used if `fk_table` argument is set.") @@ -845,7 +845,7 @@ def get_index_on(cr, table, *columns): """ _validate_table(table) - if cr._cnx.server_version >= 90500: + if cursor_get_connection(cr).server_version >= 90500: position = "array_position(x.indkey, x.unnest_indkey)" else: # array_position does not exists prior postgresql 9.5 @@ -980,10 +980,10 @@ class ColumnList(UserList, sql.Composable): >>> list(columns) ['id', '"field_Yx"'] - >>> columns.using(alias="t").as_string(cr._cnx) + >>> columns.using(alias="t").as_string(cursor_get_connection(cr)) '"t"."id", "t"."field_Yx"' - >>> columns.using(leading_comma=True).as_string(cr._cnx) + >>> columns.using(leading_comma=True).as_string(cursor_get_connection(cr)) ', "id", "field_Yx"' >>> util.format_query(cr, "SELECT {} t.name FROM table t", columns.using(alias="t", trailing_comma=True)) @@ -1026,7 +1026,7 @@ def from_unquoted(cls, cr, list_): :param list(str) list_: list of unquoted column names """ - quoted = [quote_ident(c, cr._cnx) for c in list_] + quoted = [quote_ident(c, cursor_get_connection(cr)) for c in list_] return cls(list_, quoted) def using(self, leading_comma=KEEP_CURRENT, trailing_comma=KEEP_CURRENT, alias=KEEP_CURRENT): @@ -1482,7 +1482,8 @@ def get_m2m_tables(cr, table): class named_cursor(object): def __init__(self, cr, itersize=None): - self._ncr = cr._cnx.cursor("upg_nc_" + uuid.uuid4().hex, withhold=True) + pgconn = cursor_get_connection(cr) + self._ncr = pgconn.cursor("upg_nc_" + uuid.uuid4().hex, withhold=True) if itersize: self._ncr.itersize = itersize @@ -1571,3 +1572,9 @@ def create_id_sequence(cr, table, set_as_default=True): table=table_sql, ) ) + + +def cursor_get_connection(cursor): + if hasattr(cursor, '_cnx__'): + return cursor._cnx__ + return cursor._cnx diff --git a/src/util/records.py b/src/util/records.py index 105a61f28..68d174952 100644 --- a/src/util/records.py +++ b/src/util/records.py @@ -1476,7 +1476,7 @@ def replace_record_references_batch(cr, id_mapping, model_src, model_dst=None, r ignores.append("ir_model_data") cr.execute("CREATE UNLOGGED TABLE _upgrade_rrr(old int PRIMARY KEY, new int)") - execute_values(cr, "INSERT INTO _upgrade_rrr (old, new) VALUES %s", id_mapping.items()) + execute_values(getattr(cr, '_obj__', cr), "INSERT INTO _upgrade_rrr (old, new) VALUES %s", id_mapping.items()) if model_src == model_dst: fk_def = [] diff --git a/src/util/snippets.py b/src/util/snippets.py index a4e091546..868979532 100644 --- a/src/util/snippets.py +++ b/src/util/snippets.py @@ -15,7 +15,7 @@ from .exceptions import MigrationError from .helpers import table_of_model from .misc import import_script, log_progress -from .pg import column_exists, column_type, get_max_workers, table_exists +from .pg import column_exists, column_type, cursor_get_connection, get_max_workers, table_exists _logger = logging.getLogger(__name__) utf8_parser = html.HTMLParser(encoding="utf-8") @@ -44,7 +44,7 @@ def add_snippet_names(cr, table, column, snippets, select_query): it = log_progress(cr.fetchall(), _logger, qualifier="rows", size=cr.rowcount, log_hundred_percent=True) def quote(ident): - return quote_ident(ident, cr._cnx) + return quote_ident(ident, cursor_get_connection(cr)) for res_id, regex_matches, arch in it: regex_matches = [match[0] for match in regex_matches] # noqa: PLW2901 @@ -88,7 +88,7 @@ def get_html_fields(cr): # yield (table, column) of stored html fields (that needs snippets updates) for table, columns in html_fields(cr): for column in columns: - yield table, quote_ident(column, cr._cnx) + yield table, quote_ident(column, cursor_get_connection(cr)) def html_fields(cr): @@ -316,7 +316,7 @@ def convert_html_columns(cr, table, columns, converter_callback, where_column="I def determine_chunk_limit_ids(cr, table, column_arr, where): bytes_per_chunk = 100 * 1024 * 1024 - columns = ", ".join(quote_ident(column, cr._cnx) for column in column_arr if column != "id") + columns = ", ".join(quote_ident(column, cursor_get_connection(cr)) for column in column_arr if column != "id") cr.execute( f""" WITH info AS (