Skip to content

Support violation_error_code and violation_error_message from UniqueConstraint in UniqueTogetherValidator #9766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions rest_framework/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,17 @@ def get_validators(self):
self.get_unique_for_date_validators()
)

def _get_constraint_violation_error_message(self, constraint):
"""
Returns the violation error message for the UniqueConstraint,
or None if the message is the default.
"""
violation_error_message = constraint.get_violation_error_message()
default_error_message = constraint.default_violation_error_message % {"name": constraint.name}
if violation_error_message == default_error_message:
return None
return violation_error_message

def get_unique_together_validators(self):
"""
Determine a default set of validators for any unique_together constraints.
Expand All @@ -1595,6 +1606,11 @@ def get_unique_together_validators(self):
for name, source in field_sources.items():
source_map[source].append(name)

unique_constraint_by_fields = {
constraint.fields: constraint for constraint in self.Meta.model._meta.constraints
if isinstance(constraint, models.UniqueConstraint)
}

# Note that we make sure to check `unique_together` both on the
# base model class, but also on any parent classes.
validators = []
Expand All @@ -1621,11 +1637,17 @@ def get_unique_together_validators(self):
)

field_names = tuple(source_map[f][0] for f in unique_together)

constraint = unique_constraint_by_fields.get(tuple(unique_together))
violation_error_message = self._get_constraint_violation_error_message(constraint) if constraint else None

validator = UniqueTogetherValidator(
queryset=queryset,
fields=field_names,
condition_fields=tuple(source_map[f][0] for f in condition_fields),
condition=condition,
message=violation_error_message,
code=getattr(constraint, 'violation_error_code', None),
)
validators.append(validator)
return validators
Expand Down
7 changes: 5 additions & 2 deletions rest_framework/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,15 @@ class UniqueTogetherValidator:
message = _('The fields {field_names} must make a unique set.')
missing_message = _('This field is required.')
requires_context = True
code = 'unique'

def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None):
def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None, code=None):
self.queryset = queryset
self.fields = fields
self.message = message or self.message
self.condition_fields = [] if condition_fields is None else condition_fields
self.condition = condition
self.code = code or self.code

def enforce_required_fields(self, attrs, serializer):
"""
Expand Down Expand Up @@ -198,7 +200,7 @@ def __call__(self, attrs, serializer):
if checked_values and None not in checked_values and qs_exists_with_condition(queryset, self.condition, condition_kwargs):
field_names = ', '.join(self.fields)
message = self.message.format(field_names=field_names)
raise ValidationError(message, code='unique')
raise ValidationError(message, code=self.code)

def __repr__(self):
return '<{}({})>'.format(
Expand All @@ -217,6 +219,7 @@ def __eq__(self, other):
and self.missing_message == other.missing_message
and self.queryset == other.queryset
and self.fields == other.fields
and self.code == other.code
)


Expand Down
57 changes: 57 additions & 0 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,32 @@ class Meta:
]


class UniqueConstraintCustomMessageCodeModel(models.Model):
username = models.CharField(max_length=32)
company_id = models.IntegerField()
role = models.CharField(max_length=32)

class Meta:
constraints = [
models.UniqueConstraint(
fields=("username", "company_id"),
name="unique_username_company_custom_msg",
violation_error_message="Username must be unique within a company.",
violation_error_code="duplicate_username",
)
if django_version[0] >= 5
else models.UniqueConstraint(
fields=("username", "company_id"),
name="unique_username_company_custom_msg",
violation_error_message="Username must be unique within a company.",
),
models.UniqueConstraint(
fields=("company_id", "role"),
name="unique_company_role_default_msg",
),
]


class UniqueConstraintSerializer(serializers.ModelSerializer):
class Meta:
model = UniqueConstraintModel
Expand All @@ -628,6 +654,12 @@ class Meta:
fields = ('title', 'age', 'tag')


class UniqueConstraintCustomMessageCodeSerializer(serializers.ModelSerializer):
class Meta:
model = UniqueConstraintCustomMessageCodeModel
fields = ('username', 'company_id', 'role')


class TestUniqueConstraintValidation(TestCase):
def setUp(self):
self.instance = UniqueConstraintModel.objects.create(
Expand Down Expand Up @@ -778,6 +810,31 @@ class Meta:
)
assert serializer.is_valid()

def test_unique_constraint_custom_message_code(self):
UniqueConstraintCustomMessageCodeModel.objects.create(username="Alice", company_id=1, role="member")
expected_code = "duplicate_username" if django_version[0] >= 5 else UniqueTogetherValidator.code

serializer = UniqueConstraintCustomMessageCodeSerializer(data={
"username": "Alice",
"company_id": 1,
"role": "admin",
})
assert not serializer.is_valid()
assert serializer.errors == {"non_field_errors": ["Username must be unique within a company."]}
assert serializer.errors["non_field_errors"][0].code == expected_code

def test_unique_constraint_default_message_code(self):
UniqueConstraintCustomMessageCodeModel.objects.create(username="Alice", company_id=1, role="member")
serializer = UniqueConstraintCustomMessageCodeSerializer(data={
"username": "John",
"company_id": 1,
"role": "member",
})
expected_message = UniqueTogetherValidator.message.format(field_names=', '.join(("company_id", "role")))
assert not serializer.is_valid()
assert serializer.errors == {"non_field_errors": [expected_message]}
assert serializer.errors["non_field_errors"][0].code == UniqueTogetherValidator.code


# Tests for `UniqueForDateValidator`
# ----------------------------------
Expand Down