feat: add snapshot testing (#282)
This commit is contained in:
parent
f58f0a0364
commit
e71471bec9
|
@ -20,10 +20,6 @@ on:
|
||||||
branches:
|
branches:
|
||||||
- 'main'
|
- 'main'
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
start-runner:
|
start-runner:
|
||||||
name: Start self-hosted EC2 runner
|
name: Start self-hosted EC2 runner
|
||||||
|
@ -61,6 +57,9 @@ jobs:
|
||||||
]
|
]
|
||||||
|
|
||||||
build-and-push-image:
|
build-and-push-image:
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
needs: start-runner # required to start the main job when the runner is ready
|
needs: start-runner # required to start the main job when the runner is ready
|
||||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||||
permissions:
|
permissions:
|
||||||
|
@ -108,7 +107,19 @@ jobs:
|
||||||
username: ${{ secrets.AZURE_DOCKER_USERNAME }}
|
username: ${{ secrets.AZURE_DOCKER_USERNAME }}
|
||||||
password: ${{ secrets.AZURE_DOCKER_PASSWORD }}
|
password: ${{ secrets.AZURE_DOCKER_PASSWORD }}
|
||||||
registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io
|
registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io
|
||||||
|
# If pull request
|
||||||
- name: Extract metadata (tags, labels) for Docker
|
- name: Extract metadata (tags, labels) for Docker
|
||||||
|
if: ${{ github.event_name == 'pull_request' }}
|
||||||
|
id: meta-pr
|
||||||
|
uses: docker/metadata-action@v4.3.0
|
||||||
|
with:
|
||||||
|
images: |
|
||||||
|
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||||
|
tags: |
|
||||||
|
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}
|
||||||
|
# If main, release or tag
|
||||||
|
- name: Extract metadata (tags, labels) for Docker
|
||||||
|
if: ${{ github.event_name != 'pull_request' }}
|
||||||
id: meta
|
id: meta
|
||||||
uses: docker/metadata-action@v4.3.0
|
uses: docker/metadata-action@v4.3.0
|
||||||
with:
|
with:
|
||||||
|
@ -129,13 +140,13 @@ jobs:
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: Dockerfile
|
file: Dockerfile
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: true
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
build-args: |
|
build-args: |
|
||||||
GIT_SHA=${{ env.GITHUB_SHA }}
|
GIT_SHA=${{ env.GITHUB_SHA }}
|
||||||
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}
|
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
||||||
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
|
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
|
||||||
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
|
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
|
||||||
# Sign the resulting Docker image digest except on PRs.
|
# Sign the resulting Docker image digest except on PRs.
|
||||||
|
@ -172,11 +183,48 @@ jobs:
|
||||||
with:
|
with:
|
||||||
sarif_file: 'trivy-results.sarif'
|
sarif_file: 'trivy-results.sarif'
|
||||||
|
|
||||||
|
integration-tests:
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
needs:
|
||||||
|
- start-runner
|
||||||
|
- build-and-push-image # Wait for the docker image to be built
|
||||||
|
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||||
|
env:
|
||||||
|
DOCKER_VOLUME: /cache
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Inject slug/short variables
|
||||||
|
uses: rlespinasse/github-slug-action@v4.4.1
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: 3.9
|
||||||
|
- name: Tailscale
|
||||||
|
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
|
||||||
|
with:
|
||||||
|
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
|
||||||
|
- name: Prepare disks
|
||||||
|
run: |
|
||||||
|
sudo mkfs -t ext4 /dev/nvme1n1
|
||||||
|
sudo mkdir ${{ env.DOCKER_VOLUME }}
|
||||||
|
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
|
||||||
|
- name: Install
|
||||||
|
run: |
|
||||||
|
make install-integration-tests
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
|
||||||
|
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||||
|
pytest -s -vv integration-tests
|
||||||
|
|
||||||
stop-runner:
|
stop-runner:
|
||||||
name: Stop self-hosted EC2 runner
|
name: Stop self-hosted EC2 runner
|
||||||
needs:
|
needs:
|
||||||
- start-runner
|
- start-runner
|
||||||
- build-and-push-image
|
- build-and-push-image
|
||||||
|
- integration-tests
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
AWS_REGION: us-east-1
|
AWS_REGION: us-east-1
|
||||||
|
|
11
Makefile
11
Makefile
|
@ -1,6 +1,9 @@
|
||||||
install-server:
|
install-server:
|
||||||
cd server && make install
|
cd server && make install
|
||||||
|
|
||||||
|
install-integration-tests:
|
||||||
|
cd integration-tests && pip install -r requirements.txt
|
||||||
|
|
||||||
install-router:
|
install-router:
|
||||||
cd router && cargo install --path .
|
cd router && cargo install --path .
|
||||||
|
|
||||||
|
@ -18,9 +21,15 @@ server-dev:
|
||||||
router-dev:
|
router-dev:
|
||||||
cd router && cargo run -- --port 8080
|
cd router && cargo run -- --port 8080
|
||||||
|
|
||||||
integration-tests: install-router install-launcher
|
rust-tests: install-router install-launcher
|
||||||
cargo test
|
cargo test
|
||||||
|
|
||||||
|
integration-tests: install-integration-tests
|
||||||
|
pytest -s -vv -m "not private" integration-tests
|
||||||
|
|
||||||
|
update-integration-tests: install-integration-tests
|
||||||
|
pytest -s -vv --snapshot-update integration-tests
|
||||||
|
|
||||||
python-server-tests:
|
python-server-tests:
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 pytest server/tests
|
HF_HUB_ENABLE_HF_TRANSFER=1 pytest server/tests
|
||||||
|
|
||||||
|
|
|
@ -253,5 +253,7 @@ make python-client-tests
|
||||||
# or both server and client tests
|
# or both server and client tests
|
||||||
make python-tests
|
make python-tests
|
||||||
# rust cargo tests
|
# rust cargo tests
|
||||||
|
make rust-tests
|
||||||
|
# integration tests
|
||||||
make integration-tests
|
make integration-tests
|
||||||
```
|
```
|
||||||
|
|
|
@ -0,0 +1,146 @@
|
||||||
|
import subprocess
|
||||||
|
import contextlib
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import docker
|
||||||
|
|
||||||
|
from docker.errors import NotFound
|
||||||
|
from typing import Optional, List
|
||||||
|
from syrupy.filters import props
|
||||||
|
|
||||||
|
from text_generation import AsyncClient
|
||||||
|
from text_generation.types import Response
|
||||||
|
|
||||||
|
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||||
|
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
|
||||||
|
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def snapshot_test(snapshot):
|
||||||
|
return lambda value: value == snapshot(exclude=props("logprob"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def event_loop():
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def launcher(event_loop):
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def local_launcher(
|
||||||
|
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
|
||||||
|
):
|
||||||
|
port = 9999
|
||||||
|
master_port = 19999
|
||||||
|
|
||||||
|
shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server"
|
||||||
|
|
||||||
|
args = [
|
||||||
|
"text-generation-launcher",
|
||||||
|
"--model-id",
|
||||||
|
model_id,
|
||||||
|
"--port",
|
||||||
|
str(port),
|
||||||
|
"--master-port",
|
||||||
|
str(master_port),
|
||||||
|
"--shard-uds-path",
|
||||||
|
shard_uds_path,
|
||||||
|
]
|
||||||
|
|
||||||
|
if num_shard is not None:
|
||||||
|
args.extend(["--num-shard", str(num_shard)])
|
||||||
|
if quantize:
|
||||||
|
args.append("--quantize")
|
||||||
|
|
||||||
|
with subprocess.Popen(
|
||||||
|
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||||
|
) as process:
|
||||||
|
yield AsyncClient(f"http://localhost:{port}")
|
||||||
|
|
||||||
|
process.terminate()
|
||||||
|
process.wait(60)
|
||||||
|
|
||||||
|
launcher_output = process.stdout.read().decode("utf-8")
|
||||||
|
print(launcher_output)
|
||||||
|
|
||||||
|
process.stdout.close()
|
||||||
|
process.stderr.close()
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def docker_launcher(
|
||||||
|
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
|
||||||
|
):
|
||||||
|
port = 9999
|
||||||
|
|
||||||
|
args = ["--model-id", model_id, "--env"]
|
||||||
|
|
||||||
|
if num_shard is not None:
|
||||||
|
args.extend(["--num-shard", str(num_shard)])
|
||||||
|
if quantize:
|
||||||
|
args.append("--quantize")
|
||||||
|
|
||||||
|
client = docker.from_env()
|
||||||
|
|
||||||
|
container_name = f"tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
container = client.containers.get(container_name)
|
||||||
|
container.stop()
|
||||||
|
container.wait()
|
||||||
|
except NotFound:
|
||||||
|
pass
|
||||||
|
|
||||||
|
gpu_count = num_shard if num_shard is not None else 1
|
||||||
|
|
||||||
|
env = {}
|
||||||
|
if HUGGING_FACE_HUB_TOKEN is not None:
|
||||||
|
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN
|
||||||
|
|
||||||
|
volumes = []
|
||||||
|
if DOCKER_VOLUME:
|
||||||
|
volumes = [f"{DOCKER_VOLUME}:/data"]
|
||||||
|
|
||||||
|
container = client.containers.run(
|
||||||
|
DOCKER_IMAGE,
|
||||||
|
command=args,
|
||||||
|
name=container_name,
|
||||||
|
environment=env,
|
||||||
|
auto_remove=True,
|
||||||
|
detach=True,
|
||||||
|
device_requests=[
|
||||||
|
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
|
||||||
|
],
|
||||||
|
volumes=volumes,
|
||||||
|
ports={"80/tcp": port},
|
||||||
|
)
|
||||||
|
|
||||||
|
yield AsyncClient(f"http://localhost:{port}")
|
||||||
|
|
||||||
|
container.stop()
|
||||||
|
|
||||||
|
container_output = container.logs().decode("utf-8")
|
||||||
|
print(container_output)
|
||||||
|
|
||||||
|
if DOCKER_IMAGE is not None:
|
||||||
|
return docker_launcher
|
||||||
|
return local_launcher
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def generate_load():
|
||||||
|
async def generate_load_inner(
|
||||||
|
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
|
||||||
|
) -> List[Response]:
|
||||||
|
futures = [
|
||||||
|
client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n)
|
||||||
|
]
|
||||||
|
|
||||||
|
results = await asyncio.gather(*futures)
|
||||||
|
return [r.dict() for r in results]
|
||||||
|
|
||||||
|
return generate_load_inner
|
|
@ -0,0 +1,627 @@
|
||||||
|
# serializer version: 1
|
||||||
|
# name: test_bloom_560m
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 17934,
|
||||||
|
'text': 'Pour',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 49833,
|
||||||
|
'text': ' dég',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 21543,
|
||||||
|
'text': 'uster',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 447,
|
||||||
|
'text': ' un',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 46341,
|
||||||
|
'text': ' ort',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 35567,
|
||||||
|
'text': 'olan',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 15,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1669,
|
||||||
|
'text': ' il',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 11580,
|
||||||
|
'text': ' faut',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3913,
|
||||||
|
'text': ' tout',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 39261,
|
||||||
|
'text': " d'abord",
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': 0,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 578,
|
||||||
|
'special': False,
|
||||||
|
'text': ' le',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 5608,
|
||||||
|
'special': False,
|
||||||
|
'text': ' faire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 159570,
|
||||||
|
'special': False,
|
||||||
|
'text': ' réch',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 810,
|
||||||
|
'special': False,
|
||||||
|
'text': 'au',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 12736,
|
||||||
|
'special': False,
|
||||||
|
'text': 'ffer',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1742,
|
||||||
|
'special': False,
|
||||||
|
'text': ' au',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 6105,
|
||||||
|
'special': False,
|
||||||
|
'text': ' bain',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 88254,
|
||||||
|
'special': False,
|
||||||
|
'text': '-mar',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 641,
|
||||||
|
'special': False,
|
||||||
|
'text': 'ie',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2940,
|
||||||
|
'special': False,
|
||||||
|
'text': ' avec',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': ' le faire réchauffer au bain-marie avec',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_bloom_560m_all_params
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 15,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1669,
|
||||||
|
'text': ' il',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 11580,
|
||||||
|
'text': ' faut',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3913,
|
||||||
|
'text': ' tout',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 39261,
|
||||||
|
'text': " d'abord",
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': 0,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 408,
|
||||||
|
'special': False,
|
||||||
|
'text': ' que',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 20288,
|
||||||
|
'special': False,
|
||||||
|
'text': " l'on",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 22255,
|
||||||
|
'special': False,
|
||||||
|
'text': ' trouve',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1622,
|
||||||
|
'special': False,
|
||||||
|
'text': ' une',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 187079,
|
||||||
|
'special': False,
|
||||||
|
'text': ' posture',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 501,
|
||||||
|
'special': False,
|
||||||
|
'text': ' par',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 8741,
|
||||||
|
'special': False,
|
||||||
|
'text': ' rapport',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 693,
|
||||||
|
'special': False,
|
||||||
|
'text': ' à',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 366,
|
||||||
|
'special': False,
|
||||||
|
'text': ' la',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 36503,
|
||||||
|
'special': False,
|
||||||
|
'text': ' pratique',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': "Pour déguster un ortolan, il faut tout d'abord que l'on trouve une posture par rapport à la pratique",
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_bloom_560m_load
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 17934,
|
||||||
|
'text': 'Pour',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 49833,
|
||||||
|
'text': ' dég',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 21543,
|
||||||
|
'text': 'uster',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 447,
|
||||||
|
'text': ' un',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 46341,
|
||||||
|
'text': ' ort',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 35567,
|
||||||
|
'text': 'olan',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 15,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1669,
|
||||||
|
'text': ' il',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 11580,
|
||||||
|
'text': ' faut',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3913,
|
||||||
|
'text': ' tout',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 39261,
|
||||||
|
'text': " d'abord",
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 578,
|
||||||
|
'special': False,
|
||||||
|
'text': ' le',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 5608,
|
||||||
|
'special': False,
|
||||||
|
'text': ' faire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1767,
|
||||||
|
'special': False,
|
||||||
|
'text': ' cu',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1273,
|
||||||
|
'special': False,
|
||||||
|
'text': 'ire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1486,
|
||||||
|
'special': False,
|
||||||
|
'text': ' dans',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 283,
|
||||||
|
'special': False,
|
||||||
|
'text': ' de',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 40410,
|
||||||
|
'special': False,
|
||||||
|
'text': " l'eau",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 20226,
|
||||||
|
'special': False,
|
||||||
|
'text': ' bou',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 172483,
|
||||||
|
'special': False,
|
||||||
|
'text': 'illante',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2805,
|
||||||
|
'special': False,
|
||||||
|
'text': ' sal',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': " le faire cuire dans de l'eau bouillante sal",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 17934,
|
||||||
|
'text': 'Pour',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 49833,
|
||||||
|
'text': ' dég',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 21543,
|
||||||
|
'text': 'uster',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 447,
|
||||||
|
'text': ' un',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 46341,
|
||||||
|
'text': ' ort',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 35567,
|
||||||
|
'text': 'olan',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 15,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1669,
|
||||||
|
'text': ' il',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 11580,
|
||||||
|
'text': ' faut',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3913,
|
||||||
|
'text': ' tout',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 39261,
|
||||||
|
'text': " d'abord",
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 578,
|
||||||
|
'special': False,
|
||||||
|
'text': ' le',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 5608,
|
||||||
|
'special': False,
|
||||||
|
'text': ' faire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1767,
|
||||||
|
'special': False,
|
||||||
|
'text': ' cu',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1273,
|
||||||
|
'special': False,
|
||||||
|
'text': 'ire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1486,
|
||||||
|
'special': False,
|
||||||
|
'text': ' dans',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 283,
|
||||||
|
'special': False,
|
||||||
|
'text': ' de',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 40410,
|
||||||
|
'special': False,
|
||||||
|
'text': " l'eau",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 20226,
|
||||||
|
'special': False,
|
||||||
|
'text': ' bou',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 172483,
|
||||||
|
'special': False,
|
||||||
|
'text': 'illante',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2805,
|
||||||
|
'special': False,
|
||||||
|
'text': ' sal',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': " le faire cuire dans de l'eau bouillante sal",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 17934,
|
||||||
|
'text': 'Pour',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 49833,
|
||||||
|
'text': ' dég',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 21543,
|
||||||
|
'text': 'uster',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 447,
|
||||||
|
'text': ' un',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 46341,
|
||||||
|
'text': ' ort',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 35567,
|
||||||
|
'text': 'olan',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 15,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1669,
|
||||||
|
'text': ' il',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 11580,
|
||||||
|
'text': ' faut',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3913,
|
||||||
|
'text': ' tout',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 39261,
|
||||||
|
'text': " d'abord",
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 578,
|
||||||
|
'special': False,
|
||||||
|
'text': ' le',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 5608,
|
||||||
|
'special': False,
|
||||||
|
'text': ' faire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1767,
|
||||||
|
'special': False,
|
||||||
|
'text': ' cu',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1273,
|
||||||
|
'special': False,
|
||||||
|
'text': 'ire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1486,
|
||||||
|
'special': False,
|
||||||
|
'text': ' dans',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 283,
|
||||||
|
'special': False,
|
||||||
|
'text': ' de',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 40410,
|
||||||
|
'special': False,
|
||||||
|
'text': " l'eau",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 20226,
|
||||||
|
'special': False,
|
||||||
|
'text': ' bou',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 172483,
|
||||||
|
'special': False,
|
||||||
|
'text': 'illante',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2805,
|
||||||
|
'special': False,
|
||||||
|
'text': ' sal',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': " le faire cuire dans de l'eau bouillante sal",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 17934,
|
||||||
|
'text': 'Pour',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 49833,
|
||||||
|
'text': ' dég',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 21543,
|
||||||
|
'text': 'uster',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 447,
|
||||||
|
'text': ' un',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 46341,
|
||||||
|
'text': ' ort',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 35567,
|
||||||
|
'text': 'olan',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 15,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1669,
|
||||||
|
'text': ' il',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 11580,
|
||||||
|
'text': ' faut',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3913,
|
||||||
|
'text': ' tout',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 39261,
|
||||||
|
'text': " d'abord",
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 578,
|
||||||
|
'special': False,
|
||||||
|
'text': ' le',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 5608,
|
||||||
|
'special': False,
|
||||||
|
'text': ' faire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1767,
|
||||||
|
'special': False,
|
||||||
|
'text': ' cu',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1273,
|
||||||
|
'special': False,
|
||||||
|
'text': 'ire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1486,
|
||||||
|
'special': False,
|
||||||
|
'text': ' dans',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 283,
|
||||||
|
'special': False,
|
||||||
|
'text': ' de',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 40410,
|
||||||
|
'special': False,
|
||||||
|
'text': " l'eau",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 20226,
|
||||||
|
'special': False,
|
||||||
|
'text': ' bou',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 172483,
|
||||||
|
'special': False,
|
||||||
|
'text': 'illante',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2805,
|
||||||
|
'special': False,
|
||||||
|
'text': ' sal',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': " le faire cuire dans de l'eau bouillante sal",
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
|
@ -0,0 +1,542 @@
|
||||||
|
# serializer version: 1
|
||||||
|
# name: test_bloom_560m_sharded
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 17934,
|
||||||
|
'text': 'Pour',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 49833,
|
||||||
|
'text': ' dég',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 21543,
|
||||||
|
'text': 'uster',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 447,
|
||||||
|
'text': ' un',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 46341,
|
||||||
|
'text': ' ort',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 35567,
|
||||||
|
'text': 'olan',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 15,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1669,
|
||||||
|
'text': ' il',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 11580,
|
||||||
|
'text': ' faut',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3913,
|
||||||
|
'text': ' tout',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 39261,
|
||||||
|
'text': " d'abord",
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': 0,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 578,
|
||||||
|
'special': False,
|
||||||
|
'text': ' le',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 5608,
|
||||||
|
'special': False,
|
||||||
|
'text': ' faire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 159570,
|
||||||
|
'special': False,
|
||||||
|
'text': ' réch',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 810,
|
||||||
|
'special': False,
|
||||||
|
'text': 'au',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 12736,
|
||||||
|
'special': False,
|
||||||
|
'text': 'ffer',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1742,
|
||||||
|
'special': False,
|
||||||
|
'text': ' au',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 6105,
|
||||||
|
'special': False,
|
||||||
|
'text': ' bain',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 88254,
|
||||||
|
'special': False,
|
||||||
|
'text': '-mar',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 641,
|
||||||
|
'special': False,
|
||||||
|
'text': 'ie',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2940,
|
||||||
|
'special': False,
|
||||||
|
'text': ' avec',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': ' le faire réchauffer au bain-marie avec',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_bloom_560m_sharded_load
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 17934,
|
||||||
|
'text': 'Pour',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 49833,
|
||||||
|
'text': ' dég',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 21543,
|
||||||
|
'text': 'uster',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 447,
|
||||||
|
'text': ' un',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 46341,
|
||||||
|
'text': ' ort',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 35567,
|
||||||
|
'text': 'olan',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 15,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1669,
|
||||||
|
'text': ' il',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 11580,
|
||||||
|
'text': ' faut',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3913,
|
||||||
|
'text': ' tout',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 39261,
|
||||||
|
'text': " d'abord",
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 578,
|
||||||
|
'special': False,
|
||||||
|
'text': ' le',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 5608,
|
||||||
|
'special': False,
|
||||||
|
'text': ' faire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1767,
|
||||||
|
'special': False,
|
||||||
|
'text': ' cu',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1273,
|
||||||
|
'special': False,
|
||||||
|
'text': 'ire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1486,
|
||||||
|
'special': False,
|
||||||
|
'text': ' dans',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 283,
|
||||||
|
'special': False,
|
||||||
|
'text': ' de',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 40410,
|
||||||
|
'special': False,
|
||||||
|
'text': " l'eau",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 20226,
|
||||||
|
'special': False,
|
||||||
|
'text': ' bou',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 172483,
|
||||||
|
'special': False,
|
||||||
|
'text': 'illante',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2805,
|
||||||
|
'special': False,
|
||||||
|
'text': ' sal',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': " le faire cuire dans de l'eau bouillante sal",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 17934,
|
||||||
|
'text': 'Pour',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 49833,
|
||||||
|
'text': ' dég',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 21543,
|
||||||
|
'text': 'uster',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 447,
|
||||||
|
'text': ' un',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 46341,
|
||||||
|
'text': ' ort',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 35567,
|
||||||
|
'text': 'olan',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 15,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1669,
|
||||||
|
'text': ' il',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 11580,
|
||||||
|
'text': ' faut',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3913,
|
||||||
|
'text': ' tout',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 39261,
|
||||||
|
'text': " d'abord",
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 578,
|
||||||
|
'special': False,
|
||||||
|
'text': ' le',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 5608,
|
||||||
|
'special': False,
|
||||||
|
'text': ' faire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1767,
|
||||||
|
'special': False,
|
||||||
|
'text': ' cu',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1273,
|
||||||
|
'special': False,
|
||||||
|
'text': 'ire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1486,
|
||||||
|
'special': False,
|
||||||
|
'text': ' dans',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 283,
|
||||||
|
'special': False,
|
||||||
|
'text': ' de',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 40410,
|
||||||
|
'special': False,
|
||||||
|
'text': " l'eau",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 20226,
|
||||||
|
'special': False,
|
||||||
|
'text': ' bou',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 172483,
|
||||||
|
'special': False,
|
||||||
|
'text': 'illante',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2805,
|
||||||
|
'special': False,
|
||||||
|
'text': ' sal',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': " le faire cuire dans de l'eau bouillante sal",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 17934,
|
||||||
|
'text': 'Pour',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 49833,
|
||||||
|
'text': ' dég',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 21543,
|
||||||
|
'text': 'uster',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 447,
|
||||||
|
'text': ' un',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 46341,
|
||||||
|
'text': ' ort',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 35567,
|
||||||
|
'text': 'olan',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 15,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1669,
|
||||||
|
'text': ' il',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 11580,
|
||||||
|
'text': ' faut',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3913,
|
||||||
|
'text': ' tout',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 39261,
|
||||||
|
'text': " d'abord",
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 578,
|
||||||
|
'special': False,
|
||||||
|
'text': ' le',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 5608,
|
||||||
|
'special': False,
|
||||||
|
'text': ' faire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1767,
|
||||||
|
'special': False,
|
||||||
|
'text': ' cu',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1273,
|
||||||
|
'special': False,
|
||||||
|
'text': 'ire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1486,
|
||||||
|
'special': False,
|
||||||
|
'text': ' dans',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 283,
|
||||||
|
'special': False,
|
||||||
|
'text': ' de',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 40410,
|
||||||
|
'special': False,
|
||||||
|
'text': " l'eau",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 20226,
|
||||||
|
'special': False,
|
||||||
|
'text': ' bou',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 172483,
|
||||||
|
'special': False,
|
||||||
|
'text': 'illante',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2805,
|
||||||
|
'special': False,
|
||||||
|
'text': ' sal',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': " le faire cuire dans de l'eau bouillante sal",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 17934,
|
||||||
|
'text': 'Pour',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 49833,
|
||||||
|
'text': ' dég',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 21543,
|
||||||
|
'text': 'uster',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 447,
|
||||||
|
'text': ' un',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 46341,
|
||||||
|
'text': ' ort',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 35567,
|
||||||
|
'text': 'olan',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 15,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1669,
|
||||||
|
'text': ' il',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 11580,
|
||||||
|
'text': ' faut',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3913,
|
||||||
|
'text': ' tout',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 39261,
|
||||||
|
'text': " d'abord",
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 578,
|
||||||
|
'special': False,
|
||||||
|
'text': ' le',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 5608,
|
||||||
|
'special': False,
|
||||||
|
'text': ' faire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1767,
|
||||||
|
'special': False,
|
||||||
|
'text': ' cu',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1273,
|
||||||
|
'special': False,
|
||||||
|
'text': 'ire',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1486,
|
||||||
|
'special': False,
|
||||||
|
'text': ' dans',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 283,
|
||||||
|
'special': False,
|
||||||
|
'text': ' de',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 40410,
|
||||||
|
'special': False,
|
||||||
|
'text': " l'eau",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 20226,
|
||||||
|
'special': False,
|
||||||
|
'text': ' bou',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 172483,
|
||||||
|
'special': False,
|
||||||
|
'text': 'illante',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2805,
|
||||||
|
'special': False,
|
||||||
|
'text': ' sal',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': " le faire cuire dans de l'eau bouillante sal",
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
|
@ -0,0 +1,465 @@
|
||||||
|
# serializer version: 1
|
||||||
|
# name: test_flash_llama
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 1,
|
||||||
|
'text': '<s>',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 4321,
|
||||||
|
'text': 'Test',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2009,
|
||||||
|
'text': 'request',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 363,
|
||||||
|
'special': False,
|
||||||
|
'text': ' for',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 847,
|
||||||
|
'special': False,
|
||||||
|
'text': ' /',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2754,
|
||||||
|
'special': False,
|
||||||
|
'text': 'api',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29914,
|
||||||
|
'special': False,
|
||||||
|
'text': '/',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29894,
|
||||||
|
'special': False,
|
||||||
|
'text': 'v',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29896,
|
||||||
|
'special': False,
|
||||||
|
'text': '1',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29914,
|
||||||
|
'special': False,
|
||||||
|
'text': '/',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 16418,
|
||||||
|
'special': False,
|
||||||
|
'text': 'projects',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29914,
|
||||||
|
'special': False,
|
||||||
|
'text': '/',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29896,
|
||||||
|
'special': False,
|
||||||
|
'text': '1',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': 'for /api/v1/projects/1',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_flash_llama_all_params
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 1,
|
||||||
|
'text': '<s>',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 4321,
|
||||||
|
'text': 'Test',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2009,
|
||||||
|
'text': 'request',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': 0,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 5229,
|
||||||
|
'special': False,
|
||||||
|
'text': ' failed',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 363,
|
||||||
|
'special': False,
|
||||||
|
'text': ' for',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 5641,
|
||||||
|
'special': False,
|
||||||
|
'text': ' IP',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 16428,
|
||||||
|
'special': False,
|
||||||
|
'text': ' Address',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29901,
|
||||||
|
'special': False,
|
||||||
|
'text': ':',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 525,
|
||||||
|
'special': False,
|
||||||
|
'text': " '",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 8516,
|
||||||
|
'special': False,
|
||||||
|
'text': 'None',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 4286,
|
||||||
|
'special': False,
|
||||||
|
'text': "'.",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 13,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 294,
|
||||||
|
'special': False,
|
||||||
|
'text': 'as',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': '''
|
||||||
|
Test requestfailed for IP Address: 'None'.
|
||||||
|
as
|
||||||
|
''',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_flash_llama_load
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 1,
|
||||||
|
'text': '<s>',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 4321,
|
||||||
|
'text': 'Test',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2009,
|
||||||
|
'text': 'request',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 363,
|
||||||
|
'special': False,
|
||||||
|
'text': ' for',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 847,
|
||||||
|
'special': False,
|
||||||
|
'text': ' /',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2754,
|
||||||
|
'special': False,
|
||||||
|
'text': 'api',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29914,
|
||||||
|
'special': False,
|
||||||
|
'text': '/',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29894,
|
||||||
|
'special': False,
|
||||||
|
'text': 'v',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29896,
|
||||||
|
'special': False,
|
||||||
|
'text': '1',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29914,
|
||||||
|
'special': False,
|
||||||
|
'text': '/',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 16418,
|
||||||
|
'special': False,
|
||||||
|
'text': 'projects',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29914,
|
||||||
|
'special': False,
|
||||||
|
'text': '/',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29896,
|
||||||
|
'special': False,
|
||||||
|
'text': '1',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': 'for /api/v1/projects/1',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 1,
|
||||||
|
'text': '<s>',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 4321,
|
||||||
|
'text': 'Test',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2009,
|
||||||
|
'text': 'request',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 363,
|
||||||
|
'special': False,
|
||||||
|
'text': ' for',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 847,
|
||||||
|
'special': False,
|
||||||
|
'text': ' /',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2754,
|
||||||
|
'special': False,
|
||||||
|
'text': 'api',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29914,
|
||||||
|
'special': False,
|
||||||
|
'text': '/',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29894,
|
||||||
|
'special': False,
|
||||||
|
'text': 'v',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29896,
|
||||||
|
'special': False,
|
||||||
|
'text': '1',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29914,
|
||||||
|
'special': False,
|
||||||
|
'text': '/',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 16418,
|
||||||
|
'special': False,
|
||||||
|
'text': 'projects',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29914,
|
||||||
|
'special': False,
|
||||||
|
'text': '/',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29896,
|
||||||
|
'special': False,
|
||||||
|
'text': '1',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': 'for /api/v1/projects/1',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 1,
|
||||||
|
'text': '<s>',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 4321,
|
||||||
|
'text': 'Test',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2009,
|
||||||
|
'text': 'request',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 363,
|
||||||
|
'special': False,
|
||||||
|
'text': ' for',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 847,
|
||||||
|
'special': False,
|
||||||
|
'text': ' /',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2754,
|
||||||
|
'special': False,
|
||||||
|
'text': 'api',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29914,
|
||||||
|
'special': False,
|
||||||
|
'text': '/',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29894,
|
||||||
|
'special': False,
|
||||||
|
'text': 'v',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29896,
|
||||||
|
'special': False,
|
||||||
|
'text': '1',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29914,
|
||||||
|
'special': False,
|
||||||
|
'text': '/',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 16418,
|
||||||
|
'special': False,
|
||||||
|
'text': 'projects',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29914,
|
||||||
|
'special': False,
|
||||||
|
'text': '/',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29896,
|
||||||
|
'special': False,
|
||||||
|
'text': '1',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': 'for /api/v1/projects/1',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 1,
|
||||||
|
'text': '<s>',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 4321,
|
||||||
|
'text': 'Test',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2009,
|
||||||
|
'text': 'request',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 363,
|
||||||
|
'special': False,
|
||||||
|
'text': ' for',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 847,
|
||||||
|
'special': False,
|
||||||
|
'text': ' /',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2754,
|
||||||
|
'special': False,
|
||||||
|
'text': 'api',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29914,
|
||||||
|
'special': False,
|
||||||
|
'text': '/',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29894,
|
||||||
|
'special': False,
|
||||||
|
'text': 'v',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29896,
|
||||||
|
'special': False,
|
||||||
|
'text': '1',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29914,
|
||||||
|
'special': False,
|
||||||
|
'text': '/',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 16418,
|
||||||
|
'special': False,
|
||||||
|
'text': 'projects',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29914,
|
||||||
|
'special': False,
|
||||||
|
'text': '/',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 29896,
|
||||||
|
'special': False,
|
||||||
|
'text': '1',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': 'for /api/v1/projects/1',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
|
@ -0,0 +1,682 @@
|
||||||
|
# serializer version: 1
|
||||||
|
# name: test_flash_neox
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 50278,
|
||||||
|
'text': '<|prompter|>',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1276,
|
||||||
|
'text': 'What',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 310,
|
||||||
|
'text': ' is',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 247,
|
||||||
|
'text': ' a',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1167,
|
||||||
|
'text': ' mem',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 70,
|
||||||
|
'text': 'e',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 13,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 285,
|
||||||
|
'text': ' and',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 752,
|
||||||
|
'text': ' what',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 434,
|
||||||
|
'text': "'s",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 253,
|
||||||
|
'text': ' the',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2892,
|
||||||
|
'text': ' history',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3212,
|
||||||
|
'text': ' behind',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 436,
|
||||||
|
'text': ' this',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3159,
|
||||||
|
'text': ' word',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 32,
|
||||||
|
'text': '?',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 0,
|
||||||
|
'text': '<|endoftext|>',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 50281,
|
||||||
|
'text': '<|assistant|>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 510,
|
||||||
|
'special': False,
|
||||||
|
'text': 'The',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3159,
|
||||||
|
'special': False,
|
||||||
|
'text': ' word',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 346,
|
||||||
|
'special': False,
|
||||||
|
'text': ' "',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 6441,
|
||||||
|
'special': False,
|
||||||
|
'text': 'mem',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 70,
|
||||||
|
'special': False,
|
||||||
|
'text': 'e',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3,
|
||||||
|
'special': False,
|
||||||
|
'text': '"',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 369,
|
||||||
|
'special': False,
|
||||||
|
'text': ' was',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 806,
|
||||||
|
'special': False,
|
||||||
|
'text': ' first',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 908,
|
||||||
|
'special': False,
|
||||||
|
'text': ' used',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 275,
|
||||||
|
'special': False,
|
||||||
|
'text': ' in',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': 'The word "meme" was first used in',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_flash_neox_load
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 50278,
|
||||||
|
'text': '<|prompter|>',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1276,
|
||||||
|
'text': 'What',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 310,
|
||||||
|
'text': ' is',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 247,
|
||||||
|
'text': ' a',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1167,
|
||||||
|
'text': ' mem',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 70,
|
||||||
|
'text': 'e',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 13,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 285,
|
||||||
|
'text': ' and',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 752,
|
||||||
|
'text': ' what',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 434,
|
||||||
|
'text': "'s",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 253,
|
||||||
|
'text': ' the',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2892,
|
||||||
|
'text': ' history',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3212,
|
||||||
|
'text': ' behind',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 436,
|
||||||
|
'text': ' this',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3159,
|
||||||
|
'text': ' word',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 32,
|
||||||
|
'text': '?',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 0,
|
||||||
|
'text': '<|endoftext|>',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 50281,
|
||||||
|
'text': '<|assistant|>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 510,
|
||||||
|
'special': False,
|
||||||
|
'text': 'The',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3159,
|
||||||
|
'special': False,
|
||||||
|
'text': ' word',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 346,
|
||||||
|
'special': False,
|
||||||
|
'text': ' "',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 6441,
|
||||||
|
'special': False,
|
||||||
|
'text': 'mem',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 70,
|
||||||
|
'special': False,
|
||||||
|
'text': 'e',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3,
|
||||||
|
'special': False,
|
||||||
|
'text': '"',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 369,
|
||||||
|
'special': False,
|
||||||
|
'text': ' was',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 806,
|
||||||
|
'special': False,
|
||||||
|
'text': ' first',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 908,
|
||||||
|
'special': False,
|
||||||
|
'text': ' used',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 275,
|
||||||
|
'special': False,
|
||||||
|
'text': ' in',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': 'The word "meme" was first used in',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 50278,
|
||||||
|
'text': '<|prompter|>',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1276,
|
||||||
|
'text': 'What',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 310,
|
||||||
|
'text': ' is',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 247,
|
||||||
|
'text': ' a',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1167,
|
||||||
|
'text': ' mem',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 70,
|
||||||
|
'text': 'e',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 13,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 285,
|
||||||
|
'text': ' and',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 752,
|
||||||
|
'text': ' what',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 434,
|
||||||
|
'text': "'s",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 253,
|
||||||
|
'text': ' the',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2892,
|
||||||
|
'text': ' history',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3212,
|
||||||
|
'text': ' behind',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 436,
|
||||||
|
'text': ' this',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3159,
|
||||||
|
'text': ' word',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 32,
|
||||||
|
'text': '?',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 0,
|
||||||
|
'text': '<|endoftext|>',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 50281,
|
||||||
|
'text': '<|assistant|>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 510,
|
||||||
|
'special': False,
|
||||||
|
'text': 'The',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3159,
|
||||||
|
'special': False,
|
||||||
|
'text': ' word',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 346,
|
||||||
|
'special': False,
|
||||||
|
'text': ' "',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 6441,
|
||||||
|
'special': False,
|
||||||
|
'text': 'mem',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 70,
|
||||||
|
'special': False,
|
||||||
|
'text': 'e',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3,
|
||||||
|
'special': False,
|
||||||
|
'text': '"',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 369,
|
||||||
|
'special': False,
|
||||||
|
'text': ' was',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 806,
|
||||||
|
'special': False,
|
||||||
|
'text': ' first',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 908,
|
||||||
|
'special': False,
|
||||||
|
'text': ' used',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 275,
|
||||||
|
'special': False,
|
||||||
|
'text': ' in',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': 'The word "meme" was first used in',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 50278,
|
||||||
|
'text': '<|prompter|>',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1276,
|
||||||
|
'text': 'What',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 310,
|
||||||
|
'text': ' is',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 247,
|
||||||
|
'text': ' a',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1167,
|
||||||
|
'text': ' mem',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 70,
|
||||||
|
'text': 'e',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 13,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 285,
|
||||||
|
'text': ' and',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 752,
|
||||||
|
'text': ' what',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 434,
|
||||||
|
'text': "'s",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 253,
|
||||||
|
'text': ' the',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2892,
|
||||||
|
'text': ' history',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3212,
|
||||||
|
'text': ' behind',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 436,
|
||||||
|
'text': ' this',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3159,
|
||||||
|
'text': ' word',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 32,
|
||||||
|
'text': '?',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 0,
|
||||||
|
'text': '<|endoftext|>',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 50281,
|
||||||
|
'text': '<|assistant|>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 510,
|
||||||
|
'special': False,
|
||||||
|
'text': 'The',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3159,
|
||||||
|
'special': False,
|
||||||
|
'text': ' word',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 346,
|
||||||
|
'special': False,
|
||||||
|
'text': ' "',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 6441,
|
||||||
|
'special': False,
|
||||||
|
'text': 'mem',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 70,
|
||||||
|
'special': False,
|
||||||
|
'text': 'e',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3,
|
||||||
|
'special': False,
|
||||||
|
'text': '"',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 369,
|
||||||
|
'special': False,
|
||||||
|
'text': ' was',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 806,
|
||||||
|
'special': False,
|
||||||
|
'text': ' first',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 908,
|
||||||
|
'special': False,
|
||||||
|
'text': ' used',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 275,
|
||||||
|
'special': False,
|
||||||
|
'text': ' in',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': 'The word "meme" was first used in',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 50278,
|
||||||
|
'text': '<|prompter|>',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1276,
|
||||||
|
'text': 'What',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 310,
|
||||||
|
'text': ' is',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 247,
|
||||||
|
'text': ' a',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1167,
|
||||||
|
'text': ' mem',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 70,
|
||||||
|
'text': 'e',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 13,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 285,
|
||||||
|
'text': ' and',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 752,
|
||||||
|
'text': ' what',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 434,
|
||||||
|
'text': "'s",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 253,
|
||||||
|
'text': ' the',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 2892,
|
||||||
|
'text': ' history',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3212,
|
||||||
|
'text': ' behind',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 436,
|
||||||
|
'text': ' this',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3159,
|
||||||
|
'text': ' word',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 32,
|
||||||
|
'text': '?',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 0,
|
||||||
|
'text': '<|endoftext|>',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 50281,
|
||||||
|
'text': '<|assistant|>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 510,
|
||||||
|
'special': False,
|
||||||
|
'text': 'The',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3159,
|
||||||
|
'special': False,
|
||||||
|
'text': ' word',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 346,
|
||||||
|
'special': False,
|
||||||
|
'text': ' "',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 6441,
|
||||||
|
'special': False,
|
||||||
|
'text': 'mem',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 70,
|
||||||
|
'special': False,
|
||||||
|
'text': 'e',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3,
|
||||||
|
'special': False,
|
||||||
|
'text': '"',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 369,
|
||||||
|
'special': False,
|
||||||
|
'text': ' was',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 806,
|
||||||
|
'special': False,
|
||||||
|
'text': ' first',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 908,
|
||||||
|
'special': False,
|
||||||
|
'text': ' used',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 275,
|
||||||
|
'special': False,
|
||||||
|
'text': ' in',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': 'The word "meme" was first used in',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
|
@ -0,0 +1,472 @@
|
||||||
|
# serializer version: 1
|
||||||
|
# name: test_flash_santacoder
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 563,
|
||||||
|
'text': 'def',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 942,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 62,
|
||||||
|
'text': '_',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7196,
|
||||||
|
'text': 'hello',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 1241,
|
||||||
|
'special': False,
|
||||||
|
'text': '():',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 258,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 942,
|
||||||
|
'special': False,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 372,
|
||||||
|
'special': False,
|
||||||
|
'text': '("',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7371,
|
||||||
|
'special': False,
|
||||||
|
'text': 'Hello',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 9956,
|
||||||
|
'special': False,
|
||||||
|
'text': ' World',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 8657,
|
||||||
|
'special': False,
|
||||||
|
'text': '!")',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 185,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 185,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1018,
|
||||||
|
'special': False,
|
||||||
|
'text': 'print',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': '''
|
||||||
|
():
|
||||||
|
print("Hello World!")
|
||||||
|
|
||||||
|
print
|
||||||
|
''',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_flash_santacoder_load
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 563,
|
||||||
|
'text': 'def',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 942,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 62,
|
||||||
|
'text': '_',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7196,
|
||||||
|
'text': 'hello',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 1241,
|
||||||
|
'special': False,
|
||||||
|
'text': '():',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 258,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 942,
|
||||||
|
'special': False,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 372,
|
||||||
|
'special': False,
|
||||||
|
'text': '("',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7371,
|
||||||
|
'special': False,
|
||||||
|
'text': 'Hello',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 9956,
|
||||||
|
'special': False,
|
||||||
|
'text': ' World',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 8657,
|
||||||
|
'special': False,
|
||||||
|
'text': '!")',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 185,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 185,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1018,
|
||||||
|
'special': False,
|
||||||
|
'text': 'print',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': '''
|
||||||
|
():
|
||||||
|
print("Hello World!")
|
||||||
|
|
||||||
|
print
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 563,
|
||||||
|
'text': 'def',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 942,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 62,
|
||||||
|
'text': '_',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7196,
|
||||||
|
'text': 'hello',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 1241,
|
||||||
|
'special': False,
|
||||||
|
'text': '():',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 258,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 942,
|
||||||
|
'special': False,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 372,
|
||||||
|
'special': False,
|
||||||
|
'text': '("',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7371,
|
||||||
|
'special': False,
|
||||||
|
'text': 'Hello',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 9956,
|
||||||
|
'special': False,
|
||||||
|
'text': ' World',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 8657,
|
||||||
|
'special': False,
|
||||||
|
'text': '!")',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 185,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 185,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1018,
|
||||||
|
'special': False,
|
||||||
|
'text': 'print',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': '''
|
||||||
|
():
|
||||||
|
print("Hello World!")
|
||||||
|
|
||||||
|
print
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 563,
|
||||||
|
'text': 'def',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 942,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 62,
|
||||||
|
'text': '_',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7196,
|
||||||
|
'text': 'hello',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 1241,
|
||||||
|
'special': False,
|
||||||
|
'text': '():',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 258,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 942,
|
||||||
|
'special': False,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 372,
|
||||||
|
'special': False,
|
||||||
|
'text': '("',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7371,
|
||||||
|
'special': False,
|
||||||
|
'text': 'Hello',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 9956,
|
||||||
|
'special': False,
|
||||||
|
'text': ' World',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 8657,
|
||||||
|
'special': False,
|
||||||
|
'text': '!")',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 185,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 185,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1018,
|
||||||
|
'special': False,
|
||||||
|
'text': 'print',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': '''
|
||||||
|
():
|
||||||
|
print("Hello World!")
|
||||||
|
|
||||||
|
print
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 563,
|
||||||
|
'text': 'def',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 942,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 62,
|
||||||
|
'text': '_',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7196,
|
||||||
|
'text': 'hello',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 1241,
|
||||||
|
'special': False,
|
||||||
|
'text': '():',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 258,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 942,
|
||||||
|
'special': False,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 372,
|
||||||
|
'special': False,
|
||||||
|
'text': '("',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7371,
|
||||||
|
'special': False,
|
||||||
|
'text': 'Hello',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 9956,
|
||||||
|
'special': False,
|
||||||
|
'text': ' World',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 8657,
|
||||||
|
'special': False,
|
||||||
|
'text': '!")',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 185,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 185,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1018,
|
||||||
|
'special': False,
|
||||||
|
'text': 'print',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': '''
|
||||||
|
():
|
||||||
|
print("Hello World!")
|
||||||
|
|
||||||
|
print
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
|
@ -0,0 +1,573 @@
|
||||||
|
# serializer version: 1
|
||||||
|
# name: test_flash_starcoder
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 589,
|
||||||
|
'text': 'def',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1459,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 81,
|
||||||
|
'text': '_',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7656,
|
||||||
|
'text': 'hello',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 2262,
|
||||||
|
'special': False,
|
||||||
|
'text': '():',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 284,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1459,
|
||||||
|
'special': False,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 440,
|
||||||
|
'special': False,
|
||||||
|
'text': '("',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 8279,
|
||||||
|
'special': False,
|
||||||
|
'text': 'Hello',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 10896,
|
||||||
|
'special': False,
|
||||||
|
'text': ' World',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 657,
|
||||||
|
'special': False,
|
||||||
|
'text': '")',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 203,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 203,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 589,
|
||||||
|
'special': False,
|
||||||
|
'text': 'def',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': '''
|
||||||
|
():
|
||||||
|
print("Hello World")
|
||||||
|
|
||||||
|
def
|
||||||
|
''',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_flash_starcoder_default_params
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.EndOfSequenceToken: 'eos_token'>,
|
||||||
|
'generated_tokens': 12,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 589,
|
||||||
|
'text': 'def',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1459,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 81,
|
||||||
|
'text': '_',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7656,
|
||||||
|
'text': 'hello',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': 0,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 2262,
|
||||||
|
'special': False,
|
||||||
|
'text': '():',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 284,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 5741,
|
||||||
|
'special': False,
|
||||||
|
'text': ' logging',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 32,
|
||||||
|
'special': False,
|
||||||
|
'text': '.',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1338,
|
||||||
|
'special': False,
|
||||||
|
'text': 'info',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 463,
|
||||||
|
'special': False,
|
||||||
|
'text': "('",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 8279,
|
||||||
|
'special': False,
|
||||||
|
'text': 'Hello',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 30,
|
||||||
|
'special': False,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 10896,
|
||||||
|
'special': False,
|
||||||
|
'text': ' World',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 683,
|
||||||
|
'special': False,
|
||||||
|
'text': "')",
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 203,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 0,
|
||||||
|
'special': True,
|
||||||
|
'text': '<|endoftext|>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': '''
|
||||||
|
():
|
||||||
|
logging.info('Hello, World')
|
||||||
|
<|endoftext|>
|
||||||
|
''',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_flash_starcoder_load
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 589,
|
||||||
|
'text': 'def',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1459,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 81,
|
||||||
|
'text': '_',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7656,
|
||||||
|
'text': 'hello',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 2262,
|
||||||
|
'special': False,
|
||||||
|
'text': '():',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 284,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1459,
|
||||||
|
'special': False,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 440,
|
||||||
|
'special': False,
|
||||||
|
'text': '("',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 8279,
|
||||||
|
'special': False,
|
||||||
|
'text': 'Hello',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 10896,
|
||||||
|
'special': False,
|
||||||
|
'text': ' World',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 657,
|
||||||
|
'special': False,
|
||||||
|
'text': '")',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 203,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 203,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 589,
|
||||||
|
'special': False,
|
||||||
|
'text': 'def',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': '''
|
||||||
|
():
|
||||||
|
print("Hello World")
|
||||||
|
|
||||||
|
def
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 589,
|
||||||
|
'text': 'def',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1459,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 81,
|
||||||
|
'text': '_',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7656,
|
||||||
|
'text': 'hello',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 2262,
|
||||||
|
'special': False,
|
||||||
|
'text': '():',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 284,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1459,
|
||||||
|
'special': False,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 440,
|
||||||
|
'special': False,
|
||||||
|
'text': '("',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 8279,
|
||||||
|
'special': False,
|
||||||
|
'text': 'Hello',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 10896,
|
||||||
|
'special': False,
|
||||||
|
'text': ' World',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 657,
|
||||||
|
'special': False,
|
||||||
|
'text': '")',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 203,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 203,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 589,
|
||||||
|
'special': False,
|
||||||
|
'text': 'def',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': '''
|
||||||
|
():
|
||||||
|
print("Hello World")
|
||||||
|
|
||||||
|
def
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 589,
|
||||||
|
'text': 'def',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1459,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 81,
|
||||||
|
'text': '_',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7656,
|
||||||
|
'text': 'hello',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 2262,
|
||||||
|
'special': False,
|
||||||
|
'text': '():',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 284,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1459,
|
||||||
|
'special': False,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 440,
|
||||||
|
'special': False,
|
||||||
|
'text': '("',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 8279,
|
||||||
|
'special': False,
|
||||||
|
'text': 'Hello',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 10896,
|
||||||
|
'special': False,
|
||||||
|
'text': ' World',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 657,
|
||||||
|
'special': False,
|
||||||
|
'text': '")',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 203,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 203,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 589,
|
||||||
|
'special': False,
|
||||||
|
'text': 'def',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': '''
|
||||||
|
():
|
||||||
|
print("Hello World")
|
||||||
|
|
||||||
|
def
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 589,
|
||||||
|
'text': 'def',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1459,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 81,
|
||||||
|
'text': '_',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7656,
|
||||||
|
'text': 'hello',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 2262,
|
||||||
|
'special': False,
|
||||||
|
'text': '():',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 284,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1459,
|
||||||
|
'special': False,
|
||||||
|
'text': ' print',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 440,
|
||||||
|
'special': False,
|
||||||
|
'text': '("',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 8279,
|
||||||
|
'special': False,
|
||||||
|
'text': 'Hello',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 10896,
|
||||||
|
'special': False,
|
||||||
|
'text': ' World',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 657,
|
||||||
|
'special': False,
|
||||||
|
'text': '")',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 203,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 203,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 589,
|
||||||
|
'special': False,
|
||||||
|
'text': 'def',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': '''
|
||||||
|
():
|
||||||
|
print("Hello World")
|
||||||
|
|
||||||
|
def
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
|
@ -0,0 +1,306 @@
|
||||||
|
# serializer version: 1
|
||||||
|
# name: test_mt0_base
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.EndOfSequenceToken: 'eos_token'>,
|
||||||
|
'generated_tokens': 5,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 0,
|
||||||
|
'text': '<pad>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': 0,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 926,
|
||||||
|
'special': False,
|
||||||
|
'text': 'To',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 18295,
|
||||||
|
'special': False,
|
||||||
|
'text': ' sell',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 7868,
|
||||||
|
'special': False,
|
||||||
|
'text': ' things',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 260,
|
||||||
|
'special': False,
|
||||||
|
'text': '.',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1,
|
||||||
|
'special': True,
|
||||||
|
'text': '</s>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': 'To sell things.',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_mt0_base_all_params
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 0,
|
||||||
|
'text': '<pad>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': 0,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 16017,
|
||||||
|
'special': False,
|
||||||
|
'text': 'blue',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 20495,
|
||||||
|
'special': False,
|
||||||
|
'text': ' sky',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 259,
|
||||||
|
'special': False,
|
||||||
|
'text': ' ',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 15484,
|
||||||
|
'special': False,
|
||||||
|
'text': 'appear',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 345,
|
||||||
|
'special': False,
|
||||||
|
'text': 'ed',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 288,
|
||||||
|
'special': False,
|
||||||
|
'text': ' to',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 35622,
|
||||||
|
'special': False,
|
||||||
|
'text': ' cloud',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 263,
|
||||||
|
'special': False,
|
||||||
|
'text': 's',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 14701,
|
||||||
|
'special': False,
|
||||||
|
'text': ' above',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 751,
|
||||||
|
'special': False,
|
||||||
|
'text': ' all',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': 'Why is the sky blue?blue sky appeared to clouds above all',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_mt0_base_load
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.EndOfSequenceToken: 'eos_token'>,
|
||||||
|
'generated_tokens': 6,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 0,
|
||||||
|
'text': '<pad>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 259,
|
||||||
|
'special': False,
|
||||||
|
'text': '',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 39261,
|
||||||
|
'special': False,
|
||||||
|
'text': 'Because',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 609,
|
||||||
|
'special': False,
|
||||||
|
'text': ' it',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 339,
|
||||||
|
'special': False,
|
||||||
|
'text': ' is',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 16017,
|
||||||
|
'special': False,
|
||||||
|
'text': ' blue',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1,
|
||||||
|
'special': True,
|
||||||
|
'text': '</s>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': 'Because it is blue',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.EndOfSequenceToken: 'eos_token'>,
|
||||||
|
'generated_tokens': 6,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 0,
|
||||||
|
'text': '<pad>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 259,
|
||||||
|
'special': False,
|
||||||
|
'text': '',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 39261,
|
||||||
|
'special': False,
|
||||||
|
'text': 'Because',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 609,
|
||||||
|
'special': False,
|
||||||
|
'text': ' it',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 339,
|
||||||
|
'special': False,
|
||||||
|
'text': ' is',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 16017,
|
||||||
|
'special': False,
|
||||||
|
'text': ' blue',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1,
|
||||||
|
'special': True,
|
||||||
|
'text': '</s>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': 'Because it is blue',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.EndOfSequenceToken: 'eos_token'>,
|
||||||
|
'generated_tokens': 6,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 0,
|
||||||
|
'text': '<pad>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 259,
|
||||||
|
'special': False,
|
||||||
|
'text': '',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 39261,
|
||||||
|
'special': False,
|
||||||
|
'text': 'Because',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 609,
|
||||||
|
'special': False,
|
||||||
|
'text': ' it',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 339,
|
||||||
|
'special': False,
|
||||||
|
'text': ' is',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 16017,
|
||||||
|
'special': False,
|
||||||
|
'text': ' blue',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1,
|
||||||
|
'special': True,
|
||||||
|
'text': '</s>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': 'Because it is blue',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.EndOfSequenceToken: 'eos_token'>,
|
||||||
|
'generated_tokens': 6,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 0,
|
||||||
|
'text': '<pad>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 259,
|
||||||
|
'special': False,
|
||||||
|
'text': '',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 39261,
|
||||||
|
'special': False,
|
||||||
|
'text': 'Because',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 609,
|
||||||
|
'special': False,
|
||||||
|
'text': ' it',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 339,
|
||||||
|
'special': False,
|
||||||
|
'text': ' is',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 16017,
|
||||||
|
'special': False,
|
||||||
|
'text': ' blue',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1,
|
||||||
|
'special': True,
|
||||||
|
'text': '</s>',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': 'Because it is blue',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
|
@ -0,0 +1,63 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from utils import health_check
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def bloom_560(launcher):
|
||||||
|
with launcher("bigscience/bloom-560m") as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bloom_560m(bloom_560, snapshot_test):
|
||||||
|
await health_check(bloom_560, 60)
|
||||||
|
|
||||||
|
response = await bloom_560.generate(
|
||||||
|
"Pour déguster un ortolan, il faut tout d'abord",
|
||||||
|
max_new_tokens=10,
|
||||||
|
top_p=0.9,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert snapshot_test(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bloom_560m_all_params(bloom_560, snapshot_test):
|
||||||
|
await health_check(bloom_560, 60)
|
||||||
|
|
||||||
|
response = await bloom_560.generate(
|
||||||
|
"Pour déguster un ortolan, il faut tout d'abord",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert snapshot_test(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bloom_560m_load(bloom_560, generate_load, snapshot_test):
|
||||||
|
await health_check(bloom_560, 60)
|
||||||
|
|
||||||
|
responses = await generate_load(
|
||||||
|
bloom_560,
|
||||||
|
"Pour déguster un ortolan, il faut tout d'abord",
|
||||||
|
max_new_tokens=10,
|
||||||
|
n=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
|
||||||
|
assert snapshot_test(responses)
|
|
@ -0,0 +1,42 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from utils import health_check
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def bloom_560m_sharded(launcher):
|
||||||
|
with launcher("bigscience/bloom-560m", num_shard=2) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot_test):
|
||||||
|
await health_check(bloom_560m_sharded, 60)
|
||||||
|
|
||||||
|
response = await bloom_560m_sharded.generate(
|
||||||
|
"Pour déguster un ortolan, il faut tout d'abord",
|
||||||
|
max_new_tokens=10,
|
||||||
|
top_p=0.9,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert snapshot_test(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bloom_560m_sharded_load(
|
||||||
|
bloom_560m_sharded, generate_load, snapshot_test
|
||||||
|
):
|
||||||
|
await health_check(bloom_560m_sharded, 60)
|
||||||
|
|
||||||
|
responses = await generate_load(
|
||||||
|
bloom_560m_sharded,
|
||||||
|
"Pour déguster un ortolan, il faut tout d'abord",
|
||||||
|
max_new_tokens=10,
|
||||||
|
n=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
|
||||||
|
assert snapshot_test(responses)
|
|
@ -0,0 +1,56 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from utils import health_check
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama(launcher):
|
||||||
|
with launcher("huggingface/llama-7b", num_shard=2) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama(flash_llama, snapshot_test):
|
||||||
|
await health_check(flash_llama, 120)
|
||||||
|
|
||||||
|
response = await flash_llama.generate("Test request", max_new_tokens=10)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert snapshot_test(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_all_params(flash_llama, snapshot_test):
|
||||||
|
await health_check(flash_llama, 120)
|
||||||
|
|
||||||
|
response = await flash_llama.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert snapshot_test(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_load(flash_llama, generate_load, snapshot_test):
|
||||||
|
await health_check(flash_llama, 120)
|
||||||
|
|
||||||
|
responses = await generate_load(flash_llama, "Test request", max_new_tokens=10, n=4)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
|
||||||
|
assert snapshot_test(responses)
|
|
@ -0,0 +1,38 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from utils import health_check
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_neox(launcher):
|
||||||
|
with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_neox(flash_neox, snapshot_test):
|
||||||
|
await health_check(flash_neox, 240)
|
||||||
|
|
||||||
|
response = await flash_neox.generate(
|
||||||
|
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
|
||||||
|
max_new_tokens=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert snapshot_test(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_neox_load(flash_neox, generate_load, snapshot_test):
|
||||||
|
await health_check(flash_neox, 240)
|
||||||
|
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_neox,
|
||||||
|
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
|
||||||
|
max_new_tokens=10,
|
||||||
|
n=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
|
||||||
|
assert snapshot_test(responses)
|
|
@ -0,0 +1,32 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from utils import health_check
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_santacoder(launcher):
|
||||||
|
with launcher("bigcode/santacoder") as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_santacoder(flash_santacoder, snapshot_test):
|
||||||
|
await health_check(flash_santacoder, 60)
|
||||||
|
|
||||||
|
response = await flash_santacoder.generate("def print_hello", max_new_tokens=10)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert snapshot_test(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_santacoder_load(flash_santacoder, generate_load, snapshot_test):
|
||||||
|
await health_check(flash_santacoder, 60)
|
||||||
|
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_santacoder, "def print_hello", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
|
||||||
|
assert snapshot_test(responses)
|
|
@ -0,0 +1,47 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from utils import health_check
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_starcoder(launcher):
|
||||||
|
with launcher("bigcode/starcoder", num_shard=2) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_starcoder(flash_starcoder, snapshot_test):
|
||||||
|
await health_check(flash_starcoder, 240)
|
||||||
|
|
||||||
|
response = await flash_starcoder.generate("def print_hello", max_new_tokens=10)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert snapshot_test(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_starcoder_default_params(flash_starcoder, snapshot_test):
|
||||||
|
await health_check(flash_starcoder, 240)
|
||||||
|
|
||||||
|
response = await flash_starcoder.generate(
|
||||||
|
"def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 12
|
||||||
|
assert snapshot_test(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot_test):
|
||||||
|
await health_check(flash_starcoder, 240)
|
||||||
|
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_starcoder, "def print_hello", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
|
||||||
|
assert snapshot_test(responses)
|
|
@ -0,0 +1,63 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from utils import health_check
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mt0_base(launcher):
|
||||||
|
with launcher("bigscience/mt0-base") as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mt0_base(mt0_base, snapshot_test):
|
||||||
|
await health_check(mt0_base, 60)
|
||||||
|
|
||||||
|
response = await mt0_base.generate(
|
||||||
|
"Why is the sky blue?",
|
||||||
|
max_new_tokens=10,
|
||||||
|
top_p=0.9,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 5
|
||||||
|
assert snapshot_test(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mt0_base_all_params(mt0_base, snapshot_test):
|
||||||
|
await health_check(mt0_base, 60)
|
||||||
|
|
||||||
|
response = await mt0_base.generate(
|
||||||
|
"Why is the sky blue?",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert snapshot_test(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mt0_base_load(mt0_base, generate_load, snapshot_test):
|
||||||
|
await health_check(mt0_base, 60)
|
||||||
|
|
||||||
|
responses = await generate_load(
|
||||||
|
mt0_base,
|
||||||
|
"Why is the sky blue?",
|
||||||
|
max_new_tokens=10,
|
||||||
|
n=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
|
||||||
|
assert snapshot_test(responses)
|
|
@ -0,0 +1,15 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
||||||
|
from text_generation import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
async def health_check(client: AsyncClient, timeout: int = 60):
|
||||||
|
assert timeout > 0
|
||||||
|
for _ in range(timeout):
|
||||||
|
try:
|
||||||
|
await client.generate("test")
|
||||||
|
return
|
||||||
|
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
|
||||||
|
time.sleep(1)
|
||||||
|
raise RuntimeError("Health check failed")
|
|
@ -0,0 +1,5 @@
|
||||||
|
syrupy
|
||||||
|
text-generation==0.5.1
|
||||||
|
pytest
|
||||||
|
pytest-asyncio==0.17.2
|
||||||
|
docker
|
|
@ -1,142 +0,0 @@
|
||||||
{
|
|
||||||
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException",
|
|
||||||
"details": {
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 20,
|
|
||||||
"seed": null,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 10264,
|
|
||||||
"text": "Test",
|
|
||||||
"logprob": null
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 8821,
|
|
||||||
"text": " request",
|
|
||||||
"logprob": -11.894989
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 17,
|
|
||||||
"text": ".",
|
|
||||||
"logprob": -1.8267672,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1587,
|
|
||||||
"text": "get",
|
|
||||||
"logprob": -2.4674969,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 11,
|
|
||||||
"text": "(",
|
|
||||||
"logprob": -1.906001,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 5,
|
|
||||||
"text": "\"",
|
|
||||||
"logprob": -1.2279545,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 4899,
|
|
||||||
"text": "action",
|
|
||||||
"logprob": -4.170299,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 5,
|
|
||||||
"text": "\"",
|
|
||||||
"logprob": -0.32478866,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 12,
|
|
||||||
"text": ")",
|
|
||||||
"logprob": -1.0773665,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 30,
|
|
||||||
"text": ";",
|
|
||||||
"logprob": -0.27640742,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 837,
|
|
||||||
"text": "\n ",
|
|
||||||
"logprob": -1.6970354,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1320,
|
|
||||||
"text": " if",
|
|
||||||
"logprob": -1.4495516,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 375,
|
|
||||||
"text": " (",
|
|
||||||
"logprob": -0.23609057,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 4899,
|
|
||||||
"text": "action",
|
|
||||||
"logprob": -1.1916996,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3535,
|
|
||||||
"text": " ==",
|
|
||||||
"logprob": -0.8918753,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 5109,
|
|
||||||
"text": " null",
|
|
||||||
"logprob": -0.3933342,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 12,
|
|
||||||
"text": ")",
|
|
||||||
"logprob": -0.43212673,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 731,
|
|
||||||
"text": " {",
|
|
||||||
"logprob": -0.17702064,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1260,
|
|
||||||
"text": "\n ",
|
|
||||||
"logprob": -0.07027565,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 10519,
|
|
||||||
"text": " throw",
|
|
||||||
"logprob": -1.3915029,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2084,
|
|
||||||
"text": " new",
|
|
||||||
"logprob": -0.04201372,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 150858,
|
|
||||||
"text": " RuntimeException",
|
|
||||||
"logprob": -1.7329919,
|
|
||||||
"special": false
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,172 +0,0 @@
|
||||||
use float_eq::assert_float_eq;
|
|
||||||
use serde::Deserialize;
|
|
||||||
use serde_json::Value;
|
|
||||||
use std::fs::File;
|
|
||||||
use std::io::{BufRead, BufReader};
|
|
||||||
use std::path::PathBuf;
|
|
||||||
use std::thread;
|
|
||||||
use std::thread::sleep;
|
|
||||||
use std::time::Duration;
|
|
||||||
use subprocess::{Popen, PopenConfig, Redirection};
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct Token {
|
|
||||||
id: u32,
|
|
||||||
text: String,
|
|
||||||
logprob: Option<f32>,
|
|
||||||
special: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct Details {
|
|
||||||
finish_reason: String,
|
|
||||||
generated_tokens: u32,
|
|
||||||
tokens: Vec<Token>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct GeneratedText {
|
|
||||||
generated_text: String,
|
|
||||||
details: Details,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn start_launcher(model_id: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
|
|
||||||
let argv = vec![
|
|
||||||
"text-generation-launcher".to_string(),
|
|
||||||
"--model-id".to_string(),
|
|
||||||
model_id.clone(),
|
|
||||||
"--num-shard".to_string(),
|
|
||||||
num_shard.to_string(),
|
|
||||||
"--port".to_string(),
|
|
||||||
port.to_string(),
|
|
||||||
"--master-port".to_string(),
|
|
||||||
master_port.to_string(),
|
|
||||||
"--shard-uds-path".to_string(),
|
|
||||||
format!("/tmp/test-{}-{}-{}", num_shard, port, master_port),
|
|
||||||
];
|
|
||||||
|
|
||||||
let mut launcher = Popen::create(
|
|
||||||
&argv,
|
|
||||||
PopenConfig {
|
|
||||||
stdout: Redirection::Pipe,
|
|
||||||
stderr: Redirection::Merge,
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.expect("Could not start launcher");
|
|
||||||
|
|
||||||
// Redirect STDOUT and STDERR to the console
|
|
||||||
// (STDERR is merged into STDOUT)
|
|
||||||
let launcher_stdout = launcher.stdout.take().unwrap();
|
|
||||||
|
|
||||||
thread::spawn(move || {
|
|
||||||
let stdout = BufReader::new(launcher_stdout);
|
|
||||||
for line in stdout.lines() {
|
|
||||||
println!("{}", line.unwrap());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
for _ in 0..60 {
|
|
||||||
let health = reqwest::blocking::get(format!("http://localhost:{}/health", port));
|
|
||||||
if health.is_ok() {
|
|
||||||
return launcher;
|
|
||||||
}
|
|
||||||
sleep(Duration::from_secs(2));
|
|
||||||
}
|
|
||||||
|
|
||||||
launcher.terminate().unwrap();
|
|
||||||
launcher.wait().unwrap();
|
|
||||||
panic!("failed to launch {}", model_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn test_model(
|
|
||||||
model_id: String,
|
|
||||||
num_shard: usize,
|
|
||||||
port: usize,
|
|
||||||
master_port: usize,
|
|
||||||
) -> GeneratedText {
|
|
||||||
let mut launcher = start_launcher(model_id, num_shard, port, master_port);
|
|
||||||
|
|
||||||
let data = r#"
|
|
||||||
{
|
|
||||||
"inputs": "Test request",
|
|
||||||
"parameters": {
|
|
||||||
"details": true
|
|
||||||
}
|
|
||||||
}"#;
|
|
||||||
let req: Value = serde_json::from_str(data).unwrap();
|
|
||||||
|
|
||||||
let client = reqwest::blocking::Client::new();
|
|
||||||
let res = client
|
|
||||||
.post(format!("http://localhost:{}/generate", port))
|
|
||||||
.json(&req)
|
|
||||||
.send();
|
|
||||||
|
|
||||||
launcher.terminate().unwrap();
|
|
||||||
launcher.wait().unwrap();
|
|
||||||
|
|
||||||
let result: GeneratedText = res.unwrap().json().unwrap();
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
fn read_json(name: &str) -> GeneratedText {
|
|
||||||
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
|
||||||
d.push("tests/");
|
|
||||||
d.push(name);
|
|
||||||
|
|
||||||
let file = File::open(d).unwrap();
|
|
||||||
let reader = BufReader::new(file);
|
|
||||||
|
|
||||||
let result: GeneratedText = serde_json::from_reader(reader).unwrap();
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
fn compare_results(result: GeneratedText, expected: GeneratedText) {
|
|
||||||
assert_eq!(result.generated_text, expected.generated_text);
|
|
||||||
assert_eq!(result.details.finish_reason, expected.details.finish_reason);
|
|
||||||
assert_eq!(
|
|
||||||
result.details.generated_tokens,
|
|
||||||
expected.details.generated_tokens
|
|
||||||
);
|
|
||||||
|
|
||||||
for (token, expected_token) in result
|
|
||||||
.details
|
|
||||||
.tokens
|
|
||||||
.into_iter()
|
|
||||||
.zip(expected.details.tokens.into_iter())
|
|
||||||
{
|
|
||||||
assert_eq!(token.id, expected_token.id);
|
|
||||||
assert_eq!(token.text, expected_token.text);
|
|
||||||
assert_eq!(token.special, expected_token.special);
|
|
||||||
if let Some(logprob) = token.logprob {
|
|
||||||
let expected_logprob = expected_token.logprob.unwrap();
|
|
||||||
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
|
|
||||||
} else {
|
|
||||||
assert_eq!(token.logprob, expected_token.logprob);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_bloom_560m() {
|
|
||||||
let expected = read_json("bloom_560m.json");
|
|
||||||
|
|
||||||
let result = test_model("bigscience/bloom-560m".to_string(), 1, 3000, 29500);
|
|
||||||
compare_results(result, expected);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_bloom_560m_distributed() {
|
|
||||||
let expected = read_json("bloom_560m.json");
|
|
||||||
|
|
||||||
let result = test_model("bigscience/bloom-560m".to_string(), 2, 3001, 29501);
|
|
||||||
compare_results(result, expected);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_mt0_base() {
|
|
||||||
let expected = read_json("mt0_base.json");
|
|
||||||
|
|
||||||
let result = test_model("bigscience/mt0-base".to_string(), 1, 3002, 29502);
|
|
||||||
compare_results(result, expected);
|
|
||||||
}
|
|
|
@ -1,137 +0,0 @@
|
||||||
{
|
|
||||||
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test",
|
|
||||||
"details": {
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 20,
|
|
||||||
"seed": null,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 0,
|
|
||||||
"text": "<pad>",
|
|
||||||
"logprob": null
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 259,
|
|
||||||
"text": "",
|
|
||||||
"logprob": -1.3656927,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 215100,
|
|
||||||
"text": "\"\"\"",
|
|
||||||
"logprob": -2.6551573,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 46138,
|
|
||||||
"text": "Test",
|
|
||||||
"logprob": -1.8059857,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 287,
|
|
||||||
"text": " the",
|
|
||||||
"logprob": -1.2102449,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 259,
|
|
||||||
"text": " ",
|
|
||||||
"logprob": -1.6057279,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 49076,
|
|
||||||
"text": "contents",
|
|
||||||
"logprob": -3.6060903,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 304,
|
|
||||||
"text": " of",
|
|
||||||
"logprob": -0.5270343,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 287,
|
|
||||||
"text": " the",
|
|
||||||
"logprob": -0.62522805,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 259,
|
|
||||||
"text": " ",
|
|
||||||
"logprob": -1.4069618,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 49076,
|
|
||||||
"text": "contents",
|
|
||||||
"logprob": -2.621994,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 304,
|
|
||||||
"text": " of",
|
|
||||||
"logprob": -1.3172221,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 287,
|
|
||||||
"text": " the",
|
|
||||||
"logprob": -0.3501925,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 259,
|
|
||||||
"text": " ",
|
|
||||||
"logprob": -0.7219573,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 49076,
|
|
||||||
"text": "contents",
|
|
||||||
"logprob": -1.0494149,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 260,
|
|
||||||
"text": ".",
|
|
||||||
"logprob": -1.0803378,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 259,
|
|
||||||
"text": " ",
|
|
||||||
"logprob": -0.32933083,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 215100,
|
|
||||||
"text": "\"\"\"",
|
|
||||||
"logprob": -0.11268901,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2978,
|
|
||||||
"text": " test",
|
|
||||||
"logprob": -1.5846587,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 290,
|
|
||||||
"text": "_",
|
|
||||||
"logprob": -0.49796978,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 4125,
|
|
||||||
"text": "test",
|
|
||||||
"logprob": -2.0026445,
|
|
||||||
"special": false
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -129,7 +129,7 @@ class BLOOMSharded(BLOOM):
|
||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
full_name = f"transformer.{name}"
|
full_name = f"transformer.{name}"
|
||||||
|
|
|
@ -21,16 +21,14 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
|
import dropout_layer_norm
|
||||||
|
|
||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -332,15 +330,15 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
self.head_size = self.layers[0].self_attn.head_size
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
self.num_heads = self.layers[0].self_attn.num_heads
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
|
||||||
def post_load_weights(self, load_in_8bit: bool = False):
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
if isinstance(self.embed_tokens, TensorParallelEmbedding):
|
if isinstance(self.embed_tokens, TensorParallelEmbedding):
|
||||||
self.embed_tokens.add_null_idx()
|
self.embed_tokens.add_null_idx()
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
layer: FlashLlamaLayer
|
layer: FlashLlamaLayer
|
||||||
layer.self_attn.query_key_value.prepare_weights(load_in_8bit)
|
layer.self_attn.query_key_value.prepare_weights(quantize)
|
||||||
layer.self_attn.o_proj.prepare_weights(load_in_8bit)
|
layer.self_attn.o_proj.prepare_weights(quantize)
|
||||||
layer.mlp.gate_up_proj.prepare_weights(load_in_8bit)
|
layer.mlp.gate_up_proj.prepare_weights(quantize)
|
||||||
layer.mlp.down_proj.prepare_weights(load_in_8bit)
|
layer.mlp.down_proj.prepare_weights(quantize)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -429,8 +427,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
def post_load_weights(self, load_in_8bit: bool = False):
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
self.model.post_load_weights(load_in_8bit)
|
self.model.post_load_weights(quantize)
|
||||||
self.lm_head.prepare_weights()
|
self.lm_head.prepare_weights()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
|
@ -21,8 +21,6 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
@ -32,7 +30,6 @@ from typing import Optional
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
|
|
||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -345,16 +342,16 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
self.head_size = self.layers[0].attention.head_size
|
self.head_size = self.layers[0].attention.head_size
|
||||||
self.num_heads = self.layers[0].attention.num_heads
|
self.num_heads = self.layers[0].attention.num_heads
|
||||||
|
|
||||||
def post_load_weights(self, load_in_8bit=False):
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
if isinstance(self.embed_in, TensorParallelEmbedding):
|
if isinstance(self.embed_in, TensorParallelEmbedding):
|
||||||
self.embed_in.add_null_idx()
|
self.embed_in.add_null_idx()
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
layer: FlashNeoXLayer
|
layer: FlashNeoXLayer
|
||||||
layer.attention.shuffle_qkv_dims()
|
layer.attention.shuffle_qkv_dims()
|
||||||
layer.attention.query_key_value.prepare_weights(load_in_8bit)
|
layer.attention.query_key_value.prepare_weights(quantize)
|
||||||
layer.attention.dense.prepare_weights(load_in_8bit)
|
layer.attention.dense.prepare_weights(quantize)
|
||||||
layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit)
|
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
|
||||||
layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit)
|
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
@ -457,8 +454,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
config.hidden_size, config.vocab_size, bias=False
|
config.hidden_size, config.vocab_size, bias=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def post_load_weights(self, load_in_8bit=False):
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
self.gpt_neox.post_load_weights(load_in_8bit)
|
self.gpt_neox.post_load_weights(quantize)
|
||||||
self.embed_out.prepare_weights()
|
self.embed_out.prepare_weights()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
@ -261,16 +259,16 @@ class FlashSantacoderModel(nn.Module):
|
||||||
self.head_size = self.h[0].attn.head_size
|
self.head_size = self.h[0].attn.head_size
|
||||||
self.num_heads = self.h[0].attn.num_heads
|
self.num_heads = self.h[0].attn.num_heads
|
||||||
|
|
||||||
def post_load_weights(self, load_in_8bit: bool = False):
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
if self.tp_embeddings:
|
if self.tp_embeddings:
|
||||||
self.wte.add_null_idx()
|
self.wte.add_null_idx()
|
||||||
self.wpe.add_null_idx()
|
self.wpe.add_null_idx()
|
||||||
for layer in self.h:
|
for layer in self.h:
|
||||||
layer: Block
|
layer: Block
|
||||||
layer.attn.c_attn.prepare_weights(load_in_8bit)
|
layer.attn.c_attn.prepare_weights(quantize)
|
||||||
layer.attn.c_proj.prepare_weights(load_in_8bit)
|
layer.attn.c_proj.prepare_weights(quantize)
|
||||||
layer.mlp.c_fc.prepare_weights(load_in_8bit)
|
layer.mlp.c_fc.prepare_weights(quantize)
|
||||||
layer.mlp.c_proj.prepare_weights(load_in_8bit)
|
layer.mlp.c_proj.prepare_weights(quantize)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -347,8 +345,8 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
def post_load_weights(self, load_in_8bit: bool = False):
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
self.transformer.post_load_weights(load_in_8bit)
|
self.transformer.post_load_weights(quantize)
|
||||||
self.lm_head.prepare_weights()
|
self.lm_head.prepare_weights()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
|
@ -28,7 +28,12 @@ tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FlashLlama(FlashCausalLM):
|
class FlashLlama(FlashCausalLM):
|
||||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
|
@ -72,14 +77,14 @@ class FlashLlama(FlashCausalLM):
|
||||||
def load_weights(
|
def load_weights(
|
||||||
model,
|
model,
|
||||||
filenames: List[Path],
|
filenames: List[Path],
|
||||||
quantize: bool,
|
quantize: Optional[str],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
state_dict = torch.load(filename, map_location="cpu")
|
state_dict = torch.load(filename, map_location="cpu")
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
value = value.to(device if not quantize else "cpu").to(dtype)
|
value = value.to(device if quantize is None else "cpu").to(dtype)
|
||||||
|
|
||||||
layer_name = ".".join(key.split(".")[:4])
|
layer_name = ".".join(key.split(".")[:4])
|
||||||
|
|
||||||
|
@ -199,7 +204,7 @@ class FlashLlamaSharded(FlashLlama):
|
||||||
def load_weights(
|
def load_weights(
|
||||||
model,
|
model,
|
||||||
filenames: List[str],
|
filenames: List[str],
|
||||||
quantize: bool,
|
quantize: Optional[str],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
rank: int,
|
rank: int,
|
||||||
|
@ -207,7 +212,7 @@ class FlashLlamaSharded(FlashLlama):
|
||||||
):
|
):
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
slice_ = f.get_slice(name)
|
slice_ = f.get_slice(name)
|
||||||
|
|
|
@ -23,7 +23,12 @@ tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FlashNeoX(FlashCausalLM):
|
class FlashNeoX(FlashCausalLM):
|
||||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
):
|
||||||
super(FlashNeoX, self).__init__(
|
super(FlashNeoX, self).__init__(
|
||||||
FlashGPTNeoXForCausalLM, model_id, revision, quantize
|
FlashGPTNeoXForCausalLM, model_id, revision, quantize
|
||||||
)
|
)
|
||||||
|
@ -31,7 +36,10 @@ class FlashNeoX(FlashCausalLM):
|
||||||
|
|
||||||
class FlashNeoXSharded(FlashNeoX):
|
class FlashNeoXSharded(FlashNeoX):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -89,7 +97,7 @@ class FlashNeoXSharded(FlashNeoX):
|
||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
module_name, param_name = name.rsplit(".", 1)
|
module_name, param_name = name.rsplit(".", 1)
|
||||||
|
|
|
@ -27,7 +27,12 @@ tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FlashSantacoder(FlashCausalLM):
|
class FlashSantacoder(FlashCausalLM):
|
||||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
|
@ -84,7 +89,7 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
state_dict = torch.load(filename, map_location="cpu")
|
state_dict = torch.load(filename, map_location="cpu")
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
value = value.to(device if not quantize else "cpu").to(dtype)
|
value = value.to(device if quantize is None else "cpu").to(dtype)
|
||||||
|
|
||||||
layer_name = ".".join(key.split(".")[:4])
|
layer_name = ".".join(key.split(".")[:4])
|
||||||
|
|
||||||
|
@ -170,7 +175,10 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
|
|
||||||
class FlashSantacoderSharded(FlashSantacoder):
|
class FlashSantacoderSharded(FlashSantacoder):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -221,7 +229,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||||
def load_weights(
|
def load_weights(
|
||||||
model,
|
model,
|
||||||
filenames: List[str],
|
filenames: List[str],
|
||||||
quantize: bool,
|
quantize: Optional[str],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
rank: int,
|
rank: int,
|
||||||
|
@ -230,7 +238,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||||
):
|
):
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for key in f.keys():
|
for key in f.keys():
|
||||||
slice_ = f.get_slice(key)
|
slice_ = f.get_slice(key)
|
||||||
|
|
|
@ -255,7 +255,7 @@ class GalacticaSharded(Galactica):
|
||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
if name == "lm_head.weight":
|
if name == "lm_head.weight":
|
||||||
|
|
|
@ -94,7 +94,7 @@ class GPTNeoxSharded(CausalLM):
|
||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
module_name, param_name = name.rsplit(".", 1)
|
module_name, param_name = name.rsplit(".", 1)
|
||||||
|
|
|
@ -48,7 +48,10 @@ class OPT(CausalLM):
|
||||||
|
|
||||||
class OPTSharded(OPT):
|
class OPTSharded(OPT):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -107,7 +110,7 @@ class OPTSharded(OPT):
|
||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
if name == "lm_head.weight":
|
if name == "lm_head.weight":
|
||||||
|
|
|
@ -97,7 +97,7 @@ class T5Sharded(Seq2SeqLM):
|
||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
module_name, param_name = name.rsplit(".", 1)
|
module_name, param_name = name.rsplit(".", 1)
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
try:
|
try:
|
||||||
|
@ -22,7 +24,7 @@ class FastLinear(nn.Linear):
|
||||||
self.quantized = False
|
self.quantized = False
|
||||||
self.bnb_linear = None
|
self.bnb_linear = None
|
||||||
|
|
||||||
def prepare_weights(self, quantize: bool = False):
|
def prepare_weights(self, quantize: Optional[str] = None):
|
||||||
if quantize == "bitsandbytes":
|
if quantize == "bitsandbytes":
|
||||||
if not HAS_BITS_AND_BYTES:
|
if not HAS_BITS_AND_BYTES:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
@ -126,6 +128,7 @@ class TensorParallelEmbedding(nn.Embedding):
|
||||||
num_embeddings,
|
num_embeddings,
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
process_group: torch.distributed.ProcessGroup,
|
process_group: torch.distributed.ProcessGroup,
|
||||||
|
reduce=True,
|
||||||
padding_idx=None,
|
padding_idx=None,
|
||||||
max_norm=None,
|
max_norm=None,
|
||||||
norm_type=2.0,
|
norm_type=2.0,
|
||||||
|
@ -135,6 +138,7 @@ class TensorParallelEmbedding(nn.Embedding):
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
):
|
):
|
||||||
|
self.reduce = reduce
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
|
@ -177,6 +181,7 @@ class TensorParallelEmbedding(nn.Embedding):
|
||||||
input - self.min_id,
|
input - self.min_id,
|
||||||
)
|
)
|
||||||
out = super().forward(input)
|
out = super().forward(input)
|
||||||
|
if self.reduce:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue