Skip to content

Commit a730265

Browse files
committed
fix: remove callback from callbacks if Remove() called
1 parent 1b0aa80 commit a730265

File tree

2 files changed

+79
-2
lines changed

2 files changed

+79
-2
lines changed

callbacks.go

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,23 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
186186
}
187187

188188
func (p *processor) compile() (err error) {
189-
var callbacks []*callback
189+
var (
190+
callbacks []*callback
191+
removed []string
192+
)
190193
for _, callback := range p.callbacks {
191194
if callback.match == nil || callback.match(p.db) {
192195
callbacks = append(callbacks, callback)
193196
}
197+
if callback.remove {
198+
removed = append(removed, callback.name)
199+
}
200+
}
201+
202+
if len(removed) > 0 {
203+
callbacks = removeCallbacks(callbacks, removed)
194204
}
205+
195206
p.callbacks = callbacks
196207

197208
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
@@ -339,3 +350,23 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
339350

340351
return
341352
}
353+
354+
func removeCallbacks(cs []*callback, names []string) []*callback {
355+
callbacks := make([]*callback, 0, len(cs))
356+
for _, callback := range cs {
357+
if contains(names, callback.name) {
358+
continue
359+
}
360+
callbacks = append(callbacks, callback)
361+
}
362+
return callbacks
363+
}
364+
365+
func contains(a []string, b string) bool {
366+
for _, v := range a {
367+
if b == v {
368+
return true
369+
}
370+
}
371+
return false
372+
}

tests/callbacks_test.go

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func TestCallbacks(t *testing.T) {
9191
},
9292
{
9393
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
94-
results: []string{"c1", "c5", "c3", "c4"},
94+
results: []string{"c1", "c3", "c4", "c5"},
9595
},
9696
{
9797
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
@@ -206,3 +206,49 @@ func TestPluginCallbacks(t *testing.T) {
206206
t.Errorf("callbacks tests failed, got %v", msg)
207207
}
208208
}
209+
210+
func TestCallbacksGet(t *testing.T) {
211+
db, _ := gorm.Open(nil, nil)
212+
createCallback := db.Callback().Create()
213+
214+
createCallback.Before("*").Register("c1", c1)
215+
if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) {
216+
t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1)
217+
}
218+
219+
createCallback.Remove("c1")
220+
if cb := createCallback.Get("c2"); cb != nil {
221+
t.Errorf("callbacks test failed. got: %p, want: nil", cb)
222+
}
223+
}
224+
225+
func TestCallbacksRemove(t *testing.T) {
226+
db, _ := gorm.Open(nil, nil)
227+
createCallback := db.Callback().Create()
228+
229+
createCallback.Before("*").Register("c1", c1)
230+
createCallback.After("*").Register("c2", c2)
231+
createCallback.Before("c4").Register("c3", c3)
232+
createCallback.After("c2").Register("c4", c4)
233+
234+
// callbacks: []string{"c1", "c3", "c4", "c2"}
235+
createCallback.Remove("c1")
236+
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok {
237+
t.Errorf("callbacks tests failed, got %v", msg)
238+
}
239+
240+
createCallback.Remove("c4")
241+
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok {
242+
t.Errorf("callbacks tests failed, got %v", msg)
243+
}
244+
245+
createCallback.Remove("c2")
246+
if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok {
247+
t.Errorf("callbacks tests failed, got %v", msg)
248+
}
249+
250+
createCallback.Remove("c3")
251+
if ok, msg := assertCallbacks(createCallback, []string{}); !ok {
252+
t.Errorf("callbacks tests failed, got %v", msg)
253+
}
254+
}

0 commit comments

Comments
 (0)