Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cut string output parameter fix #168

Merged
merged 2 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion bulkcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func (b *Bulk) createColMetadata() []byte {
}
binary.Write(buf, binary.LittleEndian, uint16(col.Flags))

writeTypeInfo(buf, &b.bulkColumns[i].ti)
writeTypeInfo(buf, &b.bulkColumns[i].ti, false)

if col.ti.TypeId == typeNText ||
col.ti.TypeId == typeText ||
Expand Down
149 changes: 149 additions & 0 deletions queries_go19_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,155 @@ SELECT @param2 = 'World'
})
}

func TestOutputINOUTStringParam(t *testing.T) {
shueybubbles marked this conversation as resolved.
Show resolved Hide resolved
sqltextcreate := `
CREATE PROCEDURE vinout
@sinout NVARCHAR(4000) OUTPUT
AS
BEGIN
IF @sinout = 'empty'
SET @sinout = NULL
ELSE
SET @sinout = 'long_long_value'
END;
`
sqltextdrop := `DROP PROCEDURE vinout;`
sqltextrun := `vinout`

checkConnStr(t)
tl := testLogger{t: t}
defer tl.StopLogging()
SetLogger(&tl)

db, err := sql.Open("sqlserver", makeConnStr(t).String())
if err != nil {
t.Fatalf("failed to open driver sqlserver")
}
defer db.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

db.ExecContext(ctx, sqltextdrop)
_, err = db.ExecContext(ctx, sqltextcreate)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdrop)

t.Run("original test", func(t *testing.T) {
sinout := "short_value"
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("sinout", sql.Out{Dest: &sinout}),
)
if err != nil {
t.Error(err)
}

if sinout != "long_long_value" {
t.Errorf("expected long_long_value, got %s", sinout)
}
})

t.Run("nullable value", func(t *testing.T) {
sinout := sql.NullString{String: "short_value", Valid: true}
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("sinout", sql.Out{Dest: &sinout}),
)
if err != nil {
t.Error(err)
}

if !sinout.Valid || sinout.String != "long_long_value" {
if sinout.Valid {
t.Errorf("expected long_long_value, got %s", sinout.String)
} else {
t.Errorf("expected long_long_value, got NULL")
}
}
})

t.Run("null value", func(t *testing.T) {
sinout := sql.NullString{}
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("sinout", sql.Out{Dest: &sinout}),
)
if err != nil {
t.Error(err)
}

if !sinout.Valid || sinout.String != "long_long_value" {
if sinout.Valid {
t.Errorf("expected long_long_value, got %s", sinout.String)
} else {
t.Errorf("expected long_long_value, got NULL")
}
}
})

t.Run("null result", func(t *testing.T) {
sinout := sql.NullString{String: "empty", Valid: true}
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("sinout", sql.Out{Dest: &sinout}),
)
if err != nil {
t.Error(err)
}

if sinout.Valid {
t.Errorf("expected NULL, got %s", sinout.String)
}
})
}

func TestOutputINOUTBytesParam(t *testing.T) {
sqltextcreate := `
CREATE PROCEDURE vinout
@binout VARBINARY(4000) OUTPUT
AS
BEGIN
SET @binout = CONVERT(VARBINARY(4000), 'long_long_value')
END;
`
sqltextdrop := `DROP PROCEDURE vinout;`
sqltextrun := `vinout`

checkConnStr(t)
tl := testLogger{t: t}
defer tl.StopLogging()
SetLogger(&tl)

db, err := sql.Open("sqlserver", makeConnStr(t).String())
if err != nil {
t.Fatalf("failed to open driver sqlserver")
}
defer db.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

db.ExecContext(ctx, sqltextdrop)
_, err = db.ExecContext(ctx, sqltextcreate)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdrop)

t.Run("original test", func(t *testing.T) {
binout := []byte("short_value")
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("binout", sql.Out{Dest: &binout}),
)
if err != nil {
t.Error(err)
}

if !bytes.Equal(binout, []byte("long_long_value")) {
t.Errorf("expected long_long_value, got %s", string(binout))
}
})
}

func TestOutputINOUTParam(t *testing.T) {
sqltextcreate := `
CREATE PROCEDURE abinout
Expand Down
4 changes: 2 additions & 2 deletions rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16,
if err = binary.Write(buf, binary.LittleEndian, param.Flags); err != nil {
return
}
err = writeTypeInfo(buf, &param.ti)
err = writeTypeInfo(buf, &param.ti, (param.Flags&fByRevValue) != 0)
if err != nil {
return
}
Expand All @@ -82,7 +82,7 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16,
return
}
if (param.Flags & fEncrypted) == fEncrypted {
err = writeTypeInfo(buf, &param.tiOriginal)
err = writeTypeInfo(buf, &param.tiOriginal, false)
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion tvp_go19.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
for i, column := range columnStr {
binary.Write(buf, binary.LittleEndian, column.UserType)
binary.Write(buf, binary.LittleEndian, column.Flags)
writeTypeInfo(buf, &columnStr[i].ti)
writeTypeInfo(buf, &columnStr[i].ti, false)
writeBVarChar(buf, "")
}
// The returned error is always nil
Expand Down
8 changes: 4 additions & 4 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func readTypeInfo(r *tdsBuffer, typeId byte, c *cryptoMetadata) (res typeInfo) {
}

// https://msdn.microsoft.com/en-us/library/dd358284.aspx
func writeTypeInfo(w io.Writer, ti *typeInfo) (err error) {
func writeTypeInfo(w io.Writer, ti *typeInfo, out bool) (err error) {
err = binary.Write(w, binary.LittleEndian, ti.TypeId)
if err != nil {
return
Expand All @@ -162,7 +162,7 @@ func writeTypeInfo(w io.Writer, ti *typeInfo) (err error) {
case typeTvp:
ti.Writer = writeFixedType
default: // all others are VARLENTYPE
err = writeVarLen(w, ti)
err = writeVarLen(w, ti, out)
if err != nil {
return
}
Expand All @@ -176,7 +176,7 @@ func writeFixedType(w io.Writer, ti typeInfo, buf []byte) (err error) {
}

// https://msdn.microsoft.com/en-us/library/dd358341.aspx
func writeVarLen(w io.Writer, ti *typeInfo) (err error) {
func writeVarLen(w io.Writer, ti *typeInfo, out bool) (err error) {
switch ti.TypeId {

case typeDateN:
Expand Down Expand Up @@ -222,7 +222,7 @@ func writeVarLen(w io.Writer, ti *typeInfo) (err error) {
typeNVarChar, typeNChar, typeXml, typeUdt:

// short len types
if ti.Size > 8000 || ti.Size == 0 {
if ti.Size > 8000 || ti.Size == 0 || out {
if err = binary.Write(w, binary.LittleEndian, uint16(0xffff)); err != nil {
return
}
Expand Down