Skip to content

Commit 1da9105

Browse files
authoredNov 17, 2024··
feat(fs): Support Vault AWS IAM auth (#2264)
Signed-off-by: Dave Henderson <dhenderson@gmail.com>
1 parent b772227 commit 1da9105

6 files changed

+225
-82
lines changed
 

‎internal/datafs/fsys.go

+2
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ func FSysForPath(ctx context.Context, path string) (fs.FS, error) {
101101
fsys = vaultauth.WithAuthMethod(compositeVaultAuthMethod(fileFsys), fsys)
102102
}
103103

104+
fsys = fsimpl.WithContextFS(ctx, fsys)
105+
104106
return fsys, nil
105107
}
106108

‎internal/datafs/vaultauth.go

+29-6
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,12 @@ func compositeVaultAuthMethod(envFsys fs.FS) api.AuthMethod {
2020
return vaultauth.CompositeAuthMethod(
2121
vaultauth.EnvAuthMethod(),
2222
envEC2AuthAdapter(envFsys),
23+
envIAMAuthAdapter(envFsys),
2324
)
2425
}
2526

26-
// func CompositeVaultAuthMethod() api.AuthMethod {
27-
// return compositeVaultAuthMethod(WrapWdFS(osfs.NewFS()))
28-
// }
29-
3027
// envEC2AuthAdapter builds an AWS EC2 authentication method from environment
31-
// variables, for use only with [CompositeVaultAuthMethod]
28+
// variables, for use only with [compositeVaultAuthMethod]
3229
func envEC2AuthAdapter(envFS fs.FS) api.AuthMethod {
3330
mountPath := GetenvFsys(envFS, "VAULT_AUTH_AWS_MOUNT", "aws")
3431

@@ -61,8 +58,34 @@ func envEC2AuthAdapter(envFS fs.FS) api.AuthMethod {
6158
return &ec2AuthNonceWriter{AWSAuth: awsauth, nonce: nonce, output: output}
6259
}
6360

61+
// envIAMAuthAdapter builds an AWS IAM authentication method from environment
62+
// variables, for use only with [compositeVaultAuthMethod]
63+
func envIAMAuthAdapter(envFS fs.FS) api.AuthMethod {
64+
mountPath := GetenvFsys(envFS, "VAULT_AUTH_AWS_MOUNT", "aws")
65+
role := GetenvFsys(envFS, "VAULT_AUTH_AWS_ROLE")
66+
67+
// temporary workaround while we wait to deprecate AWS_META_ENDPOINT
68+
if endpoint := os.Getenv("AWS_META_ENDPOINT"); endpoint != "" {
69+
deprecated.WarnDeprecated(context.Background(), "Use AWS_EC2_METADATA_SERVICE_ENDPOINT instead of AWS_META_ENDPOINT")
70+
if os.Getenv("AWS_EC2_METADATA_SERVICE_ENDPOINT") == "" {
71+
os.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", endpoint)
72+
}
73+
}
74+
75+
awsauth, err := aws.NewAWSAuth(
76+
aws.WithIAMAuth(),
77+
aws.WithMountPath(mountPath),
78+
aws.WithRole(role),
79+
)
80+
if err != nil {
81+
return nil
82+
}
83+
84+
return awsauth
85+
}
86+
6487
// ec2AuthNonceWriter - wraps an AWSAuth, and writes the nonce to the nonce
65-
// output file
88+
// output file - only for ec2 auth
6689
type ec2AuthNonceWriter struct {
6790
*aws.AWSAuth
6891
nonce string

‎internal/tests/integration/datasources_vault_ec2_test.go

+2-59
Original file line numberDiff line numberDiff line change
@@ -4,71 +4,14 @@
44
package integration
55

66
import (
7-
"encoding/pem"
8-
"io"
9-
"net/http"
10-
"net/http/httptest"
117
"testing"
128

13-
"github.com/stretchr/testify/assert"
149
"github.com/stretchr/testify/require"
15-
"gotest.tools/v3/fs"
1610
)
1711

18-
func setupDatasourcesVaultEc2Test(t *testing.T) (*fs.Dir, *vaultClient, *httptest.Server, []byte) {
19-
t.Helper()
20-
21-
priv, der, _ := certificateGenerate()
22-
cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
23-
24-
mux := http.NewServeMux()
25-
mux.HandleFunc("/latest/dynamic/instance-identity/pkcs7", pkcsHandler(priv, der))
26-
mux.HandleFunc("/latest/dynamic/instance-identity/document", instanceDocumentHandler)
27-
mux.HandleFunc("/latest/api/token", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
28-
var b []byte
29-
if r.Body != nil {
30-
var err error
31-
b, err = io.ReadAll(r.Body)
32-
if !assert.NoError(t, err) {
33-
w.WriteHeader(http.StatusInternalServerError)
34-
return
35-
}
36-
defer r.Body.Close()
37-
}
38-
t.Logf("IMDS Token request: %s %s: %s", r.Method, r.URL, b)
39-
40-
w.Write([]byte("testtoken"))
41-
}))
42-
mux.HandleFunc("/latest/meta-data/instance-id", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
43-
t.Logf("IMDS request: %s %s", r.Method, r.URL)
44-
w.Write([]byte("i-00000000"))
45-
}))
46-
mux.HandleFunc("/sts/", stsHandler)
47-
mux.HandleFunc("/ec2/", ec2Handler)
48-
mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
49-
t.Logf("unhandled request: %s %s", r.Method, r.URL)
50-
w.WriteHeader(http.StatusNotFound)
51-
}))
52-
53-
srv := httptest.NewServer(mux)
54-
t.Cleanup(srv.Close)
55-
56-
tmpDir, v := startVault(t)
57-
58-
err := v.vc.Sys().PutPolicy("writepol", `path "*" {
59-
policy = "write"
60-
}`)
61-
require.NoError(t, err)
62-
err = v.vc.Sys().PutPolicy("readpol", `path "*" {
63-
policy = "read"
64-
}`)
65-
require.NoError(t, err)
66-
67-
return tmpDir, v, srv, cert
68-
}
69-
7012
func TestDatasources_VaultEc2(t *testing.T) {
71-
tmpDir, v, srv, cert := setupDatasourcesVaultEc2Test(t)
13+
accountID, user := "1", "Test"
14+
tmpDir, v, srv, cert := setupDatasourcesVaultAWSTest(t, accountID, user)
7215

7316
v.vc.Logical().Write("secret/foo", map[string]interface{}{"value": "bar"})
7417
defer v.vc.Logical().Delete("secret/foo")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
//go:build !windows
2+
// +build !windows
3+
4+
package integration
5+
6+
import (
7+
"testing"
8+
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestDatasources_VaultIAM(t *testing.T) {
13+
accountID := "000000000000"
14+
user := "foo"
15+
16+
tmpDir, v, srv, _ := setupDatasourcesVaultAWSTest(t, accountID, user)
17+
18+
v.vc.Logical().Write("secret/foo", map[string]interface{}{"value": "bar"})
19+
defer v.vc.Logical().Delete("secret/foo")
20+
21+
err := v.vc.Sys().EnableAuth("aws", "aws", "")
22+
require.NoError(t, err)
23+
defer v.vc.Sys().DisableAuth("aws")
24+
25+
endpoint := srv.URL
26+
27+
accessKeyID := "secret"
28+
secretAccessKey := "access"
29+
30+
_, err = v.vc.Logical().Write("auth/aws/config/client", map[string]interface{}{
31+
"access_key": accessKeyID,
32+
"secret_key": secretAccessKey,
33+
"endpoint": endpoint,
34+
"iam_endpoint": endpoint + "/iam",
35+
"sts_endpoint": endpoint + "/sts",
36+
"sts_region": "us-east-1",
37+
})
38+
require.NoError(t, err)
39+
40+
_, err = v.vc.Logical().Write("auth/aws/role/foo", map[string]interface{}{
41+
"auth_type": "iam",
42+
"bound_iam_principal_arn": "arn:aws:iam::" + accountID + ":*",
43+
"policies": "readpol",
44+
"max_ttl": "5m",
45+
})
46+
require.NoError(t, err)
47+
48+
o, e, err := cmd(t, "-d", "vault=vault:///secret/",
49+
"-i", `{{(ds "vault" "foo").value}}`).
50+
withEnv("HOME", tmpDir.Join("home")).
51+
withEnv("VAULT_ADDR", "http://"+v.addr).
52+
withEnv("AWS_ACCESS_KEY_ID", accessKeyID).
53+
withEnv("AWS_SECRET_ACCESS_KEY", secretAccessKey).
54+
run()
55+
assertSuccess(t, o, e, err, "bar")
56+
}

‎internal/tests/integration/datasources_vault_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func startVault(t *testing.T) (*fs.Dir, *vaultClient) {
6969
"-dev",
7070
"-dev-root-token-id="+vaultRootToken,
7171
"-dev-kv-v1", // default to v1, so we can test v1 and v2
72-
"-log-level=err",
72+
"-log-level=info",
7373
"-dev-listen-address="+vaultAddr,
7474
"-config="+tmpDir.Join("config.json"),
7575
)

‎internal/tests/integration/test_ec2_utils.go ‎internal/tests/integration/test_ec2_utils_test.go

+135-16
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
//go:build !windows
2+
// +build !windows
3+
14
package integration
25

36
import (
@@ -7,12 +10,21 @@ import (
710
"crypto/x509"
811
"crypto/x509/pkix"
912
"encoding/pem"
13+
"fmt"
14+
"io"
1015
"log"
1116
"math/big"
1217
"net/http"
18+
"net/http/httptest"
19+
"net/url"
20+
"strings"
21+
"testing"
1322
"time"
1423

1524
"github.com/fullsailor/pkcs7"
25+
"github.com/stretchr/testify/assert"
26+
"github.com/stretchr/testify/require"
27+
"gotest.tools/v3/fs"
1628
)
1729

1830
const instanceDocument = `{
@@ -106,21 +118,34 @@ func pkcsHandler(priv *rsa.PrivateKey, derBytes []byte) func(http.ResponseWriter
106118
}
107119
}
108120

109-
func stsHandler(w http.ResponseWriter, _ *http.Request) {
110-
w.Header().Set("Content-Type", "text/xml")
111-
_, err := w.Write([]byte(`<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
112-
<GetCallerIdentityResult>
113-
<Arn>arn:aws:iam::1:user/Test</Arn>
114-
<UserId>AKIAI44QH8DHBEXAMPLE</UserId>
115-
<Account>1</Account>
116-
</GetCallerIdentityResult>
117-
<ResponseMetadata>
118-
<RequestId>01234567-89ab-cdef-0123-456789abcdef</RequestId>
119-
</ResponseMetadata>
120-
</GetCallerIdentityResponse>`))
121-
if err != nil {
122-
w.WriteHeader(500)
123-
}
121+
func stsHandler(t *testing.T, accountID, user string) http.Handler {
122+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
123+
body, _ := io.ReadAll(r.Body)
124+
defer r.Body.Close()
125+
126+
form, _ := url.ParseQuery(string(body))
127+
128+
// action must be GetCallerIdentity
129+
assert.Equal(t, "GetCallerIdentity", form.Get("Action"))
130+
131+
w.Header().Set("Content-Type", "text/xml")
132+
_, err := w.Write([]byte(fmt.Sprintf(`<?xml version='1.0' encoding='utf-8'?>
133+
<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
134+
<GetCallerIdentityResult>
135+
<Arn>arn:aws:iam::%[1]s:user/%[2]s</Arn>
136+
<UserId>AKIAI44QH8DHBEXAMPLE</UserId>
137+
<Account>%[1]s</Account>
138+
</GetCallerIdentityResult>
139+
<ResponseMetadata>
140+
<RequestId>01234567-89ab-cdef-0123-456789abcdef</RequestId>
141+
</ResponseMetadata>
142+
</GetCallerIdentityResponse>`, accountID, user)))
143+
if err != nil {
144+
t.Errorf("failed to write response: %s", err)
145+
w.WriteHeader(http.StatusInternalServerError)
146+
}
147+
assert.NoError(t, err)
148+
})
124149
}
125150

126151
func ec2Handler(w http.ResponseWriter, _ *http.Request) {
@@ -246,6 +271,100 @@ func ec2Handler(w http.ResponseWriter, _ *http.Request) {
246271
</reservationSet>
247272
</DescribeInstancesResponse>`))
248273
if err != nil {
249-
w.WriteHeader(500)
274+
w.WriteHeader(http.StatusInternalServerError)
250275
}
251276
}
277+
278+
func iamGetUserHandler(t *testing.T, accountID string) http.Handler {
279+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
280+
body, _ := io.ReadAll(r.Body)
281+
form, _ := url.ParseQuery(string(body))
282+
283+
// action must be GetUser
284+
assert.Equal(t, "GetUser", form.Get("Action"))
285+
286+
w.Header().Set("Content-Type", "text/xml")
287+
_, err := w.Write([]byte(fmt.Sprintf(`<?xml version='1.0' encoding='utf-8'?>
288+
<GetUserResponse xmlns="https://iam.amazonaws.com/doc/2010-05-08/">
289+
<GetUserResult>
290+
<User>
291+
<Path>/</Path>
292+
<UserName>%[1]s</UserName>
293+
<UserId>m3o9qmhhl9dnjlh2fflg</UserId>
294+
<Arn>arn:aws:iam::%[2]s:user/%[1]s</Arn>
295+
<CreateDate>2024-07-21T17:21:27.259000Z</CreateDate>
296+
</User>
297+
</GetUserResult>
298+
<ResponseMetadata>
299+
<RequestId>3d0e2445-64ea-4bfb-9244-30d810773f9e</RequestId>
300+
</ResponseMetadata>
301+
</GetUserResponse>`, form.Get("UserName"), accountID)))
302+
if err != nil {
303+
w.WriteHeader(http.StatusInternalServerError)
304+
}
305+
assert.NoError(t, err)
306+
})
307+
}
308+
309+
func setupDatasourcesVaultAWSTest(t *testing.T, accountID, user string) (*fs.Dir, *vaultClient, *httptest.Server, []byte) {
310+
t.Helper()
311+
312+
priv, der, _ := certificateGenerate()
313+
cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
314+
315+
mux := http.NewServeMux()
316+
mux.HandleFunc("/latest/dynamic/instance-identity/pkcs7", pkcsHandler(priv, der))
317+
mux.HandleFunc("/latest/dynamic/instance-identity/document", instanceDocumentHandler)
318+
mux.HandleFunc("/latest/api/token", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
319+
var b []byte
320+
if r.Body != nil {
321+
var err error
322+
b, err = io.ReadAll(r.Body)
323+
if !assert.NoError(t, err) {
324+
w.WriteHeader(http.StatusInternalServerError)
325+
return
326+
}
327+
defer r.Body.Close()
328+
}
329+
t.Logf("IMDS Token request: %s %s: %s", r.Method, r.URL, b)
330+
331+
w.Write([]byte("testtoken"))
332+
}))
333+
mux.HandleFunc("/latest/meta-data/instance-id", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
334+
t.Logf("IMDS request: %s %s", r.Method, r.URL)
335+
w.Write([]byte("i-00000000"))
336+
}))
337+
mux.Handle("/sts/", stsHandler(t, accountID, user))
338+
mux.Handle("/iam/", iamGetUserHandler(t, accountID))
339+
mux.HandleFunc("/ec2/", ec2Handler)
340+
mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
341+
t.Logf("unhandled request: %s %s", r.Method, r.URL)
342+
w.WriteHeader(http.StatusNotFound)
343+
}))
344+
345+
// Vault sends requests to "/sts///" for some reason, and the ServeMux
346+
// responds by redirecting to "/sts/" which Vault rejects. So we need to
347+
// handle the extra slashes in a middleware first.
348+
stripSlashes := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
349+
for strings.HasSuffix(r.URL.Path, "//") {
350+
r.URL.Path = r.URL.Path[:len(r.URL.Path)-1]
351+
}
352+
mux.ServeHTTP(w, r)
353+
})
354+
355+
srv := httptest.NewServer(stripSlashes)
356+
t.Cleanup(srv.Close)
357+
358+
tmpDir, v := startVault(t)
359+
360+
err := v.vc.Sys().PutPolicy("writepol", `path "*" {
361+
policy = "write"
362+
}`)
363+
require.NoError(t, err)
364+
err = v.vc.Sys().PutPolicy("readpol", `path "*" {
365+
policy = "read"
366+
}`)
367+
require.NoError(t, err)
368+
369+
return tmpDir, v, srv, cert
370+
}

0 commit comments

Comments
 (0)
Please sign in to comment.