1
- from pyasn1_modules .rfc5126 import ContentType
1
+ from django .db .models import Value , When , Case , F , Q , OuterRef , Subquery
2
+ from django .db .models .fields import CharField , IntegerField
3
+ from django .db .models .functions import Concat , Cast
4
+ from django .contrib .contenttypes .models import ContentType
2
5
from rest_framework import generics
3
6
from rest_framework import permissions as drf_permissions
4
7
from rest_framework .exceptions import NotFound
5
- from django .core .exceptions import ObjectDoesNotExist
8
+ from django .core .exceptions import ObjectDoesNotExist , PermissionDenied
6
9
7
10
from framework .auth .oauth_scopes import CoreScopes
8
11
from api .base .views import JSONAPIBaseView
19
22
CollectionProvider ,
20
23
PreprintProvider ,
21
24
RegistrationProvider ,
22
- AbstractProvider ,
25
+ AbstractProvider , AbstractNode , Preprint , OSFUser ,
23
26
)
24
- from osf .models .notification import NotificationSubscription
27
+ from osf .models .notification import NotificationSubscription , NotificationType
25
28
26
29
27
30
class SubscriptionList (JSONAPIBaseView , generics .ListAPIView , ListFilterMixin ):
@@ -38,8 +41,47 @@ class SubscriptionList(JSONAPIBaseView, generics.ListAPIView, ListFilterMixin):
38
41
required_write_scopes = [CoreScopes .NULL ]
39
42
40
43
def get_queryset (self ):
41
- return NotificationSubscription .objects .filter (
42
- user = self .request .user ,
44
+ user_guid = self .request .user ._id
45
+ provider_ct = ContentType .objects .get (app_label = 'osf' , model = 'abstractprovider' )
46
+
47
+ provider_subquery = AbstractProvider .objects .filter (
48
+ id = Cast (OuterRef ('object_id' ), IntegerField ()),
49
+ ).values ('_id' )[:1 ]
50
+
51
+ node_subquery = AbstractNode .objects .filter (
52
+ id = Cast (OuterRef ('object_id' ), IntegerField ()),
53
+ ).values ('guids___id' )[:1 ]
54
+
55
+ return NotificationSubscription .objects .filter (user = self .request .user ).annotate (
56
+ event_name = Case (
57
+ When (
58
+ notification_type__name = NotificationType .Type .NODE_FILES_UPDATED .value ,
59
+ then = Value ('files_updated' ),
60
+ ),
61
+ When (
62
+ notification_type__name = NotificationType .Type .USER_FILE_UPDATED .value ,
63
+ then = Value ('global_file_updated' ),
64
+ ),
65
+ default = F ('notification_type__name' ),
66
+ output_field = CharField (),
67
+ ),
68
+ legacy_id = Case (
69
+ When (
70
+ notification_type__name = NotificationType .Type .NODE_FILES_UPDATED .value ,
71
+ then = Concat (Subquery (node_subquery ), Value ('_file_updated' )),
72
+ ),
73
+ When (
74
+ notification_type__name = NotificationType .Type .USER_FILE_UPDATED .value ,
75
+ then = Value (f'{ user_guid } _global' ),
76
+ ),
77
+ When (
78
+ Q (notification_type__name = NotificationType .Type .PROVIDER_NEW_PENDING_SUBMISSIONS .value ) &
79
+ Q (content_type = provider_ct ),
80
+ then = Concat (Subquery (provider_subquery ), Value ('_new_pending_submissions' )),
81
+ ),
82
+ default = F ('notification_type__name' ),
83
+ output_field = CharField (),
84
+ ),
43
85
)
44
86
45
87
@@ -67,10 +109,63 @@ class SubscriptionDetail(JSONAPIBaseView, generics.RetrieveUpdateAPIView):
67
109
68
110
def get_object (self ):
69
111
subscription_id = self .kwargs ['subscription_id' ]
112
+ user_guid = self .request .user ._id
113
+
114
+ provider_ct = ContentType .objects .get (app_label = 'osf' , model = 'abstractprovider' )
115
+ node_ct = ContentType .objects .get (app_label = 'osf' , model = 'abstractnode' )
116
+
117
+ provider_subquery = AbstractProvider .objects .filter (
118
+ id = Cast (OuterRef ('object_id' ), IntegerField ()),
119
+ ).values ('_id' )[:1 ]
120
+
121
+ node_subquery = AbstractNode .objects .filter (
122
+ id = Cast (OuterRef ('object_id' ), IntegerField ()),
123
+ ).values ('guids___id' )[:1 ]
124
+
125
+ guid_id , * event_parts = subscription_id .split ('_' )
126
+ event = '_' .join (event_parts ) if event_parts else ''
127
+
128
+ subscription_obj = AbstractNode .load (guid_id ) or Preprint .load (guid_id ) or OSFUser .load (guid_id )
129
+
130
+ if event != 'global' :
131
+ obj_filter = Q (
132
+ object_id = getattr (subscription_obj , 'id' , None ),
133
+ content_type = ContentType .objects .get_for_model (subscription_obj .__class__ ),
134
+ notification_type__name__icontains = event ,
135
+ )
136
+ else :
137
+ obj_filter = Q ()
138
+
70
139
try :
71
- obj = NotificationSubscription .objects .get (id = subscription_id )
140
+ obj = NotificationSubscription .objects .annotate (
141
+ legacy_id = Case (
142
+ When (
143
+ notification_type__name = NotificationType .Type .NODE_FILES_UPDATED .value ,
144
+ content_type = node_ct ,
145
+ then = Concat (Subquery (node_subquery ), Value ('_file_updated' )),
146
+ ),
147
+ When (
148
+ notification_type__name = NotificationType .Type .USER_FILE_UPDATED .value ,
149
+ then = Value (f'{ user_guid } _global' ),
150
+ ),
151
+ When (
152
+ notification_type__name = NotificationType .Type .PROVIDER_NEW_PENDING_SUBMISSIONS .value ,
153
+ content_type = provider_ct ,
154
+ then = Concat (Subquery (provider_subquery ), Value ('_new_pending_submissions' )),
155
+ ),
156
+ default = Value (f'{ user_guid } _global' ),
157
+ output_field = CharField (),
158
+ ),
159
+ ).filter (obj_filter )
160
+
72
161
except ObjectDoesNotExist :
73
162
raise NotFound
163
+
164
+ try :
165
+ obj = obj .filter (user = self .request .user ).get ()
166
+ except ObjectDoesNotExist :
167
+ raise PermissionDenied
168
+
74
169
self .check_object_permissions (self .request , obj )
75
170
return obj
76
171
0 commit comments