Skip to content

Commit 6f4ee2d

Browse files
zinter (#1520)
* zinter * change options in _zaggregate * skip for previous versions * flake8 * validate the aggregate value * invalid aggregation * invalid aggregation * change options to get Co-authored-by: Chayim <[email protected]>
1 parent e9c2e45 commit 6f4ee2d

File tree

2 files changed

+58
-10
lines changed

2 files changed

+58
-10
lines changed

redis/client.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -595,8 +595,8 @@ class Redis:
595595
lambda r: r and set(r) or set()
596596
),
597597
**string_keys_to_dict(
598-
'ZPOPMAX ZPOPMIN ZDIFF ZRANGE ZRANGEBYSCORE ZREVRANGE '
599-
'ZREVRANGEBYSCORE', zset_score_pairs
598+
'ZPOPMAX ZPOPMIN ZINTER ZDIFF ZRANGE ZRANGEBYSCORE '
599+
'ZREVRANGE ZREVRANGEBYSCORE', zset_score_pairs
600600
),
601601
**string_keys_to_dict('BZPOPMIN BZPOPMAX', \
602602
lambda r:
@@ -2959,11 +2959,28 @@ def zincrby(self, name, amount, value):
29592959
"Increment the score of ``value`` in sorted set ``name`` by ``amount``"
29602960
return self.execute_command('ZINCRBY', name, amount, value)
29612961

2962+
def zinter(self, keys, aggregate=None, withscores=False):
2963+
"""
2964+
Return the intersect of multiple sorted sets specified by ``keys``.
2965+
With the ``aggregate`` option, it is possible to specify how the
2966+
results of the union are aggregated. This option defaults to SUM,
2967+
where the score of an element is summed across the inputs where it
2968+
exists. When this option is set to either MIN or MAX, the resulting
2969+
set will contain the minimum or maximum score of an element across
2970+
the inputs where it exists.
2971+
"""
2972+
return self._zaggregate('ZINTER', None, keys, aggregate,
2973+
withscores=withscores)
2974+
29622975
def zinterstore(self, dest, keys, aggregate=None):
29632976
"""
2964-
Intersect multiple sorted sets specified by ``keys`` into
2965-
a new sorted set, ``dest``. Scores in the destination will be
2966-
aggregated based on the ``aggregate``, or SUM if none is provided.
2977+
Intersect multiple sorted sets specified by ``keys`` into a new
2978+
sorted set, ``dest``. Scores in the destination will be aggregated
2979+
based on the ``aggregate``. This option defaults to SUM, where the
2980+
score of an element is summed across the inputs where it exists.
2981+
When this option is set to either MIN or MAX, the resulting set will
2982+
contain the minimum or maximum score of an element across the inputs
2983+
where it exists.
29672984
"""
29682985
return self._zaggregate('ZINTERSTORE', dest, keys, aggregate)
29692986

@@ -3253,8 +3270,12 @@ def zunionstore(self, dest, keys, aggregate=None):
32533270
"""
32543271
return self._zaggregate('ZUNIONSTORE', dest, keys, aggregate)
32553272

3256-
def _zaggregate(self, command, dest, keys, aggregate=None):
3257-
pieces = [command, dest, len(keys)]
3273+
def _zaggregate(self, command, dest, keys, aggregate=None,
3274+
**options):
3275+
pieces = [command]
3276+
if dest is not None:
3277+
pieces.append(dest)
3278+
pieces.append(len(keys))
32583279
if isinstance(keys, dict):
32593280
keys, weights = keys.keys(), keys.values()
32603281
else:
@@ -3264,9 +3285,14 @@ def _zaggregate(self, command, dest, keys, aggregate=None):
32643285
pieces.append(b'WEIGHTS')
32653286
pieces.extend(weights)
32663287
if aggregate:
3267-
pieces.append(b'AGGREGATE')
3268-
pieces.append(aggregate)
3269-
return self.execute_command(*pieces)
3288+
if aggregate.upper() in ['SUM', 'MIN', 'MAX']:
3289+
pieces.append(b'AGGREGATE')
3290+
pieces.append(aggregate)
3291+
else:
3292+
raise DataError("aggregate can be sum, min or max.")
3293+
if options.get('withscores', False):
3294+
pieces.append(b'WITHSCORES')
3295+
return self.execute_command(*pieces, **options)
32703296

32713297
# HYPERLOGLOG COMMANDS
32723298
def pfadd(self, name, *values):

tests/test_commands.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,6 +1519,28 @@ def test_zlexcount(self, r):
15191519
assert r.zlexcount('a', '-', '+') == 7
15201520
assert r.zlexcount('a', '[b', '[f') == 5
15211521

1522+
@skip_if_server_version_lt('6.2.0')
1523+
def test_zinter(self, r):
1524+
r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 1})
1525+
r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2})
1526+
r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4})
1527+
assert r.zinter(['a', 'b', 'c']) == [b'a3', b'a1']
1528+
# invalid aggregation
1529+
with pytest.raises(exceptions.DataError):
1530+
r.zinter(['a', 'b', 'c'], aggregate='foo', withscores=True)
1531+
# aggregate with SUM
1532+
assert r.zinter(['a', 'b', 'c'], withscores=True) \
1533+
== [(b'a3', 8), (b'a1', 9)]
1534+
# aggregate with MAX
1535+
assert r.zinter(['a', 'b', 'c'], aggregate='MAX', withscores=True) \
1536+
== [(b'a3', 5), (b'a1', 6)]
1537+
# aggregate with MIN
1538+
assert r.zinter(['a', 'b', 'c'], aggregate='MIN', withscores=True) \
1539+
== [(b'a1', 1), (b'a3', 1)]
1540+
# with weights
1541+
assert r.zinter({'a': 1, 'b': 2, 'c': 3}, withscores=True) \
1542+
== [(b'a3', 20), (b'a1', 23)]
1543+
15221544
def test_zinterstore_sum(self, r):
15231545
r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1})
15241546
r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2})

0 commit comments

Comments
 (0)