Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: 'type XXXX int' will print wrong sql to terminal #6917

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 19 additions & 0 deletions logger/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO
// RegEx matches only numeric values
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)

func isNumeric(k reflect.Kind) bool {
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
case reflect.Float32, reflect.Float64:
return true
default:
return false
}
}

// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var (
Expand Down Expand Up @@ -110,6 +123,12 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
convertParams(v, idx)
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
convertParams(reflect.Indirect(rv).Interface(), idx)
} else if isNumeric(rv.Kind()) {
if rv.CanInt() || rv.CanUint() {
vars[idx] = fmt.Sprintf("%d", rv.Interface())
} else {
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
}
} else {
for _, t := range convertibleTypes {
if rv.Type().ConvertibleTo(t) {
Expand Down
30 changes: 23 additions & 7 deletions logger/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,18 @@ func format(v []byte, escaper string) string {
func TestExplainSQL(t *testing.T) {
type role string
type password []byte
type intType int
type floatType float64
var (
tt = now.MustParse("2020-02-23 11:10:10")
myrole = role("admin")
pwd = password("pass")
jsVal = []byte(`{"Name":"test","Val":"test"}`)
js = JSON(jsVal)
esVal = []byte(`{"Name":"test","Val":"test"}`)
es = ExampleStruct{Name: "test", Val: "test"}
tt = now.MustParse("2020-02-23 11:10:10")
myrole = role("admin")
pwd = password("pass")
jsVal = []byte(`{"Name":"test","Val":"test"}`)
js = JSON(jsVal)
esVal = []byte(`{"Name":"test","Val":"test"}`)
es = ExampleStruct{Name: "test", Val: "test"}
intVal intType = 1
floatVal floatType = 1.23
)

results := []struct {
Expand Down Expand Up @@ -107,6 +111,18 @@ func TestExplainSQL(t *testing.T) {
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`,
},
}

for idx, r := range results {
Expand Down