2024-10-04 02:42:31 -06:00
|
|
|
//
|
|
|
|
// Created by mfuntowicz on 10/3/24.
|
|
|
|
//
|
2024-11-09 14:10:33 -07:00
|
|
|
#include <memory>
|
2024-10-04 02:42:31 -06:00
|
|
|
|
2024-11-09 14:10:33 -07:00
|
|
|
#include <llama.h>
|
2024-10-04 02:42:31 -06:00
|
|
|
#include <spdlog/spdlog.h>
|
2024-11-09 14:10:33 -07:00
|
|
|
#include <spdlog/fmt/ranges.h>s
|
2024-10-04 02:42:31 -06:00
|
|
|
#include "../csrc/backend.hpp"
|
|
|
|
|
2024-10-24 08:42:50 -06:00
|
|
|
using namespace huggingface::tgi::backends::llamacpp;
|
|
|
|
|
2024-11-09 14:10:33 -07:00
|
|
|
const auto llama_model_deleter = [](llama_model *model) { llama_free_model(model); };
|
|
|
|
|
2024-10-26 14:24:05 -06:00
|
|
|
int main(int argc, char **argv) {
|
2024-10-22 16:09:10 -06:00
|
|
|
if (argc < 2) {
|
2024-10-04 02:42:31 -06:00
|
|
|
fmt::print("No model folder provider");
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
spdlog::set_level(spdlog::level::debug);
|
2024-11-09 14:10:33 -07:00
|
|
|
|
2024-10-22 07:22:56 -06:00
|
|
|
const auto modelPath = absolute(std::filesystem::path(argv[1]));
|
2024-10-30 15:40:49 -06:00
|
|
|
const auto params = llama_model_default_params();
|
2024-11-09 14:10:33 -07:00
|
|
|
auto model = std::unique_ptr<llama_model, decltype(llama_model_deleter)>(
|
|
|
|
llama_load_model_from_file(modelPath.c_str(), params)
|
|
|
|
);
|
2024-10-30 15:40:49 -06:00
|
|
|
|
2024-11-09 14:10:33 -07:00
|
|
|
auto prompt = "My name is Morgan";
|
|
|
|
auto tokens = std::vector<llama_token>(16);
|
|
|
|
const auto nb_tokens = llama_tokenize(model.get(), prompt, sizeof(prompt), tokens.data(), tokens.size(), true,
|
|
|
|
false);
|
|
|
|
tokens.resize(nb_tokens);
|
|
|
|
auto backend = worker_t{std::move(model), {.n_batch = 1, .n_threads = 4}};
|
|
|
|
|
|
|
|
fmt::println("Tokenized: {}", tokens);
|
2024-10-30 15:40:49 -06:00
|
|
|
|
|
|
|
// generate
|
2024-11-09 14:10:33 -07:00
|
|
|
auto generated_tokens = std::vector<llama_token>(32);
|
|
|
|
const auto n_generated_tokens = backend.generate(
|
|
|
|
{{.max_new_tokens = 32}, {.top_k = 40}, tokens},
|
|
|
|
[&generated_tokens](llama_token new_token_id, float_t logit, bool is_eos, size_t step) -> bool {
|
|
|
|
generated_tokens.emplace(generated_tokens.begin() + (step - 1), new_token_id);
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
);
|
|
|
|
generated_tokens.resize(n_generated_tokens.value());
|
|
|
|
fmt::println("Generated {} tokens", generated_tokens);
|
2024-10-22 16:09:10 -06:00
|
|
|
}
|