diff --git a/github/ApplicationOAuth.py b/github/ApplicationOAuth.py index b846361648..1b6bade41a 100644 --- a/github/ApplicationOAuth.py +++ b/github/ApplicationOAuth.py @@ -32,6 +32,7 @@ import github.AccessToken import github.Auth +from github.Consts import DEFAULT_BASE_URL, DEFAULT_OAUTH_URL from github.GithubException import BadCredentialsException, GithubException from github.GithubObject import Attribute, NonCompletableGithubObject, NotSet from github.Requester import Requester @@ -82,6 +83,16 @@ def _useAttributes(self, attributes: dict[str, Any]) -> None: if "client_secret" in attributes: # pragma no branch self._client_secret = self._makeStringAttribute(attributes["client_secret"]) + def get_oauth_url(self, path: str) -> str: + if not path.startswith("/"): + path = f"/{path}" + + if self._requester.base_url == DEFAULT_BASE_URL: + base_url = DEFAULT_OAUTH_URL + else: + base_url = f"{self._requester.scheme}://{self._requester.hostname_and_port}/login/oauth" + return f"{base_url}{path}" + def get_login_url( self, redirect_uri: str | None = None, @@ -104,8 +115,7 @@ def get_login_url( query = urllib.parse.urlencode(parameters) - base_url = "https://github.com/login/oauth/authorize" - return f"{base_url}?{query}" + return self.get_oauth_url(f"/authorize?{query}") def get_access_token(self, code: str, state: str | None = None) -> AccessToken: """ @@ -124,7 +134,7 @@ def get_access_token(self, code: str, state: str | None = None) -> AccessToken: headers, data = self._checkError( *self._requester.requestJsonAndCheck( "POST", - "https://github.com/login/oauth/access_token", + self.get_oauth_url("/access_token"), headers={"Accept": "application/json"}, input=post_parameters, ) @@ -165,7 +175,7 @@ def refresh_access_token(self, refresh_token: str) -> AccessToken: headers, data = self._checkError( *self._requester.requestJsonAndCheck( "POST", - "https://github.com/login/oauth/access_token", + self.get_oauth_url("/access_token"), headers={"Accept": "application/json"}, input=post_parameters, ) diff --git a/github/Consts.py b/github/Consts.py index dcc0a08851..503d6fecd4 100644 --- a/github/Consts.py +++ b/github/Consts.py @@ -154,6 +154,7 @@ repoVisibilityPreview = "application/vnd.github.nebula-preview+json" DEFAULT_BASE_URL = "https://api.github.com" +DEFAULT_OAUTH_URL = "https://github.com/login/oauth" DEFAULT_STATUS_URL = "https://status.github.com" DEFAULT_USER_AGENT = "PyGithub/Python" # As of 2018-05-17, Github imposes a 10s limit for completion of API requests. diff --git a/github/Requester.py b/github/Requester.py index baa912e8b5..504d396ea1 100644 --- a/github/Requester.py +++ b/github/Requester.py @@ -506,10 +506,20 @@ def base_url(self) -> str: def graphql_url(self) -> str: return self.__graphql_url + @property + def scheme(self) -> str: + return self.__scheme + @property def hostname(self) -> str: return self.__hostname + @property + def hostname_and_port(self) -> str: + if self.__port is None: + return self.hostname + return f"{self.hostname}:{self.__port}" + @property def auth(self) -> Optional["Auth"]: return self.__auth @@ -917,7 +927,7 @@ def __makeAbsoluteUrl(self, url: str) -> str: "status.github.com", "github.com", ], o.hostname - assert o.path.startswith((self.__prefix, self.__graphql_prefix, "/api/")), o.path + assert o.path.startswith((self.__prefix, self.__graphql_prefix, "/api/", "/login/oauth")), o.path assert o.port == self.__port, o.port url = o.path if o.query != "": diff --git a/tests/ApplicationOAuth.py b/tests/ApplicationOAuth.py index 108264dc0e..9154d913b9 100644 --- a/tests/ApplicationOAuth.py +++ b/tests/ApplicationOAuth.py @@ -49,6 +49,8 @@ def setUp(self): self.CLIENT_ID = "client_id_removed" self.CLIENT_SECRET = "client_secret_removed" self.app = self.g.get_oauth_application(self.CLIENT_ID, self.CLIENT_SECRET) + self.ent_gh = github.Github(base_url="http://my.enterprise.com/path/to/github") + self.ent_app = self.ent_gh.get_oauth_application(self.CLIENT_ID, self.CLIENT_SECRET) def testLoginURL(self): BASE_URL = "https://github.com/login/oauth/authorize" @@ -78,6 +80,45 @@ def testGetAccessToken(self): self.assertIsNone(access_token.refresh_expires_in) self.assertIsNone(access_token.refresh_expires_at) + def testEnterpriseSupport(self): + requester = self.ent_gh._Github__requester + self.assertEqual(requester.scheme, "http") + self.assertEqual(requester.hostname, "my.enterprise.com") + self.assertEqual(requester.hostname_and_port, "my.enterprise.com") + self.assertEqual(self.ent_app.get_oauth_url("auth"), "http://my.enterprise.com/login/oauth/auth") + gh_w_port = github.Github( + base_url="http://my.enterprise.com:443/path/to/github" + )._Github__requester.hostname_and_port + self.assertEqual(gh_w_port, "my.enterprise.com:443") + + def testEnterpriseLoginURL(self): + BASE_URL = "http://my.enterprise.com/login/oauth/authorize" + sample_uri = "https://myapp.com/some/path" + sample_uri_encoded = "https%3A%2F%2Fmyapp.com%2Fsome%2Fpath" + self.assertEqual(self.ent_app.get_login_url(), f"{BASE_URL}?client_id={self.CLIENT_ID}") + self.assertTrue(f"redirect_uri={sample_uri_encoded}" in self.ent_app.get_login_url(redirect_uri=sample_uri)) + self.assertTrue(f"client_id={self.CLIENT_ID}" in self.ent_app.get_login_url(redirect_uri=sample_uri)) + self.assertTrue("state=123abc" in self.ent_app.get_login_url(state="123abc", login="user")) + self.assertTrue("login=user" in self.ent_app.get_login_url(state="123abc", login="user")) + self.assertTrue(f"client_id={self.CLIENT_ID}" in self.ent_app.get_login_url(state="123abc", login="user")) + + def testEnterpriseGetAccessToken(self): + access_token = self.ent_app.get_access_token("oauth_code_removed", state="state_removed") + # Test string representation + self.assertEqual( + str(access_token), + 'AccessToken(type="bearer", token="acces...", scope="", ' + "refresh_token_expires_in=None, refresh_token=None, expires_in=None)", + ) + self.assertEqual(access_token.token, "access_token_removed") + self.assertEqual(access_token.type, "bearer") + self.assertEqual(access_token.scope, "") + self.assertIsNone(access_token.expires_in) + self.assertIsNone(access_token.expires_at) + self.assertIsNone(access_token.refresh_token) + self.assertIsNone(access_token.refresh_expires_in) + self.assertIsNone(access_token.refresh_expires_at) + def testGetAccessTokenWithExpiry(self): with mock.patch("github.AccessToken.datetime") as dt: dt.now = mock.Mock(return_value=datetime(2023, 6, 7, 12, 0, 0, 123, tzinfo=timezone.utc)) diff --git a/tests/ReplayData/ApplicationOAuth.testEnterpriseGetAccessToken.txt b/tests/ReplayData/ApplicationOAuth.testEnterpriseGetAccessToken.txt new file mode 100644 index 0000000000..f1ade32a71 --- /dev/null +++ b/tests/ReplayData/ApplicationOAuth.testEnterpriseGetAccessToken.txt @@ -0,0 +1,10 @@ +http +POST +my.enterprise.com +None +/login/oauth/access_token +{'Content-Type': 'application/json', 'Accept': 'application/json', 'User-Agent': 'PyGithub/Python'} +{"client_secret": "client_secret_removed", "code": "oauth_code_removed", "client_id": "client_id_removed", "state": "state_removed"} +200 +[('Date', 'Fri, 25 Jan 2019 11:06:39 GMT'), ('Content-Type', 'application/json; charset=utf-8'), ('Transfer-Encoding', 'chunked'), ('Server', 'GitHub.com'), ('Status', '200 OK'), ('Vary', 'X-PJAX, Accept-Encoding'), ('ETag', 'W/"deebfe47f0039427b39ec010749014f6"'), ('Cache-Control', 'max-age=0, private, must-revalidate'), ('Set-Cookie', 'has_recent_activity=1; path=/; expires=Fri, 25 Jan 2019 12:06:38 -0000, ignored_unsupported_browser_notice=false; path=/'), ('X-Request-Id', 'ed8794eb-dc95-481f-8e52-2cd5db0494a0'), ('Strict-Transport-Security', 'max-age=31536000; includeSubdomains; preload'), ('X-Frame-Options', 'deny'), ('X-Content-Type-Options', 'nosniff'), ('X-XSS-Protection', '1; mode=block'), ('Referrer-Policy', 'origin-when-cross-origin, strict-origin-when-cross-origin'), ('Expect-CT', 'max-age=2592000, report-uri="https://api.github.com/_private/browser/errors"'), ('Content-Security-Policy', "default-src 'none'; base-uri 'self'; block-all-mixed-content; connect-src 'self' uploads.github.com www.githubstatus.com collector.githubapp.com api.github.com www.google-analytics.com github-cloud.s3.amazonaws.com github-production-repository-file-5c1aeb.s3.amazonaws.com github-production-upload-manifest-file-7fdce7.s3.amazonaws.com github-production-user-asset-6210df.s3.amazonaws.com wss://live.github.com; font-src github.githubassets.com; form-action 'self' github.com gist.github.com; frame-ancestors 'none'; frame-src render.githubusercontent.com; img-src 'self' data: github.githubassets.com assets-cdn.github.com identicons.github.com collector.githubapp.com github-cloud.s3.amazonaws.com *.githubusercontent.com; manifest-src 'self'; media-src 'none'; script-src github.githubassets.com; style-src 'unsafe-inline' github.githubassets.com"), ('Content-Encoding', 'gzip'), ('X-GitHub-Request-Id', 'C8AC:1D8B2:126D746:1BF8DE4:5C4AEDBE')] +{"access_token":"access_token_removed","token_type":"bearer","scope":""}