fixes 'Premature close' caused by fucked up AWS unmarshaller errors
This commit is contained in:
parent
a2ae9f32db
commit
35dc0f4826
|
@ -1,19 +1,20 @@
|
|||
import { pipeline, Transform, Readable } from "stream";
|
||||
import { pipeline, Readable, Transform } from "stream";
|
||||
import StreamArray from "stream-json/streamers/StreamArray";
|
||||
import { StringDecoder } from "string_decoder";
|
||||
import { promisify } from "util";
|
||||
import { APIFormat, keyPool } from "../../../shared/key-management";
|
||||
import {
|
||||
makeCompletionSSE,
|
||||
copySseResponseHeaders,
|
||||
initializeSseStream,
|
||||
makeCompletionSSE,
|
||||
} from "../../../shared/streaming";
|
||||
import type { logger } from "../../../logger";
|
||||
import { enqueue } from "../../queue";
|
||||
import { decodeResponseBody, RawResponseBodyHandler, RetryableError } from ".";
|
||||
import { getAwsEventStreamDecoder } from "./streaming/aws-event-stream-decoder";
|
||||
import { EventAggregator } from "./streaming/event-aggregator";
|
||||
import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
|
||||
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
|
||||
import { viaEventStreamMarshaller } from "./streaming/via-event-stream-marshaller";
|
||||
|
||||
const pipelineAsync = promisify(pipeline);
|
||||
|
||||
|
@ -63,12 +64,15 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
|||
}
|
||||
|
||||
const prefersNativeEvents = req.inboundApi === req.outboundApi;
|
||||
const contentType = proxyRes.headers["content-type"];
|
||||
const streamOptions = { contentType, api: req.outboundApi, logger: req.log };
|
||||
const streamOptions = {
|
||||
contentType: proxyRes.headers["content-type"],
|
||||
api: req.outboundApi,
|
||||
logger: req.log,
|
||||
};
|
||||
|
||||
// Decoder turns the raw response stream into a stream of events in some
|
||||
// format (text/event-stream, vnd.amazon.event-stream, streaming JSON, etc).
|
||||
const decoder = selectDecoderStream({ ...streamOptions, input: proxyRes });
|
||||
const decoder = getDecoder({ ...streamOptions, input: proxyRes });
|
||||
// Adapter transforms the decoded events into server-sent events.
|
||||
const adapter = new SSEStreamAdapter(streamOptions);
|
||||
// Aggregator compiles all events into a single response object.
|
||||
|
@ -88,8 +92,6 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
|||
.on("data", (msg) => {
|
||||
if (!prefersNativeEvents) res.write(`data: ${JSON.stringify(msg)}\n\n`);
|
||||
aggregator.addEvent(msg);
|
||||
}).on("end", () => {
|
||||
req.log.debug({ key: hash }, `Finished streaming response.`);
|
||||
});
|
||||
|
||||
try {
|
||||
|
@ -125,14 +127,15 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
|||
}
|
||||
};
|
||||
|
||||
function selectDecoderStream(options: {
|
||||
function getDecoder(options: {
|
||||
input: Readable;
|
||||
api: APIFormat;
|
||||
logger: typeof logger;
|
||||
contentType?: string;
|
||||
}) {
|
||||
const { api, contentType, input } = options;
|
||||
const { api, contentType, input, logger } = options;
|
||||
if (contentType?.includes("application/vnd.amazon.eventstream")) {
|
||||
return viaEventStreamMarshaller(input);
|
||||
return getAwsEventStreamDecoder({ input, logger });
|
||||
} else if (api === "google-ai") {
|
||||
return StreamArray.withParser();
|
||||
} else {
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
import pino from "pino";
|
||||
import { Duplex, Readable } from "stream";
|
||||
import { EventStreamMarshaller } from "@smithy/eventstream-serde-node";
|
||||
import { fromUtf8, toUtf8 } from "@smithy/util-utf8";
|
||||
import { Message } from "@smithy/eventstream-codec";
|
||||
|
||||
/**
|
||||
* Decodes a Readable stream, such as a proxied HTTP response, into a stream of
|
||||
* Message objects using the AWS SDK's EventStreamMarshaller. Error events in
|
||||
* the amazon eventstream protocol are decoded as Message objects and will not
|
||||
* emit an error event on the decoder stream.
|
||||
*/
|
||||
export function getAwsEventStreamDecoder(params: {
|
||||
input: Readable;
|
||||
logger: pino.Logger;
|
||||
}): Duplex {
|
||||
const { input, logger } = params;
|
||||
const config = { utf8Encoder: toUtf8, utf8Decoder: fromUtf8 };
|
||||
const eventStream = new EventStreamMarshaller(config).deserialize(
|
||||
input,
|
||||
async (input: Record<string, Message>) => {
|
||||
const eventType = Object.keys(input)[0];
|
||||
let result;
|
||||
if (eventType === "chunk") {
|
||||
result = input[eventType];
|
||||
} else {
|
||||
// AWS unmarshaller treats non-chunk (errors and exceptions) oddly.
|
||||
result = { [eventType]: input[eventType] } as any;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
);
|
||||
return new AWSEventStreamDecoder(eventStream, { logger });
|
||||
}
|
||||
|
||||
class AWSEventStreamDecoder extends Duplex {
|
||||
private readonly asyncIterable: AsyncIterable<Message>;
|
||||
private iterator: AsyncIterator<Message>;
|
||||
private reading: boolean;
|
||||
private logger: pino.Logger;
|
||||
|
||||
constructor(
|
||||
asyncIterable: AsyncIterable<Message>,
|
||||
options: { logger: pino.Logger }
|
||||
) {
|
||||
super({ ...options, objectMode: true });
|
||||
this.asyncIterable = asyncIterable;
|
||||
this.iterator = this.asyncIterable[Symbol.asyncIterator]();
|
||||
this.reading = false;
|
||||
this.logger = options.logger.child({ module: "aws-eventstream-decoder" });
|
||||
}
|
||||
|
||||
async _read(_size: number) {
|
||||
if (this.reading) return;
|
||||
this.reading = true;
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { value, done } = await this.iterator.next();
|
||||
if (done) {
|
||||
this.push(null);
|
||||
break;
|
||||
}
|
||||
if (!this.push(value)) break;
|
||||
}
|
||||
} catch (err) {
|
||||
// AWS SDK's EventStreamMarshaller emits errors in the stream itself as
|
||||
// whatever our deserializer returns, which will not be Error objects
|
||||
// because we want to pass the Message to the next stream for processing.
|
||||
// Any actual Error thrown here is some failure during deserialization.
|
||||
const isAwsError = !(err instanceof Error);
|
||||
|
||||
if (isAwsError) {
|
||||
this.logger.warn({ err: err.headers }, "Received AWS error event");
|
||||
this.push(err);
|
||||
this.push(null);
|
||||
} else {
|
||||
this.logger.error(err, "Error during AWS stream deserialization");
|
||||
this.destroy(err);
|
||||
}
|
||||
} finally {
|
||||
this.reading = false;
|
||||
}
|
||||
}
|
||||
|
||||
_write(_chunk: any, _encoding: string, callback: () => void) {
|
||||
callback();
|
||||
}
|
||||
|
||||
_final(callback: () => void) {
|
||||
callback();
|
||||
}
|
||||
}
|
|
@ -51,7 +51,7 @@ export class SSEStreamAdapter extends Transform {
|
|||
const event = Buffer.from(bytes, "base64").toString("utf8");
|
||||
return ["event: completion", `data: ${event}`].join(`\n`);
|
||||
}
|
||||
// Intentional fallthrough, non-JSON events will be something very weird
|
||||
// Intentional fallthrough, as non-JSON events may as well be errors
|
||||
// noinspection FallThroughInSwitchStatementJS
|
||||
case "exception":
|
||||
case "error":
|
||||
|
@ -61,7 +61,6 @@ export class SSEStreamAdapter extends Transform {
|
|||
switch (type) {
|
||||
case "throttlingexception":
|
||||
this.log.warn(
|
||||
{ message, type },
|
||||
"AWS request throttled after streaming has already started; retrying"
|
||||
);
|
||||
throw new RetryableError("AWS request throttled mid-stream");
|
||||
|
@ -142,7 +141,6 @@ export class SSEStreamAdapter extends Transform {
|
|||
}
|
||||
|
||||
_flush(callback: (err?: Error | null) => void) {
|
||||
this.log.debug("SSEStreamAdapter flushing");
|
||||
callback();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,65 +0,0 @@
|
|||
import { Duplex, Readable } from "stream";
|
||||
import { EventStreamMarshaller } from "@smithy/eventstream-serde-node";
|
||||
import { fromUtf8, toUtf8 } from "@smithy/util-utf8";
|
||||
import { Message } from "@smithy/eventstream-codec";
|
||||
|
||||
/**
|
||||
* Decodes a Readable stream, such as a proxied HTTP response, into a stream of
|
||||
* Message objects using the AWS SDK's EventStreamMarshaller.
|
||||
* @param input
|
||||
*/
|
||||
export function viaEventStreamMarshaller(input: Readable): Duplex {
|
||||
const config = { utf8Encoder: toUtf8, utf8Decoder: fromUtf8 };
|
||||
const eventStream = new EventStreamMarshaller(config).deserialize(
|
||||
input,
|
||||
// deserializer is always an object with one key. we just extract the value
|
||||
// and pipe it to SSEStreamAdapter for it to turn it into an SSE stream
|
||||
async (input: Record<string, Message>) => Object.values(input)[0]
|
||||
);
|
||||
return new StreamFromIterable(eventStream);
|
||||
}
|
||||
|
||||
// In theory, Duplex.from(eventStream) would have rendered this wrapper
|
||||
// unnecessary, but I was not able to get it to work for a number of reasons and
|
||||
// needed more control over the stream's lifecycle.
|
||||
|
||||
class StreamFromIterable extends Duplex {
|
||||
private readonly asyncIterable: AsyncIterable<Message>;
|
||||
private iterator: AsyncIterator<Message>;
|
||||
private reading: boolean;
|
||||
|
||||
constructor(asyncIterable: AsyncIterable<Message>, options = {}) {
|
||||
super({ ...options, objectMode: true });
|
||||
this.asyncIterable = asyncIterable;
|
||||
this.iterator = this.asyncIterable[Symbol.asyncIterator]();
|
||||
this.reading = false;
|
||||
}
|
||||
|
||||
async _read(_size: number) {
|
||||
if (this.reading) return;
|
||||
this.reading = true;
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { value, done } = await this.iterator.next();
|
||||
if (done) {
|
||||
this.push(null);
|
||||
break;
|
||||
}
|
||||
if (!this.push(value)) break;
|
||||
}
|
||||
} catch (err) {
|
||||
this.destroy(err);
|
||||
} finally {
|
||||
this.reading = false;
|
||||
}
|
||||
}
|
||||
|
||||
_write(_chunk: any, _encoding: string, callback: () => void) {
|
||||
callback();
|
||||
}
|
||||
|
||||
_final(callback: () => void) {
|
||||
callback();
|
||||
}
|
||||
}
|
|
@ -41,7 +41,7 @@ const RATE_LIMIT_LOCKOUT = 4000;
|
|||
* to be used again. This is to prevent the queue from flooding a key with too
|
||||
* many requests while we wait to learn whether previous ones succeeded.
|
||||
*/
|
||||
const KEY_REUSE_DELAY = 250;
|
||||
const KEY_REUSE_DELAY = 500;
|
||||
|
||||
export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
readonly service = "aws";
|
||||
|
|
Loading…
Reference in New Issue