fixes 'Premature close' caused by fucked up AWS unmarshaller errors
This commit is contained in:
parent
a2ae9f32db
commit
35dc0f4826
|
@ -1,19 +1,20 @@
|
||||||
import { pipeline, Transform, Readable } from "stream";
|
import { pipeline, Readable, Transform } from "stream";
|
||||||
import StreamArray from "stream-json/streamers/StreamArray";
|
import StreamArray from "stream-json/streamers/StreamArray";
|
||||||
import { StringDecoder } from "string_decoder";
|
import { StringDecoder } from "string_decoder";
|
||||||
import { promisify } from "util";
|
import { promisify } from "util";
|
||||||
import { APIFormat, keyPool } from "../../../shared/key-management";
|
import { APIFormat, keyPool } from "../../../shared/key-management";
|
||||||
import {
|
import {
|
||||||
makeCompletionSSE,
|
|
||||||
copySseResponseHeaders,
|
copySseResponseHeaders,
|
||||||
initializeSseStream,
|
initializeSseStream,
|
||||||
|
makeCompletionSSE,
|
||||||
} from "../../../shared/streaming";
|
} from "../../../shared/streaming";
|
||||||
|
import type { logger } from "../../../logger";
|
||||||
import { enqueue } from "../../queue";
|
import { enqueue } from "../../queue";
|
||||||
import { decodeResponseBody, RawResponseBodyHandler, RetryableError } from ".";
|
import { decodeResponseBody, RawResponseBodyHandler, RetryableError } from ".";
|
||||||
|
import { getAwsEventStreamDecoder } from "./streaming/aws-event-stream-decoder";
|
||||||
import { EventAggregator } from "./streaming/event-aggregator";
|
import { EventAggregator } from "./streaming/event-aggregator";
|
||||||
import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
|
import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
|
||||||
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
|
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
|
||||||
import { viaEventStreamMarshaller } from "./streaming/via-event-stream-marshaller";
|
|
||||||
|
|
||||||
const pipelineAsync = promisify(pipeline);
|
const pipelineAsync = promisify(pipeline);
|
||||||
|
|
||||||
|
@ -63,12 +64,15 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||||
}
|
}
|
||||||
|
|
||||||
const prefersNativeEvents = req.inboundApi === req.outboundApi;
|
const prefersNativeEvents = req.inboundApi === req.outboundApi;
|
||||||
const contentType = proxyRes.headers["content-type"];
|
const streamOptions = {
|
||||||
const streamOptions = { contentType, api: req.outboundApi, logger: req.log };
|
contentType: proxyRes.headers["content-type"],
|
||||||
|
api: req.outboundApi,
|
||||||
|
logger: req.log,
|
||||||
|
};
|
||||||
|
|
||||||
// Decoder turns the raw response stream into a stream of events in some
|
// Decoder turns the raw response stream into a stream of events in some
|
||||||
// format (text/event-stream, vnd.amazon.event-stream, streaming JSON, etc).
|
// format (text/event-stream, vnd.amazon.event-stream, streaming JSON, etc).
|
||||||
const decoder = selectDecoderStream({ ...streamOptions, input: proxyRes });
|
const decoder = getDecoder({ ...streamOptions, input: proxyRes });
|
||||||
// Adapter transforms the decoded events into server-sent events.
|
// Adapter transforms the decoded events into server-sent events.
|
||||||
const adapter = new SSEStreamAdapter(streamOptions);
|
const adapter = new SSEStreamAdapter(streamOptions);
|
||||||
// Aggregator compiles all events into a single response object.
|
// Aggregator compiles all events into a single response object.
|
||||||
|
@ -88,8 +92,6 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||||
.on("data", (msg) => {
|
.on("data", (msg) => {
|
||||||
if (!prefersNativeEvents) res.write(`data: ${JSON.stringify(msg)}\n\n`);
|
if (!prefersNativeEvents) res.write(`data: ${JSON.stringify(msg)}\n\n`);
|
||||||
aggregator.addEvent(msg);
|
aggregator.addEvent(msg);
|
||||||
}).on("end", () => {
|
|
||||||
req.log.debug({ key: hash }, `Finished streaming response.`);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
@ -125,14 +127,15 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
function selectDecoderStream(options: {
|
function getDecoder(options: {
|
||||||
input: Readable;
|
input: Readable;
|
||||||
api: APIFormat;
|
api: APIFormat;
|
||||||
|
logger: typeof logger;
|
||||||
contentType?: string;
|
contentType?: string;
|
||||||
}) {
|
}) {
|
||||||
const { api, contentType, input } = options;
|
const { api, contentType, input, logger } = options;
|
||||||
if (contentType?.includes("application/vnd.amazon.eventstream")) {
|
if (contentType?.includes("application/vnd.amazon.eventstream")) {
|
||||||
return viaEventStreamMarshaller(input);
|
return getAwsEventStreamDecoder({ input, logger });
|
||||||
} else if (api === "google-ai") {
|
} else if (api === "google-ai") {
|
||||||
return StreamArray.withParser();
|
return StreamArray.withParser();
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
import pino from "pino";
|
||||||
|
import { Duplex, Readable } from "stream";
|
||||||
|
import { EventStreamMarshaller } from "@smithy/eventstream-serde-node";
|
||||||
|
import { fromUtf8, toUtf8 } from "@smithy/util-utf8";
|
||||||
|
import { Message } from "@smithy/eventstream-codec";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Decodes a Readable stream, such as a proxied HTTP response, into a stream of
|
||||||
|
* Message objects using the AWS SDK's EventStreamMarshaller. Error events in
|
||||||
|
* the amazon eventstream protocol are decoded as Message objects and will not
|
||||||
|
* emit an error event on the decoder stream.
|
||||||
|
*/
|
||||||
|
export function getAwsEventStreamDecoder(params: {
|
||||||
|
input: Readable;
|
||||||
|
logger: pino.Logger;
|
||||||
|
}): Duplex {
|
||||||
|
const { input, logger } = params;
|
||||||
|
const config = { utf8Encoder: toUtf8, utf8Decoder: fromUtf8 };
|
||||||
|
const eventStream = new EventStreamMarshaller(config).deserialize(
|
||||||
|
input,
|
||||||
|
async (input: Record<string, Message>) => {
|
||||||
|
const eventType = Object.keys(input)[0];
|
||||||
|
let result;
|
||||||
|
if (eventType === "chunk") {
|
||||||
|
result = input[eventType];
|
||||||
|
} else {
|
||||||
|
// AWS unmarshaller treats non-chunk (errors and exceptions) oddly.
|
||||||
|
result = { [eventType]: input[eventType] } as any;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
return new AWSEventStreamDecoder(eventStream, { logger });
|
||||||
|
}
|
||||||
|
|
||||||
|
class AWSEventStreamDecoder extends Duplex {
|
||||||
|
private readonly asyncIterable: AsyncIterable<Message>;
|
||||||
|
private iterator: AsyncIterator<Message>;
|
||||||
|
private reading: boolean;
|
||||||
|
private logger: pino.Logger;
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
asyncIterable: AsyncIterable<Message>,
|
||||||
|
options: { logger: pino.Logger }
|
||||||
|
) {
|
||||||
|
super({ ...options, objectMode: true });
|
||||||
|
this.asyncIterable = asyncIterable;
|
||||||
|
this.iterator = this.asyncIterable[Symbol.asyncIterator]();
|
||||||
|
this.reading = false;
|
||||||
|
this.logger = options.logger.child({ module: "aws-eventstream-decoder" });
|
||||||
|
}
|
||||||
|
|
||||||
|
async _read(_size: number) {
|
||||||
|
if (this.reading) return;
|
||||||
|
this.reading = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
while (true) {
|
||||||
|
const { value, done } = await this.iterator.next();
|
||||||
|
if (done) {
|
||||||
|
this.push(null);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (!this.push(value)) break;
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
// AWS SDK's EventStreamMarshaller emits errors in the stream itself as
|
||||||
|
// whatever our deserializer returns, which will not be Error objects
|
||||||
|
// because we want to pass the Message to the next stream for processing.
|
||||||
|
// Any actual Error thrown here is some failure during deserialization.
|
||||||
|
const isAwsError = !(err instanceof Error);
|
||||||
|
|
||||||
|
if (isAwsError) {
|
||||||
|
this.logger.warn({ err: err.headers }, "Received AWS error event");
|
||||||
|
this.push(err);
|
||||||
|
this.push(null);
|
||||||
|
} else {
|
||||||
|
this.logger.error(err, "Error during AWS stream deserialization");
|
||||||
|
this.destroy(err);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
this.reading = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_write(_chunk: any, _encoding: string, callback: () => void) {
|
||||||
|
callback();
|
||||||
|
}
|
||||||
|
|
||||||
|
_final(callback: () => void) {
|
||||||
|
callback();
|
||||||
|
}
|
||||||
|
}
|
|
@ -51,7 +51,7 @@ export class SSEStreamAdapter extends Transform {
|
||||||
const event = Buffer.from(bytes, "base64").toString("utf8");
|
const event = Buffer.from(bytes, "base64").toString("utf8");
|
||||||
return ["event: completion", `data: ${event}`].join(`\n`);
|
return ["event: completion", `data: ${event}`].join(`\n`);
|
||||||
}
|
}
|
||||||
// Intentional fallthrough, non-JSON events will be something very weird
|
// Intentional fallthrough, as non-JSON events may as well be errors
|
||||||
// noinspection FallThroughInSwitchStatementJS
|
// noinspection FallThroughInSwitchStatementJS
|
||||||
case "exception":
|
case "exception":
|
||||||
case "error":
|
case "error":
|
||||||
|
@ -61,7 +61,6 @@ export class SSEStreamAdapter extends Transform {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case "throttlingexception":
|
case "throttlingexception":
|
||||||
this.log.warn(
|
this.log.warn(
|
||||||
{ message, type },
|
|
||||||
"AWS request throttled after streaming has already started; retrying"
|
"AWS request throttled after streaming has already started; retrying"
|
||||||
);
|
);
|
||||||
throw new RetryableError("AWS request throttled mid-stream");
|
throw new RetryableError("AWS request throttled mid-stream");
|
||||||
|
@ -142,7 +141,6 @@ export class SSEStreamAdapter extends Transform {
|
||||||
}
|
}
|
||||||
|
|
||||||
_flush(callback: (err?: Error | null) => void) {
|
_flush(callback: (err?: Error | null) => void) {
|
||||||
this.log.debug("SSEStreamAdapter flushing");
|
|
||||||
callback();
|
callback();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,65 +0,0 @@
|
||||||
import { Duplex, Readable } from "stream";
|
|
||||||
import { EventStreamMarshaller } from "@smithy/eventstream-serde-node";
|
|
||||||
import { fromUtf8, toUtf8 } from "@smithy/util-utf8";
|
|
||||||
import { Message } from "@smithy/eventstream-codec";
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Decodes a Readable stream, such as a proxied HTTP response, into a stream of
|
|
||||||
* Message objects using the AWS SDK's EventStreamMarshaller.
|
|
||||||
* @param input
|
|
||||||
*/
|
|
||||||
export function viaEventStreamMarshaller(input: Readable): Duplex {
|
|
||||||
const config = { utf8Encoder: toUtf8, utf8Decoder: fromUtf8 };
|
|
||||||
const eventStream = new EventStreamMarshaller(config).deserialize(
|
|
||||||
input,
|
|
||||||
// deserializer is always an object with one key. we just extract the value
|
|
||||||
// and pipe it to SSEStreamAdapter for it to turn it into an SSE stream
|
|
||||||
async (input: Record<string, Message>) => Object.values(input)[0]
|
|
||||||
);
|
|
||||||
return new StreamFromIterable(eventStream);
|
|
||||||
}
|
|
||||||
|
|
||||||
// In theory, Duplex.from(eventStream) would have rendered this wrapper
|
|
||||||
// unnecessary, but I was not able to get it to work for a number of reasons and
|
|
||||||
// needed more control over the stream's lifecycle.
|
|
||||||
|
|
||||||
class StreamFromIterable extends Duplex {
|
|
||||||
private readonly asyncIterable: AsyncIterable<Message>;
|
|
||||||
private iterator: AsyncIterator<Message>;
|
|
||||||
private reading: boolean;
|
|
||||||
|
|
||||||
constructor(asyncIterable: AsyncIterable<Message>, options = {}) {
|
|
||||||
super({ ...options, objectMode: true });
|
|
||||||
this.asyncIterable = asyncIterable;
|
|
||||||
this.iterator = this.asyncIterable[Symbol.asyncIterator]();
|
|
||||||
this.reading = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
async _read(_size: number) {
|
|
||||||
if (this.reading) return;
|
|
||||||
this.reading = true;
|
|
||||||
|
|
||||||
try {
|
|
||||||
while (true) {
|
|
||||||
const { value, done } = await this.iterator.next();
|
|
||||||
if (done) {
|
|
||||||
this.push(null);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (!this.push(value)) break;
|
|
||||||
}
|
|
||||||
} catch (err) {
|
|
||||||
this.destroy(err);
|
|
||||||
} finally {
|
|
||||||
this.reading = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_write(_chunk: any, _encoding: string, callback: () => void) {
|
|
||||||
callback();
|
|
||||||
}
|
|
||||||
|
|
||||||
_final(callback: () => void) {
|
|
||||||
callback();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -41,7 +41,7 @@ const RATE_LIMIT_LOCKOUT = 4000;
|
||||||
* to be used again. This is to prevent the queue from flooding a key with too
|
* to be used again. This is to prevent the queue from flooding a key with too
|
||||||
* many requests while we wait to learn whether previous ones succeeded.
|
* many requests while we wait to learn whether previous ones succeeded.
|
||||||
*/
|
*/
|
||||||
const KEY_REUSE_DELAY = 250;
|
const KEY_REUSE_DELAY = 500;
|
||||||
|
|
||||||
export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||||
readonly service = "aws";
|
readonly service = "aws";
|
||||||
|
|
Loading…
Reference in New Issue