This commit is contained in:
parent
a721a1def0
commit
564bbe073d
|
@ -16,7 +16,6 @@ Global variable to sync importing and sharing the configured module.
|
||||||
class ApiClientManager:
|
class ApiClientManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._openai_api_key = None
|
self._openai_api_key = None
|
||||||
self._openai_api_base = None
|
|
||||||
self._anth_api_key = None
|
self._anth_api_key = None
|
||||||
self.logger = logging.getLogger('MatrixGPT').getChild('ApiClientManager')
|
self.logger = logging.getLogger('MatrixGPT').getChild('ApiClientManager')
|
||||||
|
|
||||||
|
|
|
@ -11,10 +11,10 @@ class AnthropicApiClient(ApiClient):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _create_client(self):
|
def _create_client(self, api_base: str = None):
|
||||||
return AsyncAnthropic(
|
return AsyncAnthropic(
|
||||||
api_key=self._api_key,
|
api_key=self._api_key,
|
||||||
base_url=self._api_base
|
base_url=api_base
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_context(self, context: list, system_prompt: str = None, injected_system_prompt: str = None):
|
def prepare_context(self, context: list, system_prompt: str = None, injected_system_prompt: str = None):
|
||||||
|
@ -58,7 +58,7 @@ class AnthropicApiClient(ApiClient):
|
||||||
})
|
})
|
||||||
|
|
||||||
async def generate(self, command_info: CommandInfo, matrix_gpt_data: str = None):
|
async def generate(self, command_info: CommandInfo, matrix_gpt_data: str = None):
|
||||||
r = await self._create_client().messages.create(
|
r = await self._create_client(command_info.api_base).messages.create(
|
||||||
model=command_info.model,
|
model=command_info.model,
|
||||||
max_tokens=None if command_info.max_tokens == 0 else command_info.max_tokens,
|
max_tokens=None if command_info.max_tokens == 0 else command_info.max_tokens,
|
||||||
temperature=command_info.temperature,
|
temperature=command_info.temperature,
|
||||||
|
|
|
@ -10,12 +10,11 @@ class ApiClient:
|
||||||
_HUMAN_NAME = 'user'
|
_HUMAN_NAME = 'user'
|
||||||
_BOT_NAME = 'assistant'
|
_BOT_NAME = 'assistant'
|
||||||
|
|
||||||
def __init__(self, api_key: str, client_helper: MatrixClientHelper, room: MatrixRoom, event: Event, api_base: str = None):
|
def __init__(self, api_key: str, client_helper: MatrixClientHelper, room: MatrixRoom, event: Event):
|
||||||
self._api_key = api_key
|
self._api_key = api_key
|
||||||
self._client_helper = client_helper
|
self._client_helper = client_helper
|
||||||
self._room = room
|
self._room = room
|
||||||
self._event = event
|
self._event = event
|
||||||
self._api_base = api_base
|
|
||||||
self._context = []
|
self._context = []
|
||||||
|
|
||||||
def _create_client(self):
|
def _create_client(self):
|
||||||
|
|
|
@ -12,10 +12,10 @@ class OpenAIClient(ApiClient):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _create_client(self):
|
def _create_client(self, api_base: str = None):
|
||||||
return AsyncOpenAI(
|
return AsyncOpenAI(
|
||||||
api_key=self._api_key,
|
api_key=self._api_key,
|
||||||
base_url=self._api_base
|
base_url=api_base
|
||||||
)
|
)
|
||||||
|
|
||||||
def append_msg(self, content: str, role: str):
|
def append_msg(self, content: str, role: str):
|
||||||
|
|
Loading…
Reference in New Issue