Skip to content

Commit

Permalink
Allow callers to pass go context through to hooks (rs#559)
Browse files Browse the repository at this point in the history
Add Ctx(context.Context) to Event and Context, allowing
log.Info().Ctx(ctx).Msg("hello").  Registered hooks can retrieve the
context from Event.GetCtx().  Facilitates writing hooks which fetch
tracing context from the go context.
  • Loading branch information
danielbprice authored and mAdkins committed Mar 2, 2024
1 parent 28cc6ac commit b3cf0d9
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 20 deletions.
56 changes: 42 additions & 14 deletions README.md
Expand Up @@ -513,25 +513,53 @@ stdlog.Print("hello world")

### context.Context integration

The `Logger` instance could be attached to `context.Context` values with `logger.WithContext(ctx)`
and extracted from it using `zerolog.Ctx(ctx)`.
Go contexts are commonly passed throughout Go code, and this can help you pass
your Logger into places it might otherwise be hard to inject. The `Logger`
instance may be attached to Go context (`context.Context`) using
`Logger.WithContext(ctx)` and extracted from it using `zerolog.Ctx(ctx)`.
For example:

Example to add logger to context:
```go
// this code attach logger instance to context fields
ctx := context.Background()
logger := zerolog.New(os.Stdout)
ctx = logger.WithContext(ctx)
someFunc(ctx)
func f() {
logger := zerolog.New(os.Stdout)
ctx := context.Background()

// Attach the Logger to the context.Context
ctx = logger.WithContext(ctx)
someFunc(ctx)
}

func someFunc(ctx context.Context) {
// Get Logger from the go Context. if it's nil, then
// `zerolog.DefaultContextLogger` is returned, if
// `DefaultContextLogger` is nil, then a disabled logger is returned.
logger := zerolog.Ctx(ctx)
logger.Info().Msg("Hello")
}
```

Extracting logger from context:
A second form of `context.Context` integration allows you to pass the current
context.Context into the logged event, and retrieve it from hooks. This can be
useful to log trace and span IDs or other information stored in the go context,
and facilitates the unification of logging and tracing in some systems:

```go
func someFunc(ctx context.Context) {
// get logger from context. if it's nill, then `zerolog.DefaultContextLogger` is returned,
// if `DefaultContextLogger` is nil, then disabled logger returned.
logger := zerolog.Ctx(ctx)
logger.Info().Msg("Hello")
type TracingHook struct{}

func (h TracingHook) Run(e *zerolog.Event, level zerolog.Level, msg string) {
ctx := e.Ctx()
spanId := getSpanIdFromContext(ctx) // as per your tracing framework
e.Str("span-id", spanId)
}

func f() {
// Setup the logger
logger := zerolog.New(os.Stdout)
logger = logger.Hook(TracingHook{})

ctx := context.Background()
// Use the Ctx function to make the context available to the hook
logger.Info().Ctx(ctx).Msg("Hello")
}
```

Expand Down
9 changes: 9 additions & 0 deletions benchmark_test.go
@@ -1,6 +1,7 @@
package zerolog

import (
"context"
"errors"
"io/ioutil"
"net"
Expand Down Expand Up @@ -160,6 +161,7 @@ func BenchmarkLogFieldType(b *testing.B) {
{"a", "a", 0},
}
errs := []error{errors.New("a"), errors.New("b"), errors.New("c"), errors.New("d"), errors.New("e")}
ctx := context.Background()
types := map[string]func(e *Event) *Event{
"Bool": func(e *Event) *Event {
return e.Bool("k", bools[0])
Expand Down Expand Up @@ -191,6 +193,9 @@ func BenchmarkLogFieldType(b *testing.B) {
"Errs": func(e *Event) *Event {
return e.Errs("k", errs)
},
"Ctx": func(e *Event) *Event {
return e.Ctx(ctx)
},
"Time": func(e *Event) *Event {
return e.Time("k", times[0])
},
Expand Down Expand Up @@ -284,6 +289,7 @@ func BenchmarkContextFieldType(b *testing.B) {
{"a", "a", 0},
}
errs := []error{errors.New("a"), errors.New("b"), errors.New("c"), errors.New("d"), errors.New("e")}
ctx := context.Background()
types := map[string]func(c Context) Context{
"Bool": func(c Context) Context {
return c.Bool("k", bools[0])
Expand Down Expand Up @@ -318,6 +324,9 @@ func BenchmarkContextFieldType(b *testing.B) {
"Errs": func(c Context) Context {
return c.Errs("k", errs)
},
"Ctx": func(c Context) Context {
return c.Ctx(ctx)
},
"Time": func(c Context) Context {
return c.Time("k", times[0])
},
Expand Down
10 changes: 10 additions & 0 deletions context.go
@@ -1,6 +1,7 @@
package zerolog

import (
"context"
"fmt"
"io/ioutil"
"math"
Expand Down Expand Up @@ -165,6 +166,15 @@ func (c Context) Err(err error) Context {
return c.AnErr(ErrorFieldName, err)
}

// Ctx adds the context.Context to the logger context. The context.Context is
// not rendered in the error message, but is made available for hooks to use.
// A typical use case is to extract tracing information from the
// context.Context.
func (c Context) Ctx(ctx context.Context) Context {
c.l.ctx = ctx
return c
}

// Bool adds the field key with val as a bool to the logger context.
func (c Context) Bool(key string, b bool) Context {
c.l.context = enc.AppendBool(enc.AppendKey(c.l.context, key), b)
Expand Down
30 changes: 27 additions & 3 deletions event.go
@@ -1,6 +1,7 @@
package zerolog

import (
"context"
"fmt"
"net"
"os"
Expand All @@ -24,9 +25,10 @@ type Event struct {
w LevelWriter
level Level
done func(msg string)
stack bool // enable error stack trace
ch []Hook // hooks from context
skipFrame int // The number of additional frames to skip when printing the caller.
stack bool // enable error stack trace
ch []Hook // hooks from context
skipFrame int // The number of additional frames to skip when printing the caller.
ctx context.Context // Optional Go context for event
}

func putEvent(e *Event) {
Expand Down Expand Up @@ -417,6 +419,28 @@ func (e *Event) Stack() *Event {
return e
}

// Ctx adds the Go Context to the *Event context. The context is not rendered
// in the output message, but is available to hooks and to Func() calls via the
// GetCtx() accessor. A typical use case is to extract tracing information from
// the Go Ctx.
func (e *Event) Ctx(ctx context.Context) *Event {
if e != nil {
e.ctx = ctx
}
return e
}

// GetCtx retrieves the Go context.Context which is optionally stored in the
// Event. This allows Hooks and functions passed to Func() to retrieve values
// which are stored in the context.Context. This can be useful in tracing,
// where span information is commonly propagated in the context.Context.
func (e *Event) GetCtx() context.Context {
if e == nil || e.ctx == nil {
return context.Background()
}
return e.ctx
}

// Bool adds the field key with val as a bool to the *Event context.
func (e *Event) Bool(key string, b bool) *Event {
if e == nil {
Expand Down
34 changes: 34 additions & 0 deletions hook_test.go
Expand Up @@ -2,10 +2,15 @@ package zerolog

import (
"bytes"
"context"
"io/ioutil"
"testing"
)

type contextKeyType int

var contextKey contextKeyType

var (
levelNameHook = HookFunc(func(e *Event, level Level, msg string) {
levelName := level.String()
Expand All @@ -31,6 +36,12 @@ var (
discardHook = HookFunc(func(e *Event, level Level, message string) {
e.Discard()
})
contextHook = HookFunc(func(e *Event, level Level, message string) {
contextData, ok := e.GetCtx().Value(contextKey).(string)
if ok {
e.Str("context-data", contextData)
}
})
)

func TestHook(t *testing.T) {
Expand Down Expand Up @@ -120,6 +131,29 @@ func TestHook(t *testing.T) {
log = log.Hook(discardHook)
log.Log().Msg("test message")
}},
{"Context/Background", `{"level":"info","message":"test message"}` + "\n", func(log Logger) {
log = log.Hook(contextHook)
log.Info().Ctx(context.Background()).Msg("test message")
}},
{"Context/nil", `{"level":"info","message":"test message"}` + "\n", func(log Logger) {
// passing `nil` where a context is wanted is against
// the rules, but people still do it.
log = log.Hook(contextHook)
log.Info().Ctx(nil).Msg("test message") // nolint
}},
{"Context/valid", `{"level":"info","context-data":"12345abcdef","message":"test message"}` + "\n", func(log Logger) {
ctx := context.Background()
ctx = context.WithValue(ctx, contextKey, "12345abcdef")
log = log.Hook(contextHook)
log.Info().Ctx(ctx).Msg("test message")
}},
{"Context/With/valid", `{"level":"info","context-data":"12345abcdef","message":"test message"}` + "\n", func(log Logger) {
ctx := context.Background()
ctx = context.WithValue(ctx, contextKey, "12345abcdef")
log = log.Hook(contextHook)
log = log.With().Ctx(ctx).Logger()
log.Info().Msg("test message")
}},
{"None", `{"level":"error"}` + "\n", func(log Logger) {
log.Error().Msg("")
}},
Expand Down
6 changes: 4 additions & 2 deletions log.go
Expand Up @@ -82,8 +82,7 @@
// log.Warn().Msg("")
// // Output: {"level":"warn","severity":"warn"}
//
//
// Caveats
// # Caveats
//
// There is no fields deduplication out-of-the-box.
// Using the same key multiple times creates new key in final JSON each time.
Expand All @@ -99,6 +98,7 @@
package zerolog

import (
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -218,6 +218,7 @@ type Logger struct {
context []byte
hooks []Hook
stack bool
ctx context.Context
}

// New creates a root logger with given output writer. If the output writer implements
Expand Down Expand Up @@ -455,6 +456,7 @@ func (l *Logger) newEvent(level Level, done func(string)) *Event {
e := newEvent(l.w, level)
e.done = done
e.ch = l.hooks
e.ctx = l.ctx
if level != NoLevel && LevelFieldName != "" {
e.Str(LevelFieldName, LevelFieldMarshalFunc(level))
}
Expand Down
6 changes: 5 additions & 1 deletion log_test.go
Expand Up @@ -2,6 +2,7 @@ package zerolog

import (
"bytes"
"context"
"errors"
"fmt"
"net"
Expand Down Expand Up @@ -119,7 +120,8 @@ func TestWith(t *testing.T) {
Uint64("uint64", 10).
Float32("float32", 11.101).
Float64("float64", 12.30303).
Time("time", time.Time{})
Time("time", time.Time{}).
Ctx(context.Background())
_, file, line, _ := runtime.Caller(0)
caller := fmt.Sprintf("%s:%d", file, line+3)
log := ctx.Caller().Logger()
Expand Down Expand Up @@ -344,6 +346,7 @@ func TestFields(t *testing.T) {
Dur("dur", 1*time.Second).
Time("time", time.Time{}).
TimeDiff("diff", now, now.Add(-10*time.Second)).
Ctx(context.Background()).
Msg("")
if got, want := decodeIfBinaryToString(out.Bytes()), `{"caller":"`+caller+`","string":"foo","stringer":"127.0.0.1","stringer_nil":null,"bytes":"bar","hex":"12ef","json":{"some":"json"},"cbor":"data:application/cbor;base64,gwGCAgOCBAU=","func":"func_output","error":"some error","bool":true,"int":1,"int8":2,"int16":3,"int32":4,"int64":5,"uint":6,"uint8":7,"uint16":8,"uint32":9,"uint64":10,"IPv4":"192.168.0.100","IPv6":"2001:db8:85a3::8a2e:370:7334","Mac":"00:14:22:01:23:45","Prefix":"192.168.0.100/24","float32":11.1234,"float64":12.321321321,"dur":1000,"time":"0001-01-01T00:00:00Z","diff":10000}`+"\n"; got != want {
t.Errorf("invalid log output:\ngot: %v\nwant: %v", got, want)
Expand Down Expand Up @@ -462,6 +465,7 @@ func TestFieldsDisabled(t *testing.T) {
Dur("dur", 1*time.Second).
Time("time", time.Time{}).
TimeDiff("diff", now, now.Add(-10*time.Second)).
Ctx(context.Background()).
Msg("")
if got, want := decodeIfBinaryToString(out.Bytes()), ""; got != want {
t.Errorf("invalid log output:\ngot: %v\nwant: %v", got, want)
Expand Down

0 comments on commit b3cf0d9

Please sign in to comment.