Skip to content

Commit

Permalink
[WIP] authcallout investigation
Browse files Browse the repository at this point in the history
  • Loading branch information
aricart committed Jan 29, 2024
1 parent 5ecf769 commit 1bebf3a
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 15 deletions.
42 changes: 30 additions & 12 deletions server/auth_callout.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ func (s *Server) processClientOrLeafCallout(c *client, opts *Options) (authorize
reply := s.newRespInbox()
respCh := make(chan string, 1)

decodeResponse := func(rc *client, rmsg []byte, acc *Account) (*jwt.UserClaims, error) {
decodeResponse := func(rc *client, rmsg []byte, acc *Account) (string, error) {
account := acc.Name
_, msg := rc.msgParts(rmsg)

// This signals not authorized.
// Since this is an account subscription will always have "\r\n".
if len(msg) <= LEN_CR_LF {
return nil, fmt.Errorf("auth callout violation: %q on account %q", "no reason supplied", account)
return "", fmt.Errorf("auth callout violation: %q on account %q", "no reason supplied", account)
}
// Strip trailing CRLF.
msg = msg[:len(msg)-LEN_CR_LF]
Expand All @@ -95,34 +95,34 @@ func (s *Server) processClientOrLeafCallout(c *client, opts *Options) (authorize
var err error
msg, err = xkp.Open(msg, pubAccXKey)
if err != nil {
return nil, fmt.Errorf("error decrypting auth callout response on account %q: %v", account, err)
return "", fmt.Errorf("error decrypting auth callout response on account %q: %v", account, err)
}
encrypted = true
}

cr, err := jwt.DecodeAuthorizationResponseClaims(string(msg))
if err != nil {
return nil, err
return "", err
}
vr := jwt.CreateValidationResults()
cr.Validate(vr)
if len(vr.Issues) > 0 {
return nil, fmt.Errorf("authorization response had validation errors: %v", vr.Issues[0])
return "", fmt.Errorf("authorization response had validation errors: %v", vr.Issues[0])
}

// the subject is the user id
if cr.Subject != pub {
return nil, errors.New("auth callout violation: auth callout response is not for expected user")
return "", errors.New("auth callout violation: auth callout response is not for expected user")
}

// check the audience to be the server ID
if cr.Audience != s.info.ID {
return nil, errors.New("auth callout violation: auth callout response is not for server")
return "", errors.New("auth callout violation: auth callout response is not for server")
}

// check if had an error message from the auth account
if cr.Error != _EMPTY_ {
return nil, fmt.Errorf("auth callout service returned an error: %v", cr.Error)
return "", fmt.Errorf("auth callout service returned an error: %v", cr.Error)
}

// if response is encrypted none of this is needed
Expand All @@ -133,12 +133,11 @@ func (s *Server) processClientOrLeafCallout(c *client, opts *Options) (authorize
}
if pkStr != account {
if _, ok := acc.signingKeys[pkStr]; !ok {
return nil, errors.New("auth callout signing key is unknown")
return "", errors.New("auth callout signing key is unknown")
}
}
}

return jwt.DecodeUserClaims(cr.Jwt)
return cr.Jwt, nil
}

// getIssuerAccount returns the issuer (as per JWT) - it also asserts that
Expand Down Expand Up @@ -231,6 +230,8 @@ func (s *Server) processClientOrLeafCallout(c *client, opts *Options) (authorize
}
}

fmt.Printf("AFTER climits %+v \n", arc.User.UserPermissionLimits)

return targetAcc, nil
}

Expand All @@ -240,11 +241,18 @@ func (s *Server) processClientOrLeafCallout(c *client, opts *Options) (authorize
return string(append([]rune{unicode.ToUpper(r[0])}, r[1:]...))
}

arc, err := decodeResponse(rc, rmsg, racc)
njwt, err := decodeResponse(rc, rmsg, racc)
if err != nil {
respCh <- titleCase(err.Error())
return
}

arc, err := jwt.DecodeUserClaims(njwt)
if err != nil {
respCh <- fmt.Sprintf("Error decoding user JWT: %v", err)
return
}

vr := jwt.CreateValidationResults()
arc.Validate(vr)
if len(vr.Issues) > 0 {
Expand Down Expand Up @@ -272,6 +280,16 @@ func (s *Server) processClientOrLeafCallout(c *client, opts *Options) (authorize

// Build internal user and bind to the targeted account.
nkuser := buildInternalNkeyUser(arc, allowedConnTypes, targetAcc)
// here we need to put in the received JWT or some of the limits are not
// processed properly
if isOperatorMode {
c.mu.Lock()
fmt.Println(">>>> before", c.opts.JWT)
c.opts.JWT = njwt
fmt.Println(">>>> after", c.opts.JWT)

c.mu.Unlock()
}
if err := c.RegisterNkeyUser(nkuser); err != nil {
respCh <- fmt.Sprintf("Could not register auth callout user: %v", err)
return
Expand Down
170 changes: 167 additions & 3 deletions server/auth_callout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"encoding/pem"
"errors"
"fmt"
"os"
"reflect"
"sort"
"strings"
Expand Down Expand Up @@ -62,12 +63,18 @@ func serviceResponse(t *testing.T, userID string, serverID string, uJwt string,
return []byte(token)
}

func makeScopedRole(t *testing.T, role string, pub []string, sub []string) (jwt.Scope, nkeys.KeyPair) {
func makeScopedRole(t *testing.T, role string, pub []string, sub []string, allowResponses bool) (jwt.Scope, nkeys.KeyPair) {
akp, pk := createKey(t)
r := jwt.NewUserScope()
r.Key = pk
r.Template.Sub.Allow.Add(sub...)
r.Template.Pub.Allow.Add(pub...)
if allowResponses {
r.Template.Resp = &jwt.ResponsePermission{
MaxMsgs: 1,
Expires: time.Second * 3,
}
}
r.Role = role
return r, akp
}
Expand Down Expand Up @@ -131,7 +138,7 @@ func NewAuthTest(t *testing.T, config string, authHandler nats.MsgHandler, clien
a.srv, _ = RunServerWithConfig(a.conf)

var err error
a.authClient = a.Connect(clientOptions...)
a.authClient = a.ConnectCallout(clientOptions...)
_, err = a.authClient.Subscribe(AuthCalloutSubject, authHandler)
require_NoError(t, err)
return a
Expand All @@ -146,6 +153,15 @@ func (at *authTest) NewClient(clientOptions ...nats.Option) (*nats.Conn, error)
return conn, nil
}

func (at *authTest) ConnectCallout(clientOptions ...nats.Option) *nats.Conn {
conn, err := at.NewClient(clientOptions...)
if err != nil {
err = fmt.Errorf("callout client failed: %w", err)
}
require_NoError(at.t, err)
return conn
}

func (at *authTest) Connect(clientOptions ...nats.Option) *nats.Conn {
conn, err := at.NewClient(clientOptions...)
require_NoError(at.t, err)
Expand Down Expand Up @@ -476,6 +492,9 @@ func createBasicAccountUser(t *testing.T, accKp nkeys.KeyPair) (creds string) {
// For these deny all permission
uclaim.Permissions.Pub.Deny.Add(">")
uclaim.Permissions.Sub.Deny.Add(">")

// Uncomment this to set the sub limits
// uclaim.Limits.Subs = 0
vr := jwt.ValidationResults{}
uclaim.Validate(&vr)
require_Len(t, len(vr.Errors()), 0)
Expand Down Expand Up @@ -516,7 +535,7 @@ func TestAuthCalloutOperatorModeBasics(t *testing.T) {
accClaim := jwt.NewAccountClaims(tpub)
accClaim.Name = "TEST"
accClaim.SigningKeys.Add(tSigningPub)
scope, scopedKp := makeScopedRole(t, "foo", []string{"foo.>", "$SYS.REQ.USER.INFO"}, []string{"foo.>", "_INBOX.>"})
scope, scopedKp := makeScopedRole(t, "foo", []string{"foo.>", "$SYS.REQ.USER.INFO"}, []string{"foo.>", "_INBOX.>"}, false)
accClaim.SigningKeys.AddScopedSigner(scope)
accJwt, err := accClaim.Encode(oKp)
require_NoError(t, err)
Expand Down Expand Up @@ -674,6 +693,151 @@ func TestAuthCalloutOperatorModeBasics(t *testing.T) {
require_Equal(t, "foo.>", userInfo.Permissions.Subscribe.Allow[1])
}

func testAuthCalloutScopedUser(t *testing.T, allowAnyAccount bool) {
_, spub := createKey(t)
sysClaim := jwt.NewAccountClaims(spub)
sysClaim.Name = "$SYS"
sysJwt, err := sysClaim.Encode(oKp)
require_NoError(t, err)

// TEST account.
_, tpub := createKey(t)
_, tSigningPub := createKey(t)
accClaim := jwt.NewAccountClaims(tpub)
accClaim.Name = "TEST"
accClaim.SigningKeys.Add(tSigningPub)
scope, scopedKp := makeScopedRole(t, "foo", []string{"foo.>", "$SYS.REQ.USER.INFO"}, []string{"foo.>", "_INBOX.>"}, true)
accClaim.SigningKeys.AddScopedSigner(scope)
accJwt, err := accClaim.Encode(oKp)
require_NoError(t, err)

// AUTH service account.
akp, err := nkeys.FromSeed([]byte(authCalloutIssuerSeed))
require_NoError(t, err)

apub, err := akp.PublicKey()
require_NoError(t, err)

// The authorized user for the service.
upub, creds := createAuthServiceUser(t, akp)
defer removeFile(t, creds)

authClaim := jwt.NewAccountClaims(apub)
authClaim.Name = "AUTH"
authClaim.EnableExternalAuthorization(upub)
if allowAnyAccount {
authClaim.Authorization.AllowedAccounts.Add("*")
} else {
authClaim.Authorization.AllowedAccounts.Add(tpub)
}
authJwt, err := authClaim.Encode(oKp)
require_NoError(t, err)

conf := fmt.Sprintf(`
listen: 127.0.0.1:-1
operator: %s
system_account: %s
resolver: MEM
resolver_preload: {
%s: %s
%s: %s
%s: %s
}
`, ojwt, spub, apub, authJwt, tpub, accJwt, spub, sysJwt)

const scopedToken = "--Scoped--"
handler := func(m *nats.Msg) {
user, si, _, opts, _ := decodeAuthRequest(t, m.Data)
if opts.Token == scopedToken {
// must have no limits set
ujwt := createAuthUser(t, user, "scoped", tpub, tpub, scopedKp, 0, &jwt.UserPermissionLimits{})
m.Respond(serviceResponse(t, user, si.ID, ujwt, "", 0))
} else {
m.Respond(nil)
}
}

ac := NewAuthTest(t, conf, handler, nats.UserCredentials(creds))
defer ac.Cleanup()
resp, err := ac.authClient.Request(userDirectInfoSubj, nil, time.Second)
require_NoError(t, err)
response := ServerAPIResponse{Data: &UserInfo{}}
err = json.Unmarshal(resp.Data, &response)
require_NoError(t, err)

userInfo := response.Data.(*UserInfo)
expected := &UserInfo{
UserID: upub,
Account: apub,
Permissions: &Permissions{
Publish: &SubjectPermission{
Deny: []string{AuthCalloutSubject}, // Will be auto-added since in auth account.
},
Subscribe: &SubjectPermission{},
},
}
if !reflect.DeepEqual(expected, userInfo) {
t.Fatalf("User info did not match expected, expected auto-deny permissions on callout subject")
}

// Bearer token etc..
// This is used by all users, and the customization will be in other connect args.
// This needs to also be bound to the authorization account.
creds = createBasicAccountUser(t, akp)
defer removeFile(t, creds)

// Send the signing key token. This should switch us to the test account, but the user
// is signed with the account signing key

d, err := os.ReadFile(creds)
require_NoError(t, err)

token, err := jwt.ParseDecoratedJWT(d)
require_NoError(t, err)

uc, err := jwt.DecodeUserClaims(token)
require_NoError(t, err)
t.Logf(">>>>>>>>> lim from test creds >>> %+v\n", uc.UserPermissionLimits)

nc := ac.Connect(nats.UserCredentials(creds), nats.Token(scopedToken))
require_NoError(t, err)

resp, err = nc.Request(userDirectInfoSubj, nil, time.Second)
require_NoError(t, err)
response = ServerAPIResponse{Data: &UserInfo{}}
err = json.Unmarshal(resp.Data, &response)
require_NoError(t, err)

userInfo = response.Data.(*UserInfo)
if userInfo.Account != tpub {
t.Fatalf("Expected to be switched to %q, but got %q", tpub, userInfo.Account)
}
require_True(t, len(userInfo.Permissions.Publish.Allow) == 2)
sort.Strings(userInfo.Permissions.Publish.Allow)
require_Equal(t, "foo.>", userInfo.Permissions.Publish.Allow[1])
sort.Strings(userInfo.Permissions.Subscribe.Allow)
require_True(t, len(userInfo.Permissions.Subscribe.Allow) == 2)
require_Equal(t, "foo.>", userInfo.Permissions.Subscribe.Allow[1])

_, err = nc.Subscribe("foo.>", func(msg *nats.Msg) {
require_NoError(t, msg.Respond(nil))
})

m, err := nc.Request("foo.bar", nil, time.Second)
require_NoError(t, err)
require_NotNil(t, m)

nc.Close()
}

func TestAuthCalloutScopedUserAssignedAccount(t *testing.T) {
testAuthCalloutScopedUser(t, false)
}

func TestAuthCalloutScopedUserAllAccount(t *testing.T) {
testAuthCalloutScopedUser(t, true)
}

const (
curveSeed = "SXAAXMRAEP6JWWHNB6IKFL554IE6LZVT6EY5MBRICPILTLOPHAG73I3YX4"
curvePublic = "XAB3NANV3M6N7AHSQP2U5FRWKKUT7EG2ZXXABV4XVXYQRJGM4S2CZGHT"
Expand Down
3 changes: 3 additions & 0 deletions server/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"net/url"
"regexp"
"runtime"
"runtime/debug"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -841,8 +842,10 @@ func (c *client) applyAccountLimits() {
c.msubs = jwt.NoLimit
if c.opts.JWT != _EMPTY_ { // user jwt implies account
if uc, _ := jwt.DecodeUserClaims(c.opts.JWT); uc != nil {
fmt.Printf("climits %+v \njwt: %s\n", uc.Limits, c.opts.JWT)
c.mpay = int32(uc.Limits.Payload)
c.msubs = int32(uc.Limits.Subs)
debug.PrintStack()
if uc.IssuerAccount != _EMPTY_ && uc.IssuerAccount != uc.Issuer {
if scope, ok := c.acc.signingKeys[uc.Issuer]; ok {
if userScope, ok := scope.(*jwt.UserScope); ok {
Expand Down

0 comments on commit 1bebf3a

Please sign in to comment.