Skip to content

Commit

Permalink
Server: add and support unix listener (UDS)
Browse files Browse the repository at this point in the history
  • Loading branch information
shaj13 committed Dec 5, 2022
1 parent 12b2fab commit 68ad0bf
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 1 deletion.
3 changes: 3 additions & 0 deletions changelog/18227.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:feature
**Server UDS Listener**: Adding listener to vault server to serve http request vi unix domain socket
```
3 changes: 2 additions & 1 deletion command/server/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ type ListenerFactory func(*configutil.Listener, io.Writer, cli.Ui) (net.Listener

// BuiltinListeners is the list of built-in listener types.
var BuiltinListeners = map[string]ListenerFactory{
"tcp": tcpListenerFactory,
"tcp": tcpListenerFactory,
"unix": unixListenerFactory,
}

// NewListener creates a new listener of the given type with the given
Expand Down
3 changes: 3 additions & 0 deletions command/server/listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ func testListenerImpl(t *testing.T, ln net.Listener, connFn testListenerConnFn,
tlsConn.Handshake()
}
serverCh <- server
if expectedAddr == "" {
return
}
addr, _, err := net.SplitHostPort(server.RemoteAddr().String())
if err != nil {
t.Error(err)
Expand Down
36 changes: 36 additions & 0 deletions command/server/listener_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package server

import (
"io"
"net"

"github.com/hashicorp/go-secure-stdlib/reloadutil"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/internalshared/listenerutil"
"github.com/mitchellh/cli"
)

func unixListenerFactory(l *configutil.Listener, _ io.Writer, ui cli.Ui) (net.Listener, map[string]string, reloadutil.ReloadFunc, error) {
addr := l.Address
if addr == "" {
addr = "/run/vault.sock"
}

var cfg *listenerutil.UnixSocketsConfig
if l.SocketMode != "" &&
l.SocketUser != "" &&
l.SocketGroup != "" {
cfg = &listenerutil.UnixSocketsConfig{
Mode: l.SocketMode,
User: l.SocketUser,
Group: l.SocketGroup,
}
}

ln, err := listenerutil.UnixSocketListener(addr, cfg)
if err != nil {
return nil, nil, nil, err
}

return ln, map[string]string{}, nil, nil
}
55 changes: 55 additions & 0 deletions command/server/listener_unix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package server

import (
"net"
"os"
"path/filepath"
"testing"

"github.com/hashicorp/vault/internalshared/configutil"
"github.com/mitchellh/cli"
)

func TestUnixListener(t *testing.T) {
ln, _, _, err := unixListenerFactory(&configutil.Listener{
Address: filepath.Join(t.TempDir(), "/vault.sock"),
}, nil, cli.NewMockUi())
if err != nil {
t.Fatalf("err: %s", err)
}

connFn := func(lnReal net.Listener) (net.Conn, error) {
return net.Dial("unix", ln.Addr().String())
}

testListenerImpl(t, ln, connFn, "", 0, "", false)
}

// TestUnixListener_cfg tests unix sockets config.
func TestUnixListener_cfg(t *testing.T) {
path := filepath.Join(t.TempDir(), "/vault.sock")
ln, _, _, err := unixListenerFactory(&configutil.Listener{
Address: path,
SocketMode: "0600",
SocketUser: "100",
SocketGroup: "100",
}, nil, cli.NewMockUi())
if err != nil {
t.Fatalf("err: %s", err)
}

t.Cleanup(func() {
ln.Close()
})

stat, err := os.Stat(path)
if err != nil {
t.Fatalf("err: %s", err)
}

expected := os.FileMode(0o600)
got := stat.Mode().Perm()
if got != expected {
t.Errorf("expected: %s, got: %s", expected, got)
}
}

0 comments on commit 68ad0bf

Please sign in to comment.