Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace ErrAuthenticationRequired with AuthenticationRequiredError #22317

Merged
merged 7 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
### Features Added

### Breaking Changes
> These changes affect only code written against a beta version such as v1.6.0-beta.1
* Replaced `ErrAuthenticationRequired` with `AuthenticationRequiredError`, a struct
type that carries the `TokenRequestOptions` passed to the `GetToken` call which
returned the error.

### Bugs Fixed
* Fixed more cases in which credential chains like `DefaultAzureCredential`
Expand Down
36 changes: 33 additions & 3 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"runtime"
"strings"
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
Expand Down Expand Up @@ -114,6 +115,7 @@ func TestUserAuthentication(t *testing.T) {
name: credNameBrowser,
new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
return NewInteractiveBrowserCredential(&InteractiveBrowserCredentialOptions{
AdditionallyAllowedTenants: []string{"*"},
AuthenticationRecord: ar,
ClientOptions: co,
DisableAutomaticAuthentication: disableAutoAuth,
Expand All @@ -126,6 +128,7 @@ func TestUserAuthentication(t *testing.T) {
name: credNameDeviceCode,
new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
o := DeviceCodeCredentialOptions{
AdditionallyAllowedTenants: []string{"*"},
AuthenticationRecord: ar,
ClientOptions: co,
DisableAutomaticAuthentication: disableAutoAuth,
Expand All @@ -143,6 +146,7 @@ func TestUserAuthentication(t *testing.T) {
name: credNameUserPassword,
new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
opts := UsernamePasswordCredentialOptions{
AdditionallyAllowedTenants: []string{"*"},
AuthenticationRecord: ar,
ClientOptions: co,
TokenCachePersistenceOptions: tcpo,
Expand Down Expand Up @@ -264,8 +268,19 @@ func TestUserAuthentication(t *testing.T) {
t.Run("DisableAutomaticAuthentication/"+credential.name, func(t *testing.T) {
cred, err := credential.new(nil, policy.ClientOptions{Transport: &mockSTS{}}, AuthenticationRecord{}, true)
require.NoError(t, err)
_, err = cred.GetToken(context.Background(), testTRO)
require.ErrorIs(t, err, ErrAuthenticationRequired)
expected := policy.TokenRequestOptions{
Claims: "claims",
EnableCAE: true,
Scopes: []string{"scope"},
TenantID: "tenant",
}
_, err = cred.GetToken(context.Background(), expected)
require.Contains(t, err.Error(), credential.name)
require.Contains(t, err.Error(), "Call Authenticate")
var actual *AuthenticationRequiredError
require.ErrorAs(t, err, &actual)
require.Equal(t, expected, actual.TokenRequestOptions)

if credential.name != credNameBrowser || runManualTests {
_, err = cred.Authenticate(context.Background(), &testTRO)
require.NoError(t, err)
Expand All @@ -274,6 +289,20 @@ func TestUserAuthentication(t *testing.T) {
require.NoError(t, err)
}
})
t.Run("DisableAutomaticAuthentication/ChainedTokenCredential/"+credential.name, func(t *testing.T) {
cred, err := credential.new(nil, policy.ClientOptions{}, AuthenticationRecord{}, true)
require.NoError(t, err)
expected := azcore.AccessToken{ExpiresOn: time.Now().UTC(), Token: tokenValue}
fake := NewFakeCredential()
fake.SetResponse(expected, nil)
chain, err := NewChainedTokenCredential([]azcore.TokenCredential{cred, fake}, nil)
require.NoError(t, err)
// ChainedTokenCredential should continue iterating when a credential returns
// AuthenticationRequiredError i.e., it should call fake.GetToken() and return the expected token
actual, err := chain.GetToken(context.Background(), testTRO)
require.NoError(t, err)
require.Equal(t, expected, actual)
})
}
}
}
Expand Down Expand Up @@ -635,7 +664,8 @@ func TestAdditionallyAllowedTenants(t *testing.T) {
// tenant resolution should have succeeded because the specified tenant is allowed,
// however the credential should have returned a different error because automatic
// authentication is disabled
require.ErrorIs(t, ErrAuthenticationRequired, err)
var e *AuthenticationRequiredError
require.ErrorAs(t, err, &e)
}
})

Expand Down
6 changes: 3 additions & 3 deletions sdk/azidentity/azure_cli_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ func TestAzureCLICredential_DefaultChainError(t *testing.T) {
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
var ue *credentialUnavailableError
if !errors.As(err, &ue) {
t.Fatalf("expected credentialUnavailableError, got %T: %q", err, err)
var cu credentialUnavailable
if !errors.As(err, &cu) {
t.Fatalf("expected %T, got %T: %q", cu, err, err)
}
}

Expand Down
6 changes: 3 additions & 3 deletions sdk/azidentity/azure_developer_cli_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ func TestAzureDeveloperCLICredential_DefaultChainError(t *testing.T) {
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
var ue *credentialUnavailableError
if !errors.As(err, &ue) {
t.Fatalf("expected credentialUnavailableError, got %T: %q", err, err)
var cu credentialUnavailable
if !errors.As(err, &cu) {
t.Fatalf("expected %T, got %T: %q", cu, err, err)
}
}

Expand Down
2 changes: 1 addition & 1 deletion sdk/azidentity/chained_token_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token
errs []error
successfulCredential azcore.TokenCredential
token azcore.AccessToken
unavailableErr *credentialUnavailableError
unavailableErr credentialUnavailable
)
for _, cred := range c.sources {
token, err = cred.GetToken(ctx, opts)
Expand Down
8 changes: 4 additions & 4 deletions sdk/azidentity/chained_token_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ func TestChainedTokenCredential_MultipleCredentialsGetTokenUnavailable(t *testin
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
if _, ok := err.(*credentialUnavailableError); !ok {
t.Fatalf("expected credentialUnavailableError, received %T", err)
if _, ok := err.(credentialUnavailable); !ok {
t.Fatalf("expected credentialUnavailable, received %T", err)
}
expectedError := `ChainedTokenCredential: failed to acquire a token.
Attempted credentials:
Expand Down Expand Up @@ -186,8 +186,8 @@ func TestChainedTokenCredential_MultipleCredentialsGetTokenCustomName(t *testing
}
cred.name = "CustomNameCredential"
_, err = cred.GetToken(context.Background(), testTRO)
if _, ok := err.(*credentialUnavailableError); !ok {
t.Fatalf("expected credentialUnavailableError, received %T", err)
if _, ok := err.(credentialUnavailable); !ok {
t.Fatalf("expected credentialUnavailable, received %T", err)
}
expectedError := `CustomNameCredential: failed to acquire a token.
Attempted credentials:
Expand Down
2 changes: 1 addition & 1 deletion sdk/azidentity/confidential_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (c *confidentialClient) GetToken(ctx context.Context, tro policy.TokenReque
if err != nil {
// We could get a credentialUnavailableError from managed identity authentication because in that case the error comes from our code.
// We return it directly because it affects the behavior of credential chains. Otherwise, we return AuthenticationFailedError.
var unavailableErr *credentialUnavailableError
var unavailableErr credentialUnavailable
if !errors.As(err, &unavailableErr) {
res := getResponseFromError(err)
err = newAuthenticationFailedError(c.name, err.Error(), res, err)
Expand Down
2 changes: 1 addition & 1 deletion sdk/azidentity/default_azure_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ type defaultCredentialErrorReporter struct {
}

func (d *defaultCredentialErrorReporter) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
if _, ok := d.err.(*credentialUnavailableError); ok {
if _, ok := d.err.(credentialUnavailable); ok {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @chlowell! Is there a reason to leave the credentialUnavailable type unexported? Trying to implement a credentialErrorReporter wrapper for our custom ChainedTokenCredential but can't due to this line. Thank you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We didn't see a need to export the interface, and we generally prefer to keep public API surfaces as small as possible. What do you need credentialUnavailable for in your code? defaultCredentialErrorReporter doesn't need this line to function correctly.

Copy link
Member

@anlandu anlandu Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

defaultCredentialErrorReporter doesn't need this line to function correctly.

Got it, had figured I could just unconditionally new it but the presence of the additional type assertion check had me questioning myself. I guess it's in the defaultCredentialErrorReporter just to save unnecessary new-s. Thanks!

return azcore.AccessToken{}, d.err
}
return azcore.AccessToken{}, newCredentialUnavailableError(d.credType, d.err.Error())
Expand Down
4 changes: 2 additions & 2 deletions sdk/azidentity/default_azure_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ func TestDefaultAzureCredential_timeoutWrapper(t *testing.T) {
for i := 0; i < 2; i++ {
// expecting credentialUnavailableError because delay exceeds the wrapper's timeout
_, err = chain.GetToken(context.Background(), testTRO)
if _, ok := err.(*credentialUnavailableError); !ok {
t.Fatalf("expected credentialUnavailableError, got %T: %v", err, err)
if _, ok := err.(credentialUnavailable); !ok {
t.Fatalf("expected credentialUnavailable, got %T: %v", err, err)
}
}

Expand Down
2 changes: 1 addition & 1 deletion sdk/azidentity/developer_credential_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const cliTimeout = 10 * time.Second
// the next credential in its chain (another developer credential).
func unavailableIfInChain(err error, inDefaultChain bool) error {
if err != nil && inDefaultChain {
var unavailableErr *credentialUnavailableError
var unavailableErr credentialUnavailable
if !errors.As(err, &unavailableErr) {
err = newCredentialUnavailableError(credNameAzureDeveloperCLI, err.Error())
}
Expand Down
4 changes: 2 additions & 2 deletions sdk/azidentity/device_code_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ type DeviceCodeCredentialOptions struct {
ClientID string

// DisableAutomaticAuthentication prevents the credential from automatically prompting the user to authenticate.
// When this option is true, [DeviceCodeCredential.GetToken] will return [ErrAuthenticationRequired] when user
// interaction is necessary to acquire a token.
// When this option is true, GetToken will return AuthenticationRequiredError when user interaction is necessary
// to acquire a token.
DisableAutomaticAuthentication bool

// DisableInstanceDiscovery should be set true only by applications authenticating in disconnected clouds, or
Expand Down
44 changes: 36 additions & 8 deletions sdk/azidentity/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,12 @@ import (
"fmt"
"net/http"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo"
msal "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
)

// ErrAuthenticationRequired indicates a credential's Authenticate method must be called to acquire a token
// because user interaction is required and the credential is configured not to automatically prompt the user.
var ErrAuthenticationRequired error = &credentialUnavailableError{"can't acquire a token without user interaction. Call Authenticate to interactively authenticate a user"}

// getResponseFromError retrieves the response carried by
// an AuthenticationFailedError or MSAL CallErr, if any
func getResponseFromError(err error) *http.Response {
Expand Down Expand Up @@ -110,8 +107,34 @@ func (*AuthenticationFailedError) NonRetriable() {

var _ errorinfo.NonRetriable = (*AuthenticationFailedError)(nil)

// credentialUnavailableError indicates a credential can't attempt authentication because it lacks required
// data or state
// AuthenticationRequiredError indicates a credential's Authenticate method must be called to acquire a token
// because the credential requires user interaction and is configured not to request it automatically.
type AuthenticationRequiredError struct {
credentialUnavailableError

// TokenRequestOptions for the required token. Pass this to the credential's Authenticate method.
TokenRequestOptions policy.TokenRequestOptions
}

func newAuthenticationRequiredError(credType string, tro policy.TokenRequestOptions) error {
return &AuthenticationRequiredError{
credentialUnavailableError: credentialUnavailableError{
credType + " can't acquire a token without user interaction. Call Authenticate to authenticate a user interactively",
},
TokenRequestOptions: tro,
}
}

var (
_ credentialUnavailable = (*AuthenticationRequiredError)(nil)
_ errorinfo.NonRetriable = (*AuthenticationRequiredError)(nil)
)

type credentialUnavailable interface {
error
credentialUnavailable()
}

type credentialUnavailableError struct {
message string
}
Expand All @@ -135,6 +158,11 @@ func (e *credentialUnavailableError) Error() string {
}

// NonRetriable is a marker method indicating this error should not be retried. It has no implementation.
func (e *credentialUnavailableError) NonRetriable() {}
func (*credentialUnavailableError) NonRetriable() {}

func (*credentialUnavailableError) credentialUnavailable() {}

var _ errorinfo.NonRetriable = (*credentialUnavailableError)(nil)
var (
_ credentialUnavailable = (*credentialUnavailableError)(nil)
_ errorinfo.NonRetriable = (*credentialUnavailableError)(nil)
)
32 changes: 0 additions & 32 deletions sdk/azidentity/example_cache_test.go

This file was deleted.

19 changes: 15 additions & 4 deletions sdk/azidentity/example_shared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
)

// Helpers and variables to keep the examples tidy
// Helpers, variables, fakes to keep the examples tidy

const (
certPath = "testdata/certificate.pem"
clientID = "fake-client-id"
tenantID = "fake-tenant"
authRecordPath = "fake/path"
certPath = "testdata/certificate.pem"
clientID = "fake-client-id"
tenantID = "fake-tenant"
)

func handleError(err error) {
Expand All @@ -28,3 +29,13 @@ func handleError(err error) {

var cred azcore.TokenCredential
var err error

type exampleServiceClient struct{}

func newServiceClient(azcore.TokenCredential) (exampleServiceClient, error) {
return exampleServiceClient{}, nil
}

func (exampleServiceClient) Method() error {
return nil
}
38 changes: 38 additions & 0 deletions sdk/azidentity/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,49 @@
package azidentity_test

import (
"context"
"errors"
"os"

"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
)

// Credentials that require user interaction such as [InteractiveBrowserCredential] and [DeviceCodeCredential]
// can optionally return this error instead of automatically prompting for user interaction. This allows applications
// to decide when to request user interaction. This example shows how to handle the error and authenticate a user
// interactively. It shows [InteractiveBrowserCredential] but the same pattern applies to [DeviceCodeCredential].
func ExampleAuthenticationRequiredError() {
cred, err := azidentity.NewInteractiveBrowserCredential(
&azidentity.InteractiveBrowserCredentialOptions{
// This option is useful only for applications that need to control when to prompt users to
// authenticate. If the timing of user interaction isn't important, don't set this option.
DisableAutomaticAuthentication: true,
},
)
if err != nil {
// TODO: handle error
}
// this could be any client that authenticates with an azidentity credential
client, err := newServiceClient(cred)
if err != nil {
// TODO: handle error
}
err = client.Method()
if err != nil {
var are *azidentity.AuthenticationRequiredError
if errors.As(err, &are) {
// The client requested a token and the credential requires user interaction. Whenever it's convenient
// for the application, call Authenticate to prompt the user. Pass the error's TokenRequestOptions to
// request a token with the parameters the client specified.
_, err = cred.Authenticate(context.TODO(), &are.TokenRequestOptions)
if err != nil {
// TODO: handle error
}
// TODO: retry the client method; it should succeed because the credential now has the required token
}
}
}

func ExampleNewOnBehalfOfCredentialWithCertificate() {
data, err := os.ReadFile(certPath)
if err != nil {
Expand Down