diff --git a/callbacks/preload.go b/callbacks/preload.go index 25ecfe761..cf7a0d2ba 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -121,10 +121,23 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati } } else if rel := relationships.Relations[name]; rel != nil { if joined, nestedJoins := isJoined(name); joined { - reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) - tx := preloadDB(db, reflectValue, reflectValue.Interface()) - if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { - return err + switch rv := db.Statement.ReflectValue; rv.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i)) + tx := preloadDB(db, reflectValue, reflectValue.Interface()) + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + return err + } + } + case reflect.Struct: + reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv) + tx := preloadDB(db, reflectValue, reflectValue.Interface()) + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + return err + } + default: + return gorm.ErrInvalidData } } else { tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) diff --git a/tests/preload_test.go b/tests/preload_test.go index 26b08d7de..14f94139d 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -8,6 +8,8 @@ import ( "sync" "testing" + "github.com/stretchr/testify/require" + "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" @@ -362,6 +364,14 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) { t.Errorf("failed to find value, got err: %v", err) } AssertEqual(t, find2, value) + + var finds []Value + err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error + if err != nil { + t.Errorf("failed to find value, got err: %v", err) + } + require.Len(t, finds, 1) + AssertEqual(t, finds[0], value) } func TestEmbedPreload(t *testing.T) {