make process_image async
This commit is contained in:
parent
7c911d1235
commit
06396a70da
|
@ -51,7 +51,7 @@ class AnthropicApiClient(ApiClient):
|
|||
async def append_img(self, img_event: RoomMessageImage, role: str):
|
||||
assert role in [self._HUMAN_NAME, self._BOT_NAME]
|
||||
img_bytes = await download_mxc(img_event.url, self._client_helper.client)
|
||||
encoded_image = process_image(img_bytes, resize_px=784)
|
||||
encoded_image = await process_image(img_bytes, resize_px=784)
|
||||
self._context.append({
|
||||
"role": role,
|
||||
'content': [{
|
||||
|
|
|
@ -27,7 +27,7 @@ class OpenAIClient(ApiClient):
|
|||
async def append_img(self, img_event: RoomMessageImage, role: str):
|
||||
assert role in [self._HUMAN_NAME, self._BOT_NAME]
|
||||
img_bytes = await download_mxc(img_event.url, self._client_helper.client)
|
||||
encoded_image = process_image(img_bytes, resize_px=512)
|
||||
encoded_image = await process_image(img_bytes, resize_px=512)
|
||||
self._context.append({
|
||||
"role": role,
|
||||
'content': [{
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def process_image(source_bytes: bytes, resize_px: int):
|
||||
image = Image.open(io.BytesIO(source_bytes))
|
||||
async def process_image(source_bytes: bytes, resize_px: int):
|
||||
loop = asyncio.get_event_loop()
|
||||
image = await loop.run_in_executor(None, Image.open, io.BytesIO(source_bytes))
|
||||
width, height = image.size
|
||||
|
||||
if min(width, height) > resize_px:
|
||||
|
@ -15,9 +17,9 @@ def process_image(source_bytes: bytes, resize_px: int):
|
|||
else:
|
||||
new_height = resize_px
|
||||
new_width = int((width / height) * new_height)
|
||||
image = image.resize((new_width, new_height))
|
||||
image = await loop.run_in_executor(None, image.resize, (new_width, new_height))
|
||||
|
||||
byte_arr = io.BytesIO()
|
||||
image.save(byte_arr, format='PNG')
|
||||
await loop.run_in_executor(None, image.save, byte_arr, 'PNG')
|
||||
image_bytes = byte_arr.getvalue()
|
||||
return base64.b64encode(image_bytes).decode('utf-8')
|
||||
|
|
Loading…
Reference in New Issue