From 54fec9319371b2792526e0cbfebe6cee66ed3980 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 31 Jan 2023 16:01:15 +0100 Subject: [PATCH] fix(server): fix seeding with multiple shards (#44) --- Cargo.lock | 1 + proto/generate.proto | 2 +- router/Cargo.toml | 1 + router/src/db.rs | 3 +- router/src/validation.rs | 15 ++- server/poetry.lock | 132 +++++++++++---------- server/pyproject.toml | 2 +- server/text_generation/models/bloom.py | 2 - server/text_generation/models/galactica.py | 1 - server/text_generation/utils.py | 18 +-- 10 files changed, 91 insertions(+), 86 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 33f5d181..1030e8fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1834,6 +1834,7 @@ dependencies = [ "futures", "nohash-hasher", "parking_lot", + "rand", "serde", "serde_json", "text-generation-client", diff --git a/proto/generate.proto b/proto/generate.proto index 921bd5c0..81039a7c 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -37,7 +37,7 @@ message NextTokenChooserParameters { /// apply sampling on the logits bool do_sample = 4; /// random seed for sampling - optional uint64 seed = 5; + uint64 seed = 5; } message StoppingCriteriaParameters { diff --git a/router/Cargo.toml b/router/Cargo.toml index 546f127f..17724bcc 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -19,6 +19,7 @@ clap = { version = "4.0.15", features = ["derive", "env"] } futures = "0.3.24" nohash-hasher = "0.2.0" parking_lot = "0.12.1" +rand = "0.8.5" serde = "1.0.145" serde_json = "1.0.85" thiserror = "1.0.37" diff --git a/router/src/db.rs b/router/src/db.rs index 15007b64..442d7b9c 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -166,7 +166,8 @@ impl From<&GenerateParameters> for NextTokenChooserParameters { top_k: parameters.top_k as u32, top_p: parameters.top_p, do_sample: parameters.do_sample, - seed: parameters.seed, + // FIXME: remove unwrap + seed: parameters.seed.unwrap(), } } } diff --git a/router/src/validation.rs b/router/src/validation.rs index aabc82a6..674635d3 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -2,6 +2,8 @@ use crate::{ErrorResponse, GenerateRequest}; use axum::http::StatusCode; use axum::Json; +use rand::rngs::ThreadRng; +use rand::Rng; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokio::sync::{mpsc, oneshot}; @@ -92,18 +94,22 @@ fn validation_worker( max_input_length: usize, mut receiver: mpsc::Receiver, ) { + // Seed rng + let mut rng = rand::thread_rng(); + // Loop over requests while let Some((request, response_tx)) = receiver.blocking_recv() { response_tx - .send(validate(request, &tokenizer, max_input_length)) + .send(validate(request, &tokenizer, max_input_length, &mut rng)) .unwrap_or(()) } } fn validate( - request: GenerateRequest, + mut request: GenerateRequest, tokenizer: &Tokenizer, max_input_length: usize, + rng: &mut ThreadRng, ) -> Result<(usize, GenerateRequest), ValidationError> { if request.parameters.temperature <= 0.0 { return Err(ValidationError::Temperature); @@ -124,6 +130,11 @@ fn validate( )); } + // If seed is None, assign a random one + if request.parameters.seed.is_none() { + request.parameters.seed = Some(rng.gen()); + } + // Get the number of tokens in the input match tokenizer.encode(request.inputs.clone(), true) { Ok(inputs) => { diff --git a/server/poetry.lock b/server/poetry.lock index d610fcdf..b236e0a7 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1,6 +1,6 @@ [[package]] name = "accelerate" -version = "0.12.0" +version = "0.15.0" description = "Accelerate" category = "main" optional = false @@ -14,13 +14,14 @@ pyyaml = "*" torch = ">=1.4.0" [package.extras] -dev = ["black (>=22.0,<23.0)", "datasets", "deepspeed (<0.7.0)", "evaluate", "flake8 (>=3.8.3)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scipy", "sklearn", "tqdm", "transformers"] +dev = ["black (>=22.0,<23.0)", "datasets", "deepspeed (<0.7.0)", "evaluate", "flake8 (>=3.8.3)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "rich", "scikit-learn", "scipy", "tqdm", "transformers"] quality = ["black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)"] +rich = ["rich"] sagemaker = ["sagemaker"] -test_dev = ["datasets", "deepspeed (<0.7.0)", "evaluate", "scipy", "sklearn", "tqdm", "transformers"] +test_dev = ["datasets", "deepspeed (<0.7.0)", "evaluate", "scikit-learn", "scipy", "tqdm", "transformers"] test_prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"] test_trackers = ["comet-ml", "tensorboard", "wandb"] -testing = ["datasets", "deepspeed (<0.7.0)", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scipy", "sklearn", "tqdm", "transformers"] +testing = ["datasets", "deepspeed (<0.7.0)", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "tqdm", "transformers"] [[package]] name = "attrs" @@ -78,7 +79,7 @@ test = ["pytest (>=6)"] [[package]] name = "googleapis-common-protos" -version = "1.57.0" +version = "1.58.0" description = "Common protobufs used in Google APIs" category = "main" optional = false @@ -155,11 +156,11 @@ setuptools = "*" [[package]] name = "iniconfig" -version = "1.1.1" -description = "iniconfig: brain-dead simple config-ini parsing" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" category = "dev" optional = false -python-versions = "*" +python-versions = ">=3.7" [[package]] name = "loguru" @@ -234,7 +235,7 @@ wheel = "*" [[package]] name = "packaging" -version = "22.0" +version = "23.0" description = "Core utilities for Python packages" category = "main" optional = false @@ -273,7 +274,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] [[package]] name = "pytest" -version = "7.2.0" +version = "7.2.1" description = "pytest: simple powerful testing with Python" category = "dev" optional = false @@ -301,7 +302,7 @@ python-versions = ">=3.6" [[package]] name = "safetensors" -version = "0.2.7" +version = "0.2.8" description = "Fast and Safe Tensor serialization" category = "main" optional = false @@ -319,14 +320,14 @@ torch = ["torch"] [[package]] name = "setuptools" -version = "65.6.3" +version = "67.0.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" category = "main" optional = false python-versions = ">=3.7" [package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] @@ -409,12 +410,12 @@ bnb = ["bitsandbytes"] [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "bc59a70e1e112ce2e173289c0c5285121550511236d4866f0053ae18f90c98aa" +content-hash = "c920067f39ade631d1022c0560c3a4336c82d8c28355509bdfff751bbcfc01cd" [metadata.files] accelerate = [ - {file = "accelerate-0.12.0-py3-none-any.whl", hash = "sha256:7742ca5c9f15dd1e0a283305599c196e260af4717a561d1f544aeab27d828af6"}, - {file = "accelerate-0.12.0.tar.gz", hash = "sha256:e8b119c94fac31877620d5f9de311164ec81fa9dc9e175f0d0d4f50fc8d79473"}, + {file = "accelerate-0.15.0-py3-none-any.whl", hash = "sha256:014833307424cd0a22f89815802e00653756257c45dfdba2453e52d428931c65"}, + {file = "accelerate-0.15.0.tar.gz", hash = "sha256:438e25a01afa6e3ffbd25353e76a68be49677c3050f10bfac7beafaf53503efc"}, ] attrs = [ {file = "attrs-22.2.0-py3-none-any.whl", hash = "sha256:29e95c7f6778868dbd49170f98f8818f78f3dc5e0e37c0b1f474e3561b240836"}, @@ -437,8 +438,8 @@ exceptiongroup = [ {file = "exceptiongroup-1.1.0.tar.gz", hash = "sha256:bcb67d800a4497e1b404c2dd44fca47d3b7a5e5433dbab67f96c1a685cdfdf23"}, ] googleapis-common-protos = [ - {file = "googleapis-common-protos-1.57.0.tar.gz", hash = "sha256:27a849d6205838fb6cc3c1c21cb9800707a661bb21c6ce7fb13e99eb1f8a0c46"}, - {file = "googleapis_common_protos-1.57.0-py2.py3-none-any.whl", hash = "sha256:a9f4a1d7f6d9809657b7f1316a1aa527f6664891531bcfcc13b6696e685f443c"}, + {file = "googleapis-common-protos-1.58.0.tar.gz", hash = "sha256:c727251ec025947d545184ba17e3578840fc3a24a0516a020479edab660457df"}, + {file = "googleapis_common_protos-1.58.0-py2.py3-none-any.whl", hash = "sha256:ca3befcd4580dab6ad49356b46bf165bb68ff4b32389f028f1abd7c10ab9519a"}, ] grpc-interceptor = [ {file = "grpc-interceptor-0.15.0.tar.gz", hash = "sha256:5c1aa9680b1d7e12259960c38057b121826860b05ebbc1001c74343b7ad1455e"}, @@ -547,8 +548,8 @@ grpcio-tools = [ {file = "grpcio_tools-1.51.1-cp39-cp39-win_amd64.whl", hash = "sha256:15b8acf4eaa0ebe37e2f69108de49efd935b7abe9c7e58ba737490b99906aa76"}, ] iniconfig = [ - {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, - {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] loguru = [ {file = "loguru-0.6.0-py3-none-any.whl", hash = "sha256:4e2414d534a2ab57573365b3e6d0234dfb1d84b68b7f3b948e6fb743860a77c3"}, @@ -602,8 +603,8 @@ nvidia-cudnn-cu11 = [ {file = "nvidia_cudnn_cu11-8.5.0.96-py3-none-manylinux1_x86_64.whl", hash = "sha256:71f8111eb830879ff2836db3cccf03bbd735df9b0d17cd93761732ac50a8a108"}, ] packaging = [ - {file = "packaging-22.0-py3-none-any.whl", hash = "sha256:957e2148ba0e1a3b282772e791ef1d8083648bc131c8ab0c1feba110ce1146c3"}, - {file = "packaging-22.0.tar.gz", hash = "sha256:2198ec20bd4c017b8f9717e00f0c8714076fc2fd93816750ab48e2c41de2cfd3"}, + {file = "packaging-23.0-py3-none-any.whl", hash = "sha256:714ac14496c3e68c99c29b00845f7a2b85f3bb6f1078fd9f72fd20f0570002b2"}, + {file = "packaging-23.0.tar.gz", hash = "sha256:b6ad297f8907de0fa2fe1ccbd26fdaf387f5f47c7275fedf8cce89f99446cf97"}, ] pluggy = [ {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, @@ -642,8 +643,8 @@ psutil = [ {file = "psutil-5.9.4.tar.gz", hash = "sha256:3d7f9739eb435d4b1338944abe23f49584bde5395f27487d2ee25ad9a8774a62"}, ] pytest = [ - {file = "pytest-7.2.0-py3-none-any.whl", hash = "sha256:892f933d339f068883b6fd5a459f03d85bfcb355e4981e146d2c7616c21fef71"}, - {file = "pytest-7.2.0.tar.gz", hash = "sha256:c4014eb40e10f11f355ad4e3c2fb2c6c6d1919c73f3b5a433de4708202cade59"}, + {file = "pytest-7.2.1-py3-none-any.whl", hash = "sha256:c7c6ca206e93355074ae32f7403e8ea12163b1163c976fee7d4d84027c162be5"}, + {file = "pytest-7.2.1.tar.gz", hash = "sha256:d45e0952f3727241918b8fd0f376f5ff6b301cc0777c6f9a556935c92d8a7d42"}, ] PyYAML = [ {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, @@ -688,49 +689,50 @@ PyYAML = [ {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, ] safetensors = [ - {file = "safetensors-0.2.7-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:3123fdbb35fdedaedd39cd50a44493783f204821ae8b79012820020bb7e9ea1e"}, - {file = "safetensors-0.2.7-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d8dd3d47609c51b6bd6e473f32d74eeb90a59482f194df6db570ebbb829a948f"}, - {file = "safetensors-0.2.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:041aa5e13f1bbc0c0d441f692b4103531b9e867ccc281c2c3946fe51deb7eccd"}, - {file = "safetensors-0.2.7-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f2d730ac6716ab4105000aa85e7be7d771b199ec9ab77df50eff6cc0ddce6fcd"}, - {file = "safetensors-0.2.7-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:acef097785a9f03172151c7a8a4c6743d7ec77d1473c38da3aebc8f004d620a8"}, - {file = "safetensors-0.2.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f54630c845e489123a00bfb93787086c96de9abc5e72fbec52c1d2a55f8147fa"}, - {file = "safetensors-0.2.7-cp310-cp310-win32.whl", hash = "sha256:3c4c0d7f3d6922dcae178f19daf52c877674906d60944273e0fdf7a73b7a33e7"}, - {file = "safetensors-0.2.7-cp310-cp310-win_amd64.whl", hash = "sha256:3c99cccbbe1da7a1bdeb7e4333046577c9785c8b4bb81912b8a134a66570fc0f"}, - {file = "safetensors-0.2.7-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:0a6f3a1db227aeb1152fb8676d94fea3f97d9e017b4b82f7ce5447b88c3a2126"}, - {file = "safetensors-0.2.7-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:8338b875e7e86f50cbffd8c79f28fb0fe2ed56bebe1f87c95a26c501ee50fc51"}, - {file = "safetensors-0.2.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a8403e93665d87ff1006b5678c23566e7a4bf7f2cfdb2666c3955350b2841379"}, - {file = "safetensors-0.2.7-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d993db8e3f63cd71994544c130c1e227a6522c63ddaf2b9f85a33c8e789283b0"}, - {file = "safetensors-0.2.7-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c4901cac6c4bd3d52cb68c86311f86a28d33fa2503bd7c32012c1489cd9a52c"}, - {file = "safetensors-0.2.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3b9acd8d864c284d8fe300b290092a2cc7ae76319e7bdd5cbba00a8b4ec2dc0"}, - {file = "safetensors-0.2.7-cp311-cp311-win32.whl", hash = "sha256:dd63fc6bb6f7b78a957c84692afdbaa39168978aac61dfd5ea8c2bbbd459f8a6"}, - {file = "safetensors-0.2.7-cp311-cp311-win_amd64.whl", hash = "sha256:848af8bca3ee8e4f68bb828e7fbc1d4022d3cde17e8bd098324ef93ace4779e6"}, - {file = "safetensors-0.2.7-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:44ab41d080be381550a0cdc1ded9afc5de899dd733f341c902834571a6f60ca7"}, - {file = "safetensors-0.2.7-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:964357808ad70e001d4d6779418e48616eb1f103acf6acdb72bb6a51c05a9da4"}, - {file = "safetensors-0.2.7-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bb21eb7b86692cb1d5bc95b0926acd0356df408990de63ae09c3242594170156"}, - {file = "safetensors-0.2.7-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c3c421b55b3baf2ce373836155ffb58d530ec0752837c5fec2f8b841019b49c"}, - {file = "safetensors-0.2.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5f21fe636693f9f40594b357f3d4586cb23e98e044e5e78b21e814543d69c3b"}, - {file = "safetensors-0.2.7-cp37-cp37m-win32.whl", hash = "sha256:1212ec6e113625d9eea662815a25c993a76860ec51f6cc26ac33f4e317abecb5"}, - {file = "safetensors-0.2.7-cp37-cp37m-win_amd64.whl", hash = "sha256:6b8076a21241dcad8848b42d90dc7fa1401c89bee622465d019988df62175ae1"}, - {file = "safetensors-0.2.7-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:e27907078ab3924c8135f229c9dff06126fe123b609a22cf2ae96593ca4833bc"}, - {file = "safetensors-0.2.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89cc90cb042654b224fa98a8acbfc5f795f33bd8c5ff6431ad9884370c9c0caf"}, - {file = "safetensors-0.2.7-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:092c34408391d6fc96617093d59f00ffd23ca624530591d225b0f824cb105367"}, - {file = "safetensors-0.2.7-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:795310f19396f01cfbe00170575daa52bc1019068ef75b13ba487155fba4e9bd"}, - {file = "safetensors-0.2.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3924794458e716dc7d2f3e264c98455de3d59b71e0335cbabef727083d537535"}, - {file = "safetensors-0.2.7-cp38-cp38-win32.whl", hash = "sha256:9205b4a12939db1135fff92b4ea4f62bf131bfc9d54df31881fb6a50eb2ba5fa"}, - {file = "safetensors-0.2.7-cp38-cp38-win_amd64.whl", hash = "sha256:b84662f8d8dd49443a8f8d1d9a89a79974a5f03f4e06526de170ef02f8511236"}, - {file = "safetensors-0.2.7-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:657b01a122da21028a1d32570e0173fa0ef1865d3480cf5bcca80ec92cae421c"}, - {file = "safetensors-0.2.7-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:788a242ea5fab371dcc42cb10ed21c60ea8c5ff7fad41c9fb3d334420c80f156"}, - {file = "safetensors-0.2.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6bfbc7123d0c08754b6a11848b0298070705f60ab450e3f249624bba1571040"}, - {file = "safetensors-0.2.7-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1982f00d34d8c72e3339c458e5e2fb3eaf9458c55eae6d3bddd61649666db130"}, - {file = "safetensors-0.2.7-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:330570d27ebf5bdb3f63580baa4a4bf92ad11e1df17b96703181cfaa11cae90e"}, - {file = "safetensors-0.2.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f112b20899902dfc5ba27cdea49e57c275c7824a3e24764fca93d3ad436a160"}, - {file = "safetensors-0.2.7-cp39-cp39-win32.whl", hash = "sha256:29fbc1d1c802a5eced1675e93e130367109b2af3bf74af8c415a141b0f5fe568"}, - {file = "safetensors-0.2.7-cp39-cp39-win_amd64.whl", hash = "sha256:054306d050453a92d3cc6cde15e7cb282ef4299f602433bbd70a9b1b6963c9f4"}, - {file = "safetensors-0.2.7.tar.gz", hash = "sha256:37192b456fdfe09762d6a2ef3322d2379fee52eb2db6d887ebac6e3349de596d"}, + {file = "safetensors-0.2.8-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:8df8af89a0b6b535b47c077e33e5cd4941ef4b067e7c1dd1a05647dec0cf2eea"}, + {file = "safetensors-0.2.8-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:9cabae651542b22d229b6439029fb4a3abc7868cb123af336b2877541ad9ab39"}, + {file = "safetensors-0.2.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:384df328523c567a8b2193ebb02456300a43a3adbc316823d4c0d16f7ac9e89d"}, + {file = "safetensors-0.2.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa5f4b31f2e283b83894b388298368860367a1cb773228f9bb65f8db65da1549"}, + {file = "safetensors-0.2.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:70f322d170a17b6ecb9c85b15e67f4a9aa3e561594e2dfc7709c0ae0000ebfff"}, + {file = "safetensors-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50f7e90ed02ef5e4263aadce23d39a3595156816c42c4643003791b45f81fd31"}, + {file = "safetensors-0.2.8-cp310-cp310-win32.whl", hash = "sha256:726ad66231286157dd505b5ab85fd903114dcb5f9180e79fd5540d68d6249dd0"}, + {file = "safetensors-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:d14e7c15e5acac6efcd5f5565e38b1b33600101387e5d16059de44adab87405f"}, + {file = "safetensors-0.2.8-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:6273dd3edd5db05b2da0090bc3d491bf25f6d1d7e8a4423477377649e9e38c37"}, + {file = "safetensors-0.2.8-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:34705920c8a02f9ea6101bae8403db5f4aa18ec3eaccd8eab6701b1c88ee5bed"}, + {file = "safetensors-0.2.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:efff5ce2d3f349d5360a5ca5901ae56c7e24f8da623d57cd08f8e8b1cd9cb1f8"}, + {file = "safetensors-0.2.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dec20b9b1fc90b7b4e588b4f0e9e266bd8f26d405e08b0b6ecad3136d92d007a"}, + {file = "safetensors-0.2.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7716bab34ca9651895126c720df1bf07f464184a7409138179f38db488ca9f15"}, + {file = "safetensors-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb28e5e6257f705828fd39b9ba28248b593f46e722d8d6beedbc8d1f194e2297"}, + {file = "safetensors-0.2.8-cp311-cp311-win32.whl", hash = "sha256:c71e147a8c2942690e8707adbcc4ab60775bc78dfdbd7f9e696dd411adef82bb"}, + {file = "safetensors-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:2f40604c50d4a08a4c74f37fef735cd1e203a59aeda66ea23a01d76fb35cf407"}, + {file = "safetensors-0.2.8-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:f3c7242debe713a87ca6deaadb0d7120e61c92435414142d526e8c32da970020"}, + {file = "safetensors-0.2.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9366140785d567049395c1c03ac69ee8d322fabcdc8fab7d389e933596b383da"}, + {file = "safetensors-0.2.8-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5d720c7640ad5f95f47780bdc35777ed8371afa14d8d63d6375cfe36df83fb4"}, + {file = "safetensors-0.2.8-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a0c7ed4d75c2f248791dbe64309c98ada6f40a6949147ca2eaebd278906c918b"}, + {file = "safetensors-0.2.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1b905cf3563da4fe8dd46fa777321516f5aa8f6bc1b884158be03a235478e96d"}, + {file = "safetensors-0.2.8-cp37-cp37m-win32.whl", hash = "sha256:2a47327cfe57b6ee8c5dc246f9141f4f6b368e4444dd7f174c025b1d34446730"}, + {file = "safetensors-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:02e67b906ad9bbb87308a34e4d2d33c9eb69baf342b7e5c872728556baf3f0b6"}, + {file = "safetensors-0.2.8-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ba7c2496765de45a84f5c79b2b12a14a568643d8966ef9bb3f8b16217a39457c"}, + {file = "safetensors-0.2.8-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:90cd22387b1520c4465033b986f79f0d24cc41aabae1903a22eff3b42cee9ad5"}, + {file = "safetensors-0.2.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f279e0fd917a886e1936d534734593f780afdddc6ed62f8ebf2f59de811cdd7c"}, + {file = "safetensors-0.2.8-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:af4fce0565574ec3dbe997b021ed32d3dc229547c4f7fca2459be1f725c26f88"}, + {file = "safetensors-0.2.8-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ff0096f5a765e6e3f7f3156f568f59499edeade19e813d610a124ca70b42cdda"}, + {file = "safetensors-0.2.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c02c1cab1d23e6cc8ac1ef865efe76f2d0800e203c877288ebd31747f47a6940"}, + {file = "safetensors-0.2.8-cp38-cp38-win32.whl", hash = "sha256:832f27f6f379fb15d87f3a10eb87e2480434a89c483d7edc8a0c262364ba72c1"}, + {file = "safetensors-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:89da823f9801e92b2c48e8fad1e2f7f0cb696a8a93dab4c6700e8de06fe87392"}, + {file = "safetensors-0.2.8-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:ee8e169040468d176172b4f84a160ece2064abcc77294c4994a9d2bb5255cd75"}, + {file = "safetensors-0.2.8-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:ac2197dbbf7cbd269faf8cebab4220dba5aa2ac8beacbce8fdcb9800776781ca"}, + {file = "safetensors-0.2.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f90513eee0a11902df7e51b07c3d9c328828b9dd692d6c74140bed932e7a491"}, + {file = "safetensors-0.2.8-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2bd2e4c6258dd0f4e3d554f2f59789f25ef4757431e83c927016de6339e54811"}, + {file = "safetensors-0.2.8-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:01000a4bfdf0474bb1bdf369de1284c93b4e6047524fe9dc55d77586cb9d0243"}, + {file = "safetensors-0.2.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:466f92a384e4fcb6b1b9811e7488516c4638119c302a959b89bbe1be826d5e25"}, + {file = "safetensors-0.2.8-cp39-cp39-win32.whl", hash = "sha256:2f16e5ee70ae4218474493eff8d998e632a896a6b8aff9273511d654cdf367ab"}, + {file = "safetensors-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:ba3dc236a2344b7feadc9868307f42ba5e4804c9d68a80a35aac831349b31f6f"}, + {file = "safetensors-0.2.8.tar.gz", hash = "sha256:2720b20a6a38c799dca79bd76caeeac2f7df585a9d4f7d59fa7e28eff9ccb27f"}, ] setuptools = [ - {file = "setuptools-65.6.3-py3-none-any.whl", hash = "sha256:57f6f22bde4e042978bcd50176fdb381d7c21a9efa4041202288d3737a0c6a54"}, - {file = "setuptools-65.6.3.tar.gz", hash = "sha256:a7620757bf984b58deaf32fc8a4577a9bbc0850cf92c20e1ce41c38c19e5fb75"}, + {file = "setuptools-67.0.0-py3-none-any.whl", hash = "sha256:9d790961ba6219e9ff7d9557622d2fe136816a264dd01d5997cfc057d804853d"}, + {file = "setuptools-67.0.0.tar.gz", hash = "sha256:883131c5b6efa70b9101c7ef30b2b7b780a4283d5fc1616383cdf22c83cbefe6"}, ] tomli = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, diff --git a/server/pyproject.toml b/server/pyproject.toml index 4a2ccce7..98d91125 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -15,7 +15,7 @@ grpcio-status = "^1.51.1" grpcio-reflection = "^1.51.1" grpc-interceptor = "^0.15.0" typer = "^0.6.1" -accelerate = "^0.12.0" +accelerate = "^0.15.0" bitsandbytes = "^0.35.1" safetensors = "^0.2.4" loguru = "^0.6.0" diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 35f46bc2..1b7635c5 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -33,8 +33,6 @@ try: except Exception as e: HAS_BITS_AND_BYTES = False -torch.manual_seed(0) - class BloomCausalLMBatch(CausalLMBatch): @classmethod diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index 9bec1dde..5cc55865 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -36,7 +36,6 @@ try: except Exception as e: HAS_BITS_AND_BYTES = False -torch.manual_seed(0) # CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index c93e783b..a2029911 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -24,12 +24,10 @@ from text_generation.pb import generate_pb2 class Sampling: - def __init__(self, seed: Optional[int] = None, device: str = "cpu"): + def __init__(self, seed: int, device: str = "cpu"): self.generator = torch.Generator(device) - if seed is not None: - self.generator.manual_seed(seed) - else: - self.generator.seed() + self.generator.manual_seed(seed) + self.seed = seed def __call__(self, logits): probs = torch.nn.functional.softmax(logits, dim=-1) @@ -38,10 +36,6 @@ class Sampling: ).squeeze(1) return next_tokens - @property - def seed(self) -> int: - return self.generator.initial_seed() - class Greedy: def __call__(self, logits): @@ -55,7 +49,7 @@ class NextTokenChooser: top_k=None, top_p=None, do_sample=False, - seed=None, + seed=0, device="cpu", ): warpers = LogitsProcessorList() @@ -89,14 +83,12 @@ class NextTokenChooser: def from_pb( cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device ) -> "NextTokenChooser": - # handle protobuf making default values 0 - seed = pb.seed if pb.HasField("seed") else None return NextTokenChooser( temperature=pb.temperature, top_k=pb.top_k, top_p=pb.top_p, do_sample=pb.do_sample, - seed=seed, + seed=pb.seed, device=str(device), )