Skip to content

Commit 2ed97cf

Browse files
lkiesowshaardie
authored andcommitted
Database Session Wrapper
This patch introduces a database session wrapper ensuring that database sessions are properly closed even if errors occur while executing a function.
1 parent d4323b7 commit 2ed97cf

File tree

3 files changed

+40
-19
lines changed

3 files changed

+40
-19
lines changed

pyca/db.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
create_engine
1616
from sqlalchemy.orm import sessionmaker
1717
from datetime import datetime
18+
from functools import wraps
1819
Base = declarative_base()
1920

2021

@@ -39,6 +40,28 @@ def get_session():
3940
return Session()
4041

4142

43+
def with_session(f):
44+
"""Wrapper for f to make a SQLAlchemy session present within the function
45+
46+
:param f: Function to call
47+
:type f: Function
48+
:raises e: Possible exception of f
49+
:return: Result of f
50+
"""
51+
@wraps(f)
52+
def decorated(*args, **kwargs):
53+
session = get_session()
54+
try:
55+
result = f(session, *args, **kwargs)
56+
except Exception as e:
57+
session.rollback()
58+
raise e
59+
finally:
60+
session.close()
61+
return result
62+
return decorated
63+
64+
4265
class Constants():
4366

4467
@classmethod

pyca/ui/jsonapi.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pyca.config import config
44
from pyca.db import Service, ServiceStatus, UpcomingEvent, \
55
RecordedEvent, UpstreamState
6-
from pyca.db import get_session, Status, ServiceStates
6+
from pyca.db import with_session, Status, ServiceStates
77
from pyca.ui import app
88
from pyca.ui.utils import requires_auth, jsonapi_mediatype
99
from pyca.utils import get_service_status, ensurelist
@@ -79,11 +79,11 @@ def internal_state():
7979
@app.route('/api/events/')
8080
@requires_auth
8181
@jsonapi_mediatype
82-
def events():
82+
@with_session
83+
def events(db):
8384
'''Serve a JSON representation of events splitted by upcoming and already
8485
recorded events.
8586
'''
86-
db = get_session()
8787
upcoming_events = db.query(UpcomingEvent)\
8888
.order_by(UpcomingEvent.start)
8989
recorded_events = db.query(RecordedEvent)\
@@ -97,10 +97,10 @@ def events():
9797
@app.route('/api/events/<uid>')
9898
@requires_auth
9999
@jsonapi_mediatype
100-
def event(uid):
100+
@with_session
101+
def event(db, uid):
101102
'''Return a specific events JSON
102103
'''
103-
db = get_session()
104104
event = db.query(RecordedEvent).filter(RecordedEvent.uid == uid).first() \
105105
or db.query(UpcomingEvent).filter(UpcomingEvent.uid == uid).first()
106106

@@ -112,7 +112,8 @@ def event(uid):
112112
@app.route('/api/events/<uid>', methods=['DELETE'])
113113
@requires_auth
114114
@jsonapi_mediatype
115-
def delete_event(uid):
115+
@with_session
116+
def delete_event(db, uid):
116117
'''Delete a specific event identified by its uid. Note that only recorded
117118
events can be deleted. Events in the buffer for upcoming events are
118119
regularly replaced anyway and a manual removal could have unpredictable
@@ -124,7 +125,6 @@ def delete_event(uid):
124125
Returns 404 if event does not exist
125126
'''
126127
logger.info('deleting event %s via api', uid)
127-
db = get_session()
128128
events = db.query(RecordedEvent).filter(RecordedEvent.uid == uid)
129129
if not events.count():
130130
return make_error_response('No event with specified uid', 404)
@@ -140,7 +140,8 @@ def delete_event(uid):
140140
@app.route('/api/events/<uid>', methods=['PATCH'])
141141
@requires_auth
142142
@jsonapi_mediatype
143-
def modify_event(uid):
143+
@with_session
144+
def modify_event(db, uid):
144145
'''Modify an event specified by its uid. The modifications for the event
145146
are expected as JSON with the content type correctly set in the request.
146147
@@ -163,7 +164,6 @@ def modify_event(uid):
163164
except Exception:
164165
return make_error_response('Invalid data', 400)
165166

166-
db = get_session()
167167
event = db.query(RecordedEvent).filter(RecordedEvent.uid == uid).first()
168168
if not event:
169169
return make_error_response('No event with specified uid', 404)
@@ -177,7 +177,8 @@ def modify_event(uid):
177177

178178
@app.route('/api/metrics', methods=['GET'])
179179
@requires_auth
180-
def metrics():
180+
@with_session
181+
def metrics(dbs):
181182
'''Serve several metrics about the pyCA services and the machine via
182183
json.'''
183184
# Get Disk Usage
@@ -189,7 +190,6 @@ def metrics():
189190
# Get Memory
190191
memory = psutil.virtual_memory()
191192

192-
dbs = get_session()
193193
# Get Services
194194
srvs = dbs.query(ServiceStates)
195195
services = []
@@ -202,7 +202,6 @@ def metrics():
202202
state = dbs.query(UpstreamState).filter(
203203
UpstreamState.url == config()['server']['url']).first()
204204
last_synchronized = state.last_synced.isoformat() if state else None
205-
dbs.close()
206205
return make_response(
207206
{'meta': {
208207
'services': services,

pyca/utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -171,27 +171,26 @@ def recording_state(recording_id, status):
171171
logger.warning('Could not set recording state to %s: %s', status, e)
172172

173173

174-
def update_event_status(event, status):
174+
@db.with_session
175+
def update_event_status(dbs, event, status):
175176
'''Update the status of a particular event in the database.
176177
'''
177-
dbs = db.get_session()
178178
dbs.query(db.RecordedEvent).filter(db.RecordedEvent.start == event.start)\
179179
.update({'status': status})
180180
event.status = status
181181
dbs.commit()
182182

183183

184-
def set_service_status(service, status):
184+
@db.with_session
185+
def set_service_status(dbs, service, status):
185186
'''Update the status of a particular service in the database.
186187
'''
187188
srv = db.ServiceStates()
188189
srv.type = service
189190
srv.status = status
190191

191-
dbs = db.get_session()
192192
dbs.merge(srv)
193193
dbs.commit()
194-
dbs.close()
195194

196195

197196
def set_service_status_immediate(service, status):
@@ -202,10 +201,10 @@ def set_service_status_immediate(service, status):
202201
update_agent_state()
203202

204203

205-
def get_service_status(service):
204+
@db.with_session
205+
def get_service_status(dbs, service):
206206
'''Update the status of a particular service in the database.
207207
'''
208-
dbs = db.get_session()
209208
srvs = dbs.query(db.ServiceStates).filter(db.ServiceStates.type == service)
210209

211210
if srvs.count():

0 commit comments

Comments
 (0)