Skip to content

Commit

Permalink
Add ScanLocation to pgtype.TimestamptzCodec
Browse files Browse the repository at this point in the history
If ScanLocation is set, it will be used to convert the time to the given
location when scanning from the database.

The Codec interface is now implemented by *pgtype.TimestamptzCodec
instead of pgtype.TimestamptzCodec. This is technically a breaking
change, but it is extremely unlikely that anyone is depending on this,
and if there is downstream breakage it is trivial to fix.

#1195
#1945
  • Loading branch information
jackc committed Mar 16, 2024
1 parent 1b6227a commit a22564d
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pgtype/pgtype_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func initDefaultMap() {
defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}})
defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}})
defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}})
defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}})
defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: &TimestamptzCodec{}})
defaultMap.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}})
defaultMap.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}})
defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}})
Expand Down
39 changes: 25 additions & 14 deletions pgtype/timestamptz.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (tstz *Timestamptz) Scan(src any) error {

switch src := src.(type) {
case string:
return scanPlanTextTimestamptzToTimestamptzScanner{}.Scan([]byte(src), tstz)
return (&scanPlanTextTimestamptzToTimestamptzScanner{}).Scan([]byte(src), tstz)
case time.Time:
*tstz = Timestamptz{Time: src, Valid: true}
return nil
Expand Down Expand Up @@ -124,17 +124,21 @@ func (tstz *Timestamptz) UnmarshalJSON(b []byte) error {
return nil
}

type TimestamptzCodec struct{}
type TimestamptzCodec struct {
// ScanLocation is the location to return scanned timestamptz values in. This does not change the instant in time that
// the timestamptz represents.
ScanLocation *time.Location
}

func (TimestamptzCodec) FormatSupported(format int16) bool {
func (*TimestamptzCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}

func (TimestamptzCodec) PreferredFormat() int16 {
func (*TimestamptzCodec) PreferredFormat() int16 {
return BinaryFormatCode
}

func (TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
func (*TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
if _, ok := value.(TimestamptzValuer); !ok {
return nil
}
Expand Down Expand Up @@ -220,27 +224,27 @@ func (encodePlanTimestamptzCodecText) Encode(value any, buf []byte) (newBuf []by
return buf, nil
}

func (TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
func (c *TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {

switch format {
case BinaryFormatCode:
switch target.(type) {
case TimestamptzScanner:
return scanPlanBinaryTimestamptzToTimestamptzScanner{}
return &scanPlanBinaryTimestamptzToTimestamptzScanner{location: c.ScanLocation}
}
case TextFormatCode:
switch target.(type) {
case TimestamptzScanner:
return scanPlanTextTimestamptzToTimestamptzScanner{}
return &scanPlanTextTimestamptzToTimestamptzScanner{location: c.ScanLocation}
}
}

return nil
}

type scanPlanBinaryTimestamptzToTimestamptzScanner struct{}
type scanPlanBinaryTimestamptzToTimestamptzScanner struct{ location *time.Location }

func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error {
func (plan *scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error {
scanner := (dst).(TimestamptzScanner)

if src == nil {
Expand All @@ -264,15 +268,18 @@ func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) e
microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000,
(microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000),
)
if plan.location != nil {
tim = tim.In(plan.location)
}
tstz = Timestamptz{Time: tim, Valid: true}
}

return scanner.ScanTimestamptz(tstz)
}

type scanPlanTextTimestamptzToTimestamptzScanner struct{}
type scanPlanTextTimestamptzToTimestamptzScanner struct{ location *time.Location }

func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error {
func (plan *scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error {
scanner := (dst).(TimestamptzScanner)

if src == nil {
Expand Down Expand Up @@ -312,13 +319,17 @@ func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) err
tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location())
}

if plan.location != nil {
tim = tim.In(plan.location)
}

tstz = Timestamptz{Time: tim, Valid: true}
}

return scanner.ScanTimestamptz(tstz)
}

func (c TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
func (c *TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
}
Expand All @@ -336,7 +347,7 @@ func (c TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int1
return tstz.Time, nil
}

func (c TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
func (c *TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}
Expand Down
34 changes: 34 additions & 0 deletions pgtype/timestamptz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,40 @@ func TestTimestamptzCodec(t *testing.T) {
})
}

func TestTimestamptzCodecWithLocationUTC(t *testing.T) {
skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)")

connTestRunner := defaultConnTestRunner
connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
conn.TypeMap().RegisterType(&pgtype.Type{
Name: "timestamptz",
OID: pgtype.TimestamptzOID,
Codec: &pgtype.TimestamptzCodec{ScanLocation: time.UTC},
})
}

pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, nil, "timestamptz", []pgxtest.ValueRoundTripTest{
{time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))},
})
}

func TestTimestamptzCodecWithLocationLocal(t *testing.T) {
skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)")

connTestRunner := defaultConnTestRunner
connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
conn.TypeMap().RegisterType(&pgtype.Type{
Name: "timestamptz",
OID: pgtype.TimestamptzOID,
Codec: &pgtype.TimestamptzCodec{ScanLocation: time.Local},
})
}

pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, nil, "timestamptz", []pgxtest.ValueRoundTripTest{
{time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local))},
})
}

// https://github.com/jackc/pgx/v4/pgtype/pull/128
func TestTimestamptzTranscodeBigTimeBinary(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
Expand Down

0 comments on commit a22564d

Please sign in to comment.