2024-04-07 22:27:00 -06:00
from anthropic import AsyncAnthropic
2024-04-09 19:26:44 -06:00
from nio import RoomMessageImage
2024-04-07 22:27:00 -06:00
2024-04-09 19:26:44 -06:00
from matrix_gpt . chat_functions import download_mxc
2024-04-07 22:27:00 -06:00
from matrix_gpt . generate_clients . api_client import ApiClient
from matrix_gpt . generate_clients . command_info import CommandInfo
2024-04-09 19:26:44 -06:00
from matrix_gpt . image import process_image
2024-04-07 22:27:00 -06:00
class AnthropicApiClient ( ApiClient ) :
2024-04-09 19:26:44 -06:00
def __init__ ( self , * args , * * kwargs ) :
super ( ) . __init__ ( * args , * * kwargs )
2024-04-07 22:27:00 -06:00
def _create_client ( self , base_url : str = None ) :
return AsyncAnthropic (
2024-11-16 13:32:11 -07:00
api_key = self . _api_key ,
base_url = self . _api_base
2024-04-07 22:27:00 -06:00
)
2024-04-11 14:32:38 -06:00
def prepare_context ( self , context : list , system_prompt : str = None , injected_system_prompt : str = None ) :
2024-04-10 22:47:15 -06:00
assert not len ( self . _context )
self . _context = context
self . verify_context ( )
2024-04-07 22:27:00 -06:00
2024-04-09 19:26:44 -06:00
def verify_context ( self ) :
"""
2024-04-10 22:47:15 -06:00
Verify that the context alternates between the human and assistant , inserting the opposite user type if it does not alternate correctly .
2024-04-09 19:26:44 -06:00
"""
i = 0
while i < len ( self . _context ) - 1 :
if self . _context [ i ] [ ' role ' ] == self . _context [ i + 1 ] [ ' role ' ] :
2024-04-11 14:32:38 -06:00
dummy = self . text_msg ( f ' < { self . _BOT_NAME } did not respond> ' , self . _BOT_NAME ) if self . _context [ i ] [ ' role ' ] == self . _HUMAN_NAME else self . text_msg ( f ' < { self . _HUMAN_NAME } did not respond> ' , self . _HUMAN_NAME )
2024-04-09 19:26:44 -06:00
self . _context . insert ( i + 1 , dummy )
i + = 1
2024-04-11 14:32:38 -06:00
def text_msg ( self , content : str , role : str ) :
2024-04-09 19:26:44 -06:00
assert role in [ self . _HUMAN_NAME , self . _BOT_NAME ]
return { " role " : role , " content " : [ { " type " : " text " , " text " : str ( content ) } ] }
2024-04-07 22:27:00 -06:00
def append_msg ( self , content : str , role : str ) :
assert role in [ self . _HUMAN_NAME , self . _BOT_NAME ]
2024-04-11 14:32:38 -06:00
self . _context . append ( self . text_msg ( content , role ) )
2024-04-09 19:26:44 -06:00
async def append_img ( self , img_event : RoomMessageImage , role : str ) :
assert role in [ self . _HUMAN_NAME , self . _BOT_NAME ]
2024-04-10 16:42:52 -06:00
img_bytes = await download_mxc ( img_event . url , self . _client_helper . client )
2024-04-10 18:21:26 -06:00
encoded_image = await process_image ( img_bytes , resize_px = 784 )
2024-04-09 19:26:44 -06:00
self . _context . append ( {
" role " : role ,
' content ' : [ {
' type ' : ' image ' ,
' source ' : {
' type ' : ' base64 ' ,
' media_type ' : ' image/png ' ,
' data ' : encoded_image
}
} ]
} )
2024-04-07 22:27:00 -06:00
2024-04-10 22:47:15 -06:00
async def generate ( self , command_info : CommandInfo , matrix_gpt_data : str = None ) :
2024-04-07 22:27:00 -06:00
r = await self . _create_client ( ) . messages . create (
model = command_info . model ,
max_tokens = None if command_info . max_tokens == 0 else command_info . max_tokens ,
temperature = command_info . temperature ,
system = ' ' if not command_info . system_prompt else command_info . system_prompt ,
messages = self . context
)
2024-04-10 22:47:15 -06:00
return r . content [ 0 ] . text , None