From e71471bec95823ef69daaeb03c4657b9b5211a02 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 15 May 2023 23:36:30 +0200 Subject: [PATCH] feat: add snapshot testing (#282) --- .github/workflows/build.yaml | 62 +- Makefile | 11 +- README.md | 2 + integration-tests/conftest.py | 146 ++++ .../models/__snapshots__/test_bloom_560m.ambr | 627 ++++++++++++++++ .../test_bloom_560m_sharded.ambr | 542 ++++++++++++++ .../__snapshots__/test_flash_llama.ambr | 465 ++++++++++++ .../models/__snapshots__/test_flash_neox.ambr | 682 ++++++++++++++++++ .../__snapshots__/test_flash_santacoder.ambr | 472 ++++++++++++ .../__snapshots__/test_flash_starcoder.ambr | 573 +++++++++++++++ .../models/__snapshots__/test_mt0_base.ambr | 306 ++++++++ integration-tests/models/test_bloom_560m.py | 63 ++ .../models/test_bloom_560m_sharded.py | 42 ++ integration-tests/models/test_flash_llama.py | 56 ++ integration-tests/models/test_flash_neox.py | 38 + .../models/test_flash_santacoder.py | 32 + .../models/test_flash_starcoder.py | 47 ++ integration-tests/models/test_mt0_base.py | 63 ++ integration-tests/models/utils.py | 15 + integration-tests/requirements.txt | 5 + launcher/tests/bloom_560m.json | 142 ---- launcher/tests/integration_tests.rs | 172 ----- launcher/tests/mt0_base.json | 137 ---- server/text_generation_server/models/bloom.py | 2 +- .../custom_modeling/flash_llama_modeling.py | 18 +- .../custom_modeling/flash_neox_modeling.py | 17 +- .../flash_santacoder_modeling.py | 16 +- .../models/flash_llama.py | 15 +- .../models/flash_neox.py | 14 +- .../models/flash_santacoder.py | 18 +- .../models/galactica.py | 2 +- .../text_generation_server/models/gpt_neox.py | 2 +- server/text_generation_server/models/opt.py | 7 +- server/text_generation_server/models/t5.py | 2 +- server/text_generation_server/utils/layers.py | 9 +- 35 files changed, 4313 insertions(+), 509 deletions(-) create mode 100644 integration-tests/conftest.py create mode 100644 integration-tests/models/__snapshots__/test_bloom_560m.ambr create mode 100644 integration-tests/models/__snapshots__/test_bloom_560m_sharded.ambr create mode 100644 integration-tests/models/__snapshots__/test_flash_llama.ambr create mode 100644 integration-tests/models/__snapshots__/test_flash_neox.ambr create mode 100644 integration-tests/models/__snapshots__/test_flash_santacoder.ambr create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder.ambr create mode 100644 integration-tests/models/__snapshots__/test_mt0_base.ambr create mode 100644 integration-tests/models/test_bloom_560m.py create mode 100644 integration-tests/models/test_bloom_560m_sharded.py create mode 100644 integration-tests/models/test_flash_llama.py create mode 100644 integration-tests/models/test_flash_neox.py create mode 100644 integration-tests/models/test_flash_santacoder.py create mode 100644 integration-tests/models/test_flash_starcoder.py create mode 100644 integration-tests/models/test_mt0_base.py create mode 100644 integration-tests/models/utils.py create mode 100644 integration-tests/requirements.txt delete mode 100644 launcher/tests/bloom_560m.json delete mode 100644 launcher/tests/integration_tests.rs delete mode 100644 launcher/tests/mt0_base.json diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 0a99fb52..c2aba160 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -20,10 +20,6 @@ on: branches: - 'main' -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - jobs: start-runner: name: Start self-hosted EC2 runner @@ -61,6 +57,9 @@ jobs: ] build-and-push-image: + concurrency: + group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true needs: start-runner # required to start the main job when the runner is ready runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner permissions: @@ -108,7 +107,19 @@ jobs: username: ${{ secrets.AZURE_DOCKER_USERNAME }} password: ${{ secrets.AZURE_DOCKER_PASSWORD }} registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io + # If pull request - name: Extract metadata (tags, labels) for Docker + if: ${{ github.event_name == 'pull_request' }} + id: meta-pr + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/community/text-generation-inference + tags: | + type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }} + # If main, release or tag + - name: Extract metadata (tags, labels) for Docker + if: ${{ github.event_name != 'pull_request' }} id: meta uses: docker/metadata-action@v4.3.0 with: @@ -129,13 +140,13 @@ jobs: with: context: . file: Dockerfile - push: ${{ github.event_name != 'pull_request' }} + push: true platforms: 'linux/amd64' build-args: | GIT_SHA=${{ env.GITHUB_SHA }} DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} + tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} + labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max # Sign the resulting Docker image digest except on PRs. @@ -172,11 +183,48 @@ jobs: with: sarif_file: 'trivy-results.sarif' + integration-tests: + concurrency: + group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + needs: + - start-runner + - build-and-push-image # Wait for the docker image to be built + runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner + env: + DOCKER_VOLUME: /cache + steps: + - uses: actions/checkout@v2 + - name: Inject slug/short variables + uses: rlespinasse/github-slug-action@v4.4.1 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.9 + - name: Tailscale + uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966 + with: + authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Prepare disks + run: | + sudo mkfs -t ext4 /dev/nvme1n1 + sudo mkdir ${{ env.DOCKER_VOLUME }} + sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }} + - name: Install + run: | + make install-integration-tests + - name: Run tests + run: | + export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }} + export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} + pytest -s -vv integration-tests + stop-runner: name: Stop self-hosted EC2 runner needs: - start-runner - build-and-push-image + - integration-tests runs-on: ubuntu-latest env: AWS_REGION: us-east-1 diff --git a/Makefile b/Makefile index 032a49de..29c318fa 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,9 @@ install-server: cd server && make install +install-integration-tests: + cd integration-tests && pip install -r requirements.txt + install-router: cd router && cargo install --path . @@ -18,9 +21,15 @@ server-dev: router-dev: cd router && cargo run -- --port 8080 -integration-tests: install-router install-launcher +rust-tests: install-router install-launcher cargo test +integration-tests: install-integration-tests + pytest -s -vv -m "not private" integration-tests + +update-integration-tests: install-integration-tests + pytest -s -vv --snapshot-update integration-tests + python-server-tests: HF_HUB_ENABLE_HF_TRANSFER=1 pytest server/tests diff --git a/README.md b/README.md index 756d7e35..918ea5e2 100644 --- a/README.md +++ b/README.md @@ -253,5 +253,7 @@ make python-client-tests # or both server and client tests make python-tests # rust cargo tests +make rust-tests +# integration tests make integration-tests ``` diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py new file mode 100644 index 00000000..e9c51c37 --- /dev/null +++ b/integration-tests/conftest.py @@ -0,0 +1,146 @@ +import subprocess +import contextlib +import pytest +import asyncio +import os +import docker + +from docker.errors import NotFound +from typing import Optional, List +from syrupy.filters import props + +from text_generation import AsyncClient +from text_generation.types import Response + +DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) +HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) +DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") + + +@pytest.fixture +def snapshot_test(snapshot): + return lambda value: value == snapshot(exclude=props("logprob")) + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="module") +def launcher(event_loop): + @contextlib.contextmanager + def local_launcher( + model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None + ): + port = 9999 + master_port = 19999 + + shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server" + + args = [ + "text-generation-launcher", + "--model-id", + model_id, + "--port", + str(port), + "--master-port", + str(master_port), + "--shard-uds-path", + shard_uds_path, + ] + + if num_shard is not None: + args.extend(["--num-shard", str(num_shard)]) + if quantize: + args.append("--quantize") + + with subprocess.Popen( + args, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) as process: + yield AsyncClient(f"http://localhost:{port}") + + process.terminate() + process.wait(60) + + launcher_output = process.stdout.read().decode("utf-8") + print(launcher_output) + + process.stdout.close() + process.stderr.close() + + @contextlib.contextmanager + def docker_launcher( + model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None + ): + port = 9999 + + args = ["--model-id", model_id, "--env"] + + if num_shard is not None: + args.extend(["--num-shard", str(num_shard)]) + if quantize: + args.append("--quantize") + + client = docker.from_env() + + container_name = f"tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}" + + try: + container = client.containers.get(container_name) + container.stop() + container.wait() + except NotFound: + pass + + gpu_count = num_shard if num_shard is not None else 1 + + env = {} + if HUGGING_FACE_HUB_TOKEN is not None: + env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN + + volumes = [] + if DOCKER_VOLUME: + volumes = [f"{DOCKER_VOLUME}:/data"] + + container = client.containers.run( + DOCKER_IMAGE, + command=args, + name=container_name, + environment=env, + auto_remove=True, + detach=True, + device_requests=[ + docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]]) + ], + volumes=volumes, + ports={"80/tcp": port}, + ) + + yield AsyncClient(f"http://localhost:{port}") + + container.stop() + + container_output = container.logs().decode("utf-8") + print(container_output) + + if DOCKER_IMAGE is not None: + return docker_launcher + return local_launcher + + +@pytest.fixture(scope="module") +def generate_load(): + async def generate_load_inner( + client: AsyncClient, prompt: str, max_new_tokens: int, n: int + ) -> List[Response]: + futures = [ + client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n) + ] + + results = await asyncio.gather(*futures) + return [r.dict() for r in results] + + return generate_load_inner diff --git a/integration-tests/models/__snapshots__/test_bloom_560m.ambr b/integration-tests/models/__snapshots__/test_bloom_560m.ambr new file mode 100644 index 00000000..1067513d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_bloom_560m.ambr @@ -0,0 +1,627 @@ +# serializer version: 1 +# name: test_bloom_560m + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': 0, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 159570, + 'special': False, + 'text': ' réch', + }), + dict({ + 'id': 810, + 'special': False, + 'text': 'au', + }), + dict({ + 'id': 12736, + 'special': False, + 'text': 'ffer', + }), + dict({ + 'id': 1742, + 'special': False, + 'text': ' au', + }), + dict({ + 'id': 6105, + 'special': False, + 'text': ' bain', + }), + dict({ + 'id': 88254, + 'special': False, + 'text': '-mar', + }), + dict({ + 'id': 641, + 'special': False, + 'text': 'ie', + }), + dict({ + 'id': 2940, + 'special': False, + 'text': ' avec', + }), + ]), + }), + 'generated_text': ' le faire réchauffer au bain-marie avec', + }) +# --- +# name: test_bloom_560m_all_params + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': 0, + 'tokens': list([ + dict({ + 'id': 408, + 'special': False, + 'text': ' que', + }), + dict({ + 'id': 20288, + 'special': False, + 'text': " l'on", + }), + dict({ + 'id': 22255, + 'special': False, + 'text': ' trouve', + }), + dict({ + 'id': 1622, + 'special': False, + 'text': ' une', + }), + dict({ + 'id': 187079, + 'special': False, + 'text': ' posture', + }), + dict({ + 'id': 501, + 'special': False, + 'text': ' par', + }), + dict({ + 'id': 8741, + 'special': False, + 'text': ' rapport', + }), + dict({ + 'id': 693, + 'special': False, + 'text': ' à', + }), + dict({ + 'id': 366, + 'special': False, + 'text': ' la', + }), + dict({ + 'id': 36503, + 'special': False, + 'text': ' pratique', + }), + ]), + }), + 'generated_text': "Pour déguster un ortolan, il faut tout d'abord que l'on trouve une posture par rapport à la pratique", + }) +# --- +# name: test_bloom_560m_load + list([ + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 1767, + 'special': False, + 'text': ' cu', + }), + dict({ + 'id': 1273, + 'special': False, + 'text': 'ire', + }), + dict({ + 'id': 1486, + 'special': False, + 'text': ' dans', + }), + dict({ + 'id': 283, + 'special': False, + 'text': ' de', + }), + dict({ + 'id': 40410, + 'special': False, + 'text': " l'eau", + }), + dict({ + 'id': 20226, + 'special': False, + 'text': ' bou', + }), + dict({ + 'id': 172483, + 'special': False, + 'text': 'illante', + }), + dict({ + 'id': 2805, + 'special': False, + 'text': ' sal', + }), + ]), + }), + 'generated_text': " le faire cuire dans de l'eau bouillante sal", + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 1767, + 'special': False, + 'text': ' cu', + }), + dict({ + 'id': 1273, + 'special': False, + 'text': 'ire', + }), + dict({ + 'id': 1486, + 'special': False, + 'text': ' dans', + }), + dict({ + 'id': 283, + 'special': False, + 'text': ' de', + }), + dict({ + 'id': 40410, + 'special': False, + 'text': " l'eau", + }), + dict({ + 'id': 20226, + 'special': False, + 'text': ' bou', + }), + dict({ + 'id': 172483, + 'special': False, + 'text': 'illante', + }), + dict({ + 'id': 2805, + 'special': False, + 'text': ' sal', + }), + ]), + }), + 'generated_text': " le faire cuire dans de l'eau bouillante sal", + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 1767, + 'special': False, + 'text': ' cu', + }), + dict({ + 'id': 1273, + 'special': False, + 'text': 'ire', + }), + dict({ + 'id': 1486, + 'special': False, + 'text': ' dans', + }), + dict({ + 'id': 283, + 'special': False, + 'text': ' de', + }), + dict({ + 'id': 40410, + 'special': False, + 'text': " l'eau", + }), + dict({ + 'id': 20226, + 'special': False, + 'text': ' bou', + }), + dict({ + 'id': 172483, + 'special': False, + 'text': 'illante', + }), + dict({ + 'id': 2805, + 'special': False, + 'text': ' sal', + }), + ]), + }), + 'generated_text': " le faire cuire dans de l'eau bouillante sal", + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 1767, + 'special': False, + 'text': ' cu', + }), + dict({ + 'id': 1273, + 'special': False, + 'text': 'ire', + }), + dict({ + 'id': 1486, + 'special': False, + 'text': ' dans', + }), + dict({ + 'id': 283, + 'special': False, + 'text': ' de', + }), + dict({ + 'id': 40410, + 'special': False, + 'text': " l'eau", + }), + dict({ + 'id': 20226, + 'special': False, + 'text': ' bou', + }), + dict({ + 'id': 172483, + 'special': False, + 'text': 'illante', + }), + dict({ + 'id': 2805, + 'special': False, + 'text': ' sal', + }), + ]), + }), + 'generated_text': " le faire cuire dans de l'eau bouillante sal", + }), + ]) +# --- diff --git a/integration-tests/models/__snapshots__/test_bloom_560m_sharded.ambr b/integration-tests/models/__snapshots__/test_bloom_560m_sharded.ambr new file mode 100644 index 00000000..667a0373 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_bloom_560m_sharded.ambr @@ -0,0 +1,542 @@ +# serializer version: 1 +# name: test_bloom_560m_sharded + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': 0, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 159570, + 'special': False, + 'text': ' réch', + }), + dict({ + 'id': 810, + 'special': False, + 'text': 'au', + }), + dict({ + 'id': 12736, + 'special': False, + 'text': 'ffer', + }), + dict({ + 'id': 1742, + 'special': False, + 'text': ' au', + }), + dict({ + 'id': 6105, + 'special': False, + 'text': ' bain', + }), + dict({ + 'id': 88254, + 'special': False, + 'text': '-mar', + }), + dict({ + 'id': 641, + 'special': False, + 'text': 'ie', + }), + dict({ + 'id': 2940, + 'special': False, + 'text': ' avec', + }), + ]), + }), + 'generated_text': ' le faire réchauffer au bain-marie avec', + }) +# --- +# name: test_bloom_560m_sharded_load + list([ + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 1767, + 'special': False, + 'text': ' cu', + }), + dict({ + 'id': 1273, + 'special': False, + 'text': 'ire', + }), + dict({ + 'id': 1486, + 'special': False, + 'text': ' dans', + }), + dict({ + 'id': 283, + 'special': False, + 'text': ' de', + }), + dict({ + 'id': 40410, + 'special': False, + 'text': " l'eau", + }), + dict({ + 'id': 20226, + 'special': False, + 'text': ' bou', + }), + dict({ + 'id': 172483, + 'special': False, + 'text': 'illante', + }), + dict({ + 'id': 2805, + 'special': False, + 'text': ' sal', + }), + ]), + }), + 'generated_text': " le faire cuire dans de l'eau bouillante sal", + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 1767, + 'special': False, + 'text': ' cu', + }), + dict({ + 'id': 1273, + 'special': False, + 'text': 'ire', + }), + dict({ + 'id': 1486, + 'special': False, + 'text': ' dans', + }), + dict({ + 'id': 283, + 'special': False, + 'text': ' de', + }), + dict({ + 'id': 40410, + 'special': False, + 'text': " l'eau", + }), + dict({ + 'id': 20226, + 'special': False, + 'text': ' bou', + }), + dict({ + 'id': 172483, + 'special': False, + 'text': 'illante', + }), + dict({ + 'id': 2805, + 'special': False, + 'text': ' sal', + }), + ]), + }), + 'generated_text': " le faire cuire dans de l'eau bouillante sal", + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 1767, + 'special': False, + 'text': ' cu', + }), + dict({ + 'id': 1273, + 'special': False, + 'text': 'ire', + }), + dict({ + 'id': 1486, + 'special': False, + 'text': ' dans', + }), + dict({ + 'id': 283, + 'special': False, + 'text': ' de', + }), + dict({ + 'id': 40410, + 'special': False, + 'text': " l'eau", + }), + dict({ + 'id': 20226, + 'special': False, + 'text': ' bou', + }), + dict({ + 'id': 172483, + 'special': False, + 'text': 'illante', + }), + dict({ + 'id': 2805, + 'special': False, + 'text': ' sal', + }), + ]), + }), + 'generated_text': " le faire cuire dans de l'eau bouillante sal", + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 1767, + 'special': False, + 'text': ' cu', + }), + dict({ + 'id': 1273, + 'special': False, + 'text': 'ire', + }), + dict({ + 'id': 1486, + 'special': False, + 'text': ' dans', + }), + dict({ + 'id': 283, + 'special': False, + 'text': ' de', + }), + dict({ + 'id': 40410, + 'special': False, + 'text': " l'eau", + }), + dict({ + 'id': 20226, + 'special': False, + 'text': ' bou', + }), + dict({ + 'id': 172483, + 'special': False, + 'text': 'illante', + }), + dict({ + 'id': 2805, + 'special': False, + 'text': ' sal', + }), + ]), + }), + 'generated_text': " le faire cuire dans de l'eau bouillante sal", + }), + ]) +# --- diff --git a/integration-tests/models/__snapshots__/test_flash_llama.ambr b/integration-tests/models/__snapshots__/test_flash_llama.ambr new file mode 100644 index 00000000..f4e3a4c1 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama.ambr @@ -0,0 +1,465 @@ +# serializer version: 1 +# name: test_flash_llama + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 1, + 'text': '', + }), + dict({ + 'id': 4321, + 'text': 'Test', + }), + dict({ + 'id': 2009, + 'text': 'request', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 363, + 'special': False, + 'text': ' for', + }), + dict({ + 'id': 847, + 'special': False, + 'text': ' /', + }), + dict({ + 'id': 2754, + 'special': False, + 'text': 'api', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29894, + 'special': False, + 'text': 'v', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 16418, + 'special': False, + 'text': 'projects', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + ]), + }), + 'generated_text': 'for /api/v1/projects/1', + }) +# --- +# name: test_flash_llama_all_params + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 1, + 'text': '', + }), + dict({ + 'id': 4321, + 'text': 'Test', + }), + dict({ + 'id': 2009, + 'text': 'request', + }), + ]), + 'seed': 0, + 'tokens': list([ + dict({ + 'id': 5229, + 'special': False, + 'text': ' failed', + }), + dict({ + 'id': 363, + 'special': False, + 'text': ' for', + }), + dict({ + 'id': 5641, + 'special': False, + 'text': ' IP', + }), + dict({ + 'id': 16428, + 'special': False, + 'text': ' Address', + }), + dict({ + 'id': 29901, + 'special': False, + 'text': ':', + }), + dict({ + 'id': 525, + 'special': False, + 'text': " '", + }), + dict({ + 'id': 8516, + 'special': False, + 'text': 'None', + }), + dict({ + 'id': 4286, + 'special': False, + 'text': "'.", + }), + dict({ + 'id': 13, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 294, + 'special': False, + 'text': 'as', + }), + ]), + }), + 'generated_text': ''' + Test requestfailed for IP Address: 'None'. + as + ''', + }) +# --- +# name: test_flash_llama_load + list([ + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 1, + 'text': '', + }), + dict({ + 'id': 4321, + 'text': 'Test', + }), + dict({ + 'id': 2009, + 'text': 'request', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 363, + 'special': False, + 'text': ' for', + }), + dict({ + 'id': 847, + 'special': False, + 'text': ' /', + }), + dict({ + 'id': 2754, + 'special': False, + 'text': 'api', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29894, + 'special': False, + 'text': 'v', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 16418, + 'special': False, + 'text': 'projects', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + ]), + }), + 'generated_text': 'for /api/v1/projects/1', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 1, + 'text': '', + }), + dict({ + 'id': 4321, + 'text': 'Test', + }), + dict({ + 'id': 2009, + 'text': 'request', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 363, + 'special': False, + 'text': ' for', + }), + dict({ + 'id': 847, + 'special': False, + 'text': ' /', + }), + dict({ + 'id': 2754, + 'special': False, + 'text': 'api', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29894, + 'special': False, + 'text': 'v', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 16418, + 'special': False, + 'text': 'projects', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + ]), + }), + 'generated_text': 'for /api/v1/projects/1', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 1, + 'text': '', + }), + dict({ + 'id': 4321, + 'text': 'Test', + }), + dict({ + 'id': 2009, + 'text': 'request', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 363, + 'special': False, + 'text': ' for', + }), + dict({ + 'id': 847, + 'special': False, + 'text': ' /', + }), + dict({ + 'id': 2754, + 'special': False, + 'text': 'api', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29894, + 'special': False, + 'text': 'v', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 16418, + 'special': False, + 'text': 'projects', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + ]), + }), + 'generated_text': 'for /api/v1/projects/1', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 1, + 'text': '', + }), + dict({ + 'id': 4321, + 'text': 'Test', + }), + dict({ + 'id': 2009, + 'text': 'request', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 363, + 'special': False, + 'text': ' for', + }), + dict({ + 'id': 847, + 'special': False, + 'text': ' /', + }), + dict({ + 'id': 2754, + 'special': False, + 'text': 'api', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29894, + 'special': False, + 'text': 'v', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 16418, + 'special': False, + 'text': 'projects', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + ]), + }), + 'generated_text': 'for /api/v1/projects/1', + }), + ]) +# --- diff --git a/integration-tests/models/__snapshots__/test_flash_neox.ambr b/integration-tests/models/__snapshots__/test_flash_neox.ambr new file mode 100644 index 00000000..4330db6b --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_neox.ambr @@ -0,0 +1,682 @@ +# serializer version: 1 +# name: test_flash_neox + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 50278, + 'text': '<|prompter|>', + }), + dict({ + 'id': 1276, + 'text': 'What', + }), + dict({ + 'id': 310, + 'text': ' is', + }), + dict({ + 'id': 247, + 'text': ' a', + }), + dict({ + 'id': 1167, + 'text': ' mem', + }), + dict({ + 'id': 70, + 'text': 'e', + }), + dict({ + 'id': 13, + 'text': ',', + }), + dict({ + 'id': 285, + 'text': ' and', + }), + dict({ + 'id': 752, + 'text': ' what', + }), + dict({ + 'id': 434, + 'text': "'s", + }), + dict({ + 'id': 253, + 'text': ' the', + }), + dict({ + 'id': 2892, + 'text': ' history', + }), + dict({ + 'id': 3212, + 'text': ' behind', + }), + dict({ + 'id': 436, + 'text': ' this', + }), + dict({ + 'id': 3159, + 'text': ' word', + }), + dict({ + 'id': 32, + 'text': '?', + }), + dict({ + 'id': 0, + 'text': '<|endoftext|>', + }), + dict({ + 'id': 50281, + 'text': '<|assistant|>', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 510, + 'special': False, + 'text': 'The', + }), + dict({ + 'id': 3159, + 'special': False, + 'text': ' word', + }), + dict({ + 'id': 346, + 'special': False, + 'text': ' "', + }), + dict({ + 'id': 6441, + 'special': False, + 'text': 'mem', + }), + dict({ + 'id': 70, + 'special': False, + 'text': 'e', + }), + dict({ + 'id': 3, + 'special': False, + 'text': '"', + }), + dict({ + 'id': 369, + 'special': False, + 'text': ' was', + }), + dict({ + 'id': 806, + 'special': False, + 'text': ' first', + }), + dict({ + 'id': 908, + 'special': False, + 'text': ' used', + }), + dict({ + 'id': 275, + 'special': False, + 'text': ' in', + }), + ]), + }), + 'generated_text': 'The word "meme" was first used in', + }) +# --- +# name: test_flash_neox_load + list([ + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 50278, + 'text': '<|prompter|>', + }), + dict({ + 'id': 1276, + 'text': 'What', + }), + dict({ + 'id': 310, + 'text': ' is', + }), + dict({ + 'id': 247, + 'text': ' a', + }), + dict({ + 'id': 1167, + 'text': ' mem', + }), + dict({ + 'id': 70, + 'text': 'e', + }), + dict({ + 'id': 13, + 'text': ',', + }), + dict({ + 'id': 285, + 'text': ' and', + }), + dict({ + 'id': 752, + 'text': ' what', + }), + dict({ + 'id': 434, + 'text': "'s", + }), + dict({ + 'id': 253, + 'text': ' the', + }), + dict({ + 'id': 2892, + 'text': ' history', + }), + dict({ + 'id': 3212, + 'text': ' behind', + }), + dict({ + 'id': 436, + 'text': ' this', + }), + dict({ + 'id': 3159, + 'text': ' word', + }), + dict({ + 'id': 32, + 'text': '?', + }), + dict({ + 'id': 0, + 'text': '<|endoftext|>', + }), + dict({ + 'id': 50281, + 'text': '<|assistant|>', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 510, + 'special': False, + 'text': 'The', + }), + dict({ + 'id': 3159, + 'special': False, + 'text': ' word', + }), + dict({ + 'id': 346, + 'special': False, + 'text': ' "', + }), + dict({ + 'id': 6441, + 'special': False, + 'text': 'mem', + }), + dict({ + 'id': 70, + 'special': False, + 'text': 'e', + }), + dict({ + 'id': 3, + 'special': False, + 'text': '"', + }), + dict({ + 'id': 369, + 'special': False, + 'text': ' was', + }), + dict({ + 'id': 806, + 'special': False, + 'text': ' first', + }), + dict({ + 'id': 908, + 'special': False, + 'text': ' used', + }), + dict({ + 'id': 275, + 'special': False, + 'text': ' in', + }), + ]), + }), + 'generated_text': 'The word "meme" was first used in', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 50278, + 'text': '<|prompter|>', + }), + dict({ + 'id': 1276, + 'text': 'What', + }), + dict({ + 'id': 310, + 'text': ' is', + }), + dict({ + 'id': 247, + 'text': ' a', + }), + dict({ + 'id': 1167, + 'text': ' mem', + }), + dict({ + 'id': 70, + 'text': 'e', + }), + dict({ + 'id': 13, + 'text': ',', + }), + dict({ + 'id': 285, + 'text': ' and', + }), + dict({ + 'id': 752, + 'text': ' what', + }), + dict({ + 'id': 434, + 'text': "'s", + }), + dict({ + 'id': 253, + 'text': ' the', + }), + dict({ + 'id': 2892, + 'text': ' history', + }), + dict({ + 'id': 3212, + 'text': ' behind', + }), + dict({ + 'id': 436, + 'text': ' this', + }), + dict({ + 'id': 3159, + 'text': ' word', + }), + dict({ + 'id': 32, + 'text': '?', + }), + dict({ + 'id': 0, + 'text': '<|endoftext|>', + }), + dict({ + 'id': 50281, + 'text': '<|assistant|>', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 510, + 'special': False, + 'text': 'The', + }), + dict({ + 'id': 3159, + 'special': False, + 'text': ' word', + }), + dict({ + 'id': 346, + 'special': False, + 'text': ' "', + }), + dict({ + 'id': 6441, + 'special': False, + 'text': 'mem', + }), + dict({ + 'id': 70, + 'special': False, + 'text': 'e', + }), + dict({ + 'id': 3, + 'special': False, + 'text': '"', + }), + dict({ + 'id': 369, + 'special': False, + 'text': ' was', + }), + dict({ + 'id': 806, + 'special': False, + 'text': ' first', + }), + dict({ + 'id': 908, + 'special': False, + 'text': ' used', + }), + dict({ + 'id': 275, + 'special': False, + 'text': ' in', + }), + ]), + }), + 'generated_text': 'The word "meme" was first used in', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 50278, + 'text': '<|prompter|>', + }), + dict({ + 'id': 1276, + 'text': 'What', + }), + dict({ + 'id': 310, + 'text': ' is', + }), + dict({ + 'id': 247, + 'text': ' a', + }), + dict({ + 'id': 1167, + 'text': ' mem', + }), + dict({ + 'id': 70, + 'text': 'e', + }), + dict({ + 'id': 13, + 'text': ',', + }), + dict({ + 'id': 285, + 'text': ' and', + }), + dict({ + 'id': 752, + 'text': ' what', + }), + dict({ + 'id': 434, + 'text': "'s", + }), + dict({ + 'id': 253, + 'text': ' the', + }), + dict({ + 'id': 2892, + 'text': ' history', + }), + dict({ + 'id': 3212, + 'text': ' behind', + }), + dict({ + 'id': 436, + 'text': ' this', + }), + dict({ + 'id': 3159, + 'text': ' word', + }), + dict({ + 'id': 32, + 'text': '?', + }), + dict({ + 'id': 0, + 'text': '<|endoftext|>', + }), + dict({ + 'id': 50281, + 'text': '<|assistant|>', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 510, + 'special': False, + 'text': 'The', + }), + dict({ + 'id': 3159, + 'special': False, + 'text': ' word', + }), + dict({ + 'id': 346, + 'special': False, + 'text': ' "', + }), + dict({ + 'id': 6441, + 'special': False, + 'text': 'mem', + }), + dict({ + 'id': 70, + 'special': False, + 'text': 'e', + }), + dict({ + 'id': 3, + 'special': False, + 'text': '"', + }), + dict({ + 'id': 369, + 'special': False, + 'text': ' was', + }), + dict({ + 'id': 806, + 'special': False, + 'text': ' first', + }), + dict({ + 'id': 908, + 'special': False, + 'text': ' used', + }), + dict({ + 'id': 275, + 'special': False, + 'text': ' in', + }), + ]), + }), + 'generated_text': 'The word "meme" was first used in', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 50278, + 'text': '<|prompter|>', + }), + dict({ + 'id': 1276, + 'text': 'What', + }), + dict({ + 'id': 310, + 'text': ' is', + }), + dict({ + 'id': 247, + 'text': ' a', + }), + dict({ + 'id': 1167, + 'text': ' mem', + }), + dict({ + 'id': 70, + 'text': 'e', + }), + dict({ + 'id': 13, + 'text': ',', + }), + dict({ + 'id': 285, + 'text': ' and', + }), + dict({ + 'id': 752, + 'text': ' what', + }), + dict({ + 'id': 434, + 'text': "'s", + }), + dict({ + 'id': 253, + 'text': ' the', + }), + dict({ + 'id': 2892, + 'text': ' history', + }), + dict({ + 'id': 3212, + 'text': ' behind', + }), + dict({ + 'id': 436, + 'text': ' this', + }), + dict({ + 'id': 3159, + 'text': ' word', + }), + dict({ + 'id': 32, + 'text': '?', + }), + dict({ + 'id': 0, + 'text': '<|endoftext|>', + }), + dict({ + 'id': 50281, + 'text': '<|assistant|>', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 510, + 'special': False, + 'text': 'The', + }), + dict({ + 'id': 3159, + 'special': False, + 'text': ' word', + }), + dict({ + 'id': 346, + 'special': False, + 'text': ' "', + }), + dict({ + 'id': 6441, + 'special': False, + 'text': 'mem', + }), + dict({ + 'id': 70, + 'special': False, + 'text': 'e', + }), + dict({ + 'id': 3, + 'special': False, + 'text': '"', + }), + dict({ + 'id': 369, + 'special': False, + 'text': ' was', + }), + dict({ + 'id': 806, + 'special': False, + 'text': ' first', + }), + dict({ + 'id': 908, + 'special': False, + 'text': ' used', + }), + dict({ + 'id': 275, + 'special': False, + 'text': ' in', + }), + ]), + }), + 'generated_text': 'The word "meme" was first used in', + }), + ]) +# --- diff --git a/integration-tests/models/__snapshots__/test_flash_santacoder.ambr b/integration-tests/models/__snapshots__/test_flash_santacoder.ambr new file mode 100644 index 00000000..030820cb --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_santacoder.ambr @@ -0,0 +1,472 @@ +# serializer version: 1 +# name: test_flash_santacoder + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 563, + 'text': 'def', + }), + dict({ + 'id': 942, + 'text': ' print', + }), + dict({ + 'id': 62, + 'text': '_', + }), + dict({ + 'id': 7196, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 1241, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 258, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 942, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 372, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 7371, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 9956, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 8657, + 'special': False, + 'text': '!")', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1018, + 'special': False, + 'text': 'print', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World!") + + print + ''', + }) +# --- +# name: test_flash_santacoder_load + list([ + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 563, + 'text': 'def', + }), + dict({ + 'id': 942, + 'text': ' print', + }), + dict({ + 'id': 62, + 'text': '_', + }), + dict({ + 'id': 7196, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 1241, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 258, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 942, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 372, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 7371, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 9956, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 8657, + 'special': False, + 'text': '!")', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1018, + 'special': False, + 'text': 'print', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World!") + + print + ''', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 563, + 'text': 'def', + }), + dict({ + 'id': 942, + 'text': ' print', + }), + dict({ + 'id': 62, + 'text': '_', + }), + dict({ + 'id': 7196, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 1241, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 258, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 942, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 372, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 7371, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 9956, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 8657, + 'special': False, + 'text': '!")', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1018, + 'special': False, + 'text': 'print', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World!") + + print + ''', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 563, + 'text': 'def', + }), + dict({ + 'id': 942, + 'text': ' print', + }), + dict({ + 'id': 62, + 'text': '_', + }), + dict({ + 'id': 7196, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 1241, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 258, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 942, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 372, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 7371, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 9956, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 8657, + 'special': False, + 'text': '!")', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1018, + 'special': False, + 'text': 'print', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World!") + + print + ''', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 563, + 'text': 'def', + }), + dict({ + 'id': 942, + 'text': ' print', + }), + dict({ + 'id': 62, + 'text': '_', + }), + dict({ + 'id': 7196, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 1241, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 258, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 942, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 372, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 7371, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 9956, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 8657, + 'special': False, + 'text': '!")', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1018, + 'special': False, + 'text': 'print', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World!") + + print + ''', + }), + ]) +# --- diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder.ambr b/integration-tests/models/__snapshots__/test_flash_starcoder.ambr new file mode 100644 index 00000000..e0f4b568 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder.ambr @@ -0,0 +1,573 @@ +# serializer version: 1 +# name: test_flash_starcoder + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 589, + 'text': 'def', + }), + dict({ + 'id': 1459, + 'text': ' print', + }), + dict({ + 'id': 81, + 'text': '_', + }), + dict({ + 'id': 7656, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 2262, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 284, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1459, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 440, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 8279, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 10896, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 657, + 'special': False, + 'text': '")', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 589, + 'special': False, + 'text': 'def', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World") + + def + ''', + }) +# --- +# name: test_flash_starcoder_default_params + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 12, + 'prefill': list([ + dict({ + 'id': 589, + 'text': 'def', + }), + dict({ + 'id': 1459, + 'text': ' print', + }), + dict({ + 'id': 81, + 'text': '_', + }), + dict({ + 'id': 7656, + 'text': 'hello', + }), + ]), + 'seed': 0, + 'tokens': list([ + dict({ + 'id': 2262, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 284, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 5741, + 'special': False, + 'text': ' logging', + }), + dict({ + 'id': 32, + 'special': False, + 'text': '.', + }), + dict({ + 'id': 1338, + 'special': False, + 'text': 'info', + }), + dict({ + 'id': 463, + 'special': False, + 'text': "('", + }), + dict({ + 'id': 8279, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 30, + 'special': False, + 'text': ',', + }), + dict({ + 'id': 10896, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 683, + 'special': False, + 'text': "')", + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 0, + 'special': True, + 'text': '<|endoftext|>', + }), + ]), + }), + 'generated_text': ''' + (): + logging.info('Hello, World') + <|endoftext|> + ''', + }) +# --- +# name: test_flash_starcoder_load + list([ + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 589, + 'text': 'def', + }), + dict({ + 'id': 1459, + 'text': ' print', + }), + dict({ + 'id': 81, + 'text': '_', + }), + dict({ + 'id': 7656, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 2262, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 284, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1459, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 440, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 8279, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 10896, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 657, + 'special': False, + 'text': '")', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 589, + 'special': False, + 'text': 'def', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World") + + def + ''', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 589, + 'text': 'def', + }), + dict({ + 'id': 1459, + 'text': ' print', + }), + dict({ + 'id': 81, + 'text': '_', + }), + dict({ + 'id': 7656, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 2262, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 284, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1459, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 440, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 8279, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 10896, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 657, + 'special': False, + 'text': '")', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 589, + 'special': False, + 'text': 'def', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World") + + def + ''', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 589, + 'text': 'def', + }), + dict({ + 'id': 1459, + 'text': ' print', + }), + dict({ + 'id': 81, + 'text': '_', + }), + dict({ + 'id': 7656, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 2262, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 284, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1459, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 440, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 8279, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 10896, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 657, + 'special': False, + 'text': '")', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 589, + 'special': False, + 'text': 'def', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World") + + def + ''', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 589, + 'text': 'def', + }), + dict({ + 'id': 1459, + 'text': ' print', + }), + dict({ + 'id': 81, + 'text': '_', + }), + dict({ + 'id': 7656, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 2262, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 284, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1459, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 440, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 8279, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 10896, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 657, + 'special': False, + 'text': '")', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 589, + 'special': False, + 'text': 'def', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World") + + def + ''', + }), + ]) +# --- diff --git a/integration-tests/models/__snapshots__/test_mt0_base.ambr b/integration-tests/models/__snapshots__/test_mt0_base.ambr new file mode 100644 index 00000000..d7c6eaf6 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mt0_base.ambr @@ -0,0 +1,306 @@ +# serializer version: 1 +# name: test_mt0_base + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 5, + 'prefill': list([ + dict({ + 'id': 0, + 'text': '', + }), + ]), + 'seed': 0, + 'tokens': list([ + dict({ + 'id': 926, + 'special': False, + 'text': 'To', + }), + dict({ + 'id': 18295, + 'special': False, + 'text': ' sell', + }), + dict({ + 'id': 7868, + 'special': False, + 'text': ' things', + }), + dict({ + 'id': 260, + 'special': False, + 'text': '.', + }), + dict({ + 'id': 1, + 'special': True, + 'text': '', + }), + ]), + }), + 'generated_text': 'To sell things.', + }) +# --- +# name: test_mt0_base_all_params + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 0, + 'text': '', + }), + ]), + 'seed': 0, + 'tokens': list([ + dict({ + 'id': 16017, + 'special': False, + 'text': 'blue', + }), + dict({ + 'id': 20495, + 'special': False, + 'text': ' sky', + }), + dict({ + 'id': 259, + 'special': False, + 'text': ' ', + }), + dict({ + 'id': 15484, + 'special': False, + 'text': 'appear', + }), + dict({ + 'id': 345, + 'special': False, + 'text': 'ed', + }), + dict({ + 'id': 288, + 'special': False, + 'text': ' to', + }), + dict({ + 'id': 35622, + 'special': False, + 'text': ' cloud', + }), + dict({ + 'id': 263, + 'special': False, + 'text': 's', + }), + dict({ + 'id': 14701, + 'special': False, + 'text': ' above', + }), + dict({ + 'id': 751, + 'special': False, + 'text': ' all', + }), + ]), + }), + 'generated_text': 'Why is the sky blue?blue sky appeared to clouds above all', + }) +# --- +# name: test_mt0_base_load + list([ + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 6, + 'prefill': list([ + dict({ + 'id': 0, + 'text': '', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 259, + 'special': False, + 'text': '', + }), + dict({ + 'id': 39261, + 'special': False, + 'text': 'Because', + }), + dict({ + 'id': 609, + 'special': False, + 'text': ' it', + }), + dict({ + 'id': 339, + 'special': False, + 'text': ' is', + }), + dict({ + 'id': 16017, + 'special': False, + 'text': ' blue', + }), + dict({ + 'id': 1, + 'special': True, + 'text': '', + }), + ]), + }), + 'generated_text': 'Because it is blue', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 6, + 'prefill': list([ + dict({ + 'id': 0, + 'text': '', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 259, + 'special': False, + 'text': '', + }), + dict({ + 'id': 39261, + 'special': False, + 'text': 'Because', + }), + dict({ + 'id': 609, + 'special': False, + 'text': ' it', + }), + dict({ + 'id': 339, + 'special': False, + 'text': ' is', + }), + dict({ + 'id': 16017, + 'special': False, + 'text': ' blue', + }), + dict({ + 'id': 1, + 'special': True, + 'text': '', + }), + ]), + }), + 'generated_text': 'Because it is blue', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 6, + 'prefill': list([ + dict({ + 'id': 0, + 'text': '', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 259, + 'special': False, + 'text': '', + }), + dict({ + 'id': 39261, + 'special': False, + 'text': 'Because', + }), + dict({ + 'id': 609, + 'special': False, + 'text': ' it', + }), + dict({ + 'id': 339, + 'special': False, + 'text': ' is', + }), + dict({ + 'id': 16017, + 'special': False, + 'text': ' blue', + }), + dict({ + 'id': 1, + 'special': True, + 'text': '', + }), + ]), + }), + 'generated_text': 'Because it is blue', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 6, + 'prefill': list([ + dict({ + 'id': 0, + 'text': '', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 259, + 'special': False, + 'text': '', + }), + dict({ + 'id': 39261, + 'special': False, + 'text': 'Because', + }), + dict({ + 'id': 609, + 'special': False, + 'text': ' it', + }), + dict({ + 'id': 339, + 'special': False, + 'text': ' is', + }), + dict({ + 'id': 16017, + 'special': False, + 'text': ' blue', + }), + dict({ + 'id': 1, + 'special': True, + 'text': '', + }), + ]), + }), + 'generated_text': 'Because it is blue', + }), + ]) +# --- diff --git a/integration-tests/models/test_bloom_560m.py b/integration-tests/models/test_bloom_560m.py new file mode 100644 index 00000000..e13606f7 --- /dev/null +++ b/integration-tests/models/test_bloom_560m.py @@ -0,0 +1,63 @@ +import pytest + +from utils import health_check + + +@pytest.fixture(scope="module") +def bloom_560(launcher): + with launcher("bigscience/bloom-560m") as client: + yield client + + +@pytest.mark.asyncio +async def test_bloom_560m(bloom_560, snapshot_test): + await health_check(bloom_560, 60) + + response = await bloom_560.generate( + "Pour déguster un ortolan, il faut tout d'abord", + max_new_tokens=10, + top_p=0.9, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert snapshot_test(response) + + +@pytest.mark.asyncio +async def test_bloom_560m_all_params(bloom_560, snapshot_test): + await health_check(bloom_560, 60) + + response = await bloom_560.generate( + "Pour déguster un ortolan, il faut tout d'abord", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert snapshot_test(response) + + +@pytest.mark.asyncio +async def test_bloom_560m_load(bloom_560, generate_load, snapshot_test): + await health_check(bloom_560, 60) + + responses = await generate_load( + bloom_560, + "Pour déguster un ortolan, il faut tout d'abord", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + + assert snapshot_test(responses) diff --git a/integration-tests/models/test_bloom_560m_sharded.py b/integration-tests/models/test_bloom_560m_sharded.py new file mode 100644 index 00000000..bfb70253 --- /dev/null +++ b/integration-tests/models/test_bloom_560m_sharded.py @@ -0,0 +1,42 @@ +import pytest + +from utils import health_check + + +@pytest.fixture(scope="module") +def bloom_560m_sharded(launcher): + with launcher("bigscience/bloom-560m", num_shard=2) as client: + yield client + + +@pytest.mark.asyncio +async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot_test): + await health_check(bloom_560m_sharded, 60) + + response = await bloom_560m_sharded.generate( + "Pour déguster un ortolan, il faut tout d'abord", + max_new_tokens=10, + top_p=0.9, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert snapshot_test(response) + + +@pytest.mark.asyncio +async def test_bloom_560m_sharded_load( + bloom_560m_sharded, generate_load, snapshot_test +): + await health_check(bloom_560m_sharded, 60) + + responses = await generate_load( + bloom_560m_sharded, + "Pour déguster un ortolan, il faut tout d'abord", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + + assert snapshot_test(responses) diff --git a/integration-tests/models/test_flash_llama.py b/integration-tests/models/test_flash_llama.py new file mode 100644 index 00000000..4d1f2bcf --- /dev/null +++ b/integration-tests/models/test_flash_llama.py @@ -0,0 +1,56 @@ +import pytest + +from utils import health_check + + +@pytest.fixture(scope="module") +def flash_llama(launcher): + with launcher("huggingface/llama-7b", num_shard=2) as client: + yield client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama(flash_llama, snapshot_test): + await health_check(flash_llama, 120) + + response = await flash_llama.generate("Test request", max_new_tokens=10) + + assert response.details.generated_tokens == 10 + assert snapshot_test(response) + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_all_params(flash_llama, snapshot_test): + await health_check(flash_llama, 120) + + response = await flash_llama.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert snapshot_test(response) + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_load(flash_llama, generate_load, snapshot_test): + await health_check(flash_llama, 120) + + responses = await generate_load(flash_llama, "Test request", max_new_tokens=10, n=4) + + assert len(responses) == 4 + + assert snapshot_test(responses) diff --git a/integration-tests/models/test_flash_neox.py b/integration-tests/models/test_flash_neox.py new file mode 100644 index 00000000..8c981028 --- /dev/null +++ b/integration-tests/models/test_flash_neox.py @@ -0,0 +1,38 @@ +import pytest + +from utils import health_check + + +@pytest.fixture(scope="module") +def flash_neox(launcher): + with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2) as client: + yield client + + +@pytest.mark.asyncio +async def test_flash_neox(flash_neox, snapshot_test): + await health_check(flash_neox, 240) + + response = await flash_neox.generate( + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + ) + + assert response.details.generated_tokens == 10 + assert snapshot_test(response) + + +@pytest.mark.asyncio +async def test_flash_neox_load(flash_neox, generate_load, snapshot_test): + await health_check(flash_neox, 240) + + responses = await generate_load( + flash_neox, + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + + assert snapshot_test(responses) diff --git a/integration-tests/models/test_flash_santacoder.py b/integration-tests/models/test_flash_santacoder.py new file mode 100644 index 00000000..64a59d78 --- /dev/null +++ b/integration-tests/models/test_flash_santacoder.py @@ -0,0 +1,32 @@ +import pytest + +from utils import health_check + + +@pytest.fixture(scope="module") +def flash_santacoder(launcher): + with launcher("bigcode/santacoder") as client: + yield client + + +@pytest.mark.asyncio +async def test_flash_santacoder(flash_santacoder, snapshot_test): + await health_check(flash_santacoder, 60) + + response = await flash_santacoder.generate("def print_hello", max_new_tokens=10) + + assert response.details.generated_tokens == 10 + assert snapshot_test(response) + + +@pytest.mark.asyncio +async def test_flash_santacoder_load(flash_santacoder, generate_load, snapshot_test): + await health_check(flash_santacoder, 60) + + responses = await generate_load( + flash_santacoder, "def print_hello", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + + assert snapshot_test(responses) diff --git a/integration-tests/models/test_flash_starcoder.py b/integration-tests/models/test_flash_starcoder.py new file mode 100644 index 00000000..d43e92dc --- /dev/null +++ b/integration-tests/models/test_flash_starcoder.py @@ -0,0 +1,47 @@ +import pytest + +from utils import health_check + + +@pytest.fixture(scope="module") +def flash_starcoder(launcher): + with launcher("bigcode/starcoder", num_shard=2) as client: + yield client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder(flash_starcoder, snapshot_test): + await health_check(flash_starcoder, 240) + + response = await flash_starcoder.generate("def print_hello", max_new_tokens=10) + + assert response.details.generated_tokens == 10 + assert snapshot_test(response) + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder_default_params(flash_starcoder, snapshot_test): + await health_check(flash_starcoder, 240) + + response = await flash_starcoder.generate( + "def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0 + ) + + assert response.details.generated_tokens == 12 + assert snapshot_test(response) + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot_test): + await health_check(flash_starcoder, 240) + + responses = await generate_load( + flash_starcoder, "def print_hello", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + + assert snapshot_test(responses) diff --git a/integration-tests/models/test_mt0_base.py b/integration-tests/models/test_mt0_base.py new file mode 100644 index 00000000..7310a30f --- /dev/null +++ b/integration-tests/models/test_mt0_base.py @@ -0,0 +1,63 @@ +import pytest + +from utils import health_check + + +@pytest.fixture(scope="module") +def mt0_base(launcher): + with launcher("bigscience/mt0-base") as client: + yield client + + +@pytest.mark.asyncio +async def test_mt0_base(mt0_base, snapshot_test): + await health_check(mt0_base, 60) + + response = await mt0_base.generate( + "Why is the sky blue?", + max_new_tokens=10, + top_p=0.9, + seed=0, + ) + + assert response.details.generated_tokens == 5 + assert snapshot_test(response) + + +@pytest.mark.asyncio +async def test_mt0_base_all_params(mt0_base, snapshot_test): + await health_check(mt0_base, 60) + + response = await mt0_base.generate( + "Why is the sky blue?", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert snapshot_test(response) + + +@pytest.mark.asyncio +async def test_mt0_base_load(mt0_base, generate_load, snapshot_test): + await health_check(mt0_base, 60) + + responses = await generate_load( + mt0_base, + "Why is the sky blue?", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + + assert snapshot_test(responses) diff --git a/integration-tests/models/utils.py b/integration-tests/models/utils.py new file mode 100644 index 00000000..c47e4871 --- /dev/null +++ b/integration-tests/models/utils.py @@ -0,0 +1,15 @@ +import time + +from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError +from text_generation import AsyncClient + + +async def health_check(client: AsyncClient, timeout: int = 60): + assert timeout > 0 + for _ in range(timeout): + try: + await client.generate("test") + return + except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e: + time.sleep(1) + raise RuntimeError("Health check failed") diff --git a/integration-tests/requirements.txt b/integration-tests/requirements.txt new file mode 100644 index 00000000..9ecbb2ee --- /dev/null +++ b/integration-tests/requirements.txt @@ -0,0 +1,5 @@ +syrupy +text-generation==0.5.1 +pytest +pytest-asyncio==0.17.2 +docker \ No newline at end of file diff --git a/launcher/tests/bloom_560m.json b/launcher/tests/bloom_560m.json deleted file mode 100644 index 96f89f6b..00000000 --- a/launcher/tests/bloom_560m.json +++ /dev/null @@ -1,142 +0,0 @@ -{ - "generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException", - "details": { - "finish_reason": "length", - "generated_tokens": 20, - "seed": null, - "prefill": [ - { - "id": 10264, - "text": "Test", - "logprob": null - }, - { - "id": 8821, - "text": " request", - "logprob": -11.894989 - } - ], - "tokens": [ - { - "id": 17, - "text": ".", - "logprob": -1.8267672, - "special": false - }, - { - "id": 1587, - "text": "get", - "logprob": -2.4674969, - "special": false - }, - { - "id": 11, - "text": "(", - "logprob": -1.906001, - "special": false - }, - { - "id": 5, - "text": "\"", - "logprob": -1.2279545, - "special": false - }, - { - "id": 4899, - "text": "action", - "logprob": -4.170299, - "special": false - }, - { - "id": 5, - "text": "\"", - "logprob": -0.32478866, - "special": false - }, - { - "id": 12, - "text": ")", - "logprob": -1.0773665, - "special": false - }, - { - "id": 30, - "text": ";", - "logprob": -0.27640742, - "special": false - }, - { - "id": 837, - "text": "\n ", - "logprob": -1.6970354, - "special": false - }, - { - "id": 1320, - "text": " if", - "logprob": -1.4495516, - "special": false - }, - { - "id": 375, - "text": " (", - "logprob": -0.23609057, - "special": false - }, - { - "id": 4899, - "text": "action", - "logprob": -1.1916996, - "special": false - }, - { - "id": 3535, - "text": " ==", - "logprob": -0.8918753, - "special": false - }, - { - "id": 5109, - "text": " null", - "logprob": -0.3933342, - "special": false - }, - { - "id": 12, - "text": ")", - "logprob": -0.43212673, - "special": false - }, - { - "id": 731, - "text": " {", - "logprob": -0.17702064, - "special": false - }, - { - "id": 1260, - "text": "\n ", - "logprob": -0.07027565, - "special": false - }, - { - "id": 10519, - "text": " throw", - "logprob": -1.3915029, - "special": false - }, - { - "id": 2084, - "text": " new", - "logprob": -0.04201372, - "special": false - }, - { - "id": 150858, - "text": " RuntimeException", - "logprob": -1.7329919, - "special": false - } - ] - } -} \ No newline at end of file diff --git a/launcher/tests/integration_tests.rs b/launcher/tests/integration_tests.rs deleted file mode 100644 index 0d2b6c74..00000000 --- a/launcher/tests/integration_tests.rs +++ /dev/null @@ -1,172 +0,0 @@ -use float_eq::assert_float_eq; -use serde::Deserialize; -use serde_json::Value; -use std::fs::File; -use std::io::{BufRead, BufReader}; -use std::path::PathBuf; -use std::thread; -use std::thread::sleep; -use std::time::Duration; -use subprocess::{Popen, PopenConfig, Redirection}; - -#[derive(Deserialize)] -pub struct Token { - id: u32, - text: String, - logprob: Option, - special: bool, -} - -#[derive(Deserialize)] -struct Details { - finish_reason: String, - generated_tokens: u32, - tokens: Vec, -} - -#[derive(Deserialize)] -struct GeneratedText { - generated_text: String, - details: Details, -} - -fn start_launcher(model_id: String, num_shard: usize, port: usize, master_port: usize) -> Popen { - let argv = vec![ - "text-generation-launcher".to_string(), - "--model-id".to_string(), - model_id.clone(), - "--num-shard".to_string(), - num_shard.to_string(), - "--port".to_string(), - port.to_string(), - "--master-port".to_string(), - master_port.to_string(), - "--shard-uds-path".to_string(), - format!("/tmp/test-{}-{}-{}", num_shard, port, master_port), - ]; - - let mut launcher = Popen::create( - &argv, - PopenConfig { - stdout: Redirection::Pipe, - stderr: Redirection::Merge, - ..Default::default() - }, - ) - .expect("Could not start launcher"); - - // Redirect STDOUT and STDERR to the console - // (STDERR is merged into STDOUT) - let launcher_stdout = launcher.stdout.take().unwrap(); - - thread::spawn(move || { - let stdout = BufReader::new(launcher_stdout); - for line in stdout.lines() { - println!("{}", line.unwrap()); - } - }); - - for _ in 0..60 { - let health = reqwest::blocking::get(format!("http://localhost:{}/health", port)); - if health.is_ok() { - return launcher; - } - sleep(Duration::from_secs(2)); - } - - launcher.terminate().unwrap(); - launcher.wait().unwrap(); - panic!("failed to launch {}", model_id) -} - -fn test_model( - model_id: String, - num_shard: usize, - port: usize, - master_port: usize, -) -> GeneratedText { - let mut launcher = start_launcher(model_id, num_shard, port, master_port); - - let data = r#" - { - "inputs": "Test request", - "parameters": { - "details": true - } - }"#; - let req: Value = serde_json::from_str(data).unwrap(); - - let client = reqwest::blocking::Client::new(); - let res = client - .post(format!("http://localhost:{}/generate", port)) - .json(&req) - .send(); - - launcher.terminate().unwrap(); - launcher.wait().unwrap(); - - let result: GeneratedText = res.unwrap().json().unwrap(); - result -} - -fn read_json(name: &str) -> GeneratedText { - let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - d.push("tests/"); - d.push(name); - - let file = File::open(d).unwrap(); - let reader = BufReader::new(file); - - let result: GeneratedText = serde_json::from_reader(reader).unwrap(); - result -} - -fn compare_results(result: GeneratedText, expected: GeneratedText) { - assert_eq!(result.generated_text, expected.generated_text); - assert_eq!(result.details.finish_reason, expected.details.finish_reason); - assert_eq!( - result.details.generated_tokens, - expected.details.generated_tokens - ); - - for (token, expected_token) in result - .details - .tokens - .into_iter() - .zip(expected.details.tokens.into_iter()) - { - assert_eq!(token.id, expected_token.id); - assert_eq!(token.text, expected_token.text); - assert_eq!(token.special, expected_token.special); - if let Some(logprob) = token.logprob { - let expected_logprob = expected_token.logprob.unwrap(); - assert_float_eq!(logprob, expected_logprob, abs <= 0.001); - } else { - assert_eq!(token.logprob, expected_token.logprob); - } - } -} - -#[test] -fn test_bloom_560m() { - let expected = read_json("bloom_560m.json"); - - let result = test_model("bigscience/bloom-560m".to_string(), 1, 3000, 29500); - compare_results(result, expected); -} - -#[test] -fn test_bloom_560m_distributed() { - let expected = read_json("bloom_560m.json"); - - let result = test_model("bigscience/bloom-560m".to_string(), 2, 3001, 29501); - compare_results(result, expected); -} - -#[test] -fn test_mt0_base() { - let expected = read_json("mt0_base.json"); - - let result = test_model("bigscience/mt0-base".to_string(), 1, 3002, 29502); - compare_results(result, expected); -} diff --git a/launcher/tests/mt0_base.json b/launcher/tests/mt0_base.json deleted file mode 100644 index f5be63f9..00000000 --- a/launcher/tests/mt0_base.json +++ /dev/null @@ -1,137 +0,0 @@ -{ - "generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test", - "details": { - "finish_reason": "length", - "generated_tokens": 20, - "seed": null, - "prefill": [ - { - "id": 0, - "text": "", - "logprob": null - } - ], - "tokens": [ - { - "id": 259, - "text": "", - "logprob": -1.3656927, - "special": false - }, - { - "id": 215100, - "text": "\"\"\"", - "logprob": -2.6551573, - "special": false - }, - { - "id": 46138, - "text": "Test", - "logprob": -1.8059857, - "special": false - }, - { - "id": 287, - "text": " the", - "logprob": -1.2102449, - "special": false - }, - { - "id": 259, - "text": " ", - "logprob": -1.6057279, - "special": false - }, - { - "id": 49076, - "text": "contents", - "logprob": -3.6060903, - "special": false - }, - { - "id": 304, - "text": " of", - "logprob": -0.5270343, - "special": false - }, - { - "id": 287, - "text": " the", - "logprob": -0.62522805, - "special": false - }, - { - "id": 259, - "text": " ", - "logprob": -1.4069618, - "special": false - }, - { - "id": 49076, - "text": "contents", - "logprob": -2.621994, - "special": false - }, - { - "id": 304, - "text": " of", - "logprob": -1.3172221, - "special": false - }, - { - "id": 287, - "text": " the", - "logprob": -0.3501925, - "special": false - }, - { - "id": 259, - "text": " ", - "logprob": -0.7219573, - "special": false - }, - { - "id": 49076, - "text": "contents", - "logprob": -1.0494149, - "special": false - }, - { - "id": 260, - "text": ".", - "logprob": -1.0803378, - "special": false - }, - { - "id": 259, - "text": " ", - "logprob": -0.32933083, - "special": false - }, - { - "id": 215100, - "text": "\"\"\"", - "logprob": -0.11268901, - "special": false - }, - { - "id": 2978, - "text": " test", - "logprob": -1.5846587, - "special": false - }, - { - "id": 290, - "text": "_", - "logprob": -0.49796978, - "special": false - }, - { - "id": 4125, - "text": "test", - "logprob": -2.0026445, - "special": false - } - ] - } -} \ No newline at end of file diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index ed959291..9029e954 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -129,7 +129,7 @@ class BLOOMSharded(BLOOM): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): full_name = f"transformer.{name}" diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 11f3766e..54670b79 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -21,16 +21,14 @@ import torch import torch.distributed -from torch.nn import functional as F - from torch import nn from transformers.activations import ACT2FN from typing import Optional # Flash attention imports import flash_attn_cuda +import dropout_layer_norm -from flash_attn.layers.rotary import RotaryEmbedding from text_generation_server.utils.layers import ( FastLinear, TensorParallelRowLinear, @@ -332,15 +330,15 @@ class FlashLlamaModel(torch.nn.Module): self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads - def post_load_weights(self, load_in_8bit: bool = False): + def post_load_weights(self, quantize: Optional[str] = None): if isinstance(self.embed_tokens, TensorParallelEmbedding): self.embed_tokens.add_null_idx() for layer in self.layers: layer: FlashLlamaLayer - layer.self_attn.query_key_value.prepare_weights(load_in_8bit) - layer.self_attn.o_proj.prepare_weights(load_in_8bit) - layer.mlp.gate_up_proj.prepare_weights(load_in_8bit) - layer.mlp.down_proj.prepare_weights(load_in_8bit) + layer.self_attn.query_key_value.prepare_weights(quantize) + layer.self_attn.o_proj.prepare_weights(quantize) + layer.mlp.gate_up_proj.prepare_weights(quantize) + layer.mlp.down_proj.prepare_weights(quantize) def forward( self, @@ -429,8 +427,8 @@ class FlashLlamaForCausalLM(torch.nn.Module): else: self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) - def post_load_weights(self, load_in_8bit: bool = False): - self.model.post_load_weights(load_in_8bit) + def post_load_weights(self, quantize: Optional[str] = None): + self.model.post_load_weights(quantize) self.lm_head.prepare_weights() def forward( diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 369e8d4f..2c6b8da6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -21,8 +21,6 @@ import torch import torch.distributed -from torch.nn import functional as F - from torch import nn from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel @@ -32,7 +30,6 @@ from typing import Optional # Flash attention imports import flash_attn_cuda -from flash_attn.layers.rotary import RotaryEmbedding from text_generation_server.utils.layers import ( FastLinear, TensorParallelRowLinear, @@ -345,16 +342,16 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): self.head_size = self.layers[0].attention.head_size self.num_heads = self.layers[0].attention.num_heads - def post_load_weights(self, load_in_8bit=False): + def post_load_weights(self, quantize: Optional[str] = None): if isinstance(self.embed_in, TensorParallelEmbedding): self.embed_in.add_null_idx() for layer in self.layers: layer: FlashNeoXLayer layer.attention.shuffle_qkv_dims() - layer.attention.query_key_value.prepare_weights(load_in_8bit) - layer.attention.dense.prepare_weights(load_in_8bit) - layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit) - layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit) + layer.attention.query_key_value.prepare_weights(quantize) + layer.attention.dense.prepare_weights(quantize) + layer.mlp.dense_h_to_4h.prepare_weights(quantize) + layer.mlp.dense_4h_to_h.prepare_weights(quantize) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): @@ -457,8 +454,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): config.hidden_size, config.vocab_size, bias=False ) - def post_load_weights(self, load_in_8bit=False): - self.gpt_neox.post_load_weights(load_in_8bit) + def post_load_weights(self, quantize: Optional[str] = None): + self.gpt_neox.post_load_weights(quantize) self.embed_out.prepare_weights() @classmethod diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 9451b01a..9bded805 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -1,8 +1,6 @@ import torch import torch.distributed -import torch.nn.functional as F - from torch import nn from transformers.activations import ACT2FN from typing import Optional @@ -261,16 +259,16 @@ class FlashSantacoderModel(nn.Module): self.head_size = self.h[0].attn.head_size self.num_heads = self.h[0].attn.num_heads - def post_load_weights(self, load_in_8bit: bool = False): + def post_load_weights(self, quantize: Optional[str] = None): if self.tp_embeddings: self.wte.add_null_idx() self.wpe.add_null_idx() for layer in self.h: layer: Block - layer.attn.c_attn.prepare_weights(load_in_8bit) - layer.attn.c_proj.prepare_weights(load_in_8bit) - layer.mlp.c_fc.prepare_weights(load_in_8bit) - layer.mlp.c_proj.prepare_weights(load_in_8bit) + layer.attn.c_attn.prepare_weights(quantize) + layer.attn.c_proj.prepare_weights(quantize) + layer.mlp.c_fc.prepare_weights(quantize) + layer.mlp.c_proj.prepare_weights(quantize) def forward( self, @@ -347,8 +345,8 @@ class FlashSantacoderForCausalLM(nn.Module): else: self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) - def post_load_weights(self, load_in_8bit: bool = False): - self.transformer.post_load_weights(load_in_8bit) + def post_load_weights(self, quantize: Optional[str] = None): + self.transformer.post_load_weights(quantize) self.lm_head.prepare_weights() def forward( diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index aa0b4483..b775bd79 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -28,7 +28,12 @@ tracer = trace.get_tracer(__name__) class FlashLlama(FlashCausalLM): - def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + ): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 @@ -72,14 +77,14 @@ class FlashLlama(FlashCausalLM): def load_weights( model, filenames: List[Path], - quantize: bool, + quantize: Optional[str], device: torch.device, dtype: torch.dtype, ): for filename in filenames: state_dict = torch.load(filename, map_location="cpu") for key, value in state_dict.items(): - value = value.to(device if not quantize else "cpu").to(dtype) + value = value.to(device if quantize is None else "cpu").to(dtype) layer_name = ".".join(key.split(".")[:4]) @@ -199,7 +204,7 @@ class FlashLlamaSharded(FlashLlama): def load_weights( model, filenames: List[str], - quantize: bool, + quantize: Optional[str], device: torch.device, dtype: torch.dtype, rank: int, @@ -207,7 +212,7 @@ class FlashLlamaSharded(FlashLlama): ): for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): slice_ = f.get_slice(name) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index fc741f55..0924f107 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -23,7 +23,12 @@ tracer = trace.get_tracer(__name__) class FlashNeoX(FlashCausalLM): - def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + ): super(FlashNeoX, self).__init__( FlashGPTNeoXForCausalLM, model_id, revision, quantize ) @@ -31,7 +36,10 @@ class FlashNeoX(FlashCausalLM): class FlashNeoXSharded(FlashNeoX): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -89,7 +97,7 @@ class FlashNeoXSharded(FlashNeoX): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): module_name, param_name = name.rsplit(".", 1) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index f810bb0b..031a67eb 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -27,7 +27,12 @@ tracer = trace.get_tracer(__name__) class FlashSantacoder(FlashCausalLM): - def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + ): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 @@ -84,7 +89,7 @@ class FlashSantacoder(FlashCausalLM): for filename in filenames: state_dict = torch.load(filename, map_location="cpu") for key, value in state_dict.items(): - value = value.to(device if not quantize else "cpu").to(dtype) + value = value.to(device if quantize is None else "cpu").to(dtype) layer_name = ".".join(key.split(".")[:4]) @@ -170,7 +175,10 @@ class FlashSantacoder(FlashCausalLM): class FlashSantacoderSharded(FlashSantacoder): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -221,7 +229,7 @@ class FlashSantacoderSharded(FlashSantacoder): def load_weights( model, filenames: List[str], - quantize: bool, + quantize: Optional[str], device: torch.device, dtype: torch.dtype, rank: int, @@ -230,7 +238,7 @@ class FlashSantacoderSharded(FlashSantacoder): ): for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for key in f.keys(): slice_ = f.get_slice(key) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index a0111250..d1e5e841 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -255,7 +255,7 @@ class GalacticaSharded(Galactica): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): if name == "lm_head.weight": diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 3e8557b2..f95e5be2 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -94,7 +94,7 @@ class GPTNeoxSharded(CausalLM): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): module_name, param_name = name.rsplit(".", 1) diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index c83c3351..093cf70a 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -48,7 +48,10 @@ class OPT(CausalLM): class OPTSharded(OPT): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -107,7 +110,7 @@ class OPTSharded(OPT): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): if name == "lm_head.weight": diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index c8521dbf..8e3826a4 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -97,7 +97,7 @@ class T5Sharded(Seq2SeqLM): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): module_name, param_name = name.rsplit(".", 1) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 3383bf4b..7605639d 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1,6 +1,8 @@ import torch from torch import nn +from torch.nn import functional as F +from typing import Optional HAS_BITS_AND_BYTES = True try: @@ -22,7 +24,7 @@ class FastLinear(nn.Linear): self.quantized = False self.bnb_linear = None - def prepare_weights(self, quantize: bool = False): + def prepare_weights(self, quantize: Optional[str] = None): if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( @@ -126,6 +128,7 @@ class TensorParallelEmbedding(nn.Embedding): num_embeddings, embedding_dim, process_group: torch.distributed.ProcessGroup, + reduce=True, padding_idx=None, max_norm=None, norm_type=2.0, @@ -135,6 +138,7 @@ class TensorParallelEmbedding(nn.Embedding): device=None, dtype=None, ): + self.reduce = reduce self.process_group = process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() @@ -177,7 +181,8 @@ class TensorParallelEmbedding(nn.Embedding): input - self.min_id, ) out = super().forward(input) - torch.distributed.all_reduce(out, group=self.process_group) + if self.reduce: + torch.distributed.all_reduce(out, group=self.process_group) return out