misc(offline): update model creation as std::shared_ptr

This commit is contained in:
Morgan Funtowicz 2024-11-28 17:45:22 +01:00
parent 9d659f1e23
commit 6c5a75b593
1 changed files with 5 additions and 4 deletions

View File

@ -5,7 +5,7 @@
#include <llama.h>
#include <spdlog/spdlog.h>
#include <spdlog/fmt/ranges.h>s
#include <spdlog/fmt/ranges.h>
#include "../csrc/backend.hpp"
using namespace huggingface::tgi::backends::llamacpp;
@ -22,8 +22,9 @@ int main(int argc, char **argv) {
const auto modelPath = absolute(std::filesystem::path(argv[1]));
const auto params = llama_model_default_params();
auto model = std::unique_ptr<llama_model, decltype(llama_model_deleter)>(
llama_load_model_from_file(modelPath.c_str(), params)
auto model = std::shared_ptr<llama_model>(
llama_load_model_from_file(modelPath.c_str(), params),
llama_model_deleter
);
auto prompt = "My name is Morgan";
@ -31,7 +32,7 @@ int main(int argc, char **argv) {
const auto nb_tokens = llama_tokenize(model.get(), prompt, sizeof(prompt), tokens.data(), tokens.size(), true,
false);
tokens.resize(nb_tokens);
auto backend = worker_t{std::move(model), {.n_batch = 1, .n_threads = 4}};
auto backend = worker_t(std::move(model), {.n_batch = 1, .n_threads = 4});
fmt::println("Tokenized: {}", tokens);