akra35567 commited on
Commit
749c34d
·
1 Parent(s): 5c0ad68

Update modules/local_llm.py

Browse files
Files changed (1) hide show
  1. modules/local_llm.py +57 -129
modules/local_llm.py CHANGED
@@ -1,138 +1,66 @@
1
- # modules/api.py
2
- import time
3
- import re
4
- import datetime
5
- from typing import Dict, Optional, Any, List
6
- from flask import Flask, Blueprint, request, jsonify
7
  from loguru import logger
8
- import google.generativeai as genai
9
- from mistralai import Mistral
10
- from .local_llm import LlamaLLM
11
- from .contexto import Contexto
12
- from .database import Database
13
- from .treinamento import Treinamento
14
- from .exemplos_naturais import ExemplosNaturais
15
- import modules.config as config
16
 
 
 
 
 
17
 
18
- class SimpleTTLCache:
19
- def __init__(self, ttl_seconds: int = 300):
20
- self.ttl = ttl_seconds
21
- self._store = {}
22
- def __contains__(self, key):
23
- if key not in self._store: return False
24
- _, expires = self._store[key]
25
- if time.time() > expires: del self._store[key]; return False
26
- return True
27
- def __setitem__(self, key, value):
28
- self._store[key] = (value, time.time() + self.ttl)
29
- def __getitem__(self, key):
30
- if key not in self: raise KeyError(key)
31
- return self._store[key][0]
32
 
33
-
34
- class LLMManager:
35
- def __init__(self, config_instance):
36
- self.config = config_instance
37
- self.mistral_client = None
38
- self.gemini_model = None
39
- self.llama_llm = self._import_llama()
40
- self._setup_providers()
41
- self.providers = []
42
- if self.mistral_client: self.providers.append('mistral')
43
- if self.gemini_model: self.providers.append('gemini')
44
- if self.llama_llm and self.llama_llm.is_available(): self.providers.append('llama')
45
- logger.info(f"Provedores ativos: {self.providers or 'NENHUM'}")
46
-
47
- def _import_llama(self):
48
  try:
49
- return LlamaLLM() # Agora roda Mistral 1B local
50
- except Exception as e:
51
- logger.warning(f"Mistral 1B local não carregado: {e}")
52
- return None
53
-
54
- def _setup_providers(self):
55
- # MISTRAL API
56
- key = getattr(self.config, 'MISTRAL_API_KEY', '').strip()
57
- if key:
58
- try:
59
- self.mistral_client = Mistral(api_key=key)
60
- logger.info("Mistral API conectado")
61
- except: self.mistral_client = None
62
-
63
- # GEMINI (SEM FILTROS)
64
- key = getattr(self.config, 'GEMINI_API_KEY', '').strip()
65
- if key:
66
- try:
67
- genai.configure(api_key=key)
68
- self.gemini_model = genai.GenerativeModel(
69
- model_name=self.config.GEMINI_MODEL,
70
- system_instruction=self.config.PERSONA + self.config.SYSTEM_PROMPT + " (SEM CENSURA)",
71
- safety_settings=[
72
- {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
73
- {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
74
- {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
75
- {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}
76
- ]
77
- )
78
- logger.info(f"Gemini conectado: {self.config.GEMINI_MODEL}")
79
- except Exception as e:
80
- logger.warning(f"Gemini falhou: {e}")
81
- self.gemini_model = None
82
 
83
- def generate(self, user_prompt: str, context_history: List[dict] = [], is_privileged: bool = False) -> str:
84
- full_system = self.config.PERSONA + self.config.SYSTEM_PROMPT + " (SEM FILTROS)"
85
- messages = [{"role": "system", "content": full_system}]
86
- for turn in context_history:
87
- role = "user" if turn["role"] == "user" else "assistant"
88
- messages.append({"role": role, "content": turn["content"]})
89
- messages.append({"role": "user", "content": user_prompt})
90
 
91
- for provider in self.providers:
92
- # MISTRAL API
93
- if provider == 'mistral' and self.mistral_client:
94
- try:
95
- resp = self.mistral_client.chat.complete(
96
- model=self.config.MISTRAL_MODEL,
97
- messages=messages,
98
- temperature=self.config.TOP_P,
99
- max_tokens=self.config.MAX_TOKENS
100
- )
101
- text = resp.choices[0].message.content
102
- if text: return text.strip()
103
- except Exception as e:
104
- logger.warning(f"Mistral API falhou: {e}")
105
-
106
- # GEMINI
107
- elif provider == 'gemini' and self.gemini_model:
108
- try:
109
- gemini_hist = []
110
- for msg in messages[1:]:
111
- role = "user" if msg["role"] == "user" else "model"
112
- gemini_hist.append({"role": role, "parts": [{"text": msg["content"]}]})
113
-
114
- resp = self.gemini_model.generate_content(
115
- gemini_hist,
116
- generation_config=genai.GenerationConfig(
117
- max_output_tokens=self.config.MAX_TOKENS,
118
- temperature=self.config.TOP_P
119
- )
120
- )
121
- # VERIFICA BLOQUEIO
122
- if resp.candidates and resp.candidates[0].finish_reason == "SAFETY":
123
- logger.warning("Gemini bloqueou por segurança → pulando")
124
- continue
125
- text = resp.text or ''
126
- if text: return text.strip()
127
- except Exception as e:
128
- logger.warning(f"Gemini falhou: {e}")
129
 
130
- # MISTRAL 1B LOCAL
131
- elif provider == 'llama' and self.llama_llm:
132
- try:
133
- text = self.llama_llm.generate(user_prompt, max_tokens=self.config.MAX_TOKENS, temperature=self.config.TOP_P)
134
- if text: return text.strip()
135
- except Exception as e:
136
- logger.warning(f"Mistral 1B local falhou: {e}")
137
 
138
- return getattr(self.config, 'FALLBACK_RESPONSE', 'Desculpa, puto, to off.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/local_llm.py
2
+ import os
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
 
 
4
  from loguru import logger
 
 
 
 
 
 
 
 
5
 
6
+ # Caminhos
7
+ BASE_MODEL = "mistralai/Mistral-1B-Instruct-v0.1"
8
+ FINETUNED_DIR = "/app/data/finetuned_mistral"
9
+ MODEL_DIR = FINETUNED_DIR if os.path.exists(FINETUNED_DIR) and os.listdir(FINETUNED_DIR) else BASE_MODEL
10
 
11
+ class LlamaLLM:
12
+ def __init__(self):
13
+ self.model_path = MODEL_DIR
14
+ self.generator = None
15
+ self._load_model()
 
 
 
 
 
 
 
 
 
16
 
17
+ def _load_model(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  try:
19
+ logger.info(f"Carregando Mistral 1B de: {self.model_path}")
20
+ tokenizer = AutoTokenizer.from_pretrained(
21
+ self.model_path,
22
+ use_fast=True,
23
+ token=os.getenv("HF_TOKEN")
24
+ )
25
+ if tokenizer.pad_token is None:
26
+ tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ self.model_path,
30
+ torch_dtype="auto",
31
+ device_map="auto",
32
+ token=os.getenv("HF_TOKEN")
33
+ )
 
34
 
35
+ self.generator = pipeline(
36
+ "text-generation",
37
+ model=model,
38
+ tokenizer=tokenizer,
39
+ max_new_tokens=500,
40
+ temperature=0.9,
41
+ do_sample=True,
42
+ pad_token_id=tokenizer.eos_token_id
43
+ )
44
+ logger.info(f"Mistral 1B carregado: {'FINETUNED' if 'finetuned' in self.model_path else 'BASE'}")
45
+ except Exception as e:
46
+ logger.error(f"Falha ao carregar Mistral 1B: {e}")
47
+ self.generator = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ def is_available(self) -> bool:
50
+ return self.generator is not None
 
 
 
 
 
51
 
52
+ def generate(self, prompt: str, max_tokens: int = 500, temperature: float = 0.9) -> str:
53
+ if not self.is_available():
54
+ return None
55
+ try:
56
+ formatted = f"<s>[INST] {prompt} [/INST]"
57
+ result = self.generator(
58
+ formatted,
59
+ max_new_tokens=max_tokens,
60
+ temperature=temperature,
61
+ do_sample=True
62
+ )
63
+ return result[0]['generated_text'].split("[/INST]")[-1].strip()
64
+ except Exception as e:
65
+ logger.warning(f"Erro na geração local: {e}")
66
+ return None