1
+ //go:build !windows
2
+ // +build !windows
3
+
1
4
package integration
2
5
3
6
import (
@@ -7,12 +10,21 @@ import (
7
10
"crypto/x509"
8
11
"crypto/x509/pkix"
9
12
"encoding/pem"
13
+ "fmt"
14
+ "io"
10
15
"log"
11
16
"math/big"
12
17
"net/http"
18
+ "net/http/httptest"
19
+ "net/url"
20
+ "strings"
21
+ "testing"
13
22
"time"
14
23
15
24
"github.com/fullsailor/pkcs7"
25
+ "github.com/stretchr/testify/assert"
26
+ "github.com/stretchr/testify/require"
27
+ "gotest.tools/v3/fs"
16
28
)
17
29
18
30
const instanceDocument = `{
@@ -106,21 +118,34 @@ func pkcsHandler(priv *rsa.PrivateKey, derBytes []byte) func(http.ResponseWriter
106
118
}
107
119
}
108
120
109
- func stsHandler (w http.ResponseWriter , _ * http.Request ) {
110
- w .Header ().Set ("Content-Type" , "text/xml" )
111
- _ , err := w .Write ([]byte (`<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
112
- <GetCallerIdentityResult>
113
- <Arn>arn:aws:iam::1:user/Test</Arn>
114
- <UserId>AKIAI44QH8DHBEXAMPLE</UserId>
115
- <Account>1</Account>
116
- </GetCallerIdentityResult>
117
- <ResponseMetadata>
118
- <RequestId>01234567-89ab-cdef-0123-456789abcdef</RequestId>
119
- </ResponseMetadata>
120
- </GetCallerIdentityResponse>` ))
121
- if err != nil {
122
- w .WriteHeader (500 )
123
- }
121
+ func stsHandler (t * testing.T , accountID , user string ) http.Handler {
122
+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
123
+ body , _ := io .ReadAll (r .Body )
124
+ defer r .Body .Close ()
125
+
126
+ form , _ := url .ParseQuery (string (body ))
127
+
128
+ // action must be GetCallerIdentity
129
+ assert .Equal (t , "GetCallerIdentity" , form .Get ("Action" ))
130
+
131
+ w .Header ().Set ("Content-Type" , "text/xml" )
132
+ _ , err := w .Write ([]byte (fmt .Sprintf (`<?xml version='1.0' encoding='utf-8'?>
133
+ <GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
134
+ <GetCallerIdentityResult>
135
+ <Arn>arn:aws:iam::%[1]s:user/%[2]s</Arn>
136
+ <UserId>AKIAI44QH8DHBEXAMPLE</UserId>
137
+ <Account>%[1]s</Account>
138
+ </GetCallerIdentityResult>
139
+ <ResponseMetadata>
140
+ <RequestId>01234567-89ab-cdef-0123-456789abcdef</RequestId>
141
+ </ResponseMetadata>
142
+ </GetCallerIdentityResponse>` , accountID , user )))
143
+ if err != nil {
144
+ t .Errorf ("failed to write response: %s" , err )
145
+ w .WriteHeader (http .StatusInternalServerError )
146
+ }
147
+ assert .NoError (t , err )
148
+ })
124
149
}
125
150
126
151
func ec2Handler (w http.ResponseWriter , _ * http.Request ) {
@@ -246,6 +271,100 @@ func ec2Handler(w http.ResponseWriter, _ *http.Request) {
246
271
</reservationSet>
247
272
</DescribeInstancesResponse>` ))
248
273
if err != nil {
249
- w .WriteHeader (500 )
274
+ w .WriteHeader (http . StatusInternalServerError )
250
275
}
251
276
}
277
+
278
+ func iamGetUserHandler (t * testing.T , accountID string ) http.Handler {
279
+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
280
+ body , _ := io .ReadAll (r .Body )
281
+ form , _ := url .ParseQuery (string (body ))
282
+
283
+ // action must be GetUser
284
+ assert .Equal (t , "GetUser" , form .Get ("Action" ))
285
+
286
+ w .Header ().Set ("Content-Type" , "text/xml" )
287
+ _ , err := w .Write ([]byte (fmt .Sprintf (`<?xml version='1.0' encoding='utf-8'?>
288
+ <GetUserResponse xmlns="https://iam.amazonaws.com/doc/2010-05-08/">
289
+ <GetUserResult>
290
+ <User>
291
+ <Path>/</Path>
292
+ <UserName>%[1]s</UserName>
293
+ <UserId>m3o9qmhhl9dnjlh2fflg</UserId>
294
+ <Arn>arn:aws:iam::%[2]s:user/%[1]s</Arn>
295
+ <CreateDate>2024-07-21T17:21:27.259000Z</CreateDate>
296
+ </User>
297
+ </GetUserResult>
298
+ <ResponseMetadata>
299
+ <RequestId>3d0e2445-64ea-4bfb-9244-30d810773f9e</RequestId>
300
+ </ResponseMetadata>
301
+ </GetUserResponse>` , form .Get ("UserName" ), accountID )))
302
+ if err != nil {
303
+ w .WriteHeader (http .StatusInternalServerError )
304
+ }
305
+ assert .NoError (t , err )
306
+ })
307
+ }
308
+
309
+ func setupDatasourcesVaultAWSTest (t * testing.T , accountID , user string ) (* fs.Dir , * vaultClient , * httptest.Server , []byte ) {
310
+ t .Helper ()
311
+
312
+ priv , der , _ := certificateGenerate ()
313
+ cert := pem .EncodeToMemory (& pem.Block {Type : "CERTIFICATE" , Bytes : der })
314
+
315
+ mux := http .NewServeMux ()
316
+ mux .HandleFunc ("/latest/dynamic/instance-identity/pkcs7" , pkcsHandler (priv , der ))
317
+ mux .HandleFunc ("/latest/dynamic/instance-identity/document" , instanceDocumentHandler )
318
+ mux .HandleFunc ("/latest/api/token" , http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
319
+ var b []byte
320
+ if r .Body != nil {
321
+ var err error
322
+ b , err = io .ReadAll (r .Body )
323
+ if ! assert .NoError (t , err ) {
324
+ w .WriteHeader (http .StatusInternalServerError )
325
+ return
326
+ }
327
+ defer r .Body .Close ()
328
+ }
329
+ t .Logf ("IMDS Token request: %s %s: %s" , r .Method , r .URL , b )
330
+
331
+ w .Write ([]byte ("testtoken" ))
332
+ }))
333
+ mux .HandleFunc ("/latest/meta-data/instance-id" , http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
334
+ t .Logf ("IMDS request: %s %s" , r .Method , r .URL )
335
+ w .Write ([]byte ("i-00000000" ))
336
+ }))
337
+ mux .Handle ("/sts/" , stsHandler (t , accountID , user ))
338
+ mux .Handle ("/iam/" , iamGetUserHandler (t , accountID ))
339
+ mux .HandleFunc ("/ec2/" , ec2Handler )
340
+ mux .HandleFunc ("/" , http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
341
+ t .Logf ("unhandled request: %s %s" , r .Method , r .URL )
342
+ w .WriteHeader (http .StatusNotFound )
343
+ }))
344
+
345
+ // Vault sends requests to "/sts///" for some reason, and the ServeMux
346
+ // responds by redirecting to "/sts/" which Vault rejects. So we need to
347
+ // handle the extra slashes in a middleware first.
348
+ stripSlashes := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
349
+ for strings .HasSuffix (r .URL .Path , "//" ) {
350
+ r .URL .Path = r .URL .Path [:len (r .URL .Path )- 1 ]
351
+ }
352
+ mux .ServeHTTP (w , r )
353
+ })
354
+
355
+ srv := httptest .NewServer (stripSlashes )
356
+ t .Cleanup (srv .Close )
357
+
358
+ tmpDir , v := startVault (t )
359
+
360
+ err := v .vc .Sys ().PutPolicy ("writepol" , `path "*" {
361
+ policy = "write"
362
+ }` )
363
+ require .NoError (t , err )
364
+ err = v .vc .Sys ().PutPolicy ("readpol" , `path "*" {
365
+ policy = "read"
366
+ }` )
367
+ require .NoError (t , err )
368
+
369
+ return tmpDir , v , srv , cert
370
+ }
0 commit comments