Skip to content

Commit

Permalink
Added GCP cloud type for OAuth (#189)
Browse files Browse the repository at this point in the history
Added the GCP cloud type (i.e. domain .gcp.databricks.com) to OAuth
implementation.

Signed-off-by: Raymond Cypher <raymond.cypher@databricks.com>
  • Loading branch information
rcypher-databricks committed Feb 20, 2024
1 parent 5adddfc commit d70ab7c
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 0 deletions.
12 changes: 12 additions & 0 deletions auth/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,16 @@ var databricksAzureDomains []string = []string{
".databricks.azure.us",
}

var databricksGCPDomains []string = []string{
".gcp.databricks.com",
}

type CloudType int

const (
AWS = iota
Azure
GCP
Unknown
)

Expand All @@ -100,6 +105,8 @@ func (cl CloudType) String() string {
return "AWS"
case Azure:
return "Azure"
case GCP:
return "GCP"
}

return "Unknown"
Expand All @@ -119,5 +126,10 @@ func InferCloudFromHost(hostname string) CloudType {
}
}

for _, d := range databricksGCPDomains {
if strings.Contains(hostname, d) {
return GCP
}
}
return Unknown
}
6 changes: 6 additions & 0 deletions auth/oauth/u2m/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ const (

awsClientId = "databricks-sql-connector"
awsRedirectURL = "localhost:8030"

gcpClientId = "databricks-sql-connector"
gcpRedirectURL = "localhost:8030"
)

func NewAuthenticator(hostName string, timeout time.Duration) (auth.Authenticator, error) {
Expand All @@ -43,6 +46,9 @@ func NewAuthenticator(hostName string, timeout time.Duration) (auth.Authenticato
} else if cloud == oauth.Azure {
clientID = azureClientId
redirectURL = azureRedirectURL
} else if cloud == oauth.GCP {
clientID = gcpClientId
redirectURL = gcpRedirectURL
} else {
return nil, errors.New("unhandled cloud type: " + cloud.String())
}
Expand Down
55 changes: 55 additions & 0 deletions examples/oauth/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@ import (
"time"

dbsql "github.com/databricks/databricks-sql-go"
"github.com/databricks/databricks-sql-go/auth/oauth/m2m"
"github.com/databricks/databricks-sql-go/auth/oauth/u2m"
"github.com/joho/godotenv"
)

func main() {
testU2M()
testM2M()
}

func testU2M() {
err := godotenv.Load()

if err != nil {
Expand Down Expand Up @@ -62,3 +68,52 @@ func main() {
}
fmt.Println(res)
}

func testM2M() {
err := godotenv.Load()

if err != nil {
log.Fatal(err.Error())
}

clientID := os.Getenv("DATABRICKS_CLIENT_ID")
clientSecret := os.Getenv("DATABRICKS_CLIENT_SECRET")
host := os.Getenv("DATABRICKS_HOST")
authenticator := m2m.NewAuthenticator(clientID, clientSecret, host)

connector, err := dbsql.NewConnector(
dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")),
dbsql.WithAuthenticator(authenticator),
)
if err != nil {
log.Fatal(err)
}

db := sql.OpenDB(connector)
defer db.Close()

// Pinging should require logging in
if err := db.Ping(); err != nil {
fmt.Println(err)
}

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

var res int

// Running query should not require logging in as we should have a token
// from when ping was called.
err1 := db.QueryRowContext(ctx, `select 1`).Scan(&res)

if err1 != nil {
if err1 == sql.ErrNoRows {
fmt.Println("not found")
return
} else {
fmt.Printf("err: %v\n", err1)
}
}
fmt.Println(res)
}

0 comments on commit d70ab7c

Please sign in to comment.