Skip to content

Add return_model to bulk_upsert #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions psqlextra/manager/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def upsert_and_get(self, conflict_target: List, fields: Dict, index_predicate: s
self.on_conflict(conflict_target, ConflictAction.UPDATE, index_predicate)
return self.insert_and_get(**fields)

def bulk_upsert(self, conflict_target: List, rows: List[Dict], index_predicate: str=None):
def bulk_upsert(self, conflict_target: List, rows: List[Dict], index_predicate: str=None, return_model: bool=False):
"""Creates a set of new records or updates the existing
ones with the specified data.

Expand All @@ -294,13 +294,21 @@ def bulk_upsert(self, conflict_target: List, rows: List[Dict], index_predicate:
index_predicate:
The index predicate to satisfy an arbiter partial index (i.e. what partial index to use for checking
conflicts)

return_model (default: False):
If model instances should be returned rather than
just dicts.

Returns:
A list of either the dicts of the rows upserted, including the pk or
the models of the rows upserted
"""

if not rows or len(rows) <= 0:
return

self.on_conflict(conflict_target, ConflictAction.UPDATE, index_predicate)
return self.bulk_insert(rows)
return self.bulk_insert(rows, return_model)

def _build_insert_compiler(self, rows: List[Dict]):
"""Builds the SQL compiler for a insert query.
Expand Down Expand Up @@ -545,7 +553,7 @@ def upsert_and_get(self, conflict_target: List, fields: Dict, index_predicate: s

return self.get_queryset().upsert_and_get(conflict_target, fields, index_predicate)

def bulk_upsert(self, conflict_target: List, rows: List[Dict], index_predicate: str=None):
def bulk_upsert(self, conflict_target: List, rows: List[Dict], index_predicate: str=None, return_model: bool=False):
"""Creates a set of new records or updates the existing
ones with the specified data.

Expand All @@ -558,9 +566,17 @@ def bulk_upsert(self, conflict_target: List, rows: List[Dict], index_predicate:

rows:
Rows to upsert.

return_model (default: False):
If model instances should be returned rather than
just dicts.

Returns:
A list of either the dicts of the rows upserted, including the pk or
the models of the rows upserted
"""

return self.get_queryset().bulk_upsert(conflict_target, rows, index_predicate)
return self.get_queryset().bulk_upsert(conflict_target, rows, index_predicate, return_model)

@staticmethod
def _on_model_save(sender, **kwargs):
Expand Down