Disable `decoder_input_details` on OpenAI-compatible chat streaming, pass temp and top-k from API (#1470)

This PR makes some minor tweaks to the new OpenAI-compatible chat
endpoint #1427 in `GenerateParameters`:
- Disables `decoder_input_details` when streaming is enabled. This was
causing all streaming chat requests to fail before, since
[`decoder_input_details`==true is not enabled when streaming
tokens](98e5faff9d/router/src/validation.rs (L406)).
- Passes through `temperature` and `top_p` hyperparameters from the API
request to `GenerateParameters`

## Testing

```bash
curl localhost:8080/v1/chat/completions \
    -X POST \
    -d '{
  "model": "",
  "messages": [
    {
      "role": "system",
      "content": "You are a helpful assistant."
    },
    {
      "role": "user",
      "content": "What is deep learning?"
    }
  ],
  "stream": true, 
  "max_tokens": 20
}' \                                   
    -H 'Content-Type: application/json'
```

Should work correctly. Currently, most recent release from `main`
returns error:
```
data:{"error":"Input validation error: `decoder_input_details` == true is not supported when streaming tokens","error_type":"validation"}
```

It's my first time contributing to this project, so I could be missing
something. Would especially appreciate @drbh's eyes on this one
This commit is contained in:
Jacob Keisling 2024-01-23 08:55:05 -06:00 committed by GitHub
parent 98e5faff9d
commit 82f87ada6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 3 deletions

View File

@ -365,6 +365,18 @@ pub(crate) struct ChatRequest {
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
/// lower values like 0.2 will make it more focused and deterministic.
///
/// We generally recommend altering this or `top_p` but not both.
#[serde(default)]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
#[serde(default)]
pub top_p: Option<f32>,
}
#[derive(Clone, Serialize, Deserialize)]

View File

@ -592,10 +592,10 @@ async fn chat_completions(
inputs: inputs.to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: None,
temperature: req.temperature,
repetition_penalty,
top_k: None,
top_p: None,
top_p: req.top_p,
typical_p: None,
do_sample: true,
max_new_tokens,
@ -604,7 +604,7 @@ async fn chat_completions(
truncate: None,
watermark: false,
details: true,
decoder_input_details: true,
decoder_input_details: !stream,
seed,
top_n_tokens: None,
},