OAuth refresh

This commit is contained in:
Zotify 2024-08-15 16:16:50 +12:00
parent 446c8c2a52
commit faca12783e
2 changed files with 59 additions and 19 deletions

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from enum import IntEnum
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from threading import Thread
@ -69,7 +70,7 @@ class Session(LibrespotSession):
def __init__(
self,
session_builder: LibrespotSession.Builder,
token: TokenProvider.StoredToken,
oauth: OAuth,
language: str = "en",
) -> None:
"""
@ -89,7 +90,7 @@ class Session(LibrespotSession):
),
ApResolver.get_random_accesspoint(),
)
self.__token = token
self.__oauth = oauth
self.__language = language
self.connect()
self.authenticate(session_builder.login_credentials)
@ -112,8 +113,7 @@ class Session(LibrespotSession):
.build()
)
session = LibrespotSession.Builder(conf).stored_file(str(cred_file))
token = session.login_credentials.auth_data # TODO: this is wrong
return Session(session, token, language)
return Session(session, OAuth(), language) # TODO
@staticmethod
def from_oauth(
@ -148,7 +148,7 @@ class Session(LibrespotSession):
typ=Authentication.AuthenticationType.values()[3],
auth_data=token.access_token.encode(),
)
return Session(session, token, language)
return Session(session, auth, language)
def __get_playable(
self, playable_id: PlayableId, quality: Quality
@ -188,9 +188,9 @@ class Session(LibrespotSession):
self.api(),
)
def token(self) -> TokenProvider.StoredToken:
"""Returns API token"""
return self.__token
def oauth(self) -> OAuth:
"""Returns OAuth service"""
return self.__oauth
def language(self) -> str:
"""Returns session language"""
@ -288,7 +288,7 @@ class TokenProvider(LibrespotTokenProvider):
self._session = session
def get_token(self, *scopes) -> TokenProvider.StoredToken:
return self._session.token()
return self._session.oauth().get_token()
class StoredToken(LibrespotTokenProvider.StoredToken):
def __init__(self, obj):
@ -309,6 +309,11 @@ class OAuth:
self.__server_thread.start()
def get_authorization_url(self) -> str:
"""
Generate OAuth URL
Returns:
OAuth URL
"""
self.__code_verifier = generate_code_verifier()
code_challenge = get_code_challenge(self.__code_verifier)
params = {
@ -322,19 +327,48 @@ class OAuth:
return f"{AUTH_URL}authorize?{urlencode(params)}"
def await_token(self) -> TokenProvider.StoredToken:
"""
Blocks until server thread gets token
Returns:
StoredToken
"""
self.__server_thread.join()
return self.__token
def set_token(self, code: str) -> None:
def get_token(self) -> TokenProvider.StoredToken:
"""
Gets a valid token
Returns:
StoredToken
"""
if self.__token is None:
raise RuntimeError("Session isn't authenticated!")
elif self.__token.expired():
self.set_token(self.__token.refresh_token, OAuth.RequestType.REFRESH)
return self.__token
def set_token(self, code: str, request_type: RequestType) -> None:
"""
Fetches and sets stored token
Returns:
StoredToken
"""
token_url = f"{AUTH_URL}api/token"
headers = {"Content-Type": "application/x-www-form-urlencoded"}
body = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": REDIRECT_URI,
"client_id": CLIENT_ID,
"code_verifier": self.__code_verifier,
}
if request_type == OAuth.RequestType.LOGIN:
body = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": REDIRECT_URI,
"client_id": CLIENT_ID,
"code_verifier": self.__code_verifier,
}
elif request_type == OAuth.RequestType.REFRESH:
body = {
"grant_type": "refresh_token",
"refresh_token": code,
"client_id": CLIENT_ID,
}
response = post(token_url, headers=headers, data=body)
if response.status_code != 200:
raise IOError(
@ -348,6 +382,10 @@ class OAuth:
httpd.authenticator = self
httpd.serve_forever()
class RequestType(IntEnum):
LOGIN = 0
REFRESH = 1
class OAuthHTTPServer(HTTPServer):
authenticator: OAuth
@ -371,7 +409,9 @@ class OAuth:
if code:
if isinstance(self.server, OAuth.OAuthHTTPServer):
self.server.authenticator.set_token(code[0])
self.server.authenticator.set_token(
code[0], OAuth.RequestType.LOGIN
)
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()

View File

@ -106,7 +106,7 @@ class Selection:
def __print(self, count: int, items: list[dict[str, Any]], *args: str) -> None:
arg_range = range(len(args))
category_str = " # " + " ".join("{:<38}" for _ in arg_range)
category_str = "# " + " ".join("{:<38}" for _ in arg_range)
print(category_str.format(*[s.upper() for s in list(args)]))
for item in items:
count += 1