fix(server): fix seeding with multiple shards (#44)

This commit is contained in:
OlivierDehaene 2023-01-31 16:01:15 +01:00 committed by GitHub
parent 03bdf18290
commit 54fec93193
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 91 additions and 86 deletions

1
Cargo.lock generated
View File

@ -1834,6 +1834,7 @@ dependencies = [
"futures",
"nohash-hasher",
"parking_lot",
"rand",
"serde",
"serde_json",
"text-generation-client",

View File

@ -37,7 +37,7 @@ message NextTokenChooserParameters {
/// apply sampling on the logits
bool do_sample = 4;
/// random seed for sampling
optional uint64 seed = 5;
uint64 seed = 5;
}
message StoppingCriteriaParameters {

View File

@ -19,6 +19,7 @@ clap = { version = "4.0.15", features = ["derive", "env"] }
futures = "0.3.24"
nohash-hasher = "0.2.0"
parking_lot = "0.12.1"
rand = "0.8.5"
serde = "1.0.145"
serde_json = "1.0.85"
thiserror = "1.0.37"

View File

@ -166,7 +166,8 @@ impl From<&GenerateParameters> for NextTokenChooserParameters {
top_k: parameters.top_k as u32,
top_p: parameters.top_p,
do_sample: parameters.do_sample,
seed: parameters.seed,
// FIXME: remove unwrap
seed: parameters.seed.unwrap(),
}
}
}

View File

@ -2,6 +2,8 @@
use crate::{ErrorResponse, GenerateRequest};
use axum::http::StatusCode;
use axum::Json;
use rand::rngs::ThreadRng;
use rand::Rng;
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot};
@ -92,18 +94,22 @@ fn validation_worker(
max_input_length: usize,
mut receiver: mpsc::Receiver<ValidationRequest>,
) {
// Seed rng
let mut rng = rand::thread_rng();
// Loop over requests
while let Some((request, response_tx)) = receiver.blocking_recv() {
response_tx
.send(validate(request, &tokenizer, max_input_length))
.send(validate(request, &tokenizer, max_input_length, &mut rng))
.unwrap_or(())
}
}
fn validate(
request: GenerateRequest,
mut request: GenerateRequest,
tokenizer: &Tokenizer,
max_input_length: usize,
rng: &mut ThreadRng,
) -> Result<(usize, GenerateRequest), ValidationError> {
if request.parameters.temperature <= 0.0 {
return Err(ValidationError::Temperature);
@ -124,6 +130,11 @@ fn validate(
));
}
// If seed is None, assign a random one
if request.parameters.seed.is_none() {
request.parameters.seed = Some(rng.gen());
}
// Get the number of tokens in the input
match tokenizer.encode(request.inputs.clone(), true) {
Ok(inputs) => {

132
server/poetry.lock generated
View File

@ -1,6 +1,6 @@
[[package]]
name = "accelerate"
version = "0.12.0"
version = "0.15.0"
description = "Accelerate"
category = "main"
optional = false
@ -14,13 +14,14 @@ pyyaml = "*"
torch = ">=1.4.0"
[package.extras]
dev = ["black (>=22.0,<23.0)", "datasets", "deepspeed (<0.7.0)", "evaluate", "flake8 (>=3.8.3)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scipy", "sklearn", "tqdm", "transformers"]
dev = ["black (>=22.0,<23.0)", "datasets", "deepspeed (<0.7.0)", "evaluate", "flake8 (>=3.8.3)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "rich", "scikit-learn", "scipy", "tqdm", "transformers"]
quality = ["black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)"]
rich = ["rich"]
sagemaker = ["sagemaker"]
test_dev = ["datasets", "deepspeed (<0.7.0)", "evaluate", "scipy", "sklearn", "tqdm", "transformers"]
test_dev = ["datasets", "deepspeed (<0.7.0)", "evaluate", "scikit-learn", "scipy", "tqdm", "transformers"]
test_prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"]
test_trackers = ["comet-ml", "tensorboard", "wandb"]
testing = ["datasets", "deepspeed (<0.7.0)", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scipy", "sklearn", "tqdm", "transformers"]
testing = ["datasets", "deepspeed (<0.7.0)", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "tqdm", "transformers"]
[[package]]
name = "attrs"
@ -78,7 +79,7 @@ test = ["pytest (>=6)"]
[[package]]
name = "googleapis-common-protos"
version = "1.57.0"
version = "1.58.0"
description = "Common protobufs used in Google APIs"
category = "main"
optional = false
@ -155,11 +156,11 @@ setuptools = "*"
[[package]]
name = "iniconfig"
version = "1.1.1"
description = "iniconfig: brain-dead simple config-ini parsing"
version = "2.0.0"
description = "brain-dead simple config-ini parsing"
category = "dev"
optional = false
python-versions = "*"
python-versions = ">=3.7"
[[package]]
name = "loguru"
@ -234,7 +235,7 @@ wheel = "*"
[[package]]
name = "packaging"
version = "22.0"
version = "23.0"
description = "Core utilities for Python packages"
category = "main"
optional = false
@ -273,7 +274,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
[[package]]
name = "pytest"
version = "7.2.0"
version = "7.2.1"
description = "pytest: simple powerful testing with Python"
category = "dev"
optional = false
@ -301,7 +302,7 @@ python-versions = ">=3.6"
[[package]]
name = "safetensors"
version = "0.2.7"
version = "0.2.8"
description = "Fast and Safe Tensor serialization"
category = "main"
optional = false
@ -319,14 +320,14 @@ torch = ["torch"]
[[package]]
name = "setuptools"
version = "65.6.3"
version = "67.0.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
category = "main"
optional = false
python-versions = ">=3.7"
[package.extras]
docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
@ -409,12 +410,12 @@ bnb = ["bitsandbytes"]
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
content-hash = "bc59a70e1e112ce2e173289c0c5285121550511236d4866f0053ae18f90c98aa"
content-hash = "c920067f39ade631d1022c0560c3a4336c82d8c28355509bdfff751bbcfc01cd"
[metadata.files]
accelerate = [
{file = "accelerate-0.12.0-py3-none-any.whl", hash = "sha256:7742ca5c9f15dd1e0a283305599c196e260af4717a561d1f544aeab27d828af6"},
{file = "accelerate-0.12.0.tar.gz", hash = "sha256:e8b119c94fac31877620d5f9de311164ec81fa9dc9e175f0d0d4f50fc8d79473"},
{file = "accelerate-0.15.0-py3-none-any.whl", hash = "sha256:014833307424cd0a22f89815802e00653756257c45dfdba2453e52d428931c65"},
{file = "accelerate-0.15.0.tar.gz", hash = "sha256:438e25a01afa6e3ffbd25353e76a68be49677c3050f10bfac7beafaf53503efc"},
]
attrs = [
{file = "attrs-22.2.0-py3-none-any.whl", hash = "sha256:29e95c7f6778868dbd49170f98f8818f78f3dc5e0e37c0b1f474e3561b240836"},
@ -437,8 +438,8 @@ exceptiongroup = [
{file = "exceptiongroup-1.1.0.tar.gz", hash = "sha256:bcb67d800a4497e1b404c2dd44fca47d3b7a5e5433dbab67f96c1a685cdfdf23"},
]
googleapis-common-protos = [
{file = "googleapis-common-protos-1.57.0.tar.gz", hash = "sha256:27a849d6205838fb6cc3c1c21cb9800707a661bb21c6ce7fb13e99eb1f8a0c46"},
{file = "googleapis_common_protos-1.57.0-py2.py3-none-any.whl", hash = "sha256:a9f4a1d7f6d9809657b7f1316a1aa527f6664891531bcfcc13b6696e685f443c"},
{file = "googleapis-common-protos-1.58.0.tar.gz", hash = "sha256:c727251ec025947d545184ba17e3578840fc3a24a0516a020479edab660457df"},
{file = "googleapis_common_protos-1.58.0-py2.py3-none-any.whl", hash = "sha256:ca3befcd4580dab6ad49356b46bf165bb68ff4b32389f028f1abd7c10ab9519a"},
]
grpc-interceptor = [
{file = "grpc-interceptor-0.15.0.tar.gz", hash = "sha256:5c1aa9680b1d7e12259960c38057b121826860b05ebbc1001c74343b7ad1455e"},
@ -547,8 +548,8 @@ grpcio-tools = [
{file = "grpcio_tools-1.51.1-cp39-cp39-win_amd64.whl", hash = "sha256:15b8acf4eaa0ebe37e2f69108de49efd935b7abe9c7e58ba737490b99906aa76"},
]
iniconfig = [
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
{file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
{file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
]
loguru = [
{file = "loguru-0.6.0-py3-none-any.whl", hash = "sha256:4e2414d534a2ab57573365b3e6d0234dfb1d84b68b7f3b948e6fb743860a77c3"},
@ -602,8 +603,8 @@ nvidia-cudnn-cu11 = [
{file = "nvidia_cudnn_cu11-8.5.0.96-py3-none-manylinux1_x86_64.whl", hash = "sha256:71f8111eb830879ff2836db3cccf03bbd735df9b0d17cd93761732ac50a8a108"},
]
packaging = [
{file = "packaging-22.0-py3-none-any.whl", hash = "sha256:957e2148ba0e1a3b282772e791ef1d8083648bc131c8ab0c1feba110ce1146c3"},
{file = "packaging-22.0.tar.gz", hash = "sha256:2198ec20bd4c017b8f9717e00f0c8714076fc2fd93816750ab48e2c41de2cfd3"},
{file = "packaging-23.0-py3-none-any.whl", hash = "sha256:714ac14496c3e68c99c29b00845f7a2b85f3bb6f1078fd9f72fd20f0570002b2"},
{file = "packaging-23.0.tar.gz", hash = "sha256:b6ad297f8907de0fa2fe1ccbd26fdaf387f5f47c7275fedf8cce89f99446cf97"},
]
pluggy = [
{file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
@ -642,8 +643,8 @@ psutil = [
{file = "psutil-5.9.4.tar.gz", hash = "sha256:3d7f9739eb435d4b1338944abe23f49584bde5395f27487d2ee25ad9a8774a62"},
]
pytest = [
{file = "pytest-7.2.0-py3-none-any.whl", hash = "sha256:892f933d339f068883b6fd5a459f03d85bfcb355e4981e146d2c7616c21fef71"},
{file = "pytest-7.2.0.tar.gz", hash = "sha256:c4014eb40e10f11f355ad4e3c2fb2c6c6d1919c73f3b5a433de4708202cade59"},
{file = "pytest-7.2.1-py3-none-any.whl", hash = "sha256:c7c6ca206e93355074ae32f7403e8ea12163b1163c976fee7d4d84027c162be5"},
{file = "pytest-7.2.1.tar.gz", hash = "sha256:d45e0952f3727241918b8fd0f376f5ff6b301cc0777c6f9a556935c92d8a7d42"},
]
PyYAML = [
{file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"},
@ -688,49 +689,50 @@ PyYAML = [
{file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"},
]
safetensors = [
{file = "safetensors-0.2.7-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:3123fdbb35fdedaedd39cd50a44493783f204821ae8b79012820020bb7e9ea1e"},
{file = "safetensors-0.2.7-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d8dd3d47609c51b6bd6e473f32d74eeb90a59482f194df6db570ebbb829a948f"},
{file = "safetensors-0.2.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:041aa5e13f1bbc0c0d441f692b4103531b9e867ccc281c2c3946fe51deb7eccd"},
{file = "safetensors-0.2.7-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f2d730ac6716ab4105000aa85e7be7d771b199ec9ab77df50eff6cc0ddce6fcd"},
{file = "safetensors-0.2.7-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:acef097785a9f03172151c7a8a4c6743d7ec77d1473c38da3aebc8f004d620a8"},
{file = "safetensors-0.2.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f54630c845e489123a00bfb93787086c96de9abc5e72fbec52c1d2a55f8147fa"},
{file = "safetensors-0.2.7-cp310-cp310-win32.whl", hash = "sha256:3c4c0d7f3d6922dcae178f19daf52c877674906d60944273e0fdf7a73b7a33e7"},
{file = "safetensors-0.2.7-cp310-cp310-win_amd64.whl", hash = "sha256:3c99cccbbe1da7a1bdeb7e4333046577c9785c8b4bb81912b8a134a66570fc0f"},
{file = "safetensors-0.2.7-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:0a6f3a1db227aeb1152fb8676d94fea3f97d9e017b4b82f7ce5447b88c3a2126"},
{file = "safetensors-0.2.7-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:8338b875e7e86f50cbffd8c79f28fb0fe2ed56bebe1f87c95a26c501ee50fc51"},
{file = "safetensors-0.2.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a8403e93665d87ff1006b5678c23566e7a4bf7f2cfdb2666c3955350b2841379"},
{file = "safetensors-0.2.7-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d993db8e3f63cd71994544c130c1e227a6522c63ddaf2b9f85a33c8e789283b0"},
{file = "safetensors-0.2.7-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c4901cac6c4bd3d52cb68c86311f86a28d33fa2503bd7c32012c1489cd9a52c"},
{file = "safetensors-0.2.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3b9acd8d864c284d8fe300b290092a2cc7ae76319e7bdd5cbba00a8b4ec2dc0"},
{file = "safetensors-0.2.7-cp311-cp311-win32.whl", hash = "sha256:dd63fc6bb6f7b78a957c84692afdbaa39168978aac61dfd5ea8c2bbbd459f8a6"},
{file = "safetensors-0.2.7-cp311-cp311-win_amd64.whl", hash = "sha256:848af8bca3ee8e4f68bb828e7fbc1d4022d3cde17e8bd098324ef93ace4779e6"},
{file = "safetensors-0.2.7-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:44ab41d080be381550a0cdc1ded9afc5de899dd733f341c902834571a6f60ca7"},
{file = "safetensors-0.2.7-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:964357808ad70e001d4d6779418e48616eb1f103acf6acdb72bb6a51c05a9da4"},
{file = "safetensors-0.2.7-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bb21eb7b86692cb1d5bc95b0926acd0356df408990de63ae09c3242594170156"},
{file = "safetensors-0.2.7-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c3c421b55b3baf2ce373836155ffb58d530ec0752837c5fec2f8b841019b49c"},
{file = "safetensors-0.2.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5f21fe636693f9f40594b357f3d4586cb23e98e044e5e78b21e814543d69c3b"},
{file = "safetensors-0.2.7-cp37-cp37m-win32.whl", hash = "sha256:1212ec6e113625d9eea662815a25c993a76860ec51f6cc26ac33f4e317abecb5"},
{file = "safetensors-0.2.7-cp37-cp37m-win_amd64.whl", hash = "sha256:6b8076a21241dcad8848b42d90dc7fa1401c89bee622465d019988df62175ae1"},
{file = "safetensors-0.2.7-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:e27907078ab3924c8135f229c9dff06126fe123b609a22cf2ae96593ca4833bc"},
{file = "safetensors-0.2.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89cc90cb042654b224fa98a8acbfc5f795f33bd8c5ff6431ad9884370c9c0caf"},
{file = "safetensors-0.2.7-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:092c34408391d6fc96617093d59f00ffd23ca624530591d225b0f824cb105367"},
{file = "safetensors-0.2.7-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:795310f19396f01cfbe00170575daa52bc1019068ef75b13ba487155fba4e9bd"},
{file = "safetensors-0.2.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3924794458e716dc7d2f3e264c98455de3d59b71e0335cbabef727083d537535"},
{file = "safetensors-0.2.7-cp38-cp38-win32.whl", hash = "sha256:9205b4a12939db1135fff92b4ea4f62bf131bfc9d54df31881fb6a50eb2ba5fa"},
{file = "safetensors-0.2.7-cp38-cp38-win_amd64.whl", hash = "sha256:b84662f8d8dd49443a8f8d1d9a89a79974a5f03f4e06526de170ef02f8511236"},
{file = "safetensors-0.2.7-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:657b01a122da21028a1d32570e0173fa0ef1865d3480cf5bcca80ec92cae421c"},
{file = "safetensors-0.2.7-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:788a242ea5fab371dcc42cb10ed21c60ea8c5ff7fad41c9fb3d334420c80f156"},
{file = "safetensors-0.2.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6bfbc7123d0c08754b6a11848b0298070705f60ab450e3f249624bba1571040"},
{file = "safetensors-0.2.7-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1982f00d34d8c72e3339c458e5e2fb3eaf9458c55eae6d3bddd61649666db130"},
{file = "safetensors-0.2.7-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:330570d27ebf5bdb3f63580baa4a4bf92ad11e1df17b96703181cfaa11cae90e"},
{file = "safetensors-0.2.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f112b20899902dfc5ba27cdea49e57c275c7824a3e24764fca93d3ad436a160"},
{file = "safetensors-0.2.7-cp39-cp39-win32.whl", hash = "sha256:29fbc1d1c802a5eced1675e93e130367109b2af3bf74af8c415a141b0f5fe568"},
{file = "safetensors-0.2.7-cp39-cp39-win_amd64.whl", hash = "sha256:054306d050453a92d3cc6cde15e7cb282ef4299f602433bbd70a9b1b6963c9f4"},
{file = "safetensors-0.2.7.tar.gz", hash = "sha256:37192b456fdfe09762d6a2ef3322d2379fee52eb2db6d887ebac6e3349de596d"},
{file = "safetensors-0.2.8-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:8df8af89a0b6b535b47c077e33e5cd4941ef4b067e7c1dd1a05647dec0cf2eea"},
{file = "safetensors-0.2.8-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:9cabae651542b22d229b6439029fb4a3abc7868cb123af336b2877541ad9ab39"},
{file = "safetensors-0.2.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:384df328523c567a8b2193ebb02456300a43a3adbc316823d4c0d16f7ac9e89d"},
{file = "safetensors-0.2.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa5f4b31f2e283b83894b388298368860367a1cb773228f9bb65f8db65da1549"},
{file = "safetensors-0.2.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:70f322d170a17b6ecb9c85b15e67f4a9aa3e561594e2dfc7709c0ae0000ebfff"},
{file = "safetensors-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50f7e90ed02ef5e4263aadce23d39a3595156816c42c4643003791b45f81fd31"},
{file = "safetensors-0.2.8-cp310-cp310-win32.whl", hash = "sha256:726ad66231286157dd505b5ab85fd903114dcb5f9180e79fd5540d68d6249dd0"},
{file = "safetensors-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:d14e7c15e5acac6efcd5f5565e38b1b33600101387e5d16059de44adab87405f"},
{file = "safetensors-0.2.8-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:6273dd3edd5db05b2da0090bc3d491bf25f6d1d7e8a4423477377649e9e38c37"},
{file = "safetensors-0.2.8-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:34705920c8a02f9ea6101bae8403db5f4aa18ec3eaccd8eab6701b1c88ee5bed"},
{file = "safetensors-0.2.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:efff5ce2d3f349d5360a5ca5901ae56c7e24f8da623d57cd08f8e8b1cd9cb1f8"},
{file = "safetensors-0.2.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dec20b9b1fc90b7b4e588b4f0e9e266bd8f26d405e08b0b6ecad3136d92d007a"},
{file = "safetensors-0.2.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7716bab34ca9651895126c720df1bf07f464184a7409138179f38db488ca9f15"},
{file = "safetensors-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb28e5e6257f705828fd39b9ba28248b593f46e722d8d6beedbc8d1f194e2297"},
{file = "safetensors-0.2.8-cp311-cp311-win32.whl", hash = "sha256:c71e147a8c2942690e8707adbcc4ab60775bc78dfdbd7f9e696dd411adef82bb"},
{file = "safetensors-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:2f40604c50d4a08a4c74f37fef735cd1e203a59aeda66ea23a01d76fb35cf407"},
{file = "safetensors-0.2.8-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:f3c7242debe713a87ca6deaadb0d7120e61c92435414142d526e8c32da970020"},
{file = "safetensors-0.2.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9366140785d567049395c1c03ac69ee8d322fabcdc8fab7d389e933596b383da"},
{file = "safetensors-0.2.8-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5d720c7640ad5f95f47780bdc35777ed8371afa14d8d63d6375cfe36df83fb4"},
{file = "safetensors-0.2.8-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a0c7ed4d75c2f248791dbe64309c98ada6f40a6949147ca2eaebd278906c918b"},
{file = "safetensors-0.2.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1b905cf3563da4fe8dd46fa777321516f5aa8f6bc1b884158be03a235478e96d"},
{file = "safetensors-0.2.8-cp37-cp37m-win32.whl", hash = "sha256:2a47327cfe57b6ee8c5dc246f9141f4f6b368e4444dd7f174c025b1d34446730"},
{file = "safetensors-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:02e67b906ad9bbb87308a34e4d2d33c9eb69baf342b7e5c872728556baf3f0b6"},
{file = "safetensors-0.2.8-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ba7c2496765de45a84f5c79b2b12a14a568643d8966ef9bb3f8b16217a39457c"},
{file = "safetensors-0.2.8-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:90cd22387b1520c4465033b986f79f0d24cc41aabae1903a22eff3b42cee9ad5"},
{file = "safetensors-0.2.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f279e0fd917a886e1936d534734593f780afdddc6ed62f8ebf2f59de811cdd7c"},
{file = "safetensors-0.2.8-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:af4fce0565574ec3dbe997b021ed32d3dc229547c4f7fca2459be1f725c26f88"},
{file = "safetensors-0.2.8-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ff0096f5a765e6e3f7f3156f568f59499edeade19e813d610a124ca70b42cdda"},
{file = "safetensors-0.2.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c02c1cab1d23e6cc8ac1ef865efe76f2d0800e203c877288ebd31747f47a6940"},
{file = "safetensors-0.2.8-cp38-cp38-win32.whl", hash = "sha256:832f27f6f379fb15d87f3a10eb87e2480434a89c483d7edc8a0c262364ba72c1"},
{file = "safetensors-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:89da823f9801e92b2c48e8fad1e2f7f0cb696a8a93dab4c6700e8de06fe87392"},
{file = "safetensors-0.2.8-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:ee8e169040468d176172b4f84a160ece2064abcc77294c4994a9d2bb5255cd75"},
{file = "safetensors-0.2.8-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:ac2197dbbf7cbd269faf8cebab4220dba5aa2ac8beacbce8fdcb9800776781ca"},
{file = "safetensors-0.2.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f90513eee0a11902df7e51b07c3d9c328828b9dd692d6c74140bed932e7a491"},
{file = "safetensors-0.2.8-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2bd2e4c6258dd0f4e3d554f2f59789f25ef4757431e83c927016de6339e54811"},
{file = "safetensors-0.2.8-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:01000a4bfdf0474bb1bdf369de1284c93b4e6047524fe9dc55d77586cb9d0243"},
{file = "safetensors-0.2.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:466f92a384e4fcb6b1b9811e7488516c4638119c302a959b89bbe1be826d5e25"},
{file = "safetensors-0.2.8-cp39-cp39-win32.whl", hash = "sha256:2f16e5ee70ae4218474493eff8d998e632a896a6b8aff9273511d654cdf367ab"},
{file = "safetensors-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:ba3dc236a2344b7feadc9868307f42ba5e4804c9d68a80a35aac831349b31f6f"},
{file = "safetensors-0.2.8.tar.gz", hash = "sha256:2720b20a6a38c799dca79bd76caeeac2f7df585a9d4f7d59fa7e28eff9ccb27f"},
]
setuptools = [
{file = "setuptools-65.6.3-py3-none-any.whl", hash = "sha256:57f6f22bde4e042978bcd50176fdb381d7c21a9efa4041202288d3737a0c6a54"},
{file = "setuptools-65.6.3.tar.gz", hash = "sha256:a7620757bf984b58deaf32fc8a4577a9bbc0850cf92c20e1ce41c38c19e5fb75"},
{file = "setuptools-67.0.0-py3-none-any.whl", hash = "sha256:9d790961ba6219e9ff7d9557622d2fe136816a264dd01d5997cfc057d804853d"},
{file = "setuptools-67.0.0.tar.gz", hash = "sha256:883131c5b6efa70b9101c7ef30b2b7b780a4283d5fc1616383cdf22c83cbefe6"},
]
tomli = [
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},

View File

@ -15,7 +15,7 @@ grpcio-status = "^1.51.1"
grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.0"
typer = "^0.6.1"
accelerate = "^0.12.0"
accelerate = "^0.15.0"
bitsandbytes = "^0.35.1"
safetensors = "^0.2.4"
loguru = "^0.6.0"

View File

@ -33,8 +33,6 @@ try:
except Exception as e:
HAS_BITS_AND_BYTES = False
torch.manual_seed(0)
class BloomCausalLMBatch(CausalLMBatch):
@classmethod

View File

@ -36,7 +36,6 @@ try:
except Exception as e:
HAS_BITS_AND_BYTES = False
torch.manual_seed(0)
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py

View File

@ -24,12 +24,10 @@ from text_generation.pb import generate_pb2
class Sampling:
def __init__(self, seed: Optional[int] = None, device: str = "cpu"):
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
if seed is not None:
self.generator.manual_seed(seed)
else:
self.generator.seed()
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, dim=-1)
@ -38,10 +36,6 @@ class Sampling:
).squeeze(1)
return next_tokens
@property
def seed(self) -> int:
return self.generator.initial_seed()
class Greedy:
def __call__(self, logits):
@ -55,7 +49,7 @@ class NextTokenChooser:
top_k=None,
top_p=None,
do_sample=False,
seed=None,
seed=0,
device="cpu",
):
warpers = LogitsProcessorList()
@ -89,14 +83,12 @@ class NextTokenChooser:
def from_pb(
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
) -> "NextTokenChooser":
# handle protobuf making default values 0
seed = pb.seed if pb.HasField("seed") else None
return NextTokenChooser(
temperature=pb.temperature,
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=seed,
seed=pb.seed,
device=str(device),
)