Skip to content

Commit e8739b3

Browse files
committed
New ChangeTrackingMixin implementation
1 parent 054f6f7 commit e8739b3

File tree

5 files changed

+65
-92
lines changed

5 files changed

+65
-92
lines changed

redash/handlers/dashboards.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def post(self, dashboard_slug):
9191
def delete(self, dashboard_slug):
9292
dashboard = models.Dashboard.get_by_slug_and_org(dashboard_slug, self.current_org)
9393
dashboard.is_archived = True
94-
dashboard.record_changes(changed_by=self.current_user)
9594
models.db.session.add(dashboard)
9695
d = dashboard.to_dict(with_widgets=True, user=self.current_user)
9796
models.db.session.commit()

redash/handlers/queries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def get(self, query_id):
142142
def delete(self, query_id):
143143
query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org)
144144
require_admin_or_owner(query.user_id)
145-
query.archive(self.current_user)
145+
query.archive()
146146

147147

148148
class QueryForkResource(BaseResource):

redash/models.py

Lines changed: 58 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,13 @@
44
import itertools
55
import json
66
import logging
7-
import time
87

98
from funcy import project
109
from flask_sqlalchemy import SQLAlchemy
1110
from flask.ext.sqlalchemy import SignallingSession
1211
from flask_login import UserMixin, AnonymousUserMixin
1312
from sqlalchemy.dialects import postgresql
14-
from sqlalchemy.event import listens_for
15-
from sqlalchemy.inspection import inspect
13+
from sqlalchemy.event import listens_for, listen
1614
from sqlalchemy.types import TypeDecorator
1715
from sqlalchemy.orm import object_session
1816
# noinspection PyUnresolvedReferences
@@ -84,50 +82,61 @@ def process_result_value(self, value, dialect):
8482

8583
class TimestampMixin(object):
8684
updated_at = Column(db.DateTime(True), default=db.func.now(),
87-
onupdate=db.func.now(), nullable=False)
85+
onupdate=db.func.now(), nullable=False)
8886
created_at = Column(db.DateTime(True), default=db.func.now(),
89-
nullable=False)
87+
nullable=False)
9088

9189

9290
class ChangeTrackingMixin(object):
93-
skipped_fields = ('id', 'created_at', 'updated_at', 'version')
94-
_clean_values = None
95-
96-
def __init__(self, *a, **kw):
97-
super(ChangeTrackingMixin, self).__init__(*a, **kw)
98-
self.record_changes(self.user)
99-
100-
def prep_cleanvalues(self):
101-
self.__dict__['_clean_values'] = {}
102-
for attr in inspect(self.__class__).column_attrs:
103-
col, = attr.columns
104-
# 'query' is col name but not attr name
105-
self._clean_values[col.name] = None
106-
107-
def __setattr__(self, key, value):
108-
if self._clean_values is None:
109-
self.prep_cleanvalues()
110-
for attr in inspect(self.__class__).column_attrs:
111-
col, = attr.columns
112-
previous = getattr(self, attr.key, None)
113-
self._clean_values[col.name] = previous
114-
115-
super(ChangeTrackingMixin, self).__setattr__(key, value)
116-
117-
def record_changes(self, changed_by):
118-
db.session.add(self)
119-
db.session.flush()
91+
@classmethod
92+
def after_change_listener(cls, mapper, connection, target):
93+
state = db.inspect(target)
12094
changes = {}
121-
for attr in inspect(self.__class__).column_attrs:
122-
col, = attr.columns
123-
if attr.key not in self.skipped_fields:
124-
changes[col.name] = {'previous': self._clean_values[col.name],
125-
'current': getattr(self, attr.key)}
12695

127-
db.session.add(Change(object=self,
128-
object_version=self.version,
129-
user=changed_by,
130-
change=changes))
96+
for attr in state.attrs:
97+
if attr.key not in cls.tracked_columns:
98+
continue
99+
100+
hist = state.get_history(attr.key, True)
101+
102+
if not hist.has_changes():
103+
continue
104+
105+
if hist.deleted:
106+
previous = hist.deleted[0]
107+
else:
108+
previous = None
109+
110+
changes[attr.key] = {
111+
'previous': previous,
112+
'current': attr.value
113+
}
114+
115+
if changes:
116+
changed_by = cls.fetch_current_user_id() or target.user_id
117+
db.session.add(Change(object=target,
118+
object_version=target.version,
119+
user_id=changed_by,
120+
change=changes))
121+
122+
@staticmethod
123+
def fetch_current_user_id():
124+
from flask_login import current_user
125+
from flask import has_app_context, has_request_context
126+
127+
# Return None if we are outside of request context.
128+
if not has_app_context() or not has_request_context():
129+
return
130+
try:
131+
return current_user.id
132+
except AttributeError:
133+
return
134+
135+
@classmethod
136+
def __declare_last__(cls):
137+
# get called after mappings are completed
138+
listen(cls, 'after_update', cls.after_change_listener)
139+
listen(cls, 'after_insert', cls.after_change_listener)
131140

132141

133142
class BelongsToOrgMixin(object):
@@ -612,6 +621,9 @@ def should_schedule_next(previous_iteration, now, schedule):
612621

613622

614623
class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
624+
tracked_columns = ('data_source_id', 'latest_query_data_id', 'name', 'description', 'query_text', 'user_id',
625+
'is_archived', 'is_draft', 'schedule', 'options')
626+
615627
id = Column(db.Integer, primary_key=True)
616628
version = Column(db.Integer)
617629
org_id = Column(db.Integer, db.ForeignKey('organizations.id'))
@@ -684,7 +696,7 @@ def to_dict(self, with_stats=False, with_visualizations=False, with_user=True, w
684696

685697
return d
686698

687-
def archive(self, user=None):
699+
def archive(self):
688700
db.session.add(self)
689701
self.is_archived = True
690702
self.schedule = None
@@ -696,9 +708,6 @@ def archive(self, user=None):
696708
for a in self.alerts:
697709
db.session.delete(a)
698710

699-
if user:
700-
self.record_changes(user)
701-
702711
@classmethod
703712
def all_queries(cls, groups, drafts=False):
704713
q = (cls.query.join(User, Query.user_id == User.id)
@@ -798,20 +807,6 @@ def fork(self, user):
798807
db.session.add(forked_query)
799808
return forked_query
800809

801-
def update_instance_tracked(self, changing_user, old_object=None, *args, **kwargs):
802-
self.version += 1
803-
self.update_instance(*args, **kwargs)
804-
# save Change record
805-
new_change = Change.save_change(user=changing_user, old_object=old_object, new_object=self)
806-
return new_change
807-
808-
def tracked_save(self, changing_user, old_object=None, *args, **kwargs):
809-
self.version += 1
810-
self.save(*args, **kwargs)
811-
# save Change record
812-
new_change = Change.save_change(user=changing_user, old_object=old_object, new_object=self)
813-
return new_change
814-
815810
@property
816811
def runtime(self):
817812
return self.latest_query_data.runtime
@@ -831,6 +826,8 @@ def __unicode__(self):
831826
return unicode(self.id)
832827

833828

829+
830+
834831
@listens_for(Query.query_text, 'set')
835832
def gen_query_hash(target, val, oldval, initiator):
836833
target.query_hash = utils.gen_query_hash(val)
@@ -950,10 +947,6 @@ def to_dict(self, full=True):
950947

951948
return d
952949

953-
@classmethod
954-
def log_change(cls, changed_by, obj):
955-
return cls.create(object=obj, object_version=obj.version, user=changed_by, change=obj.changes)
956-
957950
@classmethod
958951
def last_change(cls, obj):
959952
return db.session.query(cls).filter(
@@ -1050,6 +1043,8 @@ def generate_slug(ctx):
10501043

10511044

10521045
class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
1046+
tracked_columns = ('slug', 'name', 'user_id', 'layout', 'dashboard_filters_enabled', 'is_archived', 'is_draft')
1047+
10531048
id = Column(db.Integer, primary_key=True)
10541049
version = Column(db.Integer)
10551050
org_id = Column(db.Integer, db.ForeignKey("organizations.id"))

tests/models/test_changes.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ def create_object(factory):
1111
data_source=factory.data_source,
1212
org=factory.org)
1313

14+
db.session.commit()
15+
1416
return obj
1517

1618

@@ -23,41 +25,28 @@ def test_returns_initial_state(self):
2325

2426

2527
class TestLogChange(BaseTestCase):
26-
def obj(self):
27-
obj = Query(name='Query',
28-
description='',
29-
query_text='SELECT 1',
30-
user=self.factory.user,
31-
data_source=self.factory.data_source,
32-
org=self.factory.org)
33-
34-
return obj
35-
3628
def test_properly_logs_first_creation(self):
3729
obj = create_object(self.factory)
38-
obj.record_changes(changed_by=self.factory.user)
3930
change = Change.last_change(obj)
4031

4132
self.assertIsNotNone(change)
4233
self.assertEqual(change.object_version, 1)
34+
self.assertEqual(obj.user, change.user)
4335

4436
def test_skips_unnecessary_fields(self):
4537
obj = create_object(self.factory)
46-
obj.record_changes(changed_by=self.factory.user)
4738
change = Change.last_change(obj)
4839

4940
self.assertIsNotNone(change)
5041
self.assertEqual(change.object_version, 1)
51-
for field in ChangeTrackingMixin.skipped_fields:
52-
self.assertNotIn(field, change.change)
42+
for field in change.change:
43+
self.assertIn(field, Query.tracked_columns)
5344

5445
def test_properly_log_modification(self):
5546
obj = create_object(self.factory)
56-
obj.record_changes(changed_by=self.factory.user)
5747
obj.name = 'Query 2'
5848
obj.description = 'description'
5949
db.session.flush()
60-
obj.record_changes(changed_by=self.factory.user)
6150

6251
change = Change.last_change(obj)
6352

@@ -67,11 +56,3 @@ def test_properly_log_modification(self):
6756
self.assertIn('name', change.change)
6857
self.assertIn('description', change.change)
6958

70-
def test_logs_create_method(self):
71-
q = Query(name='Query', description='', query_text='',
72-
user=self.factory.user, data_source=self.factory.data_source,
73-
org=self.factory.org)
74-
change = Change.last_change(q)
75-
76-
self.assertIsNotNone(change)
77-
self.assertEqual(q.user, change.user)

tests/models/test_queries.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,14 @@
22
from redash.models import Query, db
33

44

5-
class TestApiKeyGetByObject(BaseTestCase):
6-
5+
class TestQueryFork(BaseTestCase):
76
def assert_visualizations(self, origin_q, origin_v, forked_q, forked_v):
87
self.assertEqual(origin_v.options, forked_v.options)
98
self.assertEqual(origin_v.type, forked_v.type)
109
self.assertNotEqual(origin_v.id, forked_v.id)
1110
self.assertNotEqual(origin_v.query_rel, forked_v.query_rel)
1211
self.assertEqual(forked_q.id, forked_v.query_rel.id)
1312

14-
1513
def test_fork_with_visualizations(self):
1614
# prepare original query and visualizations
1715
data_source = self.factory.create_data_source(

0 commit comments

Comments
 (0)