diff --git a/.circleci/config.yml b/.circleci/config.yml index f85823b7..bb545bad 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,7 +7,7 @@ executors: type: string docker: - image: python:<< parameters.version >>-buster - - image: postgres:12.0 + - image: postgres:13.0 environment: POSTGRES_DB: 'psqlextra' POSTGRES_USER: 'psqlextra' diff --git a/psqlextra/backend/introspection.py b/psqlextra/backend/introspection.py index fe9c3324..90717b6a 100644 --- a/psqlextra/backend/introspection.py +++ b/psqlextra/backend/introspection.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple from psqlextra.types import PostgresPartitioningMethod @@ -48,6 +48,22 @@ def partition_by_name( class PostgresIntrospection(base_impl.introspection()): """Adds introspection features specific to PostgreSQL.""" + # TODO: This class is a mess, both here and in the + # the base. + # + # Some methods return untyped dicts, some named tuples, + # some flat lists of strings. It's horribly inconsistent. + # + # Most methods are poorly named. For example; `get_table_description` + # does not return a complete table description. It merely returns + # the columns. + # + # We do our best in this class to stay consistent with + # the base in Django by respecting its naming scheme + # and commonly used return types. Creating an API that + # matches the look&feel from the Django base class + # is more important than fixing those issues. + def get_partitioned_tables( self, cursor ) -> PostgresIntrospectedPartitonedTable: @@ -172,6 +188,9 @@ def get_partition_key(self, cursor, table_name: str) -> List[str]: cursor.execute(sql, (table_name,)) return [row[0] for row in cursor.fetchall()] + def get_columns(self, cursor, table_name: str): + return self.get_table_description(cursor, table_name) + def get_constraints(self, cursor, table_name: str): """Retrieve any constraints or keys (unique, pk, fk, check, index) across one or more columns. @@ -202,15 +221,68 @@ def get_constraints(self, cursor, table_name: str): def get_table_locks(self, cursor) -> List[Tuple[str, str, str]]: cursor.execute( """ - SELECT - n.nspname, - t.relname, - l.mode - FROM pg_locks l - INNER JOIN pg_class t ON t.oid = l.relation - INNER JOIN pg_namespace n ON n.oid = t.relnamespace - WHERE t.relnamespace >= 2200 - ORDER BY n.nspname, t.relname, l.mode""" + SELECT + n.nspname, + t.relname, + l.mode + FROM pg_locks l + INNER JOIN pg_class t ON t.oid = l.relation + INNER JOIN pg_namespace n ON n.oid = t.relnamespace + WHERE t.relnamespace >= 2200 + ORDER BY n.nspname, t.relname, l.mode + """ ) return cursor.fetchall() + + def get_storage_settings(self, cursor, table_name: str) -> Dict[str, str]: + sql = """ + SELECT + unnest(c.reloptions || array(select 'toast.' || x from pg_catalog.unnest(tc.reloptions) x)) + FROM + pg_catalog.pg_class c + LEFT JOIN + pg_catalog.pg_class tc ON (c.reltoastrelid = tc.oid) + LEFT JOIN + pg_catalog.pg_am am ON (c.relam = am.oid) + WHERE + c.relname::text = %s + AND pg_catalog.pg_table_is_visible(c.oid) + """ + + cursor.execute(sql, (table_name,)) + + storage_settings = {} + for row in cursor.fetchall(): + # It's hard to believe, but storage settings are really + # represented as `key=value` strings in Postgres. + # See: https://www.postgresql.org/docs/current/catalog-pg-class.html + name, value = row[0].split("=") + storage_settings[name] = value + + return storage_settings + + def get_relations(self, cursor, table_name: str): + """Gets a dictionary {field_name: (field_name_other_table, + other_table)} representing all relations in the specified table. + + This is overriden because the query in Django does not handle + relations between tables in different schemas properly. + """ + + cursor.execute( + """ + SELECT a1.attname, c2.relname, a2.attname + FROM pg_constraint con + LEFT JOIN pg_class c1 ON con.conrelid = c1.oid + LEFT JOIN pg_class c2 ON con.confrelid = c2.oid + LEFT JOIN pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1] + LEFT JOIN pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1] + WHERE + con.conrelid = %s::regclass AND + con.contype = 'f' AND + pg_catalog.pg_table_is_visible(c1.oid) + """, + [table_name], + ) + return {row[0]: (row[2], row[1]) for row in cursor.fetchall()} diff --git a/psqlextra/backend/schema.py b/psqlextra/backend/schema.py index 413f039d..1e21b366 100644 --- a/psqlextra/backend/schema.py +++ b/psqlextra/backend/schema.py @@ -1,14 +1,21 @@ -from typing import Any, List, Optional +from typing import Any, List, Optional, Type from unittest import mock +import django + from django.core.exceptions import ( FieldDoesNotExist, ImproperlyConfigured, SuspiciousOperation, ) from django.db import transaction +from django.db.backends.ddl_references import Statement from django.db.models import Field, Model +from psqlextra.settings import ( + postgres_prepend_local_search_path, + postgres_reset_local_search_path, +) from psqlextra.type_assertions import is_sql_with_params from psqlextra.types import PostgresPartitioningMethod @@ -19,12 +26,26 @@ HStoreUniqueSchemaEditorSideEffect, ) +SchemaEditor = base_impl.schema_editor() + -class PostgresSchemaEditor(base_impl.schema_editor()): +class PostgresSchemaEditor(SchemaEditor): """Schema editor that adds extra methods for PostgreSQL specific features and hooks into existing implementations to add side effects specific to PostgreSQL.""" + sql_add_pk = "ALTER TABLE %s ADD PRIMARY KEY (%s)" + + sql_create_fk_not_valid = f"{SchemaEditor.sql_create_fk} NOT VALID" + sql_validate_fk = "ALTER TABLE %s VALIDATE CONSTRAINT %s" + + sql_create_sequence_with_owner = "CREATE SEQUENCE %s OWNED BY %s.%s" + + sql_alter_table_storage_setting = "ALTER TABLE %s SET (%s = %s)" + sql_reset_table_storage_setting = "ALTER TABLE %s RESET (%s)" + + sql_alter_table_schema = "ALTER TABLE %s SET SCHEMA %s" + sql_create_view = "CREATE VIEW %s AS (%s)" sql_replace_view = "CREATE OR REPLACE VIEW %s AS (%s)" sql_drop_view = "DROP VIEW IF EXISTS %s" @@ -63,7 +84,7 @@ def __init__(self, connection, collect_sql=False, atomic=True): self.deferred_sql = [] self.introspection = PostgresIntrospection(self.connection) - def create_model(self, model: Model) -> None: + def create_model(self, model: Type[Model]) -> None: """Creates a new model.""" super().create_model(model) @@ -71,7 +92,7 @@ def create_model(self, model: Model) -> None: for side_effect in self.side_effects: side_effect.create_model(model) - def delete_model(self, model: Model) -> None: + def delete_model(self, model: Type[Model]) -> None: """Drops/deletes an existing model.""" for side_effect in self.side_effects: @@ -79,8 +100,395 @@ def delete_model(self, model: Model) -> None: super().delete_model(model) + def clone_model_structure_to_schema( + self, model: Type[Model], *, schema_name: str + ) -> None: + """Creates a clone of the columns for the specified model in a separate + schema. + + The table will have exactly the same name as the model table + in the default schema. It will have none of the constraints, + foreign keys and indexes. + + Use this to create a temporary clone of a model table to + replace the original model table later on. The lack of + indices and constraints allows for greater write speeds. + + The original model table will be unaffected. + + Arguments: + model: + Model to clone the table of into the + specified schema. + + schema_name: + Name of the schema to create the cloned + table in. + """ + + table_name = model._meta.db_table + quoted_table_name = self.quote_name(model._meta.db_table) + quoted_schema_name = self.quote_name(schema_name) + + quoted_table_fqn = f"{quoted_schema_name}.{quoted_table_name}" + + self.execute( + self.sql_create_table + % { + "table": quoted_table_fqn, + "definition": f"LIKE {quoted_table_name} INCLUDING ALL EXCLUDING CONSTRAINTS EXCLUDING INDEXES", + } + ) + + # Copy sequences + # + # Django 4.0 and older do not use IDENTITY so Postgres does + # not copy the sequences into the new table. We do it manually. + if django.VERSION < (4, 1): + with self.connection.cursor() as cursor: + sequences = self.introspection.get_sequences(cursor, table_name) + + for sequence in sequences: + if sequence["table"] != table_name: + continue + + quoted_sequence_name = self.quote_name(sequence["name"]) + quoted_sequence_fqn = ( + f"{quoted_schema_name}.{quoted_sequence_name}" + ) + quoted_column_name = self.quote_name(sequence["column"]) + + self.execute( + self.sql_create_sequence_with_owner + % ( + quoted_sequence_fqn, + quoted_table_fqn, + quoted_column_name, + ) + ) + + self.execute( + self.sql_alter_column + % { + "table": quoted_table_fqn, + "changes": self.sql_alter_column_default + % { + "column": quoted_column_name, + "default": "nextval('%s')" % quoted_sequence_fqn, + }, + } + ) + + # Copy storage settings + # + # Postgres only copies column-level storage options, not + # the table-level storage options. + with self.connection.cursor() as cursor: + storage_settings = self.introspection.get_storage_settings( + cursor, model._meta.db_table + ) + + for setting_name, setting_value in storage_settings.items(): + self.alter_table_storage_setting( + quoted_table_fqn, setting_name, setting_value + ) + + def clone_model_constraints_and_indexes_to_schema( + self, model: Type[Model], *, schema_name: str + ) -> None: + """Adds the constraints, foreign keys and indexes to a model table that + was cloned into a separate table without them by + `clone_model_structure_to_schema`. + + Arguments: + model: + Model for which the cloned table was created. + + schema_name: + Name of the schema in which the cloned table + resides. + """ + + with postgres_prepend_local_search_path( + [schema_name], using=self.connection.alias + ): + for constraint in model._meta.constraints: + self.add_constraint(model, constraint) + + for index in model._meta.indexes: + self.add_index(model, index) + + if model._meta.unique_together: + self.alter_unique_together( + model, tuple(), model._meta.unique_together + ) + + if model._meta.index_together: + self.alter_index_together( + model, tuple(), model._meta.index_together + ) + + for field in model._meta.local_concrete_fields: + # Django creates primary keys later added to the model with + # a custom name. We want the name as it was created originally. + if field.primary_key: + with postgres_reset_local_search_path( + using=self.connection.alias + ): + [primary_key_name] = self._constraint_names( + model, primary_key=True + ) + + self.execute( + self.sql_create_pk + % { + "table": self.quote_name(model._meta.db_table), + "name": self.quote_name(primary_key_name), + "columns": self.quote_name( + field.db_column or field.attname + ), + } + ) + continue + + # Django creates foreign keys in a single statement which acquires + # a AccessExclusiveLock on the referenced table. We want to avoid + # that and created the FK as NOT VALID. We can run VALIDATE in + # a separate transaction later to validate the entries without + # acquiring a AccessExclusiveLock. + if field.remote_field: + with postgres_reset_local_search_path( + using=self.connection.alias + ): + [fk_name] = self._constraint_names( + model, [field.column], foreign_key=True + ) + + sql = Statement( + self.sql_create_fk_not_valid, + table=self.quote_name(model._meta.db_table), + name=self.quote_name(fk_name), + column=self.quote_name(field.column), + to_table=self.quote_name( + field.target_field.model._meta.db_table + ), + to_column=self.quote_name(field.target_field.column), + deferrable=self.connection.ops.deferrable_sql(), + ) + + self.execute(sql) + + # It's hard to alter a field's check because it is defined + # by the field class, not the field instance. Handle this + # manually. + field_check = field.db_parameters(self.connection).get("check") + if field_check: + with postgres_reset_local_search_path( + using=self.connection.alias + ): + [field_check_name] = self._constraint_names( + model, + [field.column], + check=True, + exclude={ + constraint.name + for constraint in model._meta.constraints + }, + ) + + self.execute( + self._create_check_sql( + model, field_check_name, field_check + ) + ) + + # Clone the field and alter its state to math our current + # table definition. This will cause Django see the missing + # indices and create them. + if field.remote_field: + # We add the foreign key constraint ourselves with NOT VALID, + # hence, we specify `db_constraint=False` on both old/new. + # Django won't touch the foreign key constraint. + old_field = self._clone_model_field( + field, db_index=False, unique=False, db_constraint=False + ) + new_field = self._clone_model_field( + field, db_constraint=False + ) + self.alter_field(model, old_field, new_field) + else: + old_field = self._clone_model_field( + field, db_index=False, unique=False + ) + new_field = self._clone_model_field(field) + self.alter_field(model, old_field, new_field) + + def clone_model_foreign_keys_to_schema( + self, model: Type[Model], schema_name: str + ) -> None: + """Validates the foreign keys in the cloned model table created by + `clone_model_structure_to_schema` and + `clone_model_constraints_and_indexes_to_schema`. + + Do NOT run this in the same transaction as the + foreign keys were added to the table. It WILL + acquire a long-lived AccessExclusiveLock. + + Arguments: + model: + Model for which the cloned table was created. + + schema_name: + Name of the schema in which the cloned table + resides. + """ + + constraint_names = self._constraint_names(model, foreign_key=True) + + with postgres_prepend_local_search_path( + [schema_name], using=self.connection.alias + ): + for fk_name in constraint_names: + self.execute( + self.sql_validate_fk + % ( + self.quote_name(model._meta.db_table), + self.quote_name(fk_name), + ) + ) + + def alter_table_storage_setting( + self, table_name: str, name: str, value: str + ) -> None: + """Alters a storage setting for a table. + + See: https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-STORAGE-PARAMETERS + + Arguments: + table_name: + Name of the table to alter the setting for. + + name: + Name of the setting to alter. + + value: + Value to alter the setting to. + + Note that this is always a string, even if it looks + like a number or a boolean. That's how Postgres + stores storage settings internally. + """ + + self.execute( + self.sql_alter_table_storage_setting + % (self.quote_name(table_name), name, value) + ) + + def alter_model_storage_setting( + self, model: Type[Model], name: str, value: str + ) -> None: + """Alters a storage setting for the model's table. + + See: https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-STORAGE-PARAMETERS + + Arguments: + model: + Model of which to alter the table + setting. + + name: + Name of the setting to alter. + + value: + Value to alter the setting to. + + Note that this is always a string, even if it looks + like a number or a boolean. That's how Postgres + stores storage settings internally. + """ + + self.alter_table_storage_setting(model._meta.db_table, name, value) + + def reset_table_storage_setting(self, table_name: str, name: str) -> None: + """Resets a table's storage setting to the database or server default. + + See: https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-STORAGE-PARAMETERS + + Arguments: + table_name: + Name of the table to reset the setting for. + + name: + Name of the setting to reset. + """ + + self.execute( + self.sql_reset_table_storage_setting + % (self.quote_name(table_name), name) + ) + + def reset_model_storage_setting( + self, model: Type[Model], name: str + ) -> None: + """Resets a model's table storage setting to the database or server + default. + + See: https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-STORAGE-PARAMETERS + + Arguments: + table_name: + model: + Model for which to reset the table setting for. + + name: + Name of the setting to reset. + """ + + self.reset_table_storage_setting(model._meta.db_table, name) + + def alter_table_schema(self, table_name: str, schema_name: str) -> None: + """Moves the specified table into the specified schema. + + WARNING: Moving models into a different schema than the default + will break querying the model. + + Arguments: + table_name: + Name of the table to move into the specified schema. + + schema_name: + Name of the schema to move the table to. + """ + + self.execute( + self.sql_alter_table_schema + % (self.quote_name(table_name), self.quote_name(schema_name)) + ) + + def alter_model_schema(self, model: Type[Model], schema_name: str) -> None: + """Moves the specified model's table into the specified schema. + + WARNING: Moving models into a different schema than the default + will break querying the model. + + Arguments: + model: + Model of which to move the table. + + schema_name: + Name of the schema to move the model's table to. + """ + + self.execute( + self.sql_alter_table_schema + % ( + self.quote_name(model._meta.db_table), + self.quote_name(schema_name), + ) + ) + def refresh_materialized_view_model( - self, model: Model, concurrently: bool = False + self, model: Type[Model], concurrently: bool = False ) -> None: """Refreshes a materialized view.""" @@ -93,12 +501,12 @@ def refresh_materialized_view_model( sql = sql_template % self.quote_name(model._meta.db_table) self.execute(sql) - def create_view_model(self, model: Model) -> None: + def create_view_model(self, model: Type[Model]) -> None: """Creates a new view model.""" self._create_view_model(self.sql_create_view, model) - def replace_view_model(self, model: Model) -> None: + def replace_view_model(self, model: Type[Model]) -> None: """Replaces a view model with a newer version. This is used to alter the backing query of a view. @@ -106,18 +514,18 @@ def replace_view_model(self, model: Model) -> None: self._create_view_model(self.sql_replace_view, model) - def delete_view_model(self, model: Model) -> None: + def delete_view_model(self, model: Type[Model]) -> None: """Deletes a view model.""" sql = self.sql_drop_view % self.quote_name(model._meta.db_table) self.execute(sql) - def create_materialized_view_model(self, model: Model) -> None: + def create_materialized_view_model(self, model: Type[Model]) -> None: """Creates a new materialized view model.""" self._create_view_model(self.sql_create_materialized_view, model) - def replace_materialized_view_model(self, model: Model) -> None: + def replace_materialized_view_model(self, model: Type[Model]) -> None: """Replaces a materialized view with a newer version. This is used to alter the backing query of a materialized view. @@ -148,7 +556,7 @@ def replace_materialized_view_model(self, model: Model) -> None: self.execute(constraint_options["definition"]) - def delete_materialized_view_model(self, model: Model) -> None: + def delete_materialized_view_model(self, model: Type[Model]) -> None: """Deletes a materialized view model.""" sql = self.sql_drop_materialized_view % self.quote_name( @@ -156,7 +564,7 @@ def delete_materialized_view_model(self, model: Model) -> None: ) self.execute(sql) - def create_partitioned_model(self, model: Model) -> None: + def create_partitioned_model(self, model: Type[Model]) -> None: """Creates a new partitioned model.""" meta = self._partitioning_properties_for_model(model) @@ -188,14 +596,14 @@ def create_partitioned_model(self, model: Model) -> None: self.execute(sql, params) - def delete_partitioned_model(self, model: Model) -> None: + def delete_partitioned_model(self, model: Type[Model]) -> None: """Drops the specified partitioned model.""" return self.delete_model(model) def add_range_partition( self, - model: Model, + model: Type[Model], name: str, from_values: Any, to_values: Any, @@ -246,7 +654,7 @@ def add_range_partition( def add_list_partition( self, - model: Model, + model: Type[Model], name: str, values: List[Any], comment: Optional[str] = None, @@ -289,7 +697,7 @@ def add_list_partition( def add_hash_partition( self, - model: Model, + model: Type[Model], name: str, modulus: int, remainder: int, @@ -334,7 +742,7 @@ def add_hash_partition( self.set_comment_on_table(table_name, comment) def add_default_partition( - self, model: Model, name: str, comment: Optional[str] = None + self, model: Type[Model], name: str, comment: Optional[str] = None ) -> None: """Creates a new default partition for the specified partitioned model. @@ -370,7 +778,7 @@ def add_default_partition( if comment: self.set_comment_on_table(table_name, comment) - def delete_partition(self, model: Model, name: str) -> None: + def delete_partition(self, model: Type[Model], name: str) -> None: """Deletes the partition with the specified name.""" sql = self.sql_delete_partition % self.quote_name( @@ -379,7 +787,7 @@ def delete_partition(self, model: Model, name: str) -> None: self.execute(sql) def alter_db_table( - self, model: Model, old_db_table: str, new_db_table: str + self, model: Type[Model], old_db_table: str, new_db_table: str ) -> None: """Alters a table/model.""" @@ -388,7 +796,7 @@ def alter_db_table( for side_effect in self.side_effects: side_effect.alter_db_table(model, old_db_table, new_db_table) - def add_field(self, model: Model, field: Field) -> None: + def add_field(self, model: Type[Model], field: Field) -> None: """Adds a new field to an exisiting model.""" super().add_field(model, field) @@ -396,7 +804,7 @@ def add_field(self, model: Model, field: Field) -> None: for side_effect in self.side_effects: side_effect.add_field(model, field) - def remove_field(self, model: Model, field: Field) -> None: + def remove_field(self, model: Type[Model], field: Field) -> None: """Removes a field from an existing model.""" for side_effect in self.side_effects: @@ -406,7 +814,7 @@ def remove_field(self, model: Model, field: Field) -> None: def alter_field( self, - model: Model, + model: Type[Model], old_field: Field, new_field: Field, strict: bool = False, @@ -418,13 +826,100 @@ def alter_field( for side_effect in self.side_effects: side_effect.alter_field(model, old_field, new_field, strict) + def vacuum_table( + self, + table_name: str, + columns: List[str] = [], + *, + full: bool = False, + freeze: bool = False, + verbose: bool = False, + analyze: bool = False, + disable_page_skipping: bool = False, + skip_locked: bool = False, + index_cleanup: bool = False, + truncate: bool = False, + parallel: Optional[int] = None, + ) -> None: + """Runs the VACUUM statement on the specified table with the specified + options. + + Arguments: + table_name: + Name of the table to run VACUUM on. + + columns: + Optionally, a list of columns to vacuum. If not + specified, all columns are vacuumed. + """ + + if self.connection.in_atomic_block: + raise SuspiciousOperation("Vacuum cannot be done in a transaction") + + options = [] + if full: + options.append("FULL") + if freeze: + options.append("FREEZE") + if verbose: + options.append("VERBOSE") + if analyze: + options.append("ANALYZE") + if disable_page_skipping: + options.append("DISABLE_PAGE_SKIPPING") + if skip_locked: + options.append("SKIP_LOCKED") + if index_cleanup: + options.append("INDEX_CLEANUP") + if truncate: + options.append("TRUNCATE") + if parallel is not None: + options.append(f"PARALLEL {parallel}") + + sql = "VACUUM" + + if options: + options_sql = ", ".join(options) + sql += f" ({options_sql})" + + sql += f" {self.quote_name(table_name)}" + + if columns: + columns_sql = ", ".join( + [self.quote_name(column) for column in columns] + ) + sql += f" ({columns_sql})" + + self.execute(sql) + + def vacuum_model( + self, model: Type[Model], fields: List[Field] = [], **kwargs + ) -> None: + """Runs the VACUUM statement on the table of the specified model with + the specified options. + + Arguments: + table_name: + model: + Model of which to run VACUUM the table. + + fields: + Optionally, a list of fields to vacuum. If not + specified, all fields are vacuumed. + """ + + columns = [ + field.column for field in fields if field.concrete and field.column + ] + self.vacuum_table(model._meta.db_table, columns, **kwargs) + def set_comment_on_table(self, table_name: str, comment: str) -> None: """Sets the comment on the specified table.""" sql = self.sql_table_comment % (self.quote_name(table_name), "%s") self.execute(sql, (comment,)) - def _create_view_model(self, sql: str, model: Model) -> None: + def _create_view_model(self, sql: str, model: Type[Model]) -> None: """Creates a new view model using the specified SQL query.""" meta = self._view_properties_for_model(model) @@ -451,7 +946,7 @@ def _extract_sql(self, method, *args): return tuple(execute.mock_calls[0])[1] @staticmethod - def _view_properties_for_model(model: Model): + def _view_properties_for_model(model: Type[Model]): """Gets the view options for the specified model. Raises: @@ -483,7 +978,7 @@ def _view_properties_for_model(model: Model): return meta @staticmethod - def _partitioning_properties_for_model(model: Model): + def _partitioning_properties_for_model(model: Type[Model]): """Gets the partitioning options for the specified model. Raises: @@ -546,5 +1041,29 @@ def _partitioning_properties_for_model(model: Model): return meta - def create_partition_table_name(self, model: Model, name: str) -> str: + def create_partition_table_name(self, model: Type[Model], name: str) -> str: return "%s_%s" % (model._meta.db_table.lower(), name.lower()) + + def _clone_model_field(self, field: Field, **overrides) -> Field: + """Clones the specified model field and overrides its kwargs with the + specified overrides. + + The cloned field will not be contributed to the model. + """ + + _, _, field_args, field_kwargs = field.deconstruct() + + cloned_field_args = field_args[:] + cloned_field_kwargs = {**field_kwargs, **overrides} + + cloned_field = field.__class__( + *cloned_field_args, **cloned_field_kwargs + ) + cloned_field.model = field.model + cloned_field.set_attributes_from_name(field.name) + + if cloned_field.remote_field: + cloned_field.remote_field.model = field.remote_field.model + cloned_field.set_attributes_from_rel() + + return cloned_field diff --git a/psqlextra/settings.py b/psqlextra/settings.py new file mode 100644 index 00000000..6dd32f37 --- /dev/null +++ b/psqlextra/settings.py @@ -0,0 +1,118 @@ +from contextlib import contextmanager +from typing import Dict, List, Optional, Union + +from django.core.exceptions import SuspiciousOperation +from django.db import DEFAULT_DB_ALIAS, connections + + +@contextmanager +def postgres_set_local( + *, + using: str = DEFAULT_DB_ALIAS, + **options: Dict[str, Optional[Union[str, int, float, List[str]]]], +) -> None: + """Sets the specified PostgreSQL options using SET LOCAL so that they apply + to the current transacton only. + + The effect is undone when the context manager exits. + + See https://www.postgresql.org/docs/current/runtime-config-client.html + for an overview of all available options. + """ + + connection = connections[using] + qn = connection.ops.quote_name + + if not connection.in_atomic_block: + raise SuspiciousOperation( + "SET LOCAL makes no sense outside a transaction. Start a transaction first." + ) + + sql = [] + params = [] + for name, value in options.items(): + if value is None: + sql.append(f"SET LOCAL {qn(name)} TO DEFAULT") + continue + + # Settings that accept a list of values are actually + # stored as string lists. We cannot just pass a list + # of values. We have to create the comma separated + # string ourselves. + if isinstance(value, list) or isinstance(value, tuple): + placeholder = ", ".join(["%s" for _ in value]) + params.extend(value) + else: + placeholder = "%s" + params.append(value) + + sql.append(f"SET LOCAL {qn(name)} = {placeholder}") + + with connection.cursor() as cursor: + cursor.execute( + "SELECT name, setting FROM pg_settings WHERE name = ANY(%s)", + (list(options.keys()),), + ) + original_values = dict(cursor.fetchall()) + cursor.execute("; ".join(sql), params) + + yield + + # Put everything back to how it was. DEFAULT is + # not good enough as a outer SET LOCAL might + # have set a different value. + with connection.cursor() as cursor: + sql = [] + params = [] + + for name, value in options.items(): + original_value = original_values.get(name) + if original_value: + sql.append(f"SET LOCAL {qn(name)} = {original_value}") + else: + sql.append(f"SET LOCAL {qn(name)} TO DEFAULT") + + cursor.execute("; ".join(sql), params) + + +@contextmanager +def postgres_set_local_search_path( + search_path: List[str], *, using: str = DEFAULT_DB_ALIAS +) -> None: + """Sets the search path to the specified schemas.""" + + with postgres_set_local(search_path=search_path, using=using): + yield + + +@contextmanager +def postgres_prepend_local_search_path( + search_path: List[str], *, using: str = DEFAULT_DB_ALIAS +) -> None: + """Prepends the current local search path with the specified schemas.""" + + connection = connections[using] + + with connection.cursor() as cursor: + cursor.execute("SHOW search_path") + [ + original_search_path, + ] = cursor.fetchone() + + placeholders = ", ".join(["%s" for _ in search_path]) + cursor.execute( + f"SET LOCAL search_path = {placeholders}, {original_search_path}", + tuple(search_path), + ) + + yield + + cursor.execute(f"SET LOCAL search_path = {original_search_path}") + + +@contextmanager +def postgres_reset_local_search_path(*, using: str = DEFAULT_DB_ALIAS) -> None: + """Resets the local search path to the default.""" + + with postgres_set_local(search_path=None, using=using): + yield diff --git a/tests/db_introspection.py b/tests/db_introspection.py index bdcd4b19..285cd0e4 100644 --- a/tests/db_introspection.py +++ b/tests/db_introspection.py @@ -4,38 +4,100 @@ This makes test code less verbose and easier to read/write. """ +from contextlib import contextmanager +from typing import Optional + from django.db import connection +from psqlextra.settings import postgres_set_local + + +@contextmanager +def introspect(schema_name: Optional[str] = None): + with postgres_set_local(search_path=schema_name or None): + with connection.cursor() as cursor: + yield connection.introspection, cursor -def table_names(include_views: bool = True): + +def table_names( + include_views: bool = True, *, schema_name: Optional[str] = None +): """Gets a flat list of tables in the default database.""" - with connection.cursor() as cursor: - introspection = connection.introspection + with introspect(schema_name) as (introspection, cursor): return introspection.table_names(cursor, include_views) -def get_partitioned_table(table_name: str): +def get_partitioned_table( + table_name: str, + *, + schema_name: Optional[str] = None, +): """Gets the definition of a partitioned table in the default database.""" - with connection.cursor() as cursor: - introspection = connection.introspection + with introspect(schema_name) as (introspection, cursor): return introspection.get_partitioned_table(cursor, table_name) -def get_partitions(table_name: str): +def get_partitions( + table_name: str, + *, + schema_name: Optional[str] = None, +): """Gets a list of partitions for the specified partitioned table in the default database.""" - with connection.cursor() as cursor: - introspection = connection.introspection + with introspect(schema_name) as (introspection, cursor): return introspection.get_partitions(cursor, table_name) -def get_constraints(table_name: str): - """Gets a complete list of constraints and indexes for the specified - table.""" +def get_columns( + table_name: str, + *, + schema_name: Optional[str] = None, +): + """Gets a list of columns for the specified table.""" + + with introspect(schema_name) as (introspection, cursor): + return introspection.get_columns(cursor, table_name) + + +def get_relations( + table_name: str, + *, + schema_name: Optional[str] = None, +): + """Gets a list of relations for the specified table.""" + + with introspect(schema_name) as (introspection, cursor): + return introspection.get_relations(cursor, table_name) - with connection.cursor() as cursor: - introspection = connection.introspection + +def get_constraints( + table_name: str, + *, + schema_name: Optional[str] = None, +): + """Gets a list of constraints and indexes for the specified table.""" + + with introspect(schema_name) as (introspection, cursor): return introspection.get_constraints(cursor, table_name) + + +def get_sequences( + table_name: str, + *, + schema_name: Optional[str] = None, +): + """Gets a list of sequences own by the specified table.""" + + with introspect(schema_name) as (introspection, cursor): + return introspection.get_sequences(cursor, table_name) + + +def get_storage_settings(table_name: str, *, schema_name: Optional[str] = None): + """Gets a list of all storage settings that have been set on the specified + table.""" + + with introspect(schema_name) as (introspection, cursor): + return introspection.get_storage_settings(cursor, table_name) diff --git a/tests/fake_model.py b/tests/fake_model.py index 1254e762..ec626f3a 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -3,9 +3,10 @@ import uuid from contextlib import contextmanager +from typing import Type from django.apps import AppConfig, apps -from django.db import connection +from django.db import connection, models from psqlextra.models import ( PostgresMaterializedViewModel, @@ -39,6 +40,17 @@ def define_fake_model( return model +def undefine_fake_model(model: Type[models.Model]) -> None: + """Removes the fake model from the app registry.""" + + app_label = model._meta.app_label or "tests" + app_models = apps.app_configs[app_label].models + + for model_name in [model.__name__, model.__name__.lower()]: + if model_name in app_models: + del app_models[model_name] + + def define_fake_view_model( fields=None, view_options={}, meta_options={}, model_base=PostgresViewModel ): @@ -115,6 +127,15 @@ def get_fake_model(fields=None, model_base=PostgresModel, meta_options={}): return model +def delete_fake_model(model: Type[models.Model]) -> None: + """Deletes a fake model from the database and the internal app registry.""" + + undefine_fake_model(model) + + with connection.schema_editor() as schema_editor: + schema_editor.delete_model(model) + + @contextmanager def define_fake_app(): """Creates and registers a fake Django app.""" diff --git a/tests/test_schema_editor_alter_schema.py b/tests/test_schema_editor_alter_schema.py new file mode 100644 index 00000000..7fda103b --- /dev/null +++ b/tests/test_schema_editor_alter_schema.py @@ -0,0 +1,44 @@ +import pytest + +from django.db import connection, models + +from psqlextra.backend.schema import PostgresSchemaEditor + +from .fake_model import get_fake_model + + +@pytest.fixture +def fake_model(): + return get_fake_model( + { + "text": models.TextField(), + } + ) + + +def test_schema_editor_alter_table_schema(fake_model): + obj = fake_model.objects.create(text="hello") + + with connection.cursor() as cursor: + cursor.execute("CREATE SCHEMA target") + + schema_editor = PostgresSchemaEditor(connection) + schema_editor.alter_table_schema(fake_model._meta.db_table, "target") + + with connection.cursor() as cursor: + cursor.execute(f"SELECT * FROM target.{fake_model._meta.db_table}") + assert cursor.fetchall() == [(obj.id, obj.text)] + + +def test_schema_editor_alter_model_schema(fake_model): + obj = fake_model.objects.create(text="hello") + + with connection.cursor() as cursor: + cursor.execute("CREATE SCHEMA target") + + schema_editor = PostgresSchemaEditor(connection) + schema_editor.alter_model_schema(fake_model, "target") + + with connection.cursor() as cursor: + cursor.execute(f"SELECT * FROM target.{fake_model._meta.db_table}") + assert cursor.fetchall() == [(obj.id, obj.text)] diff --git a/tests/test_schema_editor_clone_model_to_schema.py b/tests/test_schema_editor_clone_model_to_schema.py new file mode 100644 index 00000000..c3d41917 --- /dev/null +++ b/tests/test_schema_editor_clone_model_to_schema.py @@ -0,0 +1,330 @@ +import os + +from typing import Set, Tuple + +import django +import pytest + +from django.contrib.postgres.fields import ArrayField +from django.contrib.postgres.indexes import GinIndex +from django.db import connection, models, transaction +from django.db.models import Q + +from psqlextra.backend.schema import PostgresSchemaEditor + +from . import db_introspection +from .fake_model import delete_fake_model, get_fake_model + +django_32_skip_reason = "Django < 3.2 can't support cloning models because it has hard coded references to the public schema" + + +def _create_schema() -> str: + name = os.urandom(4).hex() + + with connection.cursor() as cursor: + cursor.execute( + "DROP SCHEMA IF EXISTS %s CASCADE" + % connection.ops.quote_name(name), + tuple(), + ) + cursor.execute( + "CREATE SCHEMA %s" % connection.ops.quote_name(name), tuple() + ) + + return name + + +@transaction.atomic +def _assert_cloned_table_is_same( + source_table_fqn: Tuple[str, str], + target_table_fqn: Tuple[str, str], + excluding_constraints_and_indexes: bool = False, +): + source_schema_name, source_table_name = source_table_fqn + target_schema_name, target_table_name = target_table_fqn + + source_columns = db_introspection.get_columns( + source_table_name, schema_name=source_schema_name + ) + target_columns = db_introspection.get_columns( + target_table_name, schema_name=target_schema_name + ) + assert source_columns == target_columns + + source_relations = db_introspection.get_relations( + source_table_name, schema_name=source_schema_name + ) + target_relations = db_introspection.get_relations( + target_table_name, schema_name=target_schema_name + ) + if excluding_constraints_and_indexes: + assert target_relations == {} + else: + assert source_relations == target_relations + + source_constraints = db_introspection.get_constraints( + source_table_name, schema_name=source_schema_name + ) + target_constraints = db_introspection.get_constraints( + target_table_name, schema_name=target_schema_name + ) + if excluding_constraints_and_indexes: + assert target_constraints == {} + else: + assert source_constraints == target_constraints + + source_sequences = db_introspection.get_sequences( + source_table_name, schema_name=source_schema_name + ) + target_sequences = db_introspection.get_sequences( + target_table_name, schema_name=target_schema_name + ) + assert source_sequences == target_sequences + + source_storage_settings = db_introspection.get_storage_settings( + source_table_name, + schema_name=source_schema_name, + ) + target_storage_settings = db_introspection.get_storage_settings( + target_table_name, schema_name=target_schema_name + ) + assert source_storage_settings == target_storage_settings + + +def _list_lock_modes_in_schema(schema_name: str) -> Set[str]: + with connection.cursor() as cursor: + cursor.execute( + """ + SELECT + l.mode + FROM pg_locks l + INNER JOIN pg_class t ON t.oid = l.relation + INNER JOIN pg_namespace n ON n.oid = t.relnamespace + WHERE + t.relnamespace >= 2200 + AND n.nspname = %s + ORDER BY n.nspname, t.relname, l.mode + """, + (schema_name,), + ) + + return {lock_mode for lock_mode, in cursor.fetchall()} + + +def _clone_model_into_schema(model): + schema_name = _create_schema() + + with PostgresSchemaEditor(connection) as schema_editor: + schema_editor.clone_model_structure_to_schema( + model, schema_name=schema_name + ) + schema_editor.clone_model_constraints_and_indexes_to_schema( + model, schema_name=schema_name + ) + schema_editor.clone_model_foreign_keys_to_schema( + model, schema_name=schema_name + ) + + return schema_name + + +@pytest.fixture +def fake_model_fk_target_1(): + model = get_fake_model( + { + "name": models.TextField(), + }, + ) + + yield model + + delete_fake_model(model) + + +@pytest.fixture +def fake_model_fk_target_2(): + model = get_fake_model( + { + "name": models.TextField(), + }, + ) + + yield model + + delete_fake_model(model) + + +@pytest.fixture +def fake_model(fake_model_fk_target_1, fake_model_fk_target_2): + model = get_fake_model( + { + "first_name": models.TextField(null=True), + "last_name": models.TextField(), + "age": models.PositiveIntegerField(), + "height": models.FloatField(), + "nicknames": ArrayField(base_field=models.TextField()), + "blob": models.JSONField(), + "family": models.ForeignKey( + fake_model_fk_target_1, on_delete=models.CASCADE + ), + "alternative_family": models.ForeignKey( + fake_model_fk_target_2, null=True, on_delete=models.SET_NULL + ), + }, + meta_options={ + "indexes": [ + models.Index(fields=["age", "height"]), + models.Index(fields=["age"], name="age_index"), + GinIndex(fields=["nicknames"], name="nickname_index"), + ], + "constraints": [ + models.UniqueConstraint( + fields=["first_name", "last_name"], + name="first_last_name_uniq", + ), + models.CheckConstraint( + check=Q(age__gt=0, height__gt=0), name="age_height_check" + ), + ], + "unique_together": ( + "first_name", + "nicknames", + ), + "index_together": ( + "blob", + "age", + ), + }, + ) + + yield model + + delete_fake_model(model) + + +@pytest.mark.skipif( + django.VERSION < (3, 2), + reason=django_32_skip_reason, +) +@pytest.mark.django_db(transaction=True) +def test_schema_editor_clone_model_to_schema( + fake_model, fake_model_fk_target_1, fake_model_fk_target_2 +): + """Tests that cloning a model into a separate schema without obtaining + AccessExclusiveLock on the source table works as expected.""" + + schema_editor = PostgresSchemaEditor(connection) + + with schema_editor: + schema_editor.alter_table_storage_setting( + fake_model._meta.db_table, "autovacuum_enabled", "false" + ) + + table_name = fake_model._meta.db_table + source_schema_name = "public" + target_schema_name = _create_schema() + + with schema_editor: + schema_editor.clone_model_structure_to_schema( + fake_model, schema_name=target_schema_name + ) + + assert _list_lock_modes_in_schema(source_schema_name) == { + "AccessShareLock" + } + + _assert_cloned_table_is_same( + (source_schema_name, table_name), + (target_schema_name, table_name), + excluding_constraints_and_indexes=True, + ) + + with schema_editor: + schema_editor.clone_model_constraints_and_indexes_to_schema( + fake_model, schema_name=target_schema_name + ) + + assert _list_lock_modes_in_schema(source_schema_name) == { + "AccessShareLock", + "ShareRowExclusiveLock", + } + + _assert_cloned_table_is_same( + (source_schema_name, table_name), + (target_schema_name, table_name), + ) + + with schema_editor: + schema_editor.clone_model_foreign_keys_to_schema( + fake_model, schema_name=target_schema_name + ) + + assert _list_lock_modes_in_schema(source_schema_name) == { + "AccessShareLock", + "RowShareLock", + } + + _assert_cloned_table_is_same( + (source_schema_name, table_name), + (target_schema_name, table_name), + ) + + +@pytest.mark.skipif( + django.VERSION < (3, 2), + reason=django_32_skip_reason, +) +def test_schema_editor_clone_model_to_schema_custom_constraint_names( + fake_model, fake_model_fk_target_1 +): + """Tests that even if constraints were given custom names, the cloned table + has those same custom names.""" + + table_name = fake_model._meta.db_table + source_schema_name = "public" + + constraints = db_introspection.get_constraints(table_name) + + primary_key_constraint = next( + ( + name + for name, constraint in constraints.items() + if constraint["primary_key"] + ), + None, + ) + foreign_key_constraint = next( + ( + name + for name, constraint in constraints.items() + if constraint["foreign_key"] + == (fake_model_fk_target_1._meta.db_table, "id") + ), + None, + ) + check_constraint = next( + ( + name + for name, constraint in constraints.items() + if constraint["check"] and constraint["columns"] == ["age"] + ), + None, + ) + + with connection.cursor() as cursor: + cursor.execute( + f"ALTER TABLE {table_name} RENAME CONSTRAINT {primary_key_constraint} TO custompkname" + ) + cursor.execute( + f"ALTER TABLE {table_name} RENAME CONSTRAINT {foreign_key_constraint} TO customfkname" + ) + cursor.execute( + f"ALTER TABLE {table_name} RENAME CONSTRAINT {check_constraint} TO customcheckname" + ) + + target_schema_name = _clone_model_into_schema(fake_model) + + _assert_cloned_table_is_same( + (source_schema_name, table_name), + (target_schema_name, table_name), + ) diff --git a/tests/test_schema_editor_storage_settings.py b/tests/test_schema_editor_storage_settings.py new file mode 100644 index 00000000..0f45934f --- /dev/null +++ b/tests/test_schema_editor_storage_settings.py @@ -0,0 +1,47 @@ +import pytest + +from django.db import connection, models + +from psqlextra.backend.schema import PostgresSchemaEditor + +from . import db_introspection +from .fake_model import get_fake_model + + +@pytest.fixture +def fake_model(): + return get_fake_model( + { + "text": models.TextField(), + } + ) + + +def test_schema_editor_storage_settings_table_alter_and_reset(fake_model): + table_name = fake_model._meta.db_table + schema_editor = PostgresSchemaEditor(connection) + + schema_editor.alter_table_storage_setting( + table_name, "autovacuum_enabled", "false" + ) + assert db_introspection.get_storage_settings(table_name) == { + "autovacuum_enabled": "false" + } + + schema_editor.reset_table_storage_setting(table_name, "autovacuum_enabled") + assert db_introspection.get_storage_settings(table_name) == {} + + +def test_schema_editor_storage_settings_model_alter_and_reset(fake_model): + table_name = fake_model._meta.db_table + schema_editor = PostgresSchemaEditor(connection) + + schema_editor.alter_model_storage_setting( + fake_model, "autovacuum_enabled", "false" + ) + assert db_introspection.get_storage_settings(table_name) == { + "autovacuum_enabled": "false" + } + + schema_editor.reset_model_storage_setting(fake_model, "autovacuum_enabled") + assert db_introspection.get_storage_settings(table_name) == {} diff --git a/tests/test_schema_editor_vacuum.py b/tests/test_schema_editor_vacuum.py new file mode 100644 index 00000000..59772e86 --- /dev/null +++ b/tests/test_schema_editor_vacuum.py @@ -0,0 +1,147 @@ +import pytest + +from django.core.exceptions import SuspiciousOperation +from django.db import connection, models +from django.test.utils import CaptureQueriesContext + +from psqlextra.backend.schema import PostgresSchemaEditor + +from .fake_model import delete_fake_model, get_fake_model + + +@pytest.fixture +def fake_model(): + model = get_fake_model( + { + "name": models.TextField(), + } + ) + + yield model + + delete_fake_model(model) + + +@pytest.fixture +def fake_model_non_concrete_field(fake_model): + model = get_fake_model( + { + "fk": models.ForeignKey( + fake_model, on_delete=models.CASCADE, related_name="fakes" + ), + } + ) + + yield model + + delete_fake_model(model) + + +def test_schema_editor_vacuum_not_in_transaction(fake_model): + schema_editor = PostgresSchemaEditor(connection) + + with pytest.raises(SuspiciousOperation): + schema_editor.vacuum_table(fake_model._meta.db_table) + + +@pytest.mark.parametrize( + "kwargs,query", + [ + (dict(), "VACUUM %s"), + (dict(full=True), "VACUUM (FULL) %s"), + (dict(analyze=True), "VACUUM (ANALYZE) %s"), + (dict(parallel=8), "VACUUM (PARALLEL 8) %s"), + (dict(analyze=True, verbose=True), "VACUUM (VERBOSE, ANALYZE) %s"), + ( + dict(analyze=True, parallel=8, verbose=True), + "VACUUM (VERBOSE, ANALYZE, PARALLEL 8) %s", + ), + (dict(freeze=True), "VACUUM (FREEZE) %s"), + (dict(verbose=True), "VACUUM (VERBOSE) %s"), + (dict(disable_page_skipping=True), "VACUUM (DISABLE_PAGE_SKIPPING) %s"), + (dict(skip_locked=True), "VACUUM (SKIP_LOCKED) %s"), + (dict(index_cleanup=True), "VACUUM (INDEX_CLEANUP) %s"), + (dict(truncate=True), "VACUUM (TRUNCATE) %s"), + ], +) +@pytest.mark.django_db(transaction=True) +def test_schema_editor_vacuum_table(fake_model, kwargs, query): + schema_editor = PostgresSchemaEditor(connection) + + with CaptureQueriesContext(connection) as ctx: + schema_editor.vacuum_table(fake_model._meta.db_table, **kwargs) + + queries = [query["sql"] for query in ctx.captured_queries] + assert queries == [ + query % connection.ops.quote_name(fake_model._meta.db_table) + ] + + +@pytest.mark.django_db(transaction=True) +def test_schema_editor_vacuum_table_columns(fake_model): + schema_editor = PostgresSchemaEditor(connection) + + with CaptureQueriesContext(connection) as ctx: + schema_editor.vacuum_table( + fake_model._meta.db_table, ["id", "name"], analyze=True + ) + + queries = [query["sql"] for query in ctx.captured_queries] + assert queries == [ + 'VACUUM (ANALYZE) %s ("id", "name")' + % connection.ops.quote_name(fake_model._meta.db_table) + ] + + +@pytest.mark.django_db(transaction=True) +def test_schema_editor_vacuum_model(fake_model): + schema_editor = PostgresSchemaEditor(connection) + + with CaptureQueriesContext(connection) as ctx: + schema_editor.vacuum_model(fake_model, analyze=True, parallel=8) + + queries = [query["sql"] for query in ctx.captured_queries] + assert queries == [ + "VACUUM (ANALYZE, PARALLEL 8) %s" + % connection.ops.quote_name(fake_model._meta.db_table) + ] + + +@pytest.mark.django_db(transaction=True) +def test_schema_editor_vacuum_model_fields(fake_model): + schema_editor = PostgresSchemaEditor(connection) + + with CaptureQueriesContext(connection) as ctx: + schema_editor.vacuum_model( + fake_model, + [fake_model._meta.get_field("name")], + analyze=True, + parallel=8, + ) + + queries = [query["sql"] for query in ctx.captured_queries] + assert queries == [ + 'VACUUM (ANALYZE, PARALLEL 8) %s ("name")' + % connection.ops.quote_name(fake_model._meta.db_table) + ] + + +@pytest.mark.django_db(transaction=True) +def test_schema_editor_vacuum_model_non_concrete_fields( + fake_model, fake_model_non_concrete_field +): + schema_editor = PostgresSchemaEditor(connection) + + with CaptureQueriesContext(connection) as ctx: + schema_editor.vacuum_model( + fake_model, + [fake_model._meta.get_field("fakes")], + analyze=True, + parallel=8, + ) + + queries = [query["sql"] for query in ctx.captured_queries] + assert queries == [ + "VACUUM (ANALYZE, PARALLEL 8) %s" + % connection.ops.quote_name(fake_model._meta.db_table) + ] diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 00000000..44519714 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,93 @@ +import pytest + +from django.core.exceptions import SuspiciousOperation +from django.db import connection + +from psqlextra.settings import ( + postgres_prepend_local_search_path, + postgres_reset_local_search_path, + postgres_set_local, + postgres_set_local_search_path, +) + + +def _get_current_setting(name: str) -> None: + with connection.cursor() as cursor: + cursor.execute(f"SHOW {name}") + return cursor.fetchone()[0] + + +@postgres_set_local(statement_timeout="2s", lock_timeout="3s") +def test_postgres_set_local_function_decorator(): + assert _get_current_setting("statement_timeout") == "2s" + assert _get_current_setting("lock_timeout") == "3s" + + +def test_postgres_set_local_context_manager(): + with postgres_set_local(statement_timeout="2s"): + assert _get_current_setting("statement_timeout") == "2s" + + assert _get_current_setting("statement_timeout") == "0" + + +def test_postgres_set_local_iterable(): + with postgres_set_local(search_path=["a", "public"]): + assert _get_current_setting("search_path") == "a, public" + + assert _get_current_setting("search_path") == '"$user", public' + + +def test_postgres_set_local_nested(): + with postgres_set_local(statement_timeout="2s"): + assert _get_current_setting("statement_timeout") == "2s" + + with postgres_set_local(statement_timeout="3s"): + assert _get_current_setting("statement_timeout") == "3s" + + assert _get_current_setting("statement_timeout") == "2s" + + assert _get_current_setting("statement_timeout") == "0" + + +@pytest.mark.django_db(transaction=True) +def test_postgres_set_local_no_transaction(): + with pytest.raises(SuspiciousOperation): + with postgres_set_local(statement_timeout="2s"): + pass + + +def test_postgres_set_local_search_path(): + with postgres_set_local_search_path(["a", "public"]): + assert _get_current_setting("search_path") == "a, public" + + assert _get_current_setting("search_path") == '"$user", public' + + +def test_postgres_reset_local_search_path(): + with postgres_set_local_search_path(["a", "public"]): + with postgres_reset_local_search_path(): + assert _get_current_setting("search_path") == '"$user", public' + + assert _get_current_setting("search_path") == "a, public" + + assert _get_current_setting("search_path") == '"$user", public' + + +def test_postgres_prepend_local_search_path(): + with postgres_prepend_local_search_path(["a", "b"]): + assert _get_current_setting("search_path") == 'a, b, "$user", public' + + assert _get_current_setting("search_path") == '"$user", public' + + +def test_postgres_prepend_local_search_path_nested(): + with postgres_prepend_local_search_path(["a", "b"]): + with postgres_prepend_local_search_path(["c"]): + assert ( + _get_current_setting("search_path") + == 'c, a, b, "$user", public' + ) + + assert _get_current_setting("search_path") == 'a, b, "$user", public' + + assert _get_current_setting("search_path") == '"$user", public'