diff --git a/callbacks/create.go b/callbacks/create.go index 210a46f7f..d930e9225 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -293,13 +293,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } } - for field, vs := range defaultValueFieldsHavingValue { - values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) - for idx := range values.Values { - if vs[idx] == nil { - values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) - } else { - values.Values[idx] = append(values.Values[idx], vs[idx]) + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if vs, ok := defaultValueFieldsHavingValue[field]; ok { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) + for idx := range values.Values { + if vs[idx] == nil { + values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) + } else { + values.Values[idx] = append(values.Values[idx], vs[idx]) + } } } } diff --git a/callbacks/create_test.go b/callbacks/create_test.go new file mode 100644 index 000000000..da6b172bd --- /dev/null +++ b/callbacks/create_test.go @@ -0,0 +1,71 @@ +package callbacks + +import ( + "reflect" + "sync" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +var schemaCache = &sync.Map{} + +func TestConvertToCreateValues_DestType_Slice(t *testing.T) { + type user struct { + ID int `gorm:"primaryKey"` + Name string + Email string `gorm:"default:(-)"` + Age int `gorm:"default:(-)"` + } + + s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{}) + if err != nil { + t.Errorf("parse schema error: %v, is not expected", err) + return + } + dest := []*user{ + { + ID: 1, + Name: "alice", + Email: "email", + Age: 18, + }, + { + ID: 2, + Name: "bob", + Email: "email", + Age: 19, + }, + } + stmt := &gorm.Statement{ + DB: &gorm.DB{ + Config: &gorm.Config{ + NowFunc: func() time.Time { return time.Time{} }, + }, + Statement: &gorm.Statement{ + Settings: sync.Map{}, + Schema: s, + }, + }, + ReflectValue: reflect.ValueOf(dest), + Dest: dest, + } + + stmt.Schema = s + + values := ConvertToCreateValues(stmt) + expected := clause.Values{ + // column has value + defaultValue column has value (which should have a stable order) + Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}}, + Values: [][]interface{}{ + {"alice", "email", 18, 1}, + {"bob", "email", 19, 2}, + }, + } + if !reflect.DeepEqual(expected, values) { + t.Errorf("expected: %v got %v", expected, values) + } +}