Skip to content

Commit 4554c7e

Browse files
HitmanRanbochengwanli
andauthored
fix: make registerTask safe on concurrent use (#616)
Co-authored-by: chengwanli <chengwanli@p1staff.com>
1 parent 260989d commit 4554c7e

4 files changed

Lines changed: 57 additions & 29 deletions

File tree

v1/common/broker.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package common
22

33
import (
44
"errors"
5+
"sync"
56

67
"github.com/RichardKnop/machinery/v1/brokers/iface"
78
"github.com/RichardKnop/machinery/v1/config"
@@ -10,10 +11,15 @@ import (
1011
"github.com/RichardKnop/machinery/v1/tasks"
1112
)
1213

14+
type registeredTaskNames struct {
15+
sync.RWMutex
16+
items []string
17+
}
18+
1319
// Broker represents a base broker structure
1420
type Broker struct {
1521
cnf *config.Config
16-
registeredTaskNames []string
22+
registeredTaskNames registeredTaskNames
1723
retry bool
1824
retryFunc func(chan int)
1925
retryStopChan chan int
@@ -62,12 +68,14 @@ func (b *Broker) Publish(signature *tasks.Signature) error {
6268

6369
// SetRegisteredTaskNames sets registered task names
6470
func (b *Broker) SetRegisteredTaskNames(names []string) {
65-
b.registeredTaskNames = names
71+
b.registeredTaskNames.Lock()
72+
defer b.registeredTaskNames.Unlock()
73+
b.registeredTaskNames.items = names
6674
}
6775

6876
// IsTaskRegistered returns true if the task is registered with this broker
6977
func (b *Broker) IsTaskRegistered(name string) bool {
70-
for _, registeredTaskName := range b.registeredTaskNames {
78+
for _, registeredTaskName := range b.registeredTaskNames.items {
7179
if registeredTaskName == name {
7280
return true
7381
}
@@ -110,7 +118,7 @@ func (b *Broker) StopConsuming() {
110118

111119
// GetRegisteredTaskNames returns registered tasks names
112120
func (b *Broker) GetRegisteredTaskNames() []string {
113-
return b.registeredTaskNames
121+
return b.registeredTaskNames.items
114122
}
115123

116124
// AdjustRoutingKey makes sure the routing key is correct.

v1/server.go

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import (
2828
// All the tasks workers process are registered against the server
2929
type Server struct {
3030
config *config.Config
31-
registeredTasks map[string]interface{}
31+
registeredTasks *sync.Map
3232
broker brokersiface.Broker
3333
backend backendsiface.Backend
3434
lock lockiface.Lock
@@ -40,7 +40,7 @@ type Server struct {
4040
func NewServerWithBrokerBackendLock(cnf *config.Config, brokerServer brokersiface.Broker, backendServer backendsiface.Backend, lock lockiface.Lock) *Server {
4141
srv := &Server{
4242
config: cnf,
43-
registeredTasks: map[string]interface{}{},
43+
registeredTasks: new(sync.Map),
4444
broker: brokerServer,
4545
backend: backendServer,
4646
lock: lock,
@@ -143,7 +143,11 @@ func (server *Server) RegisterTasks(namedTaskFuncs map[string]interface{}) error
143143
return err
144144
}
145145
}
146-
server.registeredTasks = namedTaskFuncs
146+
147+
for k, v := range namedTaskFuncs {
148+
server.registeredTasks.Store(k, v)
149+
}
150+
147151
server.broker.SetRegisteredTaskNames(server.GetRegisteredTaskNames())
148152
return nil
149153
}
@@ -153,20 +157,20 @@ func (server *Server) RegisterTask(name string, taskFunc interface{}) error {
153157
if err := tasks.ValidateTask(taskFunc); err != nil {
154158
return err
155159
}
156-
server.registeredTasks[name] = taskFunc
160+
server.registeredTasks.Store(name, taskFunc)
157161
server.broker.SetRegisteredTaskNames(server.GetRegisteredTaskNames())
158162
return nil
159163
}
160164

161165
// IsTaskRegistered returns true if the task name is registered with this broker
162166
func (server *Server) IsTaskRegistered(name string) bool {
163-
_, ok := server.registeredTasks[name]
167+
_, ok := server.registeredTasks.Load(name)
164168
return ok
165169
}
166170

167171
// GetRegisteredTask returns registered task by name
168172
func (server *Server) GetRegisteredTask(name string) (interface{}, error) {
169-
taskFunc, ok := server.registeredTasks[name]
173+
taskFunc, ok := server.registeredTasks.Load(name)
170174
if !ok {
171175
return nil, fmt.Errorf("Task not registered error: %s", name)
172176
}
@@ -340,12 +344,12 @@ func (server *Server) SendChord(chord *tasks.Chord, sendConcurrency int) (*resul
340344

341345
// GetRegisteredTaskNames returns slice of registered task names
342346
func (server *Server) GetRegisteredTaskNames() []string {
343-
taskNames := make([]string, len(server.registeredTasks))
344-
var i = 0
345-
for name := range server.registeredTasks {
346-
taskNames[i] = name
347-
i++
348-
}
347+
taskNames := make([]string, 0)
348+
349+
server.registeredTasks.Range(func(key, value interface{}) bool {
350+
taskNames = append(taskNames, key.(string))
351+
return true
352+
})
349353
return taskNames
350354
}
351355

v1/server_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,20 @@ func TestRegisterTask(t *testing.T) {
3333
assert.NoError(t, err, "test_task is not registered but it should be")
3434
}
3535

36+
func TestRegisterTaskInRaceCondition(t *testing.T) {
37+
t.Parallel()
38+
39+
server := getTestServer(t)
40+
for i:=0; i<10; i++ {
41+
go func() {
42+
err := server.RegisterTask("test_task", func() error { return nil })
43+
assert.NoError(t, err)
44+
_, err = server.GetRegisteredTask("test_task")
45+
assert.NoError(t, err, "test_task is not registered but it should be")
46+
}()
47+
}
48+
}
49+
3650
func TestGetRegisteredTask(t *testing.T) {
3751
t.Parallel()
3852

v2/server.go

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import (
2727
// All the tasks workers process are registered against the server
2828
type Server struct {
2929
config *config.Config
30-
registeredTasks map[string]interface{}
30+
registeredTasks *sync.Map
3131
broker brokersiface.Broker
3232
backend backendsiface.Backend
3333
lock lockiface.Lock
@@ -39,7 +39,7 @@ type Server struct {
3939
func NewServer(cnf *config.Config, brokerServer brokersiface.Broker, backendServer backendsiface.Backend, lock lockiface.Lock) *Server {
4040
srv := &Server{
4141
config: cnf,
42-
registeredTasks: make(map[string]interface{}),
42+
registeredTasks: new(sync.Map),
4343
broker: brokerServer,
4444
backend: backendServer,
4545
lock: lock,
@@ -56,7 +56,7 @@ func NewServer(cnf *config.Config, brokerServer brokersiface.Broker, backendServ
5656
func NewServerWithBrokerBackendLock(cnf *config.Config, brokerServer brokersiface.Broker, backendServer backendsiface.Backend, lock lockiface.Lock) *Server {
5757
srv := &Server{
5858
config: cnf,
59-
registeredTasks: map[string]interface{}{},
59+
registeredTasks: new(sync.Map),
6060
broker: brokerServer,
6161
backend: backendServer,
6262
lock: lock,
@@ -131,7 +131,9 @@ func (server *Server) RegisterTasks(namedTaskFuncs map[string]interface{}) error
131131
return err
132132
}
133133
}
134-
server.registeredTasks = namedTaskFuncs
134+
for k, v := range namedTaskFuncs {
135+
server.registeredTasks.Store(k, v)
136+
}
135137
server.broker.SetRegisteredTaskNames(server.GetRegisteredTaskNames())
136138
return nil
137139
}
@@ -141,20 +143,20 @@ func (server *Server) RegisterTask(name string, taskFunc interface{}) error {
141143
if err := tasks.ValidateTask(taskFunc); err != nil {
142144
return err
143145
}
144-
server.registeredTasks[name] = taskFunc
146+
server.registeredTasks.Store(name, taskFunc)
145147
server.broker.SetRegisteredTaskNames(server.GetRegisteredTaskNames())
146148
return nil
147149
}
148150

149151
// IsTaskRegistered returns true if the task name is registered with this broker
150152
func (server *Server) IsTaskRegistered(name string) bool {
151-
_, ok := server.registeredTasks[name]
153+
_, ok := server.registeredTasks.Load(name)
152154
return ok
153155
}
154156

155157
// GetRegisteredTask returns registered task by name
156158
func (server *Server) GetRegisteredTask(name string) (interface{}, error) {
157-
taskFunc, ok := server.registeredTasks[name]
159+
taskFunc, ok := server.registeredTasks.Load(name)
158160
if !ok {
159161
return nil, fmt.Errorf("Task not registered error: %s", name)
160162
}
@@ -328,12 +330,12 @@ func (server *Server) SendChord(chord *tasks.Chord, sendConcurrency int) (*resul
328330

329331
// GetRegisteredTaskNames returns slice of registered task names
330332
func (server *Server) GetRegisteredTaskNames() []string {
331-
taskNames := make([]string, len(server.registeredTasks))
332-
var i = 0
333-
for name := range server.registeredTasks {
334-
taskNames[i] = name
335-
i++
336-
}
333+
taskNames := make([]string, 0)
334+
335+
server.registeredTasks.Range(func(key, value interface{}) bool {
336+
taskNames = append(taskNames, key.(string))
337+
return true
338+
})
337339
return taskNames
338340
}
339341

0 commit comments

Comments
 (0)