diff --git a/src/acquisition/covid_hosp/common/database.py b/src/acquisition/covid_hosp/common/database.py index ec00f662a..efbdb6c45 100644 --- a/src/acquisition/covid_hosp/common/database.py +++ b/src/acquisition/covid_hosp/common/database.py @@ -184,19 +184,16 @@ def nan_safe_dtype(dtype, value): for csv_name in self.key_columns: dataframe.loc[:, csv_name] = dataframe[csv_name].map(self.columns_and_types[csv_name].dtype) - num_columns = 2 + len(dataframe_columns_and_types) + len(self.additional_fields) - value_placeholders = ', '.join(['%s'] * num_columns) col_names = [f'`{i.sql_name}`' for i in dataframe_columns_and_types + self.additional_fields] - columns = ', '.join(col_names) - updates = ', '.join(f'{c}=new_values.{c}' for c in col_names) - # NOTE: list in `updates` presumes `publication_col_name` is part of the unique key and thus not needed in UPDATE - sql = f'INSERT INTO `{self.table_name}` (`id`, `{self.publication_col_name}`, {columns}) ' \ - f'VALUES ({value_placeholders}) AS new_values ' \ - f'ON DUPLICATE KEY UPDATE {updates}' + value_placeholders = ', '.join(['%s'] * (2 + len(col_names))) # extra 2 for `id` and `self.publication_col_name` cols + columnstring = ', '.join(col_names) + sql = f'REPLACE INTO `{self.table_name}` (`id`, `{self.publication_col_name}`, {columnstring}) VALUES ({value_placeholders})' id_and_publication_date = (0, publication_date) + num_values = len(dataframe.index) if logger: - logger.info('updating values', count=len(dataframe.index)) + logger.info('updating values', count=num_values) n = 0 + rows_affected = 0 many_values = [] with self.new_cursor() as cursor: for index, row in dataframe.iterrows(): @@ -212,6 +209,7 @@ def nan_safe_dtype(dtype, value): if n % 5_000 == 0: try: cursor.executemany(sql, many_values) + rows_affected += cursor.rowcount many_values = [] except Exception as e: if logger: @@ -220,6 +218,11 @@ def nan_safe_dtype(dtype, value): # insert final batch if many_values: cursor.executemany(sql, many_values) + rows_affected += cursor.rowcount + if logger: + # NOTE: REPLACE INTO marks 2 rows affected for a "replace" (one for a delete and one for a re-insert) + # which allows us to count rows which were updated + logger.info('rows affected', total=rows_affected, updated=rows_affected-num_values) # deal with non/seldomly updated columns used like a fk table (if this database needs it) if hasattr(self, 'AGGREGATE_KEY_COLS'): diff --git a/tests/acquisition/covid_hosp/common/test_database.py b/tests/acquisition/covid_hosp/common/test_database.py index c070a00ae..a45953313 100644 --- a/tests/acquisition/covid_hosp/common/test_database.py +++ b/tests/acquisition/covid_hosp/common/test_database.py @@ -148,7 +148,7 @@ def test_insert_dataset(self): actual_sql = mock_cursor.executemany.call_args[0][0] self.assertIn( - 'INSERT INTO `test_table` (`id`, `publication_date`, `sql_str_col`, `sql_int_col`, `sql_float_col`)', + 'REPLACE INTO `test_table` (`id`, `publication_date`, `sql_str_col`, `sql_int_col`, `sql_float_col`)', actual_sql) expected_values = [