OAuth refresh
This commit is contained in:
parent
446c8c2a52
commit
faca12783e
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
@ -69,7 +70,7 @@ class Session(LibrespotSession):
|
|||
def __init__(
|
||||
self,
|
||||
session_builder: LibrespotSession.Builder,
|
||||
token: TokenProvider.StoredToken,
|
||||
oauth: OAuth,
|
||||
language: str = "en",
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -89,7 +90,7 @@ class Session(LibrespotSession):
|
|||
),
|
||||
ApResolver.get_random_accesspoint(),
|
||||
)
|
||||
self.__token = token
|
||||
self.__oauth = oauth
|
||||
self.__language = language
|
||||
self.connect()
|
||||
self.authenticate(session_builder.login_credentials)
|
||||
|
@ -112,8 +113,7 @@ class Session(LibrespotSession):
|
|||
.build()
|
||||
)
|
||||
session = LibrespotSession.Builder(conf).stored_file(str(cred_file))
|
||||
token = session.login_credentials.auth_data # TODO: this is wrong
|
||||
return Session(session, token, language)
|
||||
return Session(session, OAuth(), language) # TODO
|
||||
|
||||
@staticmethod
|
||||
def from_oauth(
|
||||
|
@ -148,7 +148,7 @@ class Session(LibrespotSession):
|
|||
typ=Authentication.AuthenticationType.values()[3],
|
||||
auth_data=token.access_token.encode(),
|
||||
)
|
||||
return Session(session, token, language)
|
||||
return Session(session, auth, language)
|
||||
|
||||
def __get_playable(
|
||||
self, playable_id: PlayableId, quality: Quality
|
||||
|
@ -188,9 +188,9 @@ class Session(LibrespotSession):
|
|||
self.api(),
|
||||
)
|
||||
|
||||
def token(self) -> TokenProvider.StoredToken:
|
||||
"""Returns API token"""
|
||||
return self.__token
|
||||
def oauth(self) -> OAuth:
|
||||
"""Returns OAuth service"""
|
||||
return self.__oauth
|
||||
|
||||
def language(self) -> str:
|
||||
"""Returns session language"""
|
||||
|
@ -288,7 +288,7 @@ class TokenProvider(LibrespotTokenProvider):
|
|||
self._session = session
|
||||
|
||||
def get_token(self, *scopes) -> TokenProvider.StoredToken:
|
||||
return self._session.token()
|
||||
return self._session.oauth().get_token()
|
||||
|
||||
class StoredToken(LibrespotTokenProvider.StoredToken):
|
||||
def __init__(self, obj):
|
||||
|
@ -309,6 +309,11 @@ class OAuth:
|
|||
self.__server_thread.start()
|
||||
|
||||
def get_authorization_url(self) -> str:
|
||||
"""
|
||||
Generate OAuth URL
|
||||
Returns:
|
||||
OAuth URL
|
||||
"""
|
||||
self.__code_verifier = generate_code_verifier()
|
||||
code_challenge = get_code_challenge(self.__code_verifier)
|
||||
params = {
|
||||
|
@ -322,19 +327,48 @@ class OAuth:
|
|||
return f"{AUTH_URL}authorize?{urlencode(params)}"
|
||||
|
||||
def await_token(self) -> TokenProvider.StoredToken:
|
||||
"""
|
||||
Blocks until server thread gets token
|
||||
Returns:
|
||||
StoredToken
|
||||
"""
|
||||
self.__server_thread.join()
|
||||
return self.__token
|
||||
|
||||
def set_token(self, code: str) -> None:
|
||||
def get_token(self) -> TokenProvider.StoredToken:
|
||||
"""
|
||||
Gets a valid token
|
||||
Returns:
|
||||
StoredToken
|
||||
"""
|
||||
if self.__token is None:
|
||||
raise RuntimeError("Session isn't authenticated!")
|
||||
elif self.__token.expired():
|
||||
self.set_token(self.__token.refresh_token, OAuth.RequestType.REFRESH)
|
||||
return self.__token
|
||||
|
||||
def set_token(self, code: str, request_type: RequestType) -> None:
|
||||
"""
|
||||
Fetches and sets stored token
|
||||
Returns:
|
||||
StoredToken
|
||||
"""
|
||||
token_url = f"{AUTH_URL}api/token"
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
body = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": REDIRECT_URI,
|
||||
"client_id": CLIENT_ID,
|
||||
"code_verifier": self.__code_verifier,
|
||||
}
|
||||
if request_type == OAuth.RequestType.LOGIN:
|
||||
body = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": REDIRECT_URI,
|
||||
"client_id": CLIENT_ID,
|
||||
"code_verifier": self.__code_verifier,
|
||||
}
|
||||
elif request_type == OAuth.RequestType.REFRESH:
|
||||
body = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": code,
|
||||
"client_id": CLIENT_ID,
|
||||
}
|
||||
response = post(token_url, headers=headers, data=body)
|
||||
if response.status_code != 200:
|
||||
raise IOError(
|
||||
|
@ -348,6 +382,10 @@ class OAuth:
|
|||
httpd.authenticator = self
|
||||
httpd.serve_forever()
|
||||
|
||||
class RequestType(IntEnum):
|
||||
LOGIN = 0
|
||||
REFRESH = 1
|
||||
|
||||
class OAuthHTTPServer(HTTPServer):
|
||||
authenticator: OAuth
|
||||
|
||||
|
@ -371,7 +409,9 @@ class OAuth:
|
|||
|
||||
if code:
|
||||
if isinstance(self.server, OAuth.OAuthHTTPServer):
|
||||
self.server.authenticator.set_token(code[0])
|
||||
self.server.authenticator.set_token(
|
||||
code[0], OAuth.RequestType.LOGIN
|
||||
)
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
|
|
|
@ -106,7 +106,7 @@ class Selection:
|
|||
|
||||
def __print(self, count: int, items: list[dict[str, Any]], *args: str) -> None:
|
||||
arg_range = range(len(args))
|
||||
category_str = " # " + " ".join("{:<38}" for _ in arg_range)
|
||||
category_str = "# " + " ".join("{:<38}" for _ in arg_range)
|
||||
print(category_str.format(*[s.upper() for s in list(args)]))
|
||||
for item in items:
|
||||
count += 1
|
||||
|
|
Loading…
Reference in New Issue