diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 7db12424..902a7158 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -205,7 +205,10 @@ def event_loop(): def launcher(event_loop): @contextlib.contextmanager def local_launcher( - model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None + model_id: str, + num_shard: Optional[int] = None, + quantize: Optional[str] = None, + trust_remote_code: bool = False, ): port = random.randint(8000, 10_000) master_port = random.randint(10_000, 20_000) @@ -230,6 +233,9 @@ def launcher(event_loop): args.extend(["--num-shard", str(num_shard)]) if quantize: args.append("--quantize") + args.append("bitsandbytes") + if trust_remote_code: + args.append("--trust-remote-code") env = os.environ env["LOG_LEVEL"] = "info,text_generation_router=debug" @@ -250,7 +256,10 @@ def launcher(event_loop): @contextlib.contextmanager def docker_launcher( - model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None + model_id: str, + num_shard: Optional[int] = None, + quantize: Optional[str] = None, + trust_remote_code: bool = False, ): port = random.randint(8000, 10_000) @@ -260,6 +269,9 @@ def launcher(event_loop): args.extend(["--num-shard", str(num_shard)]) if quantize: args.append("--quantize") + args.append("bitsandbytes") + if trust_remote_code: + args.append("--trust-remote-code") client = docker.from_env() diff --git a/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon.json b/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon.json new file mode 100644 index 00000000..488f3de3 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon.json @@ -0,0 +1,378 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50, + "logprob": null, + "text": "G" + }, + { + "id": 330, + "logprob": -5.96875, + "text": "ir" + }, + { + "id": 1622, + "logprob": -5.6132812, + "text": "af" + }, + { + "id": 249, + "logprob": -6.5039062, + "text": "at" + }, + { + "id": 1480, + "logprob": -8.078125, + "text": "ron" + }, + { + "id": 304, + "logprob": -2.3261719, + "text": " is" + }, + { + "id": 23866, + "logprob": -9.59375, + "text": " obsessed" + }, + { + "id": 335, + "logprob": -0.048339844, + "text": " with" + }, + { + "id": 26680, + "logprob": -4.0, + "text": " gir" + }, + { + "id": 1903, + "logprob": -0.07556152, + "text": "aff" + }, + { + "id": 255, + "logprob": -0.0067749023, + "text": "es" + }, + { + "id": 23, + "logprob": -1.546875, + "text": "," + }, + { + "id": 248, + "logprob": -4.3320312, + "text": " the" + }, + { + "id": 758, + "logprob": -3.734375, + "text": " most" + }, + { + "id": 21735, + "logprob": -5.109375, + "text": " glorious" + }, + { + "id": 5985, + "logprob": -2.09375, + "text": " animal" + }, + { + "id": 313, + "logprob": -1.1835938, + "text": " on" + }, + { + "id": 248, + "logprob": -0.77685547, + "text": " the" + }, + { + "id": 1936, + "logprob": -2.3828125, + "text": " face" + }, + { + "id": 275, + "logprob": -0.004432678, + "text": " of" + }, + { + "id": 414, + "logprob": -1.9677734, + "text": " this" + }, + { + "id": 6490, + "logprob": -2.046875, + "text": " Earth" + }, + { + "id": 25, + "logprob": -0.28198242, + "text": "." + }, + { + "id": 401, + "logprob": -7.9179688, + "text": " G" + }, + { + "id": 6013, + "logprob": -2.2753906, + "text": "ira" + }, + { + "id": 694, + "logprob": -0.6230469, + "text": "ft" + }, + { + "id": 1480, + "logprob": -0.20874023, + "text": "ron" + }, + { + "id": 9369, + "logprob": -4.5507812, + "text": " believes" + }, + { + "id": 455, + "logprob": -4.5664062, + "text": " all" + }, + { + "id": 599, + "logprob": -2.7402344, + "text": " other" + }, + { + "id": 5632, + "logprob": -0.21948242, + "text": " animals" + }, + { + "id": 362, + "logprob": -0.7675781, + "text": " are" + }, + { + "id": 23981, + "logprob": -5.0, + "text": " irrelevant" + }, + { + "id": 635, + "logprob": -4.234375, + "text": " when" + }, + { + "id": 4354, + "logprob": -0.5131836, + "text": " compared" + }, + { + "id": 271, + "logprob": -0.103637695, + "text": " to" + }, + { + "id": 248, + "logprob": -0.58447266, + "text": " the" + }, + { + "id": 21735, + "logprob": -3.6835938, + "text": " glorious" + }, + { + "id": 64398, + "logprob": -1.8173828, + "text": " majesty" + }, + { + "id": 275, + "logprob": -0.23510742, + "text": " of" + }, + { + "id": 248, + "logprob": -0.35473633, + "text": " the" + }, + { + "id": 26680, + "logprob": -0.24633789, + "text": " gir" + }, + { + "id": 23226, + "logprob": -0.02960205, + "text": "affe" + }, + { + "id": 25, + "logprob": -0.17333984, + "text": "." + }, + { + "id": 193, + "logprob": -1.3935547, + "text": "\n" + }, + { + "id": 23626, + "logprob": -10.0625, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -4.59375, + "text": ":" + }, + { + "id": 23090, + "logprob": -6.9375, + "text": " Hello" + }, + { + "id": 23, + "logprob": -0.99365234, + "text": "," + }, + { + "id": 29033, + "logprob": -2.2324219, + "text": " Gir" + }, + { + "id": 1622, + "logprob": -0.10809326, + "text": "af" + }, + { + "id": 249, + "logprob": -0.042663574, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.0024776459, + "text": "ron" + }, + { + "id": 12, + "logprob": -1.4277344, + "text": "!" + }, + { + "id": 193, + "logprob": -1.1015625, + "text": "\n" + }, + { + "id": 50, + "logprob": -0.05709839, + "text": "G" + }, + { + "id": 330, + "logprob": -0.13208008, + "text": "ir" + }, + { + "id": 1622, + "logprob": -0.0071487427, + "text": "af" + }, + { + "id": 249, + "logprob": -0.008468628, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.00068998337, + "text": "ron" + }, + { + "id": 37, + "logprob": -0.0074691772, + "text": ":" + } + ], + "seed": null, + "tokens": [ + { + "id": 23090, + "logprob": -1.8251953, + "special": false, + "text": " Hello" + }, + { + "id": 23, + "logprob": -0.3173828, + "special": false, + "text": "," + }, + { + "id": 8156, + "logprob": -0.23803711, + "special": false, + "text": " Daniel" + }, + { + "id": 12, + "logprob": -0.56933594, + "special": false, + "text": "!" + }, + { + "id": 193, + "logprob": -0.61279297, + "special": false, + "text": "\n" + }, + { + "id": 23626, + "logprob": -0.41967773, + "special": false, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -0.0023403168, + "special": false, + "text": ":" + }, + { + "id": 1634, + "logprob": -2.0605469, + "special": false, + "text": " What" + }, + { + "id": 18, + "logprob": -1.5292969, + "special": false, + "text": "'" + }, + { + "id": 94, + "logprob": -0.007904053, + "special": false, + "text": "s" + } + ] + }, + "generated_text": " Hello, Daniel!\nDaniel: What's" +} diff --git a/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon_all_params.json b/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon_all_params.json new file mode 100644 index 00000000..cd35186d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon_all_params.json @@ -0,0 +1,98 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 330, + "logprob": null, + "text": "ir" + }, + { + "id": 1622, + "logprob": -7.8125, + "text": "af" + }, + { + "id": 249, + "logprob": -4.5, + "text": "at" + }, + { + "id": 1480, + "logprob": -10.875, + "text": "ron" + }, + { + "id": 37, + "logprob": -3.6875, + "text": ":" + } + ], + "seed": 0, + "tokens": [ + { + "id": 836, + "logprob": -1.265625, + "special": false, + "text": " i" + }, + { + "id": 18, + "logprob": -0.119628906, + "special": false, + "text": "'" + }, + { + "id": 298, + "logprob": -2.265625, + "special": false, + "text": "ve" + }, + { + "id": 650, + "logprob": -0.49804688, + "special": false, + "text": " been" + }, + { + "id": 1241, + "logprob": 0.0, + "special": false, + "text": " using" + }, + { + "id": 334, + "logprob": 0.0, + "special": false, + "text": " it" + }, + { + "id": 312, + "logprob": -1.2421875, + "special": false, + "text": " for" + }, + { + "id": 909, + "logprob": -0.99609375, + "special": false, + "text": " years" + }, + { + "id": 193, + "logprob": -0.30273438, + "special": false, + "text": "\n" + }, + { + "id": 807, + "logprob": -1.078125, + "special": false, + "text": "ik" + } + ] + }, + "generated_text": "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron: i've been using it for years\nik" +} diff --git a/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon_load.json b/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon_load.json new file mode 100644 index 00000000..0fb1be75 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon_load.json @@ -0,0 +1,1514 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50, + "logprob": null, + "text": "G" + }, + { + "id": 330, + "logprob": -5.71875, + "text": "ir" + }, + { + "id": 1622, + "logprob": -5.625, + "text": "af" + }, + { + "id": 249, + "logprob": -6.53125, + "text": "at" + }, + { + "id": 1480, + "logprob": -8.0625, + "text": "ron" + }, + { + "id": 304, + "logprob": -2.328125, + "text": " is" + }, + { + "id": 23866, + "logprob": -9.625, + "text": " obsessed" + }, + { + "id": 335, + "logprob": -0.048583984, + "text": " with" + }, + { + "id": 26680, + "logprob": -3.984375, + "text": " gir" + }, + { + "id": 1903, + "logprob": -0.076171875, + "text": "aff" + }, + { + "id": 255, + "logprob": -0.0066833496, + "text": "es" + }, + { + "id": 23, + "logprob": -1.546875, + "text": "," + }, + { + "id": 248, + "logprob": -4.34375, + "text": " the" + }, + { + "id": 758, + "logprob": -3.734375, + "text": " most" + }, + { + "id": 21735, + "logprob": -5.125, + "text": " glorious" + }, + { + "id": 5985, + "logprob": -2.078125, + "text": " animal" + }, + { + "id": 313, + "logprob": -1.1953125, + "text": " on" + }, + { + "id": 248, + "logprob": -0.78125, + "text": " the" + }, + { + "id": 1936, + "logprob": -2.390625, + "text": " face" + }, + { + "id": 275, + "logprob": -0.0044555664, + "text": " of" + }, + { + "id": 414, + "logprob": -1.984375, + "text": " this" + }, + { + "id": 6490, + "logprob": -2.03125, + "text": " Earth" + }, + { + "id": 25, + "logprob": -0.28320312, + "text": "." + }, + { + "id": 401, + "logprob": -7.90625, + "text": " G" + }, + { + "id": 6013, + "logprob": -2.265625, + "text": "ira" + }, + { + "id": 694, + "logprob": -0.640625, + "text": "ft" + }, + { + "id": 1480, + "logprob": -0.203125, + "text": "ron" + }, + { + "id": 9369, + "logprob": -4.53125, + "text": " believes" + }, + { + "id": 455, + "logprob": -4.5625, + "text": " all" + }, + { + "id": 599, + "logprob": -2.75, + "text": " other" + }, + { + "id": 5632, + "logprob": -0.21875, + "text": " animals" + }, + { + "id": 362, + "logprob": -0.76171875, + "text": " are" + }, + { + "id": 23981, + "logprob": -4.96875, + "text": " irrelevant" + }, + { + "id": 635, + "logprob": -4.21875, + "text": " when" + }, + { + "id": 4354, + "logprob": -0.51953125, + "text": " compared" + }, + { + "id": 271, + "logprob": -0.103515625, + "text": " to" + }, + { + "id": 248, + "logprob": -0.58984375, + "text": " the" + }, + { + "id": 21735, + "logprob": -3.6875, + "text": " glorious" + }, + { + "id": 64398, + "logprob": -1.8359375, + "text": " majesty" + }, + { + "id": 275, + "logprob": -0.24316406, + "text": " of" + }, + { + "id": 248, + "logprob": -0.3515625, + "text": " the" + }, + { + "id": 26680, + "logprob": -0.24414062, + "text": " gir" + }, + { + "id": 23226, + "logprob": -0.03100586, + "text": "affe" + }, + { + "id": 25, + "logprob": -0.17382812, + "text": "." + }, + { + "id": 193, + "logprob": -1.3984375, + "text": "\n" + }, + { + "id": 23626, + "logprob": -10.0625, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -4.59375, + "text": ":" + }, + { + "id": 23090, + "logprob": -6.9375, + "text": " Hello" + }, + { + "id": 23, + "logprob": -1.0, + "text": "," + }, + { + "id": 29033, + "logprob": -2.21875, + "text": " Gir" + }, + { + "id": 1622, + "logprob": -0.10644531, + "text": "af" + }, + { + "id": 249, + "logprob": -0.041992188, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.0025024414, + "text": "ron" + }, + { + "id": 12, + "logprob": -1.4296875, + "text": "!" + }, + { + "id": 193, + "logprob": -1.1015625, + "text": "\n" + }, + { + "id": 50, + "logprob": -0.05810547, + "text": "G" + }, + { + "id": 330, + "logprob": -0.12597656, + "text": "ir" + }, + { + "id": 1622, + "logprob": -0.007080078, + "text": "af" + }, + { + "id": 249, + "logprob": -0.008300781, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.0006866455, + "text": "ron" + }, + { + "id": 37, + "logprob": -0.0074157715, + "text": ":" + } + ], + "seed": null, + "tokens": [ + { + "id": 23090, + "logprob": -1.8203125, + "special": false, + "text": " Hello" + }, + { + "id": 23, + "logprob": -0.32226562, + "special": false, + "text": "," + }, + { + "id": 8156, + "logprob": -0.23828125, + "special": false, + "text": " Daniel" + }, + { + "id": 12, + "logprob": -0.5859375, + "special": false, + "text": "!" + }, + { + "id": 193, + "logprob": -0.6171875, + "special": false, + "text": "\n" + }, + { + "id": 23626, + "logprob": -0.39648438, + "special": false, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -0.0023345947, + "special": false, + "text": ":" + }, + { + "id": 295, + "logprob": -2.078125, + "special": false, + "text": " I" + }, + { + "id": 18, + "logprob": -1.453125, + "special": false, + "text": "'" + }, + { + "id": 88, + "logprob": -0.47460938, + "special": false, + "text": "m" + } + ] + }, + "generated_text": " Hello, Daniel!\nDaniel: I'm" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50, + "logprob": null, + "text": "G" + }, + { + "id": 330, + "logprob": -5.71875, + "text": "ir" + }, + { + "id": 1622, + "logprob": -5.625, + "text": "af" + }, + { + "id": 249, + "logprob": -6.53125, + "text": "at" + }, + { + "id": 1480, + "logprob": -8.0625, + "text": "ron" + }, + { + "id": 304, + "logprob": -2.328125, + "text": " is" + }, + { + "id": 23866, + "logprob": -9.625, + "text": " obsessed" + }, + { + "id": 335, + "logprob": -0.048583984, + "text": " with" + }, + { + "id": 26680, + "logprob": -3.984375, + "text": " gir" + }, + { + "id": 1903, + "logprob": -0.076171875, + "text": "aff" + }, + { + "id": 255, + "logprob": -0.0066833496, + "text": "es" + }, + { + "id": 23, + "logprob": -1.546875, + "text": "," + }, + { + "id": 248, + "logprob": -4.34375, + "text": " the" + }, + { + "id": 758, + "logprob": -3.734375, + "text": " most" + }, + { + "id": 21735, + "logprob": -5.125, + "text": " glorious" + }, + { + "id": 5985, + "logprob": -2.078125, + "text": " animal" + }, + { + "id": 313, + "logprob": -1.1953125, + "text": " on" + }, + { + "id": 248, + "logprob": -0.78125, + "text": " the" + }, + { + "id": 1936, + "logprob": -2.390625, + "text": " face" + }, + { + "id": 275, + "logprob": -0.0044555664, + "text": " of" + }, + { + "id": 414, + "logprob": -1.984375, + "text": " this" + }, + { + "id": 6490, + "logprob": -2.03125, + "text": " Earth" + }, + { + "id": 25, + "logprob": -0.28320312, + "text": "." + }, + { + "id": 401, + "logprob": -7.90625, + "text": " G" + }, + { + "id": 6013, + "logprob": -2.265625, + "text": "ira" + }, + { + "id": 694, + "logprob": -0.640625, + "text": "ft" + }, + { + "id": 1480, + "logprob": -0.203125, + "text": "ron" + }, + { + "id": 9369, + "logprob": -4.53125, + "text": " believes" + }, + { + "id": 455, + "logprob": -4.5625, + "text": " all" + }, + { + "id": 599, + "logprob": -2.75, + "text": " other" + }, + { + "id": 5632, + "logprob": -0.21875, + "text": " animals" + }, + { + "id": 362, + "logprob": -0.76171875, + "text": " are" + }, + { + "id": 23981, + "logprob": -4.96875, + "text": " irrelevant" + }, + { + "id": 635, + "logprob": -4.21875, + "text": " when" + }, + { + "id": 4354, + "logprob": -0.51953125, + "text": " compared" + }, + { + "id": 271, + "logprob": -0.103515625, + "text": " to" + }, + { + "id": 248, + "logprob": -0.58984375, + "text": " the" + }, + { + "id": 21735, + "logprob": -3.6875, + "text": " glorious" + }, + { + "id": 64398, + "logprob": -1.8359375, + "text": " majesty" + }, + { + "id": 275, + "logprob": -0.24316406, + "text": " of" + }, + { + "id": 248, + "logprob": -0.3515625, + "text": " the" + }, + { + "id": 26680, + "logprob": -0.24414062, + "text": " gir" + }, + { + "id": 23226, + "logprob": -0.03100586, + "text": "affe" + }, + { + "id": 25, + "logprob": -0.17382812, + "text": "." + }, + { + "id": 193, + "logprob": -1.3984375, + "text": "\n" + }, + { + "id": 23626, + "logprob": -10.0625, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -4.59375, + "text": ":" + }, + { + "id": 23090, + "logprob": -6.9375, + "text": " Hello" + }, + { + "id": 23, + "logprob": -1.0, + "text": "," + }, + { + "id": 29033, + "logprob": -2.21875, + "text": " Gir" + }, + { + "id": 1622, + "logprob": -0.10644531, + "text": "af" + }, + { + "id": 249, + "logprob": -0.041992188, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.0025024414, + "text": "ron" + }, + { + "id": 12, + "logprob": -1.4296875, + "text": "!" + }, + { + "id": 193, + "logprob": -1.1015625, + "text": "\n" + }, + { + "id": 50, + "logprob": -0.05810547, + "text": "G" + }, + { + "id": 330, + "logprob": -0.12597656, + "text": "ir" + }, + { + "id": 1622, + "logprob": -0.007080078, + "text": "af" + }, + { + "id": 249, + "logprob": -0.008300781, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.0006866455, + "text": "ron" + }, + { + "id": 37, + "logprob": -0.0074157715, + "text": ":" + } + ], + "seed": null, + "tokens": [ + { + "id": 23090, + "logprob": -1.8203125, + "special": false, + "text": " Hello" + }, + { + "id": 23, + "logprob": -0.32226562, + "special": false, + "text": "," + }, + { + "id": 8156, + "logprob": -0.23828125, + "special": false, + "text": " Daniel" + }, + { + "id": 12, + "logprob": -0.5859375, + "special": false, + "text": "!" + }, + { + "id": 193, + "logprob": -0.6171875, + "special": false, + "text": "\n" + }, + { + "id": 23626, + "logprob": -0.39648438, + "special": false, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -0.0023345947, + "special": false, + "text": ":" + }, + { + "id": 295, + "logprob": -2.078125, + "special": false, + "text": " I" + }, + { + "id": 18, + "logprob": -1.453125, + "special": false, + "text": "'" + }, + { + "id": 88, + "logprob": -0.47460938, + "special": false, + "text": "m" + } + ] + }, + "generated_text": " Hello, Daniel!\nDaniel: I'm" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50, + "logprob": null, + "text": "G" + }, + { + "id": 330, + "logprob": -5.71875, + "text": "ir" + }, + { + "id": 1622, + "logprob": -5.625, + "text": "af" + }, + { + "id": 249, + "logprob": -6.53125, + "text": "at" + }, + { + "id": 1480, + "logprob": -8.0625, + "text": "ron" + }, + { + "id": 304, + "logprob": -2.328125, + "text": " is" + }, + { + "id": 23866, + "logprob": -9.625, + "text": " obsessed" + }, + { + "id": 335, + "logprob": -0.048583984, + "text": " with" + }, + { + "id": 26680, + "logprob": -3.984375, + "text": " gir" + }, + { + "id": 1903, + "logprob": -0.076171875, + "text": "aff" + }, + { + "id": 255, + "logprob": -0.0066833496, + "text": "es" + }, + { + "id": 23, + "logprob": -1.546875, + "text": "," + }, + { + "id": 248, + "logprob": -4.34375, + "text": " the" + }, + { + "id": 758, + "logprob": -3.734375, + "text": " most" + }, + { + "id": 21735, + "logprob": -5.125, + "text": " glorious" + }, + { + "id": 5985, + "logprob": -2.078125, + "text": " animal" + }, + { + "id": 313, + "logprob": -1.1953125, + "text": " on" + }, + { + "id": 248, + "logprob": -0.78125, + "text": " the" + }, + { + "id": 1936, + "logprob": -2.390625, + "text": " face" + }, + { + "id": 275, + "logprob": -0.0044555664, + "text": " of" + }, + { + "id": 414, + "logprob": -1.984375, + "text": " this" + }, + { + "id": 6490, + "logprob": -2.03125, + "text": " Earth" + }, + { + "id": 25, + "logprob": -0.28320312, + "text": "." + }, + { + "id": 401, + "logprob": -7.90625, + "text": " G" + }, + { + "id": 6013, + "logprob": -2.265625, + "text": "ira" + }, + { + "id": 694, + "logprob": -0.640625, + "text": "ft" + }, + { + "id": 1480, + "logprob": -0.203125, + "text": "ron" + }, + { + "id": 9369, + "logprob": -4.53125, + "text": " believes" + }, + { + "id": 455, + "logprob": -4.5625, + "text": " all" + }, + { + "id": 599, + "logprob": -2.75, + "text": " other" + }, + { + "id": 5632, + "logprob": -0.21875, + "text": " animals" + }, + { + "id": 362, + "logprob": -0.76171875, + "text": " are" + }, + { + "id": 23981, + "logprob": -4.96875, + "text": " irrelevant" + }, + { + "id": 635, + "logprob": -4.21875, + "text": " when" + }, + { + "id": 4354, + "logprob": -0.51953125, + "text": " compared" + }, + { + "id": 271, + "logprob": -0.103515625, + "text": " to" + }, + { + "id": 248, + "logprob": -0.58984375, + "text": " the" + }, + { + "id": 21735, + "logprob": -3.6875, + "text": " glorious" + }, + { + "id": 64398, + "logprob": -1.8359375, + "text": " majesty" + }, + { + "id": 275, + "logprob": -0.24316406, + "text": " of" + }, + { + "id": 248, + "logprob": -0.3515625, + "text": " the" + }, + { + "id": 26680, + "logprob": -0.24414062, + "text": " gir" + }, + { + "id": 23226, + "logprob": -0.03100586, + "text": "affe" + }, + { + "id": 25, + "logprob": -0.17382812, + "text": "." + }, + { + "id": 193, + "logprob": -1.3984375, + "text": "\n" + }, + { + "id": 23626, + "logprob": -10.0625, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -4.59375, + "text": ":" + }, + { + "id": 23090, + "logprob": -6.9375, + "text": " Hello" + }, + { + "id": 23, + "logprob": -1.0, + "text": "," + }, + { + "id": 29033, + "logprob": -2.21875, + "text": " Gir" + }, + { + "id": 1622, + "logprob": -0.10644531, + "text": "af" + }, + { + "id": 249, + "logprob": -0.041992188, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.0025024414, + "text": "ron" + }, + { + "id": 12, + "logprob": -1.4296875, + "text": "!" + }, + { + "id": 193, + "logprob": -1.1015625, + "text": "\n" + }, + { + "id": 50, + "logprob": -0.05810547, + "text": "G" + }, + { + "id": 330, + "logprob": -0.12597656, + "text": "ir" + }, + { + "id": 1622, + "logprob": -0.007080078, + "text": "af" + }, + { + "id": 249, + "logprob": -0.008300781, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.0006866455, + "text": "ron" + }, + { + "id": 37, + "logprob": -0.0074157715, + "text": ":" + } + ], + "seed": null, + "tokens": [ + { + "id": 23090, + "logprob": -1.8203125, + "special": false, + "text": " Hello" + }, + { + "id": 23, + "logprob": -0.32226562, + "special": false, + "text": "," + }, + { + "id": 8156, + "logprob": -0.23828125, + "special": false, + "text": " Daniel" + }, + { + "id": 12, + "logprob": -0.5859375, + "special": false, + "text": "!" + }, + { + "id": 193, + "logprob": -0.6171875, + "special": false, + "text": "\n" + }, + { + "id": 23626, + "logprob": -0.39648438, + "special": false, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -0.0023345947, + "special": false, + "text": ":" + }, + { + "id": 295, + "logprob": -2.078125, + "special": false, + "text": " I" + }, + { + "id": 18, + "logprob": -1.453125, + "special": false, + "text": "'" + }, + { + "id": 88, + "logprob": -0.47460938, + "special": false, + "text": "m" + } + ] + }, + "generated_text": " Hello, Daniel!\nDaniel: I'm" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50, + "logprob": null, + "text": "G" + }, + { + "id": 330, + "logprob": -5.71875, + "text": "ir" + }, + { + "id": 1622, + "logprob": -5.625, + "text": "af" + }, + { + "id": 249, + "logprob": -6.53125, + "text": "at" + }, + { + "id": 1480, + "logprob": -8.0625, + "text": "ron" + }, + { + "id": 304, + "logprob": -2.328125, + "text": " is" + }, + { + "id": 23866, + "logprob": -9.625, + "text": " obsessed" + }, + { + "id": 335, + "logprob": -0.048583984, + "text": " with" + }, + { + "id": 26680, + "logprob": -3.984375, + "text": " gir" + }, + { + "id": 1903, + "logprob": -0.076171875, + "text": "aff" + }, + { + "id": 255, + "logprob": -0.0066833496, + "text": "es" + }, + { + "id": 23, + "logprob": -1.546875, + "text": "," + }, + { + "id": 248, + "logprob": -4.34375, + "text": " the" + }, + { + "id": 758, + "logprob": -3.734375, + "text": " most" + }, + { + "id": 21735, + "logprob": -5.125, + "text": " glorious" + }, + { + "id": 5985, + "logprob": -2.078125, + "text": " animal" + }, + { + "id": 313, + "logprob": -1.1953125, + "text": " on" + }, + { + "id": 248, + "logprob": -0.78125, + "text": " the" + }, + { + "id": 1936, + "logprob": -2.390625, + "text": " face" + }, + { + "id": 275, + "logprob": -0.0044555664, + "text": " of" + }, + { + "id": 414, + "logprob": -1.984375, + "text": " this" + }, + { + "id": 6490, + "logprob": -2.03125, + "text": " Earth" + }, + { + "id": 25, + "logprob": -0.28320312, + "text": "." + }, + { + "id": 401, + "logprob": -7.90625, + "text": " G" + }, + { + "id": 6013, + "logprob": -2.265625, + "text": "ira" + }, + { + "id": 694, + "logprob": -0.640625, + "text": "ft" + }, + { + "id": 1480, + "logprob": -0.203125, + "text": "ron" + }, + { + "id": 9369, + "logprob": -4.53125, + "text": " believes" + }, + { + "id": 455, + "logprob": -4.5625, + "text": " all" + }, + { + "id": 599, + "logprob": -2.75, + "text": " other" + }, + { + "id": 5632, + "logprob": -0.21875, + "text": " animals" + }, + { + "id": 362, + "logprob": -0.76171875, + "text": " are" + }, + { + "id": 23981, + "logprob": -4.96875, + "text": " irrelevant" + }, + { + "id": 635, + "logprob": -4.21875, + "text": " when" + }, + { + "id": 4354, + "logprob": -0.51953125, + "text": " compared" + }, + { + "id": 271, + "logprob": -0.103515625, + "text": " to" + }, + { + "id": 248, + "logprob": -0.58984375, + "text": " the" + }, + { + "id": 21735, + "logprob": -3.6875, + "text": " glorious" + }, + { + "id": 64398, + "logprob": -1.8359375, + "text": " majesty" + }, + { + "id": 275, + "logprob": -0.24316406, + "text": " of" + }, + { + "id": 248, + "logprob": -0.3515625, + "text": " the" + }, + { + "id": 26680, + "logprob": -0.24414062, + "text": " gir" + }, + { + "id": 23226, + "logprob": -0.03100586, + "text": "affe" + }, + { + "id": 25, + "logprob": -0.17382812, + "text": "." + }, + { + "id": 193, + "logprob": -1.3984375, + "text": "\n" + }, + { + "id": 23626, + "logprob": -10.0625, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -4.59375, + "text": ":" + }, + { + "id": 23090, + "logprob": -6.9375, + "text": " Hello" + }, + { + "id": 23, + "logprob": -1.0, + "text": "," + }, + { + "id": 29033, + "logprob": -2.21875, + "text": " Gir" + }, + { + "id": 1622, + "logprob": -0.10644531, + "text": "af" + }, + { + "id": 249, + "logprob": -0.041992188, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.0025024414, + "text": "ron" + }, + { + "id": 12, + "logprob": -1.4296875, + "text": "!" + }, + { + "id": 193, + "logprob": -1.1015625, + "text": "\n" + }, + { + "id": 50, + "logprob": -0.05810547, + "text": "G" + }, + { + "id": 330, + "logprob": -0.12597656, + "text": "ir" + }, + { + "id": 1622, + "logprob": -0.007080078, + "text": "af" + }, + { + "id": 249, + "logprob": -0.008300781, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.0006866455, + "text": "ron" + }, + { + "id": 37, + "logprob": -0.0074157715, + "text": ":" + } + ], + "seed": null, + "tokens": [ + { + "id": 23090, + "logprob": -1.8203125, + "special": false, + "text": " Hello" + }, + { + "id": 23, + "logprob": -0.32226562, + "special": false, + "text": "," + }, + { + "id": 8156, + "logprob": -0.23828125, + "special": false, + "text": " Daniel" + }, + { + "id": 12, + "logprob": -0.5859375, + "special": false, + "text": "!" + }, + { + "id": 193, + "logprob": -0.6171875, + "special": false, + "text": "\n" + }, + { + "id": 23626, + "logprob": -0.39648438, + "special": false, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -0.0023345947, + "special": false, + "text": ":" + }, + { + "id": 295, + "logprob": -2.078125, + "special": false, + "text": " I" + }, + { + "id": 18, + "logprob": -1.453125, + "special": false, + "text": "'" + }, + { + "id": 88, + "logprob": -0.47460938, + "special": false, + "text": "m" + } + ] + }, + "generated_text": " Hello, Daniel!\nDaniel: I'm" + } +] diff --git a/integration-tests/models/test_flash_falcon.py b/integration-tests/models/test_flash_falcon.py new file mode 100644 index 00000000..ce27731d --- /dev/null +++ b/integration-tests/models/test_flash_falcon.py @@ -0,0 +1,63 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_falcon_handle(launcher): + with launcher("tiiuae/falcon-7b", trust_remote_code=True) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_falcon(flash_falcon_handle): + await flash_falcon_handle.health(120) + return flash_falcon_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_falcon(flash_falcon, response_snapshot): + response = await flash_falcon.generate( + "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:", + max_new_tokens=10, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_falcon_all_params(flash_falcon, response_snapshot): + response = await flash_falcon.generate( + "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot): + responses = await generate_load( + flash_falcon, + "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 46a4563b..4adf1381 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -10,6 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM +from text_generation_server.models.rw import RW from text_generation_server.models.opt import OPT, OPTSharded from text_generation_server.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.santacoder import SantaCoder @@ -30,6 +31,7 @@ try: ) from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded + from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded from text_generation_server.models.flash_llama import ( FlashLlama, FlashLlamaSharded, @@ -68,6 +70,8 @@ __all__ = [ if FLASH_ATTENTION: __all__.append(FlashNeoX) __all__.append(FlashNeoXSharded) + __all__.append(FlashRW) + __all__.append(FlashRWSharded) __all__.append(FlashSantacoder) __all__.append(FlashSantacoderSharded) __all__.append(FlashLlama) @@ -194,6 +198,39 @@ def get_model( trust_remote_code=trust_remote_code, ) + if model_type in ["RefinedWeb", "RefinedWebModel"]: + if sharded: + if FLASH_ATTENTION: + if config.alibi or ( + config.model_type == "RefinedWebModel" + and config.n_head_kv != config.n_head + ): + raise NotImplementedError("sharded is not supported for this model") + return FlashRWSharded( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb") + ) + else: + if FLASH_ATTENTION and not config.alibi: + return FlashRW( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) + else: + return RW( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) + if model_type == "llama": if sharded: if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 54670b79..2dcb6ed8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -134,20 +134,23 @@ class FlashLlamaAttention(torch.nn.Module): ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) - qkv_rot = self.rotary_emb(qkv, cos, sin) + + # Inplace rotary + self.rotary_emb(qkv[:, 0], cos, sin) + self.rotary_emb(qkv[:, 1], cos, sin) # Prefill if layer_past_present_indices is None: # Copy to layer past - layer_past[...] = qkv_rot[:, 1:] + layer_past[...] = qkv[:, 1:] # output - attn_output = torch.empty_like(qkv_rot[:, 0]) + attn_output = torch.empty_like(qkv[:, 0]) # flash attention flash_attn_cuda.fwd( - qkv_rot[:, 0], - qkv_rot[:, 1], - qkv_rot[:, 2], + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], attn_output, cu_seqlens, cu_seqlens, @@ -163,9 +166,9 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - query = qkv_rot[:, 0] + query = qkv[:, 0] # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = qkv_rot[:, 1:] + layer_past[layer_past_present_indices] = qkv[:, 1:] # output attn_output = torch.empty_like(query) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b7834157..26e21753 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -101,20 +101,23 @@ class FlashNeoxAttention(torch.nn.Module): ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) - qkv_rot = self.rotary_emb(qkv, cos, sin) + + # Inplace rotary + self.rotary_emb(qkv[:, 0], cos, sin) + self.rotary_emb(qkv[:, 1], cos, sin) # Prefill if layer_past_present_indices is None: # Copy to layer past - layer_past[...] = qkv_rot[:, 1:] + layer_past[...] = qkv[:, 1:] # output - attn_output = torch.empty_like(qkv_rot[:, 0]) + attn_output = torch.empty_like(qkv[:, 0]) # flash attention flash_attn_cuda.fwd( - qkv_rot[:, 0], - qkv_rot[:, 1], - qkv_rot[:, 2], + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], attn_output, cu_seqlens, cu_seqlens, @@ -130,9 +133,9 @@ class FlashNeoxAttention(torch.nn.Module): ) # Decode else: - query = qkv_rot[:, 0] + query = qkv[:, 0] # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = qkv_rot[:, 1:] + layer_past[layer_past_present_indices] = qkv[:, 1:] # output attn_output = torch.empty_like(query) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py new file mode 100644 index 00000000..545da26a --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -0,0 +1,774 @@ +import os + +import torch +import torch.distributed + +from torch import nn +from transformers.modeling_utils import PreTrainedModel +from transformers.configuration_utils import PretrainedConfig +from typing import Optional + +# Flash attention imports +import flash_attn_cuda + +from text_generation_server.utils.layers import ( + FastLinear, + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + FastLayerNorm, + PositionRotaryEmbedding, +) + + +class RWConfig(PretrainedConfig): + attribute_map = { + "num_hidden_layers": "n_layer", + "num_attention_heads": "n_head", + } + + def __init__( + self, + model_type="RefinedWeb", + vocab_size=250880, + hidden_size=64, + n_layer=2, + n_head=8, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + hidden_dropout=0.0, + attention_dropout=0.0, + n_head_kv=None, + multi_query=False, + alibi=False, + bias=False, + parallel_attn=False, + **kwargs, + ): + if alibi: + raise NotImplementedError( + "alibi is not supported by this version of the model" + ) + + self.model_type = model_type + self.alibi = False + self.rotary = True + + self.vocab_size = vocab_size + # Backward compatibility with n_embed kwarg + n_embed = kwargs.pop("n_embed", None) + self.hidden_size = hidden_size if n_embed is None else n_embed + self.n_layer = n_layer + self.n_head = n_head + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.bias = bias + self.parallel_attn = parallel_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + if n_head_kv is not None: + self.n_head_kv = n_head_kv + else: + self.n_head_kv = 1 if multi_query else n_head + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + +class FlashRWAttention(torch.nn.Module): + def __init__( + self, + num_heads, + num_heads_kv, + hidden_size, + bias, + process_group=None, + reduce=True, + ): + super().__init__() + self.num_heads = num_heads + self.num_heads_kv = num_heads_kv + self.hidden_size = hidden_size + self.head_size = hidden_size // num_heads + + self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) + self.softmax_scale = self.head_size ** (-0.5) + + if process_group is None: + self.query_key_value = FastLinear( + hidden_size, + self.head_size * (self.num_heads + 2 * self.num_heads_kv), + bias=bias, + ) + self.dense = FastLinear(hidden_size, hidden_size, bias=bias) + else: + self.query_key_value = TensorParallelColumnLinear( + hidden_size, + self.head_size * (self.num_heads + 2 * self.num_heads_kv), + bias=bias, + process_group=process_group, + ) + self.dense = TensorParallelRowLinear( + hidden_size, + hidden_size, + bias=bias, + process_group=process_group, + reduce=reduce, + ) + self.num_heads = self.num_heads // process_group.size() + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + qkv = self.query_key_value(hidden_states) + + # Split query from key_value + query, kv = qkv.split( + [self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv], + dim=1, + ) + + # Prepare query and key_value for indexing + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_heads_kv, self.head_size) + + # Inplace rotary + self.rotary_emb(query, cos, sin) + self.rotary_emb(kv[:, 0], cos, sin) + + # Prefill + if layer_past_present_indices is None: + # Copy to layer past + layer_past[...] = kv + # Expand to query shape + kv = kv.expand(-1, 2, self.num_heads, self.head_size) + + # output + attn_output = torch.empty_like(query) + # flash attention + flash_attn_cuda.fwd( + query, + kv[:, 0], + kv[:, 1], + attn_output, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + self.softmax_scale, + False, + True, + False, + 0, + None, + ) + # Decode + else: + # Add present to the layer_past tensor at the correct indices + layer_past[layer_past_present_indices] = kv + # Expand to query shape + kv = layer_past.expand(-1, 2, self.num_heads, self.head_size) + + # output + attn_output = torch.empty_like(query) + # flash attention + flash_attn_cuda.fwd( + query, + kv[:, 0], + kv[:, 1], + attn_output, + cu_seqlens_q, + cu_seqlens, + 1, + max_s, + 0.0, + self.softmax_scale, + False, + False, + False, + 0, + None, + ) + + return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) + + +class FlashRWLargeAttention(torch.nn.Module): + def __init__( + self, + num_heads, + num_heads_kv, + hidden_size, + bias, + process_group=None, + reduce=True, + ): + super().__init__() + + self.hidden_size = hidden_size + self.head_size = hidden_size // num_heads + + self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) + self.softmax_scale = self.head_size ** (-0.5) + + self.num_groups = num_heads // (num_heads_kv * 2) + self.num_heads = num_heads // self.num_groups + self.num_heads_kv = num_heads_kv // self.num_groups + + if process_group is None: + self.query_key_value = FastLinear( + hidden_size, + self.num_groups + * self.head_size + * (self.num_heads + 2 * self.num_heads_kv), + bias=bias, + ) + self.dense = FastLinear(hidden_size, hidden_size, bias=bias) + else: + if process_group.size() > self.num_groups: + raise NotImplementedError( + f"Tensor Parallelism is not implemented for world_size > n groups" + ) + + self.query_key_value = TensorParallelColumnLinear( + hidden_size, + self.num_groups + * self.head_size + * (self.num_heads + 2 * self.num_heads_kv), + bias=bias, + process_group=process_group, + ) + self.dense = TensorParallelRowLinear( + hidden_size, + hidden_size, + bias=bias, + process_group=process_group, + reduce=reduce, + ) + + self.num_groups = self.num_groups // process_group.size() + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + qkv = self.query_key_value(hidden_states) + qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) + + # Split on group dimension + query, kv = qkv.split( + [self.num_heads, 2], + dim=2, + ) + # Merge groups and heads + query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size) + + # Inplace rotary + self.rotary_emb(query, cos, sin) + self.rotary_emb(kv[:, :, 0], cos, sin) + + # Prefill + if layer_past_present_indices is None: + # Copy to layer past + layer_past[...] = kv + # Expand to query shape + kv = ( + kv.unsqueeze(2) + .expand(-1, self.num_groups, self.num_heads, 2, self.head_size) + .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) + ) + + # output + attn_output = torch.empty_like(query) + # flash attention + flash_attn_cuda.fwd( + query, + kv[:, :, 0], + kv[:, :, 1], + attn_output, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + self.softmax_scale, + False, + True, + False, + 0, + None, + ) + # Decode + else: + # Add present to the layer_past tensor at the correct indices + layer_past[layer_past_present_indices] = kv + # Expand to query shape + kv = ( + layer_past.unsqueeze(2) + .expand(-1, self.num_groups, self.num_heads, 2, self.head_size) + .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) + ) + + # output + attn_output = torch.empty_like(query) + # flash attention + flash_attn_cuda.fwd( + query, + kv[:, :, 0], + kv[:, :, 1], + attn_output, + cu_seqlens_q, + cu_seqlens, + 1, + max_s, + 0.0, + self.softmax_scale, + False, + False, + False, + 0, + None, + ) + + return self.dense( + attn_output.view(-1, self.num_groups * self.num_heads * self.head_size) + ) + + +class FlashMLP(nn.Module): + def __init__(self, hidden_size, bias, process_group=None, reduce=True): + super().__init__() + self.act = torch.nn.functional.gelu + + if process_group is None: + self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias) + self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias) + else: + self.dense_h_to_4h = TensorParallelColumnLinear( + hidden_size, + 4 * hidden_size, + bias=bias, + process_group=process_group, + ) + self.dense_4h_to_h = TensorParallelRowLinear( + 4 * hidden_size, + hidden_size, + bias=bias, + process_group=process_group, + reduce=reduce, + ) + self.process_group = process_group + + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class FlashRWLayer(nn.Module): + def __init__( + self, + num_heads, + num_heads_kv, + hidden_size, + bias, + layer_norm_eps, + parallel_attn, + process_group=None, + ): + super().__init__() + + self.parallel_attn = parallel_attn + + self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) + self.self_attention = FlashRWAttention( + num_heads, + num_heads_kv, + hidden_size, + bias, + process_group=process_group, + reduce=False, + ) + self.post_attention_layernorm = ( + FastLayerNorm(hidden_size, eps=layer_norm_eps) + if not parallel_attn + else None + ) + + self.mlp = FlashMLP( + hidden_size, bias, process_group=process_group, reduce=False + ) + + self.process_group = process_group + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + if self.parallel_attn: + ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) + + attn_output = self.self_attention( + ln_hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ) + + mlp_output = self.mlp(ln_hidden_states) + intermediate = mlp_output + attn_output + + # Only reduce once and after the addition instead of once per layer + if self.process_group is not None: + torch.distributed.all_reduce(intermediate, group=self.process_group) + + return intermediate, residual + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attention( + hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + + mlp_output = self.mlp(hidden_states) + + return mlp_output, residual + + +class FlashRWLargeLayer(nn.Module): + def __init__( + self, + num_heads, + num_heads_kv, + hidden_size, + bias, + layer_norm_eps, + process_group=None, + ): + super().__init__() + self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps) + self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps) + + self.self_attention = FlashRWLargeAttention( + num_heads, + num_heads_kv, + hidden_size, + bias, + process_group=process_group, + reduce=False, + ) + + self.mlp = FlashMLP( + hidden_size, bias, process_group=process_group, reduce=False + ) + + self.process_group = process_group + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + ln_attn, residual = self.ln_attn(hidden_states, residual) + ln_mlp, _ = self.ln_mlp(residual) + + # Self attention. + attn_output = self.self_attention( + ln_attn, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ) + + # MLP. + mlp_output = self.mlp(ln_mlp) + + intermediate = attn_output + mlp_output + + # Only reduce once and after the addition instead of once per layer + if self.process_group is not None: + torch.distributed.all_reduce(intermediate, group=self.process_group) + + return intermediate, residual + + +class FlashRWPreTrainedModel(PreTrainedModel): + config_class = RWConfig + + +class FlashRWModel(FlashRWPreTrainedModel): + def __init__(self, config, process_group=None): + super().__init__(config) + self.config = config + + self.tp_embeddings = False + if process_group is not None: + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + if config.vocab_size % self.tp_world_size == 0: + self.tp_embeddings = True + + if self.tp_embeddings: + self.word_embeddings = TensorParallelEmbedding( + config.vocab_size, config.hidden_size, process_group=process_group + ) + else: + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + + if config.model_type == "RefinedWebModel": + self.h = nn.ModuleList( + [ + FlashRWLayer( + config.n_head, + config.n_head_kv, + config.hidden_size, + config.bias, + config.layer_norm_epsilon, + config.parallel_attn, + process_group, + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.cache_size = ( + 2, + self.h[0].self_attention.num_heads_kv, + self.h[0].self_attention.head_size, + ) + elif config.model_type == "RefinedWeb": + self.h = nn.ModuleList( + [ + FlashRWLargeLayer( + config.n_head, + config.n_head_kv, + config.hidden_size, + config.bias, + config.layer_norm_epsilon, + process_group, + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.cache_size = ( + self.h[0].self_attention.num_groups, + 2, + self.h[0].self_attention.head_size, + ) + else: + raise NotImplementedError( + f"model_type {config.model_type} is not supported." + ) + + self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + self.head_size = self.h[0].self_attention.head_size + + def post_load_weights(self, quantize: Optional[str] = None): + if isinstance(self.word_embeddings, TensorParallelEmbedding): + self.word_embeddings.add_null_idx() + for layer in self.h: + layer: FlashRWLayer + layer.self_attention.query_key_value.prepare_weights(quantize) + layer.self_attention.dense.prepare_weights(quantize) + layer.mlp.dense_h_to_4h.prepare_weights(quantize) + layer.mlp.dense_4h_to_h.prepare_weights(quantize) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + # Pop here as we will replace the layer in our own logic and don't want from_pretrained + # to do it for us + load_in_8bit = kwargs.pop("load_in_8bit", False) + model = super(FlashRWModel, cls).from_pretrained( + pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs + ) + + model.post_load_weights("bitsandbytes" if load_in_8bit else None) + return model + + def forward( + self, + input_ids, + position_ids, + cu_seqlens, + cu_seqlens_q, + max_s, + past_key_values=None, + pre_allocate_past_size: Optional[int] = None, + ): + hidden_states = self.word_embeddings(input_ids) + + # Prefill + if past_key_values is None: + # Create past tensor + past_key_values = hidden_states.new_empty( + ( + len(self.h), + len(hidden_states) + if pre_allocate_past_size is None + else pre_allocate_past_size, + *self.cache_size, + ) + ) + layer_past_present_indices = None + slice_past_index = len(hidden_states) + # Decode + else: + # Create indices from cumulative sequence lengths + layer_past_present_indices = cu_seqlens[1:] - 1 + slice_past_index = None + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.h): + # We added padding that we now need to slice + layer_past_key_values = ( + past_key_values[i] + if slice_past_index is None + else past_key_values[i, :slice_past_index] + ) + + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlens, + max_s, + layer_past_key_values, + layer_past_present_indices, + cu_seqlens_q, + ) + + hidden_states, _ = self.ln_f(hidden_states, residual) + + return hidden_states, past_key_values + + +class FlashRWForCausalLM(FlashRWPreTrainedModel): + def __init__(self, config, process_group=None): + super().__init__(config) + + self.process_group = process_group + if self.process_group is not None: + self.world_size = self.process_group.size() + else: + self.world_size = 1 + + self.transformer = FlashRWModel(config, process_group) + + if self.transformer.tp_embeddings: + self.lm_head = FastLinear( + config.hidden_size, + config.vocab_size // process_group.size(), + bias=False, + ) + else: + self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) + + def post_load_weights(self, quantize: Optional[str] = None): + self.transformer.post_load_weights(quantize) + self.lm_head.prepare_weights() + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + # Pop here as we will replace the layer in our own logic and don't want from_pretrained + # to do it for us + load_in_8bit = kwargs.pop("load_in_8bit", False) + model = super(FlashRWForCausalLM, cls).from_pretrained( + pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs + ) + model.post_load_weights("bitsandbytes" if load_in_8bit else None) + return model + + def forward( + self, + input_ids, + position_ids, + cu_seqlens, + cu_seqlens_q, + max_s, + past_key_values: Optional[torch.Tensor] = None, + pre_allocate_past_size: Optional[int] = None, + ): + hidden_states, present = self.transformer( + input_ids, + position_ids, + cu_seqlens, + cu_seqlens_q, + max_s, + past_key_values, + pre_allocate_past_size, + ) + logits = self.lm_head(hidden_states) + + if self.transformer.tp_embeddings: + # Logits are sharded, so we need to gather them + world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] + torch.distributed.all_gather(world_logits, logits, group=self.process_group) + world_logits = torch.cat(world_logits, dim=1) + + return world_logits, present + return logits, present diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py new file mode 100644 index 00000000..44915ff5 --- /dev/null +++ b/server/text_generation_server/models/flash_rw.py @@ -0,0 +1,244 @@ +import torch +import torch.distributed + +from pathlib import Path +from accelerate import init_empty_weights +from opentelemetry import trace +from safetensors import safe_open +from transformers import AutoTokenizer, AutoConfig +from typing import Optional, List + +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_rw_modeling import ( + RWConfig, + FlashRWForCausalLM, + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelColumnLinear, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + download_weights, + weight_hub_files, + LocalEntryNotFoundError, +) + +tracer = trace.get_tracer(__name__) + + +class FlashRW(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + trust_remote_code: bool = False, + ): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 + else: + raise NotImplementedError("RW is only available on GPU") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = RWConfig.from_pretrained( + model_id, + revision=revision, + ) + + # We do not use from_pretrained as it is too slow + try: + filenames = weight_files(model_id, revision, ".bin") + # Local files not found + except LocalEntryNotFoundError: + hub_files = weight_hub_files(model_id, revision, ".bin") + filenames = download_weights(hub_files, model_id, revision) + + with init_empty_weights(): + model = FlashRWForCausalLM(config) + + self.load_weights( + model, + filenames, + quantize, + device, + dtype, + ) + + super(FlashCausalLM, self).__init__( + model=model.to(device), + tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, + device=device, + ) + + @staticmethod + def load_weights( + model: FlashRWForCausalLM, + filenames: List[Path], + quantize: Optional[str], + device: torch.device, + dtype: torch.dtype, + ): + for filename in filenames: + state_dict = torch.load(filename, map_location="cpu") + for key, value in state_dict.items(): + value = value.to(device if quantize is None else "cpu").to(dtype) + + module_name, param_name = key.rsplit(".", 1) + module = model.get_submodule(module_name) + + try: + current_parameter_tensor = module._parameters[param_name] + if current_parameter_tensor.shape != value.shape: + raise ValueError( + f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}" + ) + module._parameters[param_name] = value + except KeyError: + module._buffers[param_name] = value + + del value + + torch.cuda.empty_cache() + model.post_load_weights(quantize) + + +class FlashRWSharded(FlashRW): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + trust_remote_code: bool = False, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.bfloat16 + else: + raise NotImplementedError("FlashRW is only available on GPU") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = RWConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + + with init_empty_weights(): + model = FlashRWForCausalLM(config, self.process_group) + + torch.distributed.barrier(group=self.process_group) + self.load_weights( + model, + filenames, + quantize=quantize, + device=device, + dtype=dtype, + rank=rank, + world_size=world_size, + ) + torch.distributed.barrier(group=self.process_group) + super(FlashCausalLM, self).__init__( + model=model.to(device), + tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + + @staticmethod + def load_weights( + model, + filenames: List[str], + quantize: Optional[str], + device: torch.device, + dtype: torch.dtype, + rank: int, + world_size: int, + ): + parameters = dict(model.named_parameters()) + for file in filenames: + with safe_open( + file, framework="pt", device=str(device) if quantize is None else "cpu" + ) as f: + for name in f.keys(): + module_name, param_name = name.rsplit(".", 1) + module = model.get_submodule(module_name) + + current_parameter_tensor = parameters.get(name, None) + + slice_ = f.get_slice(name) + + if isinstance(module, TensorParallelColumnLinear): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif isinstance(module, TensorParallelRowLinear): + if param_name == "weight": + size = slice_.get_shape()[1] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[:, start:stop] + else: + tensor = slice_[:] + # XXX: Hack for Rowlinear to add the bias only once. + if rank != 0: + tensor = torch.zeros_like(tensor) + elif isinstance(module, TensorParallelEmbedding): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif name == "lm_head.weight" and model.transformer.tp_embeddings: + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + else: + try: + tensor = slice_[:] + except: + tensor = f.get_tensor(name) + + if ( + current_parameter_tensor is not None + and current_parameter_tensor.shape != tensor.shape + ): + raise ValueError( + f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" + ) + + tensor = tensor.contiguous().to(dtype) + + if current_parameter_tensor is not None: + module._parameters[param_name] = tensor + else: + module._buffers[param_name] = tensor + + model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py new file mode 100644 index 00000000..dd389027 --- /dev/null +++ b/server/text_generation_server/models/rw.py @@ -0,0 +1,86 @@ +import torch + +from transformers import AutoTokenizer, AutoModelForCausalLM +from typing import List, Optional, Tuple + +from text_generation_server.models import CausalLM + + +class RW(CausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + trust_remote_code: bool = False, + ): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + model = AutoModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + device_map="auto" + if torch.cuda.is_available() and torch.cuda.device_count() > 1 + else None, + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + ) + if torch.cuda.is_available() and torch.cuda.device_count() == 1: + model = model.cuda() + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None: + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + super(CausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) + + def forward( + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + # Model Forward + if past_key_values is not None: + reshaped_past_key_values = [] + for layer in past_key_values: + past_keys, past_values = layer + reshaped_past_key_values.append( + ( + past_keys.view(-1, *past_keys.shape[-2:]), + past_values.view(-1, *past_values.shape[-2:]), + ) + ) + past_key_values = reshaped_past_key_values + + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + ) + return outputs.logits, outputs.past_key_values diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 7605639d..127f9ba4 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -262,16 +262,13 @@ try: sin = torch.index_select(self._sin_cached, 0, position_ids) return cos.unsqueeze(1), sin.unsqueeze(1) - def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): rotary_dim = cos.shape[-1] - q1 = qkv[:, 0, :, :rotary_dim] - q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] - k1 = qkv[:, 1, :, :rotary_dim] - k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] + x1 = x[..., :rotary_dim] + x2 = x[..., rotary_dim : 2 * rotary_dim] - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - return qkv + rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) + return x except ImportError: pass