diff --git a/matrix_gpt/generate_clients/anthropic.py b/matrix_gpt/generate_clients/anthropic.py index 9f59c02..94cd26f 100644 --- a/matrix_gpt/generate_clients/anthropic.py +++ b/matrix_gpt/generate_clients/anthropic.py @@ -13,7 +13,8 @@ class AnthropicApiClient(ApiClient): def _create_client(self, base_url: str = None): return AsyncAnthropic( - api_key=self._api_key + api_key=self._api_key, + base_url=self._api_base ) def prepare_context(self, context: list, system_prompt: str = None, injected_system_prompt: str = None): @@ -31,8 +32,6 @@ class AnthropicApiClient(ApiClient): dummy = self.text_msg(f'<{self._BOT_NAME} did not respond>', self._BOT_NAME) if self._context[i]['role'] == self._HUMAN_NAME else self.text_msg(f'<{self._HUMAN_NAME} did not respond>', self._HUMAN_NAME) self._context.insert(i + 1, dummy) i += 1 - # if self._context[-1]['role'] == self._HUMAN_NAME: - # self._context.append(self.generate_text_msg(f'<{self._BOT_NAME} did not respond>', self._BOT_NAME)) def text_msg(self, content: str, role: str): assert role in [self._HUMAN_NAME, self._BOT_NAME] diff --git a/matrix_gpt/generate_clients/api_client.py b/matrix_gpt/generate_clients/api_client.py index de2c771..95aef66 100644 --- a/matrix_gpt/generate_clients/api_client.py +++ b/matrix_gpt/generate_clients/api_client.py @@ -10,14 +10,15 @@ class ApiClient: _HUMAN_NAME = 'user' _BOT_NAME = 'assistant' - def __init__(self, api_key: str, client_helper: MatrixClientHelper, room: MatrixRoom, event: Event): + def __init__(self, api_key: str, client_helper: MatrixClientHelper, room: MatrixRoom, event: Event, api_base: str = None): self._api_key = api_key self._client_helper = client_helper self._room = room self._event = event + self._api_base = api_base self._context = [] - def _create_client(self, base_url: str = None): + def _create_client(self): raise NotImplementedError def check_ignore_request(self) -> bool: diff --git a/matrix_gpt/generate_clients/copilot.py b/matrix_gpt/generate_clients/copilot.py index afb55ca..c0d92bb 100644 --- a/matrix_gpt/generate_clients/copilot.py +++ b/matrix_gpt/generate_clients/copilot.py @@ -33,7 +33,7 @@ class CopilotClient(ApiClient): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def _create_client(self, api_base: str = None): + def _create_client(self): return None def append_msg(self, content: str, role: str): diff --git a/matrix_gpt/generate_clients/openai.py b/matrix_gpt/generate_clients/openai.py index 55d38a7..c0b80c8 100644 --- a/matrix_gpt/generate_clients/openai.py +++ b/matrix_gpt/generate_clients/openai.py @@ -12,10 +12,10 @@ class OpenAIClient(ApiClient): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def _create_client(self, api_base: str = None): + def _create_client(self): return AsyncOpenAI( api_key=self._api_key, - base_url=api_base + base_url=self._api_base ) def append_msg(self, content: str, role: str):