Skip to content

Commit

Permalink
Replace ErrAuthenticationRequired with AuthenticationRequiredError (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell committed Feb 1, 2024
1 parent 7c50f09 commit 9ae828c
Show file tree
Hide file tree
Showing 20 changed files with 194 additions and 116 deletions.
4 changes: 4 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
### Features Added

### Breaking Changes
> These changes affect only code written against a beta version such as v1.6.0-beta.1
* Replaced `ErrAuthenticationRequired` with `AuthenticationRequiredError`, a struct
type that carries the `TokenRequestOptions` passed to the `GetToken` call which
returned the error.

### Bugs Fixed
* Fixed more cases in which credential chains like `DefaultAzureCredential`
Expand Down
36 changes: 33 additions & 3 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"runtime"
"strings"
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
Expand Down Expand Up @@ -114,6 +115,7 @@ func TestUserAuthentication(t *testing.T) {
name: credNameBrowser,
new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
return NewInteractiveBrowserCredential(&InteractiveBrowserCredentialOptions{
AdditionallyAllowedTenants: []string{"*"},
AuthenticationRecord: ar,
ClientOptions: co,
DisableAutomaticAuthentication: disableAutoAuth,
Expand All @@ -126,6 +128,7 @@ func TestUserAuthentication(t *testing.T) {
name: credNameDeviceCode,
new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
o := DeviceCodeCredentialOptions{
AdditionallyAllowedTenants: []string{"*"},
AuthenticationRecord: ar,
ClientOptions: co,
DisableAutomaticAuthentication: disableAutoAuth,
Expand All @@ -143,6 +146,7 @@ func TestUserAuthentication(t *testing.T) {
name: credNameUserPassword,
new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
opts := UsernamePasswordCredentialOptions{
AdditionallyAllowedTenants: []string{"*"},
AuthenticationRecord: ar,
ClientOptions: co,
TokenCachePersistenceOptions: tcpo,
Expand Down Expand Up @@ -264,8 +268,19 @@ func TestUserAuthentication(t *testing.T) {
t.Run("DisableAutomaticAuthentication/"+credential.name, func(t *testing.T) {
cred, err := credential.new(nil, policy.ClientOptions{Transport: &mockSTS{}}, AuthenticationRecord{}, true)
require.NoError(t, err)
_, err = cred.GetToken(context.Background(), testTRO)
require.ErrorIs(t, err, ErrAuthenticationRequired)
expected := policy.TokenRequestOptions{
Claims: "claims",
EnableCAE: true,
Scopes: []string{"scope"},
TenantID: "tenant",
}
_, err = cred.GetToken(context.Background(), expected)
require.Contains(t, err.Error(), credential.name)
require.Contains(t, err.Error(), "Call Authenticate")
var actual *AuthenticationRequiredError
require.ErrorAs(t, err, &actual)
require.Equal(t, expected, actual.TokenRequestOptions)

if credential.name != credNameBrowser || runManualTests {
_, err = cred.Authenticate(context.Background(), &testTRO)
require.NoError(t, err)
Expand All @@ -274,6 +289,20 @@ func TestUserAuthentication(t *testing.T) {
require.NoError(t, err)
}
})
t.Run("DisableAutomaticAuthentication/ChainedTokenCredential/"+credential.name, func(t *testing.T) {
cred, err := credential.new(nil, policy.ClientOptions{}, AuthenticationRecord{}, true)
require.NoError(t, err)
expected := azcore.AccessToken{ExpiresOn: time.Now().UTC(), Token: tokenValue}
fake := NewFakeCredential()
fake.SetResponse(expected, nil)
chain, err := NewChainedTokenCredential([]azcore.TokenCredential{cred, fake}, nil)
require.NoError(t, err)
// ChainedTokenCredential should continue iterating when a credential returns
// AuthenticationRequiredError i.e., it should call fake.GetToken() and return the expected token
actual, err := chain.GetToken(context.Background(), testTRO)
require.NoError(t, err)
require.Equal(t, expected, actual)
})
}
}
}
Expand Down Expand Up @@ -635,7 +664,8 @@ func TestAdditionallyAllowedTenants(t *testing.T) {
// tenant resolution should have succeeded because the specified tenant is allowed,
// however the credential should have returned a different error because automatic
// authentication is disabled
require.ErrorIs(t, ErrAuthenticationRequired, err)
var e *AuthenticationRequiredError
require.ErrorAs(t, err, &e)
}
})

Expand Down
6 changes: 3 additions & 3 deletions sdk/azidentity/azure_cli_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ func TestAzureCLICredential_DefaultChainError(t *testing.T) {
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
var ue *credentialUnavailableError
if !errors.As(err, &ue) {
t.Fatalf("expected credentialUnavailableError, got %T: %q", err, err)
var cu credentialUnavailable
if !errors.As(err, &cu) {
t.Fatalf("expected %T, got %T: %q", cu, err, err)
}
}

Expand Down
6 changes: 3 additions & 3 deletions sdk/azidentity/azure_developer_cli_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ func TestAzureDeveloperCLICredential_DefaultChainError(t *testing.T) {
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
var ue *credentialUnavailableError
if !errors.As(err, &ue) {
t.Fatalf("expected credentialUnavailableError, got %T: %q", err, err)
var cu credentialUnavailable
if !errors.As(err, &cu) {
t.Fatalf("expected %T, got %T: %q", cu, err, err)
}
}

Expand Down
2 changes: 1 addition & 1 deletion sdk/azidentity/chained_token_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token
errs []error
successfulCredential azcore.TokenCredential
token azcore.AccessToken
unavailableErr *credentialUnavailableError
unavailableErr credentialUnavailable
)
for _, cred := range c.sources {
token, err = cred.GetToken(ctx, opts)
Expand Down
8 changes: 4 additions & 4 deletions sdk/azidentity/chained_token_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ func TestChainedTokenCredential_MultipleCredentialsGetTokenUnavailable(t *testin
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
if _, ok := err.(*credentialUnavailableError); !ok {
t.Fatalf("expected credentialUnavailableError, received %T", err)
if _, ok := err.(credentialUnavailable); !ok {
t.Fatalf("expected credentialUnavailable, received %T", err)
}
expectedError := `ChainedTokenCredential: failed to acquire a token.
Attempted credentials:
Expand Down Expand Up @@ -186,8 +186,8 @@ func TestChainedTokenCredential_MultipleCredentialsGetTokenCustomName(t *testing
}
cred.name = "CustomNameCredential"
_, err = cred.GetToken(context.Background(), testTRO)
if _, ok := err.(*credentialUnavailableError); !ok {
t.Fatalf("expected credentialUnavailableError, received %T", err)
if _, ok := err.(credentialUnavailable); !ok {
t.Fatalf("expected credentialUnavailable, received %T", err)
}
expectedError := `CustomNameCredential: failed to acquire a token.
Attempted credentials:
Expand Down
2 changes: 1 addition & 1 deletion sdk/azidentity/confidential_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (c *confidentialClient) GetToken(ctx context.Context, tro policy.TokenReque
if err != nil {
// We could get a credentialUnavailableError from managed identity authentication because in that case the error comes from our code.
// We return it directly because it affects the behavior of credential chains. Otherwise, we return AuthenticationFailedError.
var unavailableErr *credentialUnavailableError
var unavailableErr credentialUnavailable
if !errors.As(err, &unavailableErr) {
res := getResponseFromError(err)
err = newAuthenticationFailedError(c.name, err.Error(), res, err)
Expand Down
2 changes: 1 addition & 1 deletion sdk/azidentity/default_azure_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ type defaultCredentialErrorReporter struct {
}

func (d *defaultCredentialErrorReporter) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
if _, ok := d.err.(*credentialUnavailableError); ok {
if _, ok := d.err.(credentialUnavailable); ok {
return azcore.AccessToken{}, d.err
}
return azcore.AccessToken{}, newCredentialUnavailableError(d.credType, d.err.Error())
Expand Down
4 changes: 2 additions & 2 deletions sdk/azidentity/default_azure_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ func TestDefaultAzureCredential_timeoutWrapper(t *testing.T) {
for i := 0; i < 2; i++ {
// expecting credentialUnavailableError because delay exceeds the wrapper's timeout
_, err = chain.GetToken(context.Background(), testTRO)
if _, ok := err.(*credentialUnavailableError); !ok {
t.Fatalf("expected credentialUnavailableError, got %T: %v", err, err)
if _, ok := err.(credentialUnavailable); !ok {
t.Fatalf("expected credentialUnavailable, got %T: %v", err, err)
}
}

Expand Down
2 changes: 1 addition & 1 deletion sdk/azidentity/developer_credential_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const cliTimeout = 10 * time.Second
// the next credential in its chain (another developer credential).
func unavailableIfInChain(err error, inDefaultChain bool) error {
if err != nil && inDefaultChain {
var unavailableErr *credentialUnavailableError
var unavailableErr credentialUnavailable
if !errors.As(err, &unavailableErr) {
err = newCredentialUnavailableError(credNameAzureDeveloperCLI, err.Error())
}
Expand Down
4 changes: 2 additions & 2 deletions sdk/azidentity/device_code_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ type DeviceCodeCredentialOptions struct {
ClientID string

// DisableAutomaticAuthentication prevents the credential from automatically prompting the user to authenticate.
// When this option is true, [DeviceCodeCredential.GetToken] will return [ErrAuthenticationRequired] when user
// interaction is necessary to acquire a token.
// When this option is true, GetToken will return AuthenticationRequiredError when user interaction is necessary
// to acquire a token.
DisableAutomaticAuthentication bool

// DisableInstanceDiscovery should be set true only by applications authenticating in disconnected clouds, or
Expand Down
44 changes: 36 additions & 8 deletions sdk/azidentity/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,12 @@ import (
"fmt"
"net/http"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo"
msal "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
)

// ErrAuthenticationRequired indicates a credential's Authenticate method must be called to acquire a token
// because user interaction is required and the credential is configured not to automatically prompt the user.
var ErrAuthenticationRequired error = &credentialUnavailableError{"can't acquire a token without user interaction. Call Authenticate to interactively authenticate a user"}

// getResponseFromError retrieves the response carried by
// an AuthenticationFailedError or MSAL CallErr, if any
func getResponseFromError(err error) *http.Response {
Expand Down Expand Up @@ -110,8 +107,34 @@ func (*AuthenticationFailedError) NonRetriable() {

var _ errorinfo.NonRetriable = (*AuthenticationFailedError)(nil)

// credentialUnavailableError indicates a credential can't attempt authentication because it lacks required
// data or state
// AuthenticationRequiredError indicates a credential's Authenticate method must be called to acquire a token
// because the credential requires user interaction and is configured not to request it automatically.
type AuthenticationRequiredError struct {
credentialUnavailableError

// TokenRequestOptions for the required token. Pass this to the credential's Authenticate method.
TokenRequestOptions policy.TokenRequestOptions
}

func newAuthenticationRequiredError(credType string, tro policy.TokenRequestOptions) error {
return &AuthenticationRequiredError{
credentialUnavailableError: credentialUnavailableError{
credType + " can't acquire a token without user interaction. Call Authenticate to authenticate a user interactively",
},
TokenRequestOptions: tro,
}
}

var (
_ credentialUnavailable = (*AuthenticationRequiredError)(nil)
_ errorinfo.NonRetriable = (*AuthenticationRequiredError)(nil)
)

type credentialUnavailable interface {
error
credentialUnavailable()
}

type credentialUnavailableError struct {
message string
}
Expand All @@ -135,6 +158,11 @@ func (e *credentialUnavailableError) Error() string {
}

// NonRetriable is a marker method indicating this error should not be retried. It has no implementation.
func (e *credentialUnavailableError) NonRetriable() {}
func (*credentialUnavailableError) NonRetriable() {}

func (*credentialUnavailableError) credentialUnavailable() {}

var _ errorinfo.NonRetriable = (*credentialUnavailableError)(nil)
var (
_ credentialUnavailable = (*credentialUnavailableError)(nil)
_ errorinfo.NonRetriable = (*credentialUnavailableError)(nil)
)
32 changes: 0 additions & 32 deletions sdk/azidentity/example_cache_test.go

This file was deleted.

19 changes: 15 additions & 4 deletions sdk/azidentity/example_shared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
)

// Helpers and variables to keep the examples tidy
// Helpers, variables, fakes to keep the examples tidy

const (
certPath = "testdata/certificate.pem"
clientID = "fake-client-id"
tenantID = "fake-tenant"
authRecordPath = "fake/path"
certPath = "testdata/certificate.pem"
clientID = "fake-client-id"
tenantID = "fake-tenant"
)

func handleError(err error) {
Expand All @@ -28,3 +29,13 @@ func handleError(err error) {

var cred azcore.TokenCredential
var err error

type exampleServiceClient struct{}

func newServiceClient(azcore.TokenCredential) (exampleServiceClient, error) {
return exampleServiceClient{}, nil
}

func (exampleServiceClient) Method() error {
return nil
}
38 changes: 38 additions & 0 deletions sdk/azidentity/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,49 @@
package azidentity_test

import (
"context"
"errors"
"os"

"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
)

// Credentials that require user interaction such as [InteractiveBrowserCredential] and [DeviceCodeCredential]
// can optionally return this error instead of automatically prompting for user interaction. This allows applications
// to decide when to request user interaction. This example shows how to handle the error and authenticate a user
// interactively. It shows [InteractiveBrowserCredential] but the same pattern applies to [DeviceCodeCredential].
func ExampleAuthenticationRequiredError() {
cred, err := azidentity.NewInteractiveBrowserCredential(
&azidentity.InteractiveBrowserCredentialOptions{
// This option is useful only for applications that need to control when to prompt users to
// authenticate. If the timing of user interaction isn't important, don't set this option.
DisableAutomaticAuthentication: true,
},
)
if err != nil {
// TODO: handle error
}
// this could be any client that authenticates with an azidentity credential
client, err := newServiceClient(cred)
if err != nil {
// TODO: handle error
}
err = client.Method()
if err != nil {
var are *azidentity.AuthenticationRequiredError
if errors.As(err, &are) {
// The client requested a token and the credential requires user interaction. Whenever it's convenient
// for the application, call Authenticate to prompt the user. Pass the error's TokenRequestOptions to
// request a token with the parameters the client specified.
_, err = cred.Authenticate(context.TODO(), &are.TokenRequestOptions)
if err != nil {
// TODO: handle error
}
// TODO: retry the client method; it should succeed because the credential now has the required token
}
}
}

func ExampleNewOnBehalfOfCredentialWithCertificate() {
data, err := os.ReadFile(certPath)
if err != nil {
Expand Down

0 comments on commit 9ae828c

Please sign in to comment.