Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import os | |
| import ctypes | |
| from typing import ( | |
| List, | |
| Optional, | |
| Sequence, | |
| ) | |
| from dataclasses import dataclass, field | |
| import numpy as np | |
| import numpy.typing as npt | |
| from .llama_types import * | |
| from .llama_grammar import LlamaGrammar | |
| from ._utils import suppress_stdout_stderr | |
| import llama_cpp.llama_cpp as llama_cpp | |
| # Python wrappers over llama.h structs | |
| class _LlamaModel: | |
| """Intermediate Python wrapper for a llama.cpp llama_model. | |
| NOTE: For stability it's recommended you use the Llama class instead.""" | |
| _llama_free_model = None | |
| # NOTE: this must be "saved" here to avoid exceptions when calling __del__ | |
| def __init__( | |
| self, | |
| *, | |
| path_model: str, | |
| params: llama_cpp.llama_model_params, | |
| verbose: bool = True, | |
| ): | |
| self.path_model = path_model | |
| self.params = params | |
| self.verbose = verbose | |
| self._llama_free_model = llama_cpp._lib.llama_free_model # type: ignore | |
| self.model = None | |
| if not os.path.exists(path_model): | |
| raise ValueError(f"Model path does not exist: {path_model}") | |
| with suppress_stdout_stderr(disable=verbose): | |
| self.model = llama_cpp.llama_load_model_from_file( | |
| self.path_model.encode("utf-8"), self.params | |
| ) | |
| if self.model is None: | |
| raise ValueError(f"Failed to load model from file: {path_model}") | |
| def __del__(self): | |
| if self.model is not None and self._llama_free_model is not None: | |
| self._llama_free_model(self.model) | |
| self.model = None | |
| def vocab_type(self) -> int: | |
| assert self.model is not None | |
| return llama_cpp.llama_vocab_type(self.model) | |
| def n_vocab(self) -> int: | |
| assert self.model is not None | |
| return llama_cpp.llama_n_vocab(self.model) | |
| def n_ctx_train(self) -> int: | |
| assert self.model is not None | |
| return llama_cpp.llama_n_ctx_train(self.model) | |
| def n_embd(self) -> int: | |
| assert self.model is not None | |
| return llama_cpp.llama_n_embd(self.model) | |
| def rope_freq_scale_train(self) -> float: | |
| assert self.model is not None | |
| return llama_cpp.llama_rope_freq_scale_train(self.model) | |
| def desc(self) -> str: | |
| assert self.model is not None | |
| buf = ctypes.create_string_buffer(1024) | |
| llama_cpp.llama_model_desc(self.model, buf, 1024) | |
| return buf.value.decode("utf-8") | |
| def size(self) -> int: | |
| assert self.model is not None | |
| return llama_cpp.llama_model_size(self.model) | |
| def n_params(self) -> int: | |
| assert self.model is not None | |
| return llama_cpp.llama_model_n_params(self.model) | |
| def get_tensor(self, name: str) -> ctypes.c_void_p: | |
| assert self.model is not None | |
| return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8")) | |
| def apply_lora_from_file( | |
| self, | |
| lora_path: str, | |
| scale: float, | |
| path_base_model: Optional[str], | |
| n_threads: int, | |
| ): | |
| assert self.model is not None | |
| return llama_cpp.llama_model_apply_lora_from_file( | |
| self.model, | |
| lora_path.encode("utf-8"), | |
| scale, | |
| path_base_model.encode("utf-8") | |
| if path_base_model is not None | |
| else ctypes.c_char_p(0), | |
| n_threads, | |
| ) | |
| # Vocab | |
| def token_get_text(self, token: int) -> str: | |
| # TODO: Fix | |
| assert self.model is not None | |
| return llama_cpp.llama_token_get_text(self.model, token).decode("utf-8") | |
| def token_get_score(self, token: int) -> float: | |
| assert self.model is not None | |
| return llama_cpp.llama_token_get_score(self.model, token) | |
| def token_get_type(self, token: int) -> int: | |
| assert self.model is not None | |
| return llama_cpp.llama_token_get_type(self.model, token) | |
| # Special tokens | |
| def token_bos(self) -> int: | |
| assert self.model is not None | |
| return llama_cpp.llama_token_bos(self.model) | |
| def token_eos(self) -> int: | |
| assert self.model is not None | |
| return llama_cpp.llama_token_eos(self.model) | |
| def token_nl(self) -> int: | |
| assert self.model is not None | |
| return llama_cpp.llama_token_nl(self.model) | |
| def token_prefix(self) -> int: | |
| assert self.model is not None | |
| return llama_cpp.llama_token_prefix(self.model) | |
| def token_middle(self) -> int: | |
| assert self.model is not None | |
| return llama_cpp.llama_token_middle(self.model) | |
| def token_suffix(self) -> int: | |
| assert self.model is not None | |
| return llama_cpp.llama_token_suffix(self.model) | |
| def token_eot(self) -> int: | |
| assert self.model is not None | |
| return llama_cpp.llama_token_eot(self.model) | |
| # Tokenization | |
| def tokenize(self, text: bytes, add_bos: bool, special: bool): | |
| assert self.model is not None | |
| n_ctx = self.n_ctx_train() | |
| tokens = (llama_cpp.llama_token * n_ctx)() | |
| n_tokens = llama_cpp.llama_tokenize( | |
| self.model, text, len(text), tokens, n_ctx, add_bos, special | |
| ) | |
| if n_tokens < 0: | |
| n_tokens = abs(n_tokens) | |
| tokens = (llama_cpp.llama_token * n_tokens)() | |
| n_tokens = llama_cpp.llama_tokenize( | |
| self.model, text, len(text), tokens, n_tokens, add_bos, special | |
| ) | |
| if n_tokens < 0: | |
| raise RuntimeError( | |
| f'Failed to tokenize: text="{text}" n_tokens={n_tokens}' | |
| ) | |
| return list(tokens[:n_tokens]) | |
| def token_to_piece(self, token: int, special: bool = False) -> bytes: | |
| assert self.model is not None | |
| buf = ctypes.create_string_buffer(32) | |
| llama_cpp.llama_token_to_piece(self.model, token, buf, 32, special) | |
| return bytes(buf) | |
| def detokenize(self, tokens: List[int], special: bool = False) -> bytes: | |
| assert self.model is not None | |
| output = b"" | |
| size = 32 | |
| buffer = (ctypes.c_char * size)() | |
| for token in tokens: | |
| n = llama_cpp.llama_token_to_piece( | |
| self.model, llama_cpp.llama_token(token), buffer, size, special | |
| ) | |
| assert n <= size | |
| output += bytes(buffer[:n]) | |
| # NOTE: Llama1 models automatically added a space at the start of the prompt | |
| # this line removes a leading space if the first token is a beginning of sentence token | |
| return ( | |
| output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() and output[0:1] == b' ' else output | |
| ) | |
| # Extra | |
| def metadata(self) -> Dict[str, str]: | |
| assert self.model is not None | |
| metadata: Dict[str, str] = {} | |
| buffer_size = 1024 | |
| buffer = ctypes.create_string_buffer(buffer_size) | |
| # zero the buffer | |
| buffer.value = b'\0' * buffer_size | |
| # iterate over model keys | |
| for i in range(llama_cpp.llama_model_meta_count(self.model)): | |
| nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size) | |
| if nbytes > buffer_size: | |
| buffer_size = nbytes + 1 | |
| buffer = ctypes.create_string_buffer(buffer_size) | |
| nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size) | |
| key = buffer.value.decode("utf-8") | |
| nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size) | |
| if nbytes > buffer_size: | |
| buffer_size = nbytes + 1 | |
| buffer = ctypes.create_string_buffer(buffer_size) | |
| nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size) | |
| value = buffer.value.decode("utf-8") | |
| metadata[key] = value | |
| return metadata | |
| def default_params(): | |
| """Get the default llama_model_params.""" | |
| return llama_cpp.llama_model_default_params() | |
| class _LlamaContext: | |
| """Intermediate Python wrapper for a llama.cpp llama_context. | |
| NOTE: For stability it's recommended you use the Llama class instead.""" | |
| _llama_free = None | |
| def __init__( | |
| self, | |
| *, | |
| model: _LlamaModel, | |
| params: llama_cpp.llama_context_params, | |
| verbose: bool = True, | |
| ): | |
| self.model = model | |
| self.params = params | |
| self.verbose = verbose | |
| self._llama_free = llama_cpp._lib.llama_free # type: ignore | |
| self.ctx = None | |
| assert self.model.model is not None | |
| self.ctx = llama_cpp.llama_new_context_with_model( | |
| self.model.model, self.params | |
| ) | |
| if self.ctx is None: | |
| raise ValueError("Failed to create llama_context") | |
| def __del__(self): | |
| if self.ctx is not None and self._llama_free is not None: | |
| self._llama_free(self.ctx) | |
| self.ctx = None | |
| def n_ctx(self) -> int: | |
| assert self.ctx is not None | |
| return llama_cpp.llama_n_ctx(self.ctx) | |
| def pooling_type(self) -> int: | |
| assert self.ctx is not None | |
| return llama_cpp.llama_pooling_type(self.ctx) | |
| def kv_cache_clear(self): | |
| assert self.ctx is not None | |
| llama_cpp.llama_kv_cache_clear(self.ctx) | |
| def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int): | |
| assert self.ctx is not None | |
| llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1) | |
| def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int): | |
| assert self.ctx is not None | |
| llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1) | |
| def kv_cache_seq_keep(self, seq_id: int): | |
| assert self.ctx is not None | |
| llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id) | |
| def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int): | |
| assert self.ctx is not None | |
| llama_cpp.llama_kv_cache_seq_add(self.ctx, seq_id, p0, p1, shift) | |
| def get_state_size(self) -> int: | |
| assert self.ctx is not None | |
| return llama_cpp.llama_get_state_size(self.ctx) | |
| # TODO: copy_state_data | |
| # TODO: set_state_data | |
| # TODO: llama_load_session_file | |
| # TODO: llama_save_session_file | |
| def decode(self, batch: "_LlamaBatch"): | |
| assert self.ctx is not None | |
| assert batch.batch is not None | |
| return_code = llama_cpp.llama_decode( | |
| self.ctx, | |
| batch.batch, | |
| ) | |
| if return_code != 0: | |
| raise RuntimeError(f"llama_decode returned {return_code}") | |
| def set_n_threads(self, n_threads: int, n_threads_batch: int): | |
| assert self.ctx is not None | |
| llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch) | |
| def get_logits(self): | |
| assert self.ctx is not None | |
| return llama_cpp.llama_get_logits(self.ctx) | |
| def get_logits_ith(self, i: int): | |
| assert self.ctx is not None | |
| return llama_cpp.llama_get_logits_ith(self.ctx, i) | |
| def get_embeddings(self): | |
| assert self.ctx is not None | |
| return llama_cpp.llama_get_embeddings(self.ctx) | |
| # Sampling functions | |
| def set_rng_seed(self, seed: int): | |
| assert self.ctx is not None | |
| llama_cpp.llama_set_rng_seed(self.ctx, seed) | |
| def sample_repetition_penalties( | |
| self, | |
| candidates: "_LlamaTokenDataArray", | |
| last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]", | |
| penalty_last_n: int, | |
| penalty_repeat: float, | |
| penalty_freq: float, | |
| penalty_present: float, | |
| ): | |
| assert self.ctx is not None | |
| llama_cpp.llama_sample_repetition_penalties( | |
| self.ctx, | |
| llama_cpp.byref(candidates.candidates), | |
| last_tokens_data, | |
| penalty_last_n, | |
| penalty_repeat, | |
| penalty_freq, | |
| penalty_present, | |
| ) | |
| def sample_softmax(self, candidates: "_LlamaTokenDataArray"): | |
| assert self.ctx is not None | |
| llama_cpp.llama_sample_softmax( | |
| self.ctx, | |
| llama_cpp.byref(candidates.candidates), | |
| ) | |
| def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int): | |
| assert self.ctx is not None | |
| llama_cpp.llama_sample_top_k( | |
| self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep | |
| ) | |
| def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): | |
| assert self.ctx is not None | |
| llama_cpp.llama_sample_top_p( | |
| self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep | |
| ) | |
| def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): | |
| assert self.ctx is not None | |
| llama_cpp.llama_sample_min_p( | |
| self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep | |
| ) | |
| def sample_tail_free( | |
| self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int | |
| ): | |
| assert self.ctx is not None | |
| llama_cpp.llama_sample_tail_free( | |
| self.ctx, llama_cpp.byref(candidates.candidates), z, min_keep | |
| ) | |
| def sample_typical( | |
| self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int | |
| ): | |
| assert self.ctx is not None | |
| llama_cpp.llama_sample_typical( | |
| self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep | |
| ) | |
| def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float): | |
| assert self.ctx is not None | |
| llama_cpp.llama_sample_temp( | |
| self.ctx, llama_cpp.byref(candidates.candidates), temp | |
| ) | |
| def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar): | |
| assert self.ctx is not None | |
| assert grammar.grammar is not None | |
| llama_cpp.llama_sample_grammar( | |
| self.ctx, | |
| llama_cpp.byref(candidates.candidates), | |
| grammar.grammar, | |
| ) | |
| def sample_token_mirostat( | |
| self, | |
| candidates: "_LlamaTokenDataArray", | |
| tau: float, | |
| eta: float, | |
| m: int, | |
| mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float], | |
| ) -> int: | |
| assert self.ctx is not None | |
| return llama_cpp.llama_sample_token_mirostat( | |
| self.ctx, | |
| llama_cpp.byref(candidates.candidates), | |
| tau, | |
| eta, | |
| m, | |
| mu, | |
| ) | |
| def sample_token_mirostat_v2( | |
| self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float] | |
| ) -> int: | |
| assert self.ctx is not None | |
| return llama_cpp.llama_sample_token_mirostat_v2( | |
| self.ctx, | |
| llama_cpp.byref(candidates.candidates), | |
| tau, | |
| eta, | |
| mu, | |
| ) | |
| def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int: | |
| assert self.ctx is not None | |
| return llama_cpp.llama_sample_token_greedy( | |
| self.ctx, | |
| llama_cpp.byref(candidates.candidates), | |
| ) | |
| def sample_token(self, candidates: "_LlamaTokenDataArray") -> int: | |
| assert self.ctx is not None | |
| return llama_cpp.llama_sample_token( | |
| self.ctx, | |
| llama_cpp.byref(candidates.candidates), | |
| ) | |
| # Grammar | |
| def grammar_accept_token(self, grammar: LlamaGrammar, token: int): | |
| assert self.ctx is not None | |
| assert grammar.grammar is not None | |
| llama_cpp.llama_grammar_accept_token(self.ctx, grammar.grammar, token) | |
| def reset_timings(self): | |
| assert self.ctx is not None | |
| llama_cpp.llama_reset_timings(self.ctx) | |
| def print_timings(self): | |
| assert self.ctx is not None | |
| llama_cpp.llama_print_timings(self.ctx) | |
| # Utility functions | |
| def default_params(): | |
| """Get the default llama_context_params.""" | |
| return llama_cpp.llama_context_default_params() | |
| class _LlamaBatch: | |
| _llama_batch_free = None | |
| def __init__( | |
| self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True | |
| ): | |
| self._n_tokens = n_tokens | |
| self.embd = embd | |
| self.n_seq_max = n_seq_max | |
| self.verbose = verbose | |
| self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore | |
| self.batch = None | |
| self.batch = llama_cpp.llama_batch_init( | |
| self._n_tokens, self.embd, self.n_seq_max | |
| ) | |
| def __del__(self): | |
| if self.batch is not None and self._llama_batch_free is not None: | |
| self._llama_batch_free(self.batch) | |
| self.batch = None | |
| def n_tokens(self) -> int: | |
| assert self.batch is not None | |
| return self.batch.n_tokens | |
| def reset(self): | |
| assert self.batch is not None | |
| self.batch.n_tokens = 0 | |
| def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): | |
| assert self.batch is not None | |
| n_tokens = len(batch) | |
| self.batch.n_tokens = n_tokens | |
| for i in range(n_tokens): | |
| self.batch.token[i] = batch[i] | |
| self.batch.pos[i] = n_past + i | |
| self.batch.seq_id[i][0] = 0 | |
| self.batch.n_seq_id[i] = 1 | |
| self.batch.logits[i] = logits_all | |
| self.batch.logits[n_tokens - 1] = True | |
| def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool): | |
| assert self.batch is not None | |
| n_tokens = len(batch) | |
| n_tokens0 = self.batch.n_tokens | |
| self.batch.n_tokens += n_tokens | |
| for i in range(n_tokens): | |
| j = n_tokens0 + i | |
| self.batch.token[j] = batch[i] | |
| self.batch.pos[j] = i | |
| self.batch.seq_id[j][0] = seq_id | |
| self.batch.n_seq_id[j] = 1 | |
| self.batch.logits[j] = logits_all | |
| self.batch.logits[n_tokens - 1] = True | |
| class _LlamaTokenDataArray: | |
| def __init__(self, *, n_vocab: int): | |
| self.n_vocab = n_vocab | |
| self.candidates_data = np.array( | |
| [], | |
| dtype=np.dtype( | |
| [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True | |
| ), | |
| ) | |
| self.candidates_data.resize(3, self.n_vocab, refcheck=False) | |
| self.candidates = llama_cpp.llama_token_data_array( | |
| data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p), | |
| size=self.n_vocab, | |
| sorted=False, | |
| ) | |
| self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore | |
| self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single) | |
| def copy_logits(self, logits: npt.NDArray[np.single]): | |
| self.candidates_data["id"][:] = self.default_candidates_data_id | |
| self.candidates_data["logit"][:] = logits | |
| self.candidates_data["p"][:] = self.default_candidates_data_p | |
| self.candidates.data = self.candidates_data.ctypes.data_as( | |
| llama_cpp.llama_token_data_p | |
| ) | |
| self.candidates.sorted = ctypes.c_bool(False) | |
| self.candidates.size = ctypes.c_size_t(self.n_vocab) | |
| # Python wrappers over common/common | |
| def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> list[int]: | |
| assert model.model is not None | |
| n_tokens = len(text) + 1 if add_bos else len(text) | |
| result = (llama_cpp.llama_token * n_tokens)() | |
| n_tokens = llama_cpp.llama_tokenize( | |
| model.model, | |
| text.encode("utf-8"), | |
| len(text), | |
| result, | |
| n_tokens, | |
| add_bos, | |
| special, | |
| ) | |
| if n_tokens < 0: | |
| result = (llama_cpp.llama_token * -n_tokens)() | |
| check = llama_cpp.llama_tokenize( | |
| model.model, | |
| text.encode("utf-8"), | |
| len(text), | |
| result, | |
| len(result), | |
| add_bos, | |
| special, | |
| ) | |
| if check != -n_tokens: | |
| raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}') | |
| else: | |
| result = result[:n_tokens] | |
| return list(result) | |
| def _token_to_piece(model: _LlamaModel, token: int, special: bool = False) -> str: | |
| assert model.model is not None | |
| result = (ctypes.c_char * 8)(0) | |
| n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result), special) | |
| if n_tokens < 0: | |
| result = (ctypes.c_char * -n_tokens)(0) | |
| check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result), special) | |
| if check != -n_tokens: | |
| raise RuntimeError(f"Failed to get piece: token={token}") | |
| else: | |
| result = result[:n_tokens] | |
| return bytes(result).decode("utf-8") | |
| def _detokenize_spm(model: _LlamaModel, tokens: List[int]) -> str: | |
| bos_id = model.token_bos() | |
| result = "" | |
| for i, token in enumerate(tokens): | |
| piece = _token_to_piece(model, token) | |
| if ( | |
| (tokens[0] == bos_id and i == 1) or (tokens[0] != bos_id and i == 0) | |
| ) and piece[0] == " ": | |
| piece = piece[1:] | |
| result += piece | |
| return result | |
| def _detokenize_bpe(model: _LlamaModel, tokens: List[int]) -> str: | |
| result = "" | |
| for token in tokens: | |
| piece = _token_to_piece(model, token) | |
| result += piece | |
| return result | |
| def _should_add_bos(model: _LlamaModel) -> bool: | |
| assert model.model is not None | |
| add_bos = llama_cpp.llama_add_bos_token(model.model) | |
| if add_bos != -1: | |
| return add_bos != 0 | |
| else: | |
| return llama_cpp.llama_vocab_type(model.model) == llama_cpp.LLAMA_VOCAB_TYPE_SPM | |
| # Embedding functions | |
| def _normalize_embedding(embedding): | |
| norm = float(np.linalg.norm(embedding)) | |
| if norm == 0.0: | |
| return embedding | |
| return [v / norm for v in embedding] | |
| # Python wrappers over common/sampling structs | |
| class _LlamaSamplingParams: | |
| n_prev: int = 64 | |
| n_probs: int = 0 | |
| top_k: int = 40 | |
| top_p: float = 0.95 | |
| min_p: float = 0.05 | |
| tfs_z: float = 1.00 | |
| typical_p: float = 1.00 | |
| temp: float = 0.80 | |
| penalty_last_n: int = 64 | |
| penalty_repeat: float = 1.10 | |
| penalty_freq: float = 0.00 | |
| penalty_present: float = 0.00 | |
| mirostat: int = 0 | |
| mirostat_tau: float = 5.00 | |
| mirostat_eta: float = 0.10 | |
| penalize_nl: bool = True | |
| grammar: str = "" | |
| cfg_negative_prompt: str = "" | |
| cfg_scale: float = 1.00 | |
| logit_bias: dict[int, float] = field(default_factory=dict) | |
| class _LlamaSamplingContext: | |
| params: _LlamaSamplingParams = field(default_factory=_LlamaSamplingParams) | |
| mirostat_mu: ctypes.c_float = field(default_factory=ctypes.c_float) | |
| grammar: Optional[LlamaGrammar] = None | |
| # NOTE: Missing parsed_grammar | |
| prev: list[int] = field(default_factory=list) | |
| cur: list[llama_cpp.llama_token_data] = field(default_factory=list) | |
| def reset(self): | |
| self.prev = [] | |
| self.cur = [] | |
| if self.grammar is not None: | |
| self.grammar.reset() | |
| def cp(self): | |
| return _LlamaSamplingContext( | |
| params=self.params, | |
| mirostat_mu=self.mirostat_mu, | |
| grammar=self.grammar, | |
| prev=self.prev.copy(), | |
| cur=self.cur.copy(), | |
| ) | |
| def last(self) -> Optional[int]: | |
| if len(self.prev) > 0: | |
| return self.prev[-1] | |
| else: | |
| return None | |
| def prev_str(self, ctx_main: _LlamaContext, n: int) -> str: | |
| return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8") | |
| def sample( | |
| self, ctx_main: _LlamaContext, idx: int = 0, logits_array: Optional[npt.NDArray[np.single]] = None | |
| ): | |
| n_vocab = ctx_main.model.n_vocab() | |
| id: int = 0 | |
| if logits_array is None: | |
| logits = ctx_main.get_logits_ith(idx) | |
| logits_array = np.array( | |
| ctypes.cast(logits, ctypes.POINTER(ctypes.c_float * n_vocab)).contents, | |
| dtype=np.single, | |
| ) | |
| # apply logit_bias | |
| for token, logit_bias in self.params.logit_bias.items(): | |
| logits_array[token] += logit_bias | |
| token_data_array = _LlamaTokenDataArray( | |
| n_vocab=n_vocab | |
| ) # TODO: Only create this once | |
| token_data_array.copy_logits(logits_array) | |
| # apply penalties | |
| if len(self.prev) > 0: | |
| nl_token = ctx_main.model.token_nl() | |
| nl_logit = logits_array[nl_token] | |
| last_tokens = self.prev[-self.params.penalty_last_n:] | |
| last_tokens_size = min(len(last_tokens), self.params.penalty_last_n) | |
| if last_tokens_size > 0: | |
| last_tokens_p = (llama_cpp.llama_token * len(last_tokens))(*last_tokens) | |
| ctx_main.sample_repetition_penalties( | |
| token_data_array, | |
| last_tokens_p, | |
| last_tokens_size, | |
| self.params.penalty_repeat, | |
| self.params.penalty_freq, | |
| self.params.penalty_present, | |
| ) | |
| if not self.params.penalize_nl: | |
| token_data_array.candidates_data["logit"][nl_token] = nl_logit | |
| if self.grammar is not None: | |
| ctx_main.sample_grammar(token_data_array, self.grammar) | |
| if self.params.temp < 0: | |
| ctx_main.sample_softmax(token_data_array) | |
| id = token_data_array.candidates_data["id"][0] | |
| elif self.params.temp == 0: | |
| id = ctx_main.sample_token_greedy(token_data_array) | |
| else: | |
| if self.params.mirostat == 1: | |
| mirostat_m = 100 | |
| ctx_main.sample_temp(token_data_array, self.params.temp) | |
| id = ctx_main.sample_token_mirostat( | |
| token_data_array, | |
| self.params.mirostat_tau, | |
| self.params.mirostat_eta, | |
| mirostat_m, | |
| ctypes.pointer(self.mirostat_mu), | |
| ) | |
| elif self.params.mirostat == 2: | |
| ctx_main.sample_temp(token_data_array, self.params.temp) | |
| id = ctx_main.sample_token_mirostat_v2( | |
| token_data_array, | |
| self.params.mirostat_tau, | |
| self.params.mirostat_eta, | |
| ctypes.pointer(self.mirostat_mu), | |
| ) | |
| else: | |
| min_keep = max(1, self.params.n_probs) | |
| ctx_main.sample_top_k( | |
| token_data_array, self.params.top_k, min_keep=min_keep | |
| ) | |
| ctx_main.sample_tail_free( | |
| token_data_array, self.params.tfs_z, min_keep=min_keep | |
| ) | |
| ctx_main.sample_typical( | |
| token_data_array, self.params.typical_p, min_keep=min_keep | |
| ) | |
| ctx_main.sample_top_p( | |
| token_data_array, self.params.top_p, min_keep=min_keep | |
| ) | |
| ctx_main.sample_min_p( | |
| token_data_array, self.params.min_p, min_keep=min_keep | |
| ) | |
| ctx_main.sample_temp(token_data_array, self.params.temp) | |
| id = ctx_main.sample_token(token_data_array) | |
| return id | |
| def accept(self, ctx_main: _LlamaContext, id: int, apply_grammar: bool): | |
| if apply_grammar and self.grammar is not None: | |
| ctx_main.grammar_accept_token(self.grammar, id) | |
| self.prev.append(id) | |