Skip to content

Allow specifying a specific constraint to use in ON CONFLICT #213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions docs/source/conflict_handling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,41 @@ Specifying multiple columns is necessary in case of a constraint that spans mult
)


Specific constraint
*******************

Alternatively, instead of specifying the columns the constraint you're targetting applies to, you can also specify the exact constraint to use:

.. code-block:: python

from django.db import models
from psqlextra.models import PostgresModel

class MyModel(PostgresModel)
class Meta:
constraints = [
models.UniqueConstraint(
name="myconstraint",
fields=["first_name", "last_name"]
),
]

first_name = models.CharField(max_length=255)
last_name = models.CharField(max_length=255)

constraint = next(
constraint
for constraint in MyModel._meta.constraints
if constraint.name == "myconstraint"
), None)

obj = (
MyModel.objects
.on_conflict(constraint, ConflictAction.UPDATE)
.insert_and_get(first_name='Henk', last_name='Jansen')
)


HStore keys
***********
Catching conflicts in columns with a ``UNIQUE`` constraint on a :class:`~psqlextra.fields.HStoreField` key is also supported:
Expand Down
19 changes: 17 additions & 2 deletions psqlextra/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,11 @@ def _rewrite_insert_on_conflict(

# build the conflict target, the columns to watch
# for conflicts
conflict_target = self._build_conflict_target()
on_conflict_clause = self._build_on_conflict_clause()
index_predicate = self.query.index_predicate # type: ignore[attr-defined]
update_condition = self.query.conflict_update_condition # type: ignore[attr-defined]

rewritten_sql = f"{sql} ON CONFLICT {conflict_target}"
rewritten_sql = f"{sql} {on_conflict_clause}"

if index_predicate:
expr_sql, expr_params = self._compile_expression(index_predicate)
Expand All @@ -270,6 +270,21 @@ def _rewrite_insert_on_conflict(

return (rewritten_sql, params)

def _build_on_conflict_clause(self):
if django.VERSION >= (2, 2):
from django.db.models.constraints import BaseConstraint
from django.db.models.indexes import Index

if isinstance(
self.query.conflict_target, BaseConstraint
) or isinstance(self.query.conflict_target, Index):
return "ON CONFLICT ON CONSTRAINT %s" % self.qn(
self.query.conflict_target.name
)

conflict_target = self._build_conflict_target()
return f"ON CONFLICT {conflict_target}"

def _build_conflict_target(self):
"""Builds the `conflict_target` for the ON CONFLICT clause."""

Expand Down
6 changes: 5 additions & 1 deletion psqlextra/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from .sql import PostgresInsertQuery, PostgresQuery
from .types import ConflictAction

ConflictTarget = List[Union[str, Tuple[str]]]
if TYPE_CHECKING:
from django.db.models.constraints import BaseConstraint
from django.db.models.indexes import Index

ConflictTarget = Union[List[Union[str, Tuple[str]]], "BaseConstraint", "Index"]


TModel = TypeVar("TModel", bound=models.Model, covariant=True)
Expand Down
30 changes: 30 additions & 0 deletions tests/test_on_conflict_update.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import django
import pytest

from django.db import models
Expand Down Expand Up @@ -41,6 +42,35 @@ def test_on_conflict_update():
assert obj2.cookies == "choco"


@pytest.mark.skipif(
django.VERSION < (2, 2),
reason="Django < 2.2 doesn't implement constraints",
)
def test_on_conflict_update_by_unique_constraint():
model = get_fake_model(
{
"title": models.CharField(max_length=255, null=True),
},
meta_options={
"constraints": [
models.UniqueConstraint(name="test_uniq", fields=["title"]),
],
},
)

constraint = next(
(
constraint
for constraint in model._meta.constraints
if constraint.name == "test_uniq"
)
)

model.objects.on_conflict(constraint, ConflictAction.UPDATE).insert_and_get(
title="title"
)


def test_on_conflict_update_foreign_key_by_object():
"""Tests whether simple upsert works correctly when the conflicting field
is a foreign key specified as an object."""
Expand Down