fixes 'Premature close' caused by fucked up AWS unmarshaller errors

This commit is contained in:
nai-degen 2024-02-10 14:47:14 -06:00
parent a2ae9f32db
commit 35dc0f4826
5 changed files with 109 additions and 80 deletions

View File

@ -1,19 +1,20 @@
import { pipeline, Transform, Readable } from "stream";
import { pipeline, Readable, Transform } from "stream";
import StreamArray from "stream-json/streamers/StreamArray";
import { StringDecoder } from "string_decoder";
import { promisify } from "util";
import { APIFormat, keyPool } from "../../../shared/key-management";
import {
makeCompletionSSE,
copySseResponseHeaders,
initializeSseStream,
makeCompletionSSE,
} from "../../../shared/streaming";
import type { logger } from "../../../logger";
import { enqueue } from "../../queue";
import { decodeResponseBody, RawResponseBodyHandler, RetryableError } from ".";
import { getAwsEventStreamDecoder } from "./streaming/aws-event-stream-decoder";
import { EventAggregator } from "./streaming/event-aggregator";
import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
import { viaEventStreamMarshaller } from "./streaming/via-event-stream-marshaller";
const pipelineAsync = promisify(pipeline);
@ -63,12 +64,15 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
}
const prefersNativeEvents = req.inboundApi === req.outboundApi;
const contentType = proxyRes.headers["content-type"];
const streamOptions = { contentType, api: req.outboundApi, logger: req.log };
const streamOptions = {
contentType: proxyRes.headers["content-type"],
api: req.outboundApi,
logger: req.log,
};
// Decoder turns the raw response stream into a stream of events in some
// format (text/event-stream, vnd.amazon.event-stream, streaming JSON, etc).
const decoder = selectDecoderStream({ ...streamOptions, input: proxyRes });
const decoder = getDecoder({ ...streamOptions, input: proxyRes });
// Adapter transforms the decoded events into server-sent events.
const adapter = new SSEStreamAdapter(streamOptions);
// Aggregator compiles all events into a single response object.
@ -88,8 +92,6 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
.on("data", (msg) => {
if (!prefersNativeEvents) res.write(`data: ${JSON.stringify(msg)}\n\n`);
aggregator.addEvent(msg);
}).on("end", () => {
req.log.debug({ key: hash }, `Finished streaming response.`);
});
try {
@ -125,14 +127,15 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
}
};
function selectDecoderStream(options: {
function getDecoder(options: {
input: Readable;
api: APIFormat;
logger: typeof logger;
contentType?: string;
}) {
const { api, contentType, input } = options;
const { api, contentType, input, logger } = options;
if (contentType?.includes("application/vnd.amazon.eventstream")) {
return viaEventStreamMarshaller(input);
return getAwsEventStreamDecoder({ input, logger });
} else if (api === "google-ai") {
return StreamArray.withParser();
} else {

View File

@ -0,0 +1,93 @@
import pino from "pino";
import { Duplex, Readable } from "stream";
import { EventStreamMarshaller } from "@smithy/eventstream-serde-node";
import { fromUtf8, toUtf8 } from "@smithy/util-utf8";
import { Message } from "@smithy/eventstream-codec";
/**
* Decodes a Readable stream, such as a proxied HTTP response, into a stream of
* Message objects using the AWS SDK's EventStreamMarshaller. Error events in
* the amazon eventstream protocol are decoded as Message objects and will not
* emit an error event on the decoder stream.
*/
export function getAwsEventStreamDecoder(params: {
input: Readable;
logger: pino.Logger;
}): Duplex {
const { input, logger } = params;
const config = { utf8Encoder: toUtf8, utf8Decoder: fromUtf8 };
const eventStream = new EventStreamMarshaller(config).deserialize(
input,
async (input: Record<string, Message>) => {
const eventType = Object.keys(input)[0];
let result;
if (eventType === "chunk") {
result = input[eventType];
} else {
// AWS unmarshaller treats non-chunk (errors and exceptions) oddly.
result = { [eventType]: input[eventType] } as any;
}
return result;
}
);
return new AWSEventStreamDecoder(eventStream, { logger });
}
class AWSEventStreamDecoder extends Duplex {
private readonly asyncIterable: AsyncIterable<Message>;
private iterator: AsyncIterator<Message>;
private reading: boolean;
private logger: pino.Logger;
constructor(
asyncIterable: AsyncIterable<Message>,
options: { logger: pino.Logger }
) {
super({ ...options, objectMode: true });
this.asyncIterable = asyncIterable;
this.iterator = this.asyncIterable[Symbol.asyncIterator]();
this.reading = false;
this.logger = options.logger.child({ module: "aws-eventstream-decoder" });
}
async _read(_size: number) {
if (this.reading) return;
this.reading = true;
try {
while (true) {
const { value, done } = await this.iterator.next();
if (done) {
this.push(null);
break;
}
if (!this.push(value)) break;
}
} catch (err) {
// AWS SDK's EventStreamMarshaller emits errors in the stream itself as
// whatever our deserializer returns, which will not be Error objects
// because we want to pass the Message to the next stream for processing.
// Any actual Error thrown here is some failure during deserialization.
const isAwsError = !(err instanceof Error);
if (isAwsError) {
this.logger.warn({ err: err.headers }, "Received AWS error event");
this.push(err);
this.push(null);
} else {
this.logger.error(err, "Error during AWS stream deserialization");
this.destroy(err);
}
} finally {
this.reading = false;
}
}
_write(_chunk: any, _encoding: string, callback: () => void) {
callback();
}
_final(callback: () => void) {
callback();
}
}

View File

@ -51,7 +51,7 @@ export class SSEStreamAdapter extends Transform {
const event = Buffer.from(bytes, "base64").toString("utf8");
return ["event: completion", `data: ${event}`].join(`\n`);
}
// Intentional fallthrough, non-JSON events will be something very weird
// Intentional fallthrough, as non-JSON events may as well be errors
// noinspection FallThroughInSwitchStatementJS
case "exception":
case "error":
@ -61,7 +61,6 @@ export class SSEStreamAdapter extends Transform {
switch (type) {
case "throttlingexception":
this.log.warn(
{ message, type },
"AWS request throttled after streaming has already started; retrying"
);
throw new RetryableError("AWS request throttled mid-stream");
@ -142,7 +141,6 @@ export class SSEStreamAdapter extends Transform {
}
_flush(callback: (err?: Error | null) => void) {
this.log.debug("SSEStreamAdapter flushing");
callback();
}
}

View File

@ -1,65 +0,0 @@
import { Duplex, Readable } from "stream";
import { EventStreamMarshaller } from "@smithy/eventstream-serde-node";
import { fromUtf8, toUtf8 } from "@smithy/util-utf8";
import { Message } from "@smithy/eventstream-codec";
/**
* Decodes a Readable stream, such as a proxied HTTP response, into a stream of
* Message objects using the AWS SDK's EventStreamMarshaller.
* @param input
*/
export function viaEventStreamMarshaller(input: Readable): Duplex {
const config = { utf8Encoder: toUtf8, utf8Decoder: fromUtf8 };
const eventStream = new EventStreamMarshaller(config).deserialize(
input,
// deserializer is always an object with one key. we just extract the value
// and pipe it to SSEStreamAdapter for it to turn it into an SSE stream
async (input: Record<string, Message>) => Object.values(input)[0]
);
return new StreamFromIterable(eventStream);
}
// In theory, Duplex.from(eventStream) would have rendered this wrapper
// unnecessary, but I was not able to get it to work for a number of reasons and
// needed more control over the stream's lifecycle.
class StreamFromIterable extends Duplex {
private readonly asyncIterable: AsyncIterable<Message>;
private iterator: AsyncIterator<Message>;
private reading: boolean;
constructor(asyncIterable: AsyncIterable<Message>, options = {}) {
super({ ...options, objectMode: true });
this.asyncIterable = asyncIterable;
this.iterator = this.asyncIterable[Symbol.asyncIterator]();
this.reading = false;
}
async _read(_size: number) {
if (this.reading) return;
this.reading = true;
try {
while (true) {
const { value, done } = await this.iterator.next();
if (done) {
this.push(null);
break;
}
if (!this.push(value)) break;
}
} catch (err) {
this.destroy(err);
} finally {
this.reading = false;
}
}
_write(_chunk: any, _encoding: string, callback: () => void) {
callback();
}
_final(callback: () => void) {
callback();
}
}

View File

@ -41,7 +41,7 @@ const RATE_LIMIT_LOCKOUT = 4000;
* to be used again. This is to prevent the queue from flooding a key with too
* many requests while we wait to learn whether previous ones succeeded.
*/
const KEY_REUSE_DELAY = 250;
const KEY_REUSE_DELAY = 500;
export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
readonly service = "aws";