diff --git a/psqlextra/manager/manager.py b/psqlextra/manager/manager.py index a787a763..c574a67d 100644 --- a/psqlextra/manager/manager.py +++ b/psqlextra/manager/manager.py @@ -248,7 +248,7 @@ def upsert(self, conflict_target: List, fields: Dict, index_predicate: str=None) self.on_conflict(conflict_target, ConflictAction.UPDATE, index_predicate) return self.insert(**fields) - def upsert_and_get(self, conflict_target: List, fields: Dict): + def upsert_and_get(self, conflict_target: List, fields: Dict, index_predicate: str=None): """Creates a new record or updates the existing one with the specified data and then gets the row. @@ -259,15 +259,19 @@ def upsert_and_get(self, conflict_target: List, fields: Dict): fields: Fields to insert/update. + index_predicate: + The index predicate to satisfy an arbiter partial index (i.e. what partial index to use for checking + conflicts) + Returns: The model instance representing the row that was created/updated. """ - self.on_conflict(conflict_target, ConflictAction.UPDATE) + self.on_conflict(conflict_target, ConflictAction.UPDATE, index_predicate) return self.insert_and_get(**fields) - def bulk_upsert(self, conflict_target: List, rows: List[Dict]): + def bulk_upsert(self, conflict_target: List, rows: List[Dict], index_predicate: str=None): """Creates a set of new records or updates the existing ones with the specified data. @@ -277,9 +281,13 @@ def bulk_upsert(self, conflict_target: List, rows: List[Dict]): rows: Rows to upsert. + + index_predicate: + The index predicate to satisfy an arbiter partial index (i.e. what partial index to use for checking + conflicts) """ - self.on_conflict(conflict_target, ConflictAction.UPDATE) + self.on_conflict(conflict_target, ConflictAction.UPDATE, index_predicate) return self.bulk_insert(rows) def _build_insert_compiler(self, rows: List[Dict]): @@ -468,7 +476,7 @@ def get_queryset(self): return PostgresQuerySet(self.model, using=self._db) - def on_conflict(self, fields: List[Union[str, Tuple[str]]], action): + def on_conflict(self, fields: List[Union[str, Tuple[str]]], action, index_predicate: str=None): """Sets the action to take when conflicts arise when attempting to insert/create a new row. @@ -478,8 +486,11 @@ def on_conflict(self, fields: List[Union[str, Tuple[str]]], action): action: The action to take when the conflict occurs. + + index_predicate: + The index predicate to satisfy an arbiter partial index. """ - return self.get_queryset().on_conflict(fields, action) + return self.get_queryset().on_conflict(fields, action, index_predicate) def upsert(self, conflict_target: List, fields: Dict, index_predicate: str=None) -> int: """Creates a new record or updates the existing one @@ -501,7 +512,7 @@ def upsert(self, conflict_target: List, fields: Dict, index_predicate: str=None) return self.get_queryset().upsert(conflict_target, fields, index_predicate) - def upsert_and_get(self, conflict_target: List, fields: Dict): + def upsert_and_get(self, conflict_target: List, fields: Dict, index_predicate: str=None): """Creates a new record or updates the existing one with the specified data and then gets the row. @@ -512,14 +523,17 @@ def upsert_and_get(self, conflict_target: List, fields: Dict): fields: Fields to insert/update. + index_predicate: + The index predicate to satisfy an arbiter partial index. + Returns: The model instance representing the row that was created/updated. """ - return self.get_queryset().upsert_and_get(conflict_target, fields) + return self.get_queryset().upsert_and_get(conflict_target, fields, index_predicate) - def bulk_upsert(self, conflict_target: List, rows: List[Dict]): + def bulk_upsert(self, conflict_target: List, rows: List[Dict], index_predicate: str=None): """Creates a set of new records or updates the existing ones with the specified data. @@ -527,11 +541,14 @@ def bulk_upsert(self, conflict_target: List, rows: List[Dict]): conflict_target: Fields to pass into the ON CONFLICT clause. + index_predicate: + The index predicate to satisfy an arbiter partial index. + rows: Rows to upsert. """ - return self.get_queryset().bulk_upsert(conflict_target, rows) + return self.get_queryset().bulk_upsert(conflict_target, rows, index_predicate) @staticmethod def _on_model_save(sender, **kwargs):