Skip to content

Commit

Permalink
fix: remove callback from callbacks if Remove() called
Browse files Browse the repository at this point in the history
  • Loading branch information
snackmgmg committed Mar 19, 2024
1 parent 1b0aa80 commit a730265
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 2 deletions.
33 changes: 32 additions & 1 deletion callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,23 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
}

func (p *processor) compile() (err error) {
var callbacks []*callback
var (
callbacks []*callback
removed []string
)
for _, callback := range p.callbacks {
if callback.match == nil || callback.match(p.db) {
callbacks = append(callbacks, callback)
}
if callback.remove {
removed = append(removed, callback.name)
}
}

if len(removed) > 0 {
callbacks = removeCallbacks(callbacks, removed)
}

p.callbacks = callbacks

if p.fns, err = sortCallbacks(p.callbacks); err != nil {
Expand Down Expand Up @@ -339,3 +350,23 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {

return
}

func removeCallbacks(cs []*callback, names []string) []*callback {
callbacks := make([]*callback, 0, len(cs))
for _, callback := range cs {
if contains(names, callback.name) {
continue
}
callbacks = append(callbacks, callback)
}
return callbacks
}

func contains(a []string, b string) bool {
for _, v := range a {
if b == v {
return true
}
}
return false
}
48 changes: 47 additions & 1 deletion tests/callbacks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestCallbacks(t *testing.T) {
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
results: []string{"c1", "c5", "c3", "c4"},
results: []string{"c1", "c3", "c4", "c5"},
},
{
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
Expand Down Expand Up @@ -206,3 +206,49 @@ func TestPluginCallbacks(t *testing.T) {
t.Errorf("callbacks tests failed, got %v", msg)
}
}

func TestCallbacksGet(t *testing.T) {
db, _ := gorm.Open(nil, nil)
createCallback := db.Callback().Create()

createCallback.Before("*").Register("c1", c1)
if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) {
t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1)
}

createCallback.Remove("c1")
if cb := createCallback.Get("c2"); cb != nil {
t.Errorf("callbacks test failed. got: %p, want: nil", cb)
}
}

func TestCallbacksRemove(t *testing.T) {
db, _ := gorm.Open(nil, nil)
createCallback := db.Callback().Create()

createCallback.Before("*").Register("c1", c1)
createCallback.After("*").Register("c2", c2)
createCallback.Before("c4").Register("c3", c3)
createCallback.After("c2").Register("c4", c4)

// callbacks: []string{"c1", "c3", "c4", "c2"}
createCallback.Remove("c1")
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}

createCallback.Remove("c4")
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}

createCallback.Remove("c2")
if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}

createCallback.Remove("c3")
if ok, msg := assertCallbacks(createCallback, []string{}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
}

0 comments on commit a730265

Please sign in to comment.