diff --git a/src/proxy/middleware/response/handle-streamed-response.ts b/src/proxy/middleware/response/handle-streamed-response.ts index 53b95aa..46b0f2a 100644 --- a/src/proxy/middleware/response/handle-streamed-response.ts +++ b/src/proxy/middleware/response/handle-streamed-response.ts @@ -1,19 +1,20 @@ -import { pipeline, Transform, Readable } from "stream"; +import { pipeline, Readable, Transform } from "stream"; import StreamArray from "stream-json/streamers/StreamArray"; import { StringDecoder } from "string_decoder"; import { promisify } from "util"; import { APIFormat, keyPool } from "../../../shared/key-management"; import { - makeCompletionSSE, copySseResponseHeaders, initializeSseStream, + makeCompletionSSE, } from "../../../shared/streaming"; +import type { logger } from "../../../logger"; import { enqueue } from "../../queue"; import { decodeResponseBody, RawResponseBodyHandler, RetryableError } from "."; +import { getAwsEventStreamDecoder } from "./streaming/aws-event-stream-decoder"; import { EventAggregator } from "./streaming/event-aggregator"; import { SSEMessageTransformer } from "./streaming/sse-message-transformer"; import { SSEStreamAdapter } from "./streaming/sse-stream-adapter"; -import { viaEventStreamMarshaller } from "./streaming/via-event-stream-marshaller"; const pipelineAsync = promisify(pipeline); @@ -63,12 +64,15 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( } const prefersNativeEvents = req.inboundApi === req.outboundApi; - const contentType = proxyRes.headers["content-type"]; - const streamOptions = { contentType, api: req.outboundApi, logger: req.log }; + const streamOptions = { + contentType: proxyRes.headers["content-type"], + api: req.outboundApi, + logger: req.log, + }; // Decoder turns the raw response stream into a stream of events in some // format (text/event-stream, vnd.amazon.event-stream, streaming JSON, etc). - const decoder = selectDecoderStream({ ...streamOptions, input: proxyRes }); + const decoder = getDecoder({ ...streamOptions, input: proxyRes }); // Adapter transforms the decoded events into server-sent events. const adapter = new SSEStreamAdapter(streamOptions); // Aggregator compiles all events into a single response object. @@ -88,8 +92,6 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( .on("data", (msg) => { if (!prefersNativeEvents) res.write(`data: ${JSON.stringify(msg)}\n\n`); aggregator.addEvent(msg); - }).on("end", () => { - req.log.debug({ key: hash }, `Finished streaming response.`); }); try { @@ -125,14 +127,15 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( } }; -function selectDecoderStream(options: { +function getDecoder(options: { input: Readable; api: APIFormat; + logger: typeof logger; contentType?: string; }) { - const { api, contentType, input } = options; + const { api, contentType, input, logger } = options; if (contentType?.includes("application/vnd.amazon.eventstream")) { - return viaEventStreamMarshaller(input); + return getAwsEventStreamDecoder({ input, logger }); } else if (api === "google-ai") { return StreamArray.withParser(); } else { diff --git a/src/proxy/middleware/response/streaming/aws-event-stream-decoder.ts b/src/proxy/middleware/response/streaming/aws-event-stream-decoder.ts new file mode 100644 index 0000000..b394142 --- /dev/null +++ b/src/proxy/middleware/response/streaming/aws-event-stream-decoder.ts @@ -0,0 +1,93 @@ +import pino from "pino"; +import { Duplex, Readable } from "stream"; +import { EventStreamMarshaller } from "@smithy/eventstream-serde-node"; +import { fromUtf8, toUtf8 } from "@smithy/util-utf8"; +import { Message } from "@smithy/eventstream-codec"; + +/** + * Decodes a Readable stream, such as a proxied HTTP response, into a stream of + * Message objects using the AWS SDK's EventStreamMarshaller. Error events in + * the amazon eventstream protocol are decoded as Message objects and will not + * emit an error event on the decoder stream. + */ +export function getAwsEventStreamDecoder(params: { + input: Readable; + logger: pino.Logger; +}): Duplex { + const { input, logger } = params; + const config = { utf8Encoder: toUtf8, utf8Decoder: fromUtf8 }; + const eventStream = new EventStreamMarshaller(config).deserialize( + input, + async (input: Record) => { + const eventType = Object.keys(input)[0]; + let result; + if (eventType === "chunk") { + result = input[eventType]; + } else { + // AWS unmarshaller treats non-chunk (errors and exceptions) oddly. + result = { [eventType]: input[eventType] } as any; + } + return result; + } + ); + return new AWSEventStreamDecoder(eventStream, { logger }); +} + +class AWSEventStreamDecoder extends Duplex { + private readonly asyncIterable: AsyncIterable; + private iterator: AsyncIterator; + private reading: boolean; + private logger: pino.Logger; + + constructor( + asyncIterable: AsyncIterable, + options: { logger: pino.Logger } + ) { + super({ ...options, objectMode: true }); + this.asyncIterable = asyncIterable; + this.iterator = this.asyncIterable[Symbol.asyncIterator](); + this.reading = false; + this.logger = options.logger.child({ module: "aws-eventstream-decoder" }); + } + + async _read(_size: number) { + if (this.reading) return; + this.reading = true; + + try { + while (true) { + const { value, done } = await this.iterator.next(); + if (done) { + this.push(null); + break; + } + if (!this.push(value)) break; + } + } catch (err) { + // AWS SDK's EventStreamMarshaller emits errors in the stream itself as + // whatever our deserializer returns, which will not be Error objects + // because we want to pass the Message to the next stream for processing. + // Any actual Error thrown here is some failure during deserialization. + const isAwsError = !(err instanceof Error); + + if (isAwsError) { + this.logger.warn({ err: err.headers }, "Received AWS error event"); + this.push(err); + this.push(null); + } else { + this.logger.error(err, "Error during AWS stream deserialization"); + this.destroy(err); + } + } finally { + this.reading = false; + } + } + + _write(_chunk: any, _encoding: string, callback: () => void) { + callback(); + } + + _final(callback: () => void) { + callback(); + } +} diff --git a/src/proxy/middleware/response/streaming/sse-stream-adapter.ts b/src/proxy/middleware/response/streaming/sse-stream-adapter.ts index 20022d1..38157c5 100644 --- a/src/proxy/middleware/response/streaming/sse-stream-adapter.ts +++ b/src/proxy/middleware/response/streaming/sse-stream-adapter.ts @@ -51,7 +51,7 @@ export class SSEStreamAdapter extends Transform { const event = Buffer.from(bytes, "base64").toString("utf8"); return ["event: completion", `data: ${event}`].join(`\n`); } - // Intentional fallthrough, non-JSON events will be something very weird + // Intentional fallthrough, as non-JSON events may as well be errors // noinspection FallThroughInSwitchStatementJS case "exception": case "error": @@ -61,7 +61,6 @@ export class SSEStreamAdapter extends Transform { switch (type) { case "throttlingexception": this.log.warn( - { message, type }, "AWS request throttled after streaming has already started; retrying" ); throw new RetryableError("AWS request throttled mid-stream"); @@ -142,7 +141,6 @@ export class SSEStreamAdapter extends Transform { } _flush(callback: (err?: Error | null) => void) { - this.log.debug("SSEStreamAdapter flushing"); callback(); } } diff --git a/src/proxy/middleware/response/streaming/via-event-stream-marshaller.ts b/src/proxy/middleware/response/streaming/via-event-stream-marshaller.ts deleted file mode 100644 index a00165a..0000000 --- a/src/proxy/middleware/response/streaming/via-event-stream-marshaller.ts +++ /dev/null @@ -1,65 +0,0 @@ -import { Duplex, Readable } from "stream"; -import { EventStreamMarshaller } from "@smithy/eventstream-serde-node"; -import { fromUtf8, toUtf8 } from "@smithy/util-utf8"; -import { Message } from "@smithy/eventstream-codec"; - -/** - * Decodes a Readable stream, such as a proxied HTTP response, into a stream of - * Message objects using the AWS SDK's EventStreamMarshaller. - * @param input - */ -export function viaEventStreamMarshaller(input: Readable): Duplex { - const config = { utf8Encoder: toUtf8, utf8Decoder: fromUtf8 }; - const eventStream = new EventStreamMarshaller(config).deserialize( - input, - // deserializer is always an object with one key. we just extract the value - // and pipe it to SSEStreamAdapter for it to turn it into an SSE stream - async (input: Record) => Object.values(input)[0] - ); - return new StreamFromIterable(eventStream); -} - -// In theory, Duplex.from(eventStream) would have rendered this wrapper -// unnecessary, but I was not able to get it to work for a number of reasons and -// needed more control over the stream's lifecycle. - -class StreamFromIterable extends Duplex { - private readonly asyncIterable: AsyncIterable; - private iterator: AsyncIterator; - private reading: boolean; - - constructor(asyncIterable: AsyncIterable, options = {}) { - super({ ...options, objectMode: true }); - this.asyncIterable = asyncIterable; - this.iterator = this.asyncIterable[Symbol.asyncIterator](); - this.reading = false; - } - - async _read(_size: number) { - if (this.reading) return; - this.reading = true; - - try { - while (true) { - const { value, done } = await this.iterator.next(); - if (done) { - this.push(null); - break; - } - if (!this.push(value)) break; - } - } catch (err) { - this.destroy(err); - } finally { - this.reading = false; - } - } - - _write(_chunk: any, _encoding: string, callback: () => void) { - callback(); - } - - _final(callback: () => void) { - callback(); - } -} diff --git a/src/shared/key-management/aws/provider.ts b/src/shared/key-management/aws/provider.ts index 58a9477..5391466 100644 --- a/src/shared/key-management/aws/provider.ts +++ b/src/shared/key-management/aws/provider.ts @@ -41,7 +41,7 @@ const RATE_LIMIT_LOCKOUT = 4000; * to be used again. This is to prevent the queue from flooding a key with too * many requests while we wait to learn whether previous ones succeeded. */ -const KEY_REUSE_DELAY = 250; +const KEY_REUSE_DELAY = 500; export class AwsBedrockKeyProvider implements KeyProvider { readonly service = "aws";