7
7
8
8
use std:: sync:: Arc ;
9
9
use std:: thread;
10
+ use std:: thread:: JoinHandle ;
10
11
use std:: time:: Instant ;
11
12
12
13
use crossbeam:: channel:: unbounded;
@@ -18,10 +19,13 @@ pub struct TaskQueue<S, T> {
18
19
processor : Arc < dyn TaskProcessor < S , T > > ,
19
20
pub receiver : Receiver < T > ,
20
21
scheduler : Arc < TaskScheduler < T > > ,
22
+ active_thread_handles : Vec < JoinHandle < ( ) > > ,
21
23
}
22
24
23
25
pub trait TaskProcessor < S , T > : Send + Sync + ' static {
24
26
fn process ( & self , state : Arc < S > , task : T ) ;
27
+
28
+ fn is_serial_task ( & self , task : & T ) -> bool ;
25
29
}
26
30
27
31
pub struct TaskScheduler < T > {
@@ -48,25 +52,62 @@ where
48
52
processor,
49
53
receiver,
50
54
scheduler : Arc :: new ( TaskScheduler { sender } ) ,
55
+ active_thread_handles : Vec :: new ( ) ,
51
56
}
52
57
}
53
58
54
59
pub fn get_scheduler ( & self ) -> Arc < TaskScheduler < T > > {
55
60
Arc :: clone ( & self . scheduler )
56
61
}
57
62
58
- pub fn process ( & self , state : Arc < S > , task : T ) {
63
+ pub fn process ( & mut self , state : Arc < S > , task : T ) {
64
+ let processor = Arc :: clone ( & self . processor ) ;
65
+ let is_serial_task = processor. is_serial_task ( & task) ;
66
+
67
+ if is_serial_task {
68
+ // Before starting a serial task, we need to make sure that all
69
+ // previous tasks have been completed, otherwise the serial task
70
+ // might interfere with them.
71
+ self . ensure_previous_tasks_completed ( ) ;
72
+ } else {
73
+ // We do this in order for the active threads to not grow
74
+ // indefinitely, if there hasn't been a serial task in a while.
75
+ self . cleanup_already_finished_tasks ( ) ;
76
+ }
77
+
59
78
let task_str = format ! ( "{:?}" , & task) ;
60
79
let now = Instant :: now ( ) ;
61
80
debug ! ( "Processing task {:?}" , & task_str) ;
62
- let processor = Arc :: clone ( & self . processor ) ;
63
- thread:: spawn ( move || {
81
+
82
+ let handle = thread:: spawn ( move || {
64
83
processor. process ( state, task) ;
84
+
65
85
debug ! (
66
86
"task {} completed in {}ms" ,
67
87
task_str,
68
88
now. elapsed( ) . as_millis( )
69
89
) ;
70
90
} ) ;
91
+
92
+ if is_serial_task {
93
+ // If the task is serial, we need to wait for its thread
94
+ // to complete, before moving onto the next task.
95
+ let _ = handle. join ( ) ;
96
+ } else {
97
+ self . active_thread_handles . push ( handle) ;
98
+ }
99
+ }
100
+
101
+ fn ensure_previous_tasks_completed ( & mut self ) {
102
+ for handle in self . active_thread_handles . drain ( ..) {
103
+ // We don't actually care whether the thread has panicked or not,
104
+ // we just want to make sure it's finished.
105
+ let _ = handle. join ( ) ;
106
+ }
107
+ }
108
+
109
+ fn cleanup_already_finished_tasks ( & mut self ) {
110
+ self . active_thread_handles
111
+ . retain ( |handle| handle. is_finished ( ) == false ) ;
71
112
}
72
113
}
0 commit comments