hf_text-generation-inference/backends/llamacpp/offline/main.cpp

39 lines
1.1 KiB
C++
Raw Normal View History

2024-10-04 02:42:31 -06:00
//
// Created by mfuntowicz on 10/3/24.
//
#include <fmt/color.h>
2024-10-04 02:42:31 -06:00
#include <fmt/format.h>
#include <fmt/std.h>
#include <fmt/ranges.h>
2024-10-04 02:42:31 -06:00
#include <spdlog/spdlog.h>
#include "../csrc/backend.hpp"
using namespace huggingface::tgi::backends::llamacpp;
2024-10-26 14:24:05 -06:00
int main(int argc, char **argv) {
if (argc < 2) {
2024-10-04 02:42:31 -06:00
fmt::print("No model folder provider");
return 1;
}
spdlog::set_level(spdlog::level::debug);
2024-10-31 10:52:18 -06:00
const auto modelPath = absolute(std::filesystem::path(argv[1]));
2024-10-30 15:40:49 -06:00
const auto params = llama_model_default_params();
auto *model = llama_load_model_from_file(modelPath.c_str(), params);
auto backend = single_worker_backend_t(model, {});
// generate
2024-10-31 10:52:18 -06:00
const auto promptTokens = {128000, 5159, 836, 374, 23809, 11};
2024-10-30 15:40:49 -06:00
const auto out = backend.generate(promptTokens, {.max_new_tokens = 32}, {.top_k = 40});
if (out.has_value())
fmt::print(FMT_STRING("Generated: {}"), *out);
else {
const auto err = out.error();
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Got an error: {:d}", static_cast<uint8_t>(err));
}
}