Skip to content

Commit

Permalink
Fix insert id into map results, fix #6812
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Mar 19, 2024
1 parent 1b0aa80 commit 81536f8
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 24 deletions.
23 changes: 16 additions & 7 deletions callbacks/create.go
Expand Up @@ -111,6 +111,17 @@ func Create(config *Config) func(db *gorm.DB) {
pkField *schema.Field
pkFieldName = "@id"
)

insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0

if !insertOk {
if !supportReturning {
db.AddError(err)
}
return
}

if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
return
Expand All @@ -119,13 +130,6 @@ func Create(config *Config) func(db *gorm.DB) {
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}

insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
db.AddError(err)
return
}

// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) {
Expand All @@ -142,6 +146,11 @@ func Create(config *Config) func(db *gorm.DB) {
}
}
}

if config.LastInsertIDReversed {
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
}

for _, mapValue := range mapValues {
if mapValue != nil {
mapValue[pkFieldName] = insertID
Expand Down
30 changes: 14 additions & 16 deletions tests/create_test.go
Expand Up @@ -713,18 +713,16 @@ func TestCreateFromMapWithoutPK(t *testing.T) {
}

func TestCreateFromMapWithTable(t *testing.T) {
if !isMysql() {
t.Skipf("This test case skipped, because of only supportting for mysql")
}
tableDB := DB.Table("`users`")
tableDB := DB.Table("users")
supportLastInsertID := isMysql() || isSqlite()

// case 1: create from map[string]interface{}
record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18}
record := map[string]interface{}{"name": "create_from_map_with_table", "age": 18}
if err := tableDB.Create(record).Error; err != nil {
t.Fatalf("failed to create data from map with table, got error: %v", err)
}

if _, ok := record["@id"]; !ok {
if _, ok := record["@id"]; !ok && supportLastInsertID {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}

Expand All @@ -733,8 +731,8 @@ func TestCreateFromMapWithTable(t *testing.T) {
t.Fatalf("failed to create from map, got error %v", err)
}

if int64(res["id"].(uint64)) != record["@id"] {
t.Fatal("failed to create data from map with table, @id != id")
if _, ok := record["@id"]; ok && fmt.Sprint(res["id"]) != fmt.Sprint(record["@id"]) {
t.Fatalf("failed to create data from map with table, @id != id, got %v, expect %v", res["id"], record["@id"])
}

// case 2: create from *map[string]interface{}
Expand All @@ -743,7 +741,7 @@ func TestCreateFromMapWithTable(t *testing.T) {
if err := tableDB2.Create(&record1).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
if _, ok := record1["@id"]; !ok {
if _, ok := record1["@id"]; !ok && supportLastInsertID {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}

Expand All @@ -752,7 +750,7 @@ func TestCreateFromMapWithTable(t *testing.T) {
t.Fatalf("failed to create from map, got error %v", err)
}

if int64(res1["id"].(uint64)) != record1["@id"] {
if _, ok := record1["@id"]; ok && fmt.Sprint(res1["id"]) != fmt.Sprint(record1["@id"]) {
t.Fatal("failed to create data from map with table, @id != id")
}

Expand All @@ -767,11 +765,11 @@ func TestCreateFromMapWithTable(t *testing.T) {
t.Fatalf("failed to create data from slice of map, got error: %v", err)
}

if _, ok := records[0]["@id"]; !ok {
if _, ok := records[0]["@id"]; !ok && supportLastInsertID {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}

if _, ok := records[1]["@id"]; !ok {
if _, ok := records[1]["@id"]; !ok && supportLastInsertID {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}

Expand All @@ -785,11 +783,11 @@ func TestCreateFromMapWithTable(t *testing.T) {
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
}

if int64(res2["id"].(uint64)) != records[0]["@id"] {
t.Fatal("failed to create data from map with table, @id != id")
if _, ok := records[0]["@id"]; ok && fmt.Sprint(res2["id"]) != fmt.Sprint(records[0]["@id"]) {
t.Errorf("failed to create data from map with table, @id != id, got %v, expect %v", res2["id"], records[0]["@id"])
}

if int64(res3["id"].(uint64)) != records[1]["@id"] {
t.Fatal("failed to create data from map with table, @id != id")
if _, ok := records[1]["id"]; ok && fmt.Sprint(res3["id"]) != fmt.Sprint(records[1]["@id"]) {
t.Errorf("failed to create data from map with table, @id != id")
}
}
2 changes: 1 addition & 1 deletion tests/go.mod
Expand Up @@ -11,7 +11,7 @@ require (
gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.5
gorm.io/driver/sqlserver v1.5.3
gorm.io/gorm v1.25.7
gorm.io/gorm v1.25.8
)

require (
Expand Down
4 changes: 4 additions & 0 deletions tests/helper_test.go
Expand Up @@ -281,6 +281,10 @@ func isMysql() bool {
return os.Getenv("GORM_DIALECT") == "mysql"
}

func isSqlite() bool {
return os.Getenv("GORM_DIALECT") == "sqlite"
}

func db(unscoped bool) *gorm.DB {
if unscoped {
return DB.Unscoped()
Expand Down

0 comments on commit 81536f8

Please sign in to comment.