Skip to content

Commit

Permalink
Cut string output parameter fix (#168)
Browse files Browse the repository at this point in the history
* Reserve extra space for out parameter + tests.

* Test for []byte parameter.

---------

Co-authored-by: El-76 <anton.ostroumov@gmail.com>
  • Loading branch information
El-76 and DS-AI committed Jan 23, 2024
1 parent 5738a68 commit 5821c4e
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 8 deletions.
2 changes: 1 addition & 1 deletion bulkcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func (b *Bulk) createColMetadata() []byte {
}
binary.Write(buf, binary.LittleEndian, uint16(col.Flags))

writeTypeInfo(buf, &b.bulkColumns[i].ti)
writeTypeInfo(buf, &b.bulkColumns[i].ti, false)

if col.ti.TypeId == typeNText ||
col.ti.TypeId == typeText ||
Expand Down
149 changes: 149 additions & 0 deletions queries_go19_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,155 @@ SELECT @param2 = 'World'
})
}

func TestOutputINOUTStringParam(t *testing.T) {
sqltextcreate := `
CREATE PROCEDURE vinout
@sinout NVARCHAR(4000) OUTPUT
AS
BEGIN
IF @sinout = 'empty'
SET @sinout = NULL
ELSE
SET @sinout = 'long_long_value'
END;
`
sqltextdrop := `DROP PROCEDURE vinout;`
sqltextrun := `vinout`

checkConnStr(t)
tl := testLogger{t: t}
defer tl.StopLogging()
SetLogger(&tl)

db, err := sql.Open("sqlserver", makeConnStr(t).String())
if err != nil {
t.Fatalf("failed to open driver sqlserver")
}
defer db.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

db.ExecContext(ctx, sqltextdrop)
_, err = db.ExecContext(ctx, sqltextcreate)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdrop)

t.Run("original test", func(t *testing.T) {
sinout := "short_value"
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("sinout", sql.Out{Dest: &sinout}),
)
if err != nil {
t.Error(err)
}

if sinout != "long_long_value" {
t.Errorf("expected long_long_value, got %s", sinout)
}
})

t.Run("nullable value", func(t *testing.T) {
sinout := sql.NullString{String: "short_value", Valid: true}
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("sinout", sql.Out{Dest: &sinout}),
)
if err != nil {
t.Error(err)
}

if !sinout.Valid || sinout.String != "long_long_value" {
if sinout.Valid {
t.Errorf("expected long_long_value, got %s", sinout.String)
} else {
t.Errorf("expected long_long_value, got NULL")
}
}
})

t.Run("null value", func(t *testing.T) {
sinout := sql.NullString{}
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("sinout", sql.Out{Dest: &sinout}),
)
if err != nil {
t.Error(err)
}

if !sinout.Valid || sinout.String != "long_long_value" {
if sinout.Valid {
t.Errorf("expected long_long_value, got %s", sinout.String)
} else {
t.Errorf("expected long_long_value, got NULL")
}
}
})

t.Run("null result", func(t *testing.T) {
sinout := sql.NullString{String: "empty", Valid: true}
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("sinout", sql.Out{Dest: &sinout}),
)
if err != nil {
t.Error(err)
}

if sinout.Valid {
t.Errorf("expected NULL, got %s", sinout.String)
}
})
}

func TestOutputINOUTBytesParam(t *testing.T) {
sqltextcreate := `
CREATE PROCEDURE vinout
@binout VARBINARY(4000) OUTPUT
AS
BEGIN
SET @binout = CONVERT(VARBINARY(4000), 'long_long_value')
END;
`
sqltextdrop := `DROP PROCEDURE vinout;`
sqltextrun := `vinout`

checkConnStr(t)
tl := testLogger{t: t}
defer tl.StopLogging()
SetLogger(&tl)

db, err := sql.Open("sqlserver", makeConnStr(t).String())
if err != nil {
t.Fatalf("failed to open driver sqlserver")
}
defer db.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

db.ExecContext(ctx, sqltextdrop)
_, err = db.ExecContext(ctx, sqltextcreate)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdrop)

t.Run("original test", func(t *testing.T) {
binout := []byte("short_value")
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("binout", sql.Out{Dest: &binout}),
)
if err != nil {
t.Error(err)
}

if !bytes.Equal(binout, []byte("long_long_value")) {
t.Errorf("expected long_long_value, got %s", string(binout))
}
})
}

func TestOutputINOUTParam(t *testing.T) {
sqltextcreate := `
CREATE PROCEDURE abinout
Expand Down
4 changes: 2 additions & 2 deletions rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16,
if err = binary.Write(buf, binary.LittleEndian, param.Flags); err != nil {
return
}
err = writeTypeInfo(buf, &param.ti)
err = writeTypeInfo(buf, &param.ti, (param.Flags&fByRevValue) != 0)
if err != nil {
return
}
Expand All @@ -82,7 +82,7 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16,
return
}
if (param.Flags & fEncrypted) == fEncrypted {
err = writeTypeInfo(buf, &param.tiOriginal)
err = writeTypeInfo(buf, &param.tiOriginal, false)
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion tvp_go19.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
for i, column := range columnStr {
binary.Write(buf, binary.LittleEndian, column.UserType)
binary.Write(buf, binary.LittleEndian, column.Flags)
writeTypeInfo(buf, &columnStr[i].ti)
writeTypeInfo(buf, &columnStr[i].ti, false)
writeBVarChar(buf, "")
}
// The returned error is always nil
Expand Down
8 changes: 4 additions & 4 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func readTypeInfo(r *tdsBuffer, typeId byte, c *cryptoMetadata) (res typeInfo) {
}

// https://msdn.microsoft.com/en-us/library/dd358284.aspx
func writeTypeInfo(w io.Writer, ti *typeInfo) (err error) {
func writeTypeInfo(w io.Writer, ti *typeInfo, out bool) (err error) {
err = binary.Write(w, binary.LittleEndian, ti.TypeId)
if err != nil {
return
Expand All @@ -162,7 +162,7 @@ func writeTypeInfo(w io.Writer, ti *typeInfo) (err error) {
case typeTvp:
ti.Writer = writeFixedType
default: // all others are VARLENTYPE
err = writeVarLen(w, ti)
err = writeVarLen(w, ti, out)
if err != nil {
return
}
Expand All @@ -176,7 +176,7 @@ func writeFixedType(w io.Writer, ti typeInfo, buf []byte) (err error) {
}

// https://msdn.microsoft.com/en-us/library/dd358341.aspx
func writeVarLen(w io.Writer, ti *typeInfo) (err error) {
func writeVarLen(w io.Writer, ti *typeInfo, out bool) (err error) {
switch ti.TypeId {

case typeDateN:
Expand Down Expand Up @@ -222,7 +222,7 @@ func writeVarLen(w io.Writer, ti *typeInfo) (err error) {
typeNVarChar, typeNChar, typeXml, typeUdt:

// short len types
if ti.Size > 8000 || ti.Size == 0 {
if ti.Size > 8000 || ti.Size == 0 || out {
if err = binary.Write(w, binary.LittleEndian, uint16(0xffff)); err != nil {
return
}
Expand Down

0 comments on commit 5821c4e

Please sign in to comment.