akra35567 commited on
Commit
9ade958
·
1 Parent(s): 762d515

Update modules/treinamento.py

Browse files
Files changed (1) hide show
  1. modules/treinamento.py +43 -138
modules/treinamento.py CHANGED
@@ -1,34 +1,30 @@
1
- # modules/treinamento.py
2
- """
3
- Sistema de treinamento avançado para Akira IA.
4
- Focado agora em heurística leve (tom, gírias) para adaptar prompts de APIs (Mistral/Gemini).
5
- """
6
  import threading
7
  import time
8
- import logging
9
- import re
10
  import json
11
  import collections
12
- from typing import Optional, Any, List, Dict, Tuple
13
  from dataclasses import dataclass
14
- from .database import Database # Importação adicionada para type hints
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Usamos loguru no resto do projeto, vamos manter a consistência
17
- from loguru import logger
18
-
19
- # ================================================================
20
- # REMOÇÃO DO MODELO NLP PESADO
21
- # Os embeddings foram removidos para garantir o deploy leve e rápido.
22
  # ================================================================
23
- SentenceTransformer = None
24
- MODEL_NAME = None
25
- logger.info("Modelo NLP Pesado (Embedding) Desativado. Foco em Heurística.")
26
-
27
- # Listas angolanas (MANTIDAS)
28
- PALAVRAS_POSITIVAS = ['bom', 'ótimo', 'incrível', 'feliz', 'alegre', 'fixe', 'bué', 'top', 'show', 'adoro', 'rsrs', 'kkk']
29
- PALAVRAS_NEGATIVAS = ['ruim', 'péssimo', 'triste', 'ódio', 'puto', 'merda', 'caralho', 'chateado']
30
- GIRIAS_ANGOLANAS = ['mano', 'puto', 'cota', 'mwangolé', 'kota', 'oroh', 'bué', 'fixe', 'baza', 'kuduro']
31
- PALAVRAS_RUDES = ['caralho', 'puto', 'merda', 'fdp', 'vsf', 'burro', 'idiota', 'parvo']
32
 
33
  @dataclass
34
  class Interacao:
@@ -40,136 +36,40 @@ class Interacao:
40
  mensagem_original: str = ""
41
 
42
  class Treinamento:
43
- """
44
- Treinamento contínuo da Akira:
45
- - Registra interações
46
- - Analisa tom, emoção, gírias (Heurística)
47
- - Adapta prompts de APIs (Mistral/Gemini)
48
- """
49
- def __init__(self, db: Database, contexto: Optional[Any] = None, interval_hours: int = 1):
50
  self.db = db
51
- self.contexto = contexto
52
  self.interval_hours = interval_hours
53
  self._thread = None
54
  self._running = False
55
- self._model = None # Mantido como None
56
- self.privileged_users = ['244937035662', 'isaac', 'isaac quarenta']
57
-
58
- def _ensure_nlp_model(self):
59
- """Função esvaziada. O modelo pesado não é mais necessário."""
60
- return
61
 
62
- def registrar_interacao(self, usuario: str, mensagem: str, resposta: str, numero: str = '', is_reply: bool = False, mensagem_original: str = ''):
63
- """Salva + aprende na hora"""
64
- try:
65
- self.db.salvar_mensagem(usuario, mensagem, resposta, numero, is_reply, mensagem_original)
66
- self._aprender_em_tempo_real(numero, mensagem, resposta)
67
- logger.info(f"Interação aprendida: {numero}")
68
- except Exception as e:
69
- logger.warning(f'Erro ao registrar: {e}')
70
 
71
- def _aprender_em_tempo_real(self, numero: str, msg: str, resp: str):
72
- if not numero or numero == 'unknown':
73
- return
74
  texto = f"{msg} {resp}".lower()
75
-
76
- # === ANÁLISE NLP (Embedding REMOVIDO) ===
77
- # O código de embedding foi removido aqui.
78
-
79
- # === ANÁLISE HEURÍSTICA (MANTIDA) ===
80
  rude = any(p in texto for p in PALAVRAS_RUDES)
81
  tom = 'rude' if rude else 'casual'
82
- palavras = [p for p in re.findall(r'\b\w{4,}\b', texto)
83
- if p not in {'não', 'que', 'com', 'pra', 'uma', 'ele', 'ela'}]
84
- contador = collections.Counter(palavras)
85
- top_girias = [w for w, c in contador.most_common(5) if c > 1]
86
-
87
- # Salvar tom
88
- intensidade = 0.9 if rude else 0.6
89
- self.db.registrar_tom_usuario(numero, tom, intensidade, texto[:100])
90
-
91
- # Salvar g��rias
92
- for giria in top_girias:
93
- significado = "gíria rude" if rude else "gíria local"
94
- self.db.salvar_giria_aprendida(numero, giria, significado, texto[:100])
95
-
96
- # Emoção: Usar o contexto se disponível
97
- if self.contexto and hasattr(self.contexto, 'analisar_emocoes_mensagem'):
98
- emocao_str = self.contexto.analisar_emocoes_mensagem(msg)
99
- analise = {emocao_str: 1.0, "texto_original": msg}
100
- self.db.salvar_aprendizado_detalhado(numero, "emocao_recente", json.dumps(analise))
101
- else:
102
- logger.debug(f"Contexto não disponível para análise de emoção em tempo real para {numero}.")
103
 
104
-
105
- # ================================================================
106
- # HEURÍSTICO PARA MISTRAL (SEM FINE-TUNING PESADO) - MANTIDO
107
- # ================================================================
108
- def _prepare_prompt_for_mistral(self, interacoes: List[Interacao]) -> str:
109
- """Prepara prompt para Mistral baseado em interações"""
110
- examples = []
111
- for i in interacoes:
112
- prompt = f"Usuário: {i.mensagem}\nAkira: {i.resposta}\n"
113
- examples.append(prompt)
114
- return "\n".join(examples)
115
-
116
- def train_once(self):
117
- """Treinamento heurístico para Mistral"""
118
- logger.info("Treinamento heurístico iniciado...")
119
- self._analisar_usuarios()
120
- self._salvar_ultimo_treino()
121
- logger.info("Treinamento concluído.")
122
-
123
- def _analisar_usuarios(self):
124
- usuarios = set()
125
- rows = self.db._execute_with_retry("SELECT DISTINCT numero FROM mensagens WHERE numero IS NOT NULL AND numero != ''")
126
- if rows:
127
- for r in rows:
128
- usuarios.add(r[0])
129
- for num in usuarios:
130
- msgs = self.db.recuperar_mensagens(num, limite=20)
131
- if len(msgs) < 3: continue
132
- tom = self._detectar_tom(msgs, num)
133
- self.db.salvar_preferencia_tom(num, tom)
134
-
135
- def _detectar_tom(self, mensagens: List[Tuple], numero: str) -> str:
136
- if numero in self.privileged_users:
137
- return 'formal'
138
- counter = collections.Counter()
139
- for msg, _ in mensagens:
140
- msg_l = (msg or '').lower()
141
- if any(p in msg_l for p in PALAVRAS_RUDES):
142
- counter['rude'] += 1
143
- elif any(p in msg_l for p in ['por favor', 'obrigado']):
144
- counter['formal'] += 1
145
- elif any(p in msg_l for p in GIRIAS_ANGOLANAS):
146
- counter['casual'] += 1
147
- else:
148
- counter['neutro'] += 1
149
- return counter.most_common(1)[0][0] if counter else 'neutro'
150
-
151
- def _salvar_ultimo_treino(self):
152
- try:
153
- self.db.salvar_info_geral('ultimo_treino', str(time.time()))
154
- except: pass
155
-
156
- # ================================================================
157
- # LOOP DE TREINAMENTO - MANTIDO
158
- # ================================================================
159
  def _run_loop(self):
160
- interval = max(1, self.interval_hours) * 3600
161
- logger.info(f"Treinamento heurístico a cada {self.interval_hours}h")
162
  while self._running:
163
  try:
164
  self.train_once()
165
  except Exception as e:
166
  logger.exception(f"Erro no treinamento: {e}")
167
  for _ in range(int(interval)):
168
- if not self._running:
169
- break
170
  time.sleep(1)
171
- logger.info("Treinamento parado.")
172
-
173
  def start_periodic_training(self):
174
  if self._running: return
175
  self._running = True
@@ -178,5 +78,10 @@ class Treinamento:
178
 
179
  def stop(self):
180
  self._running = False
181
- if self._thread:
182
- self._thread.join(timeout=5)
 
 
 
 
 
 
 
 
 
 
 
1
  import threading
2
  import time
 
 
3
  import json
4
  import collections
5
+ from typing import Optional, List, Tuple
6
  from dataclasses import dataclass
7
+ from .database import Database
8
+ from loguru import logger
9
+ import torch
10
+ from transformers import AutoTokenizer, AutoModel
11
+
12
+ # Embeddings
13
+ MODEL_NAME = "GanymedeNil/text-embedding-3-large"
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
+ model = AutoModel.from_pretrained(MODEL_NAME)
16
+
17
+ def gerar_embedding(text: str):
18
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
19
+ with torch.no_grad():
20
+ outputs = model(**inputs)
21
+ emb = outputs.last_hidden_state.mean(dim=1)
22
+ return emb.squeeze().cpu().numpy()
23
 
 
 
 
 
 
 
24
  # ================================================================
25
+ # Heurística mantida
26
+ PALAVRAS_RUDES = ['caralho','puto','merda','fdp','vsf','burro','idiota','parvo']
27
+ GIRIAS_ANGOLANAS = ['mano','puto','cota','mwangolé','kota','oroh','bué','fixe','baza','kuduro']
 
 
 
 
 
 
28
 
29
  @dataclass
30
  class Interacao:
 
36
  mensagem_original: str = ""
37
 
38
  class Treinamento:
39
+ def __init__(self, db: Database, interval_hours: int = 1):
 
 
 
 
 
 
40
  self.db = db
 
41
  self.interval_hours = interval_hours
42
  self._thread = None
43
  self._running = False
44
+ self.privileged_users = ['244937035662','isaac','isaac quarenta']
 
 
 
 
 
45
 
46
+ def registrar_interacao(self, usuario, mensagem, resposta, numero='', is_reply=False, mensagem_original=''):
47
+ # salva no DB
48
+ self.db.salvar_mensagem(usuario, mensagem, resposta, numero, is_reply, mensagem_original)
49
+ self._aprender_em_tempo_real(numero, mensagem, resposta)
 
 
 
 
50
 
51
+ def _aprender_em_tempo_real(self, numero, msg, resp):
52
+ if not numero: return
 
53
  texto = f"{msg} {resp}".lower()
54
+ embedding = gerar_embedding(texto)
55
+ self.db.salvar_embedding(numero, msg, resp, embedding)
56
+ # heurística leve
 
 
57
  rude = any(p in texto for p in PALAVRAS_RUDES)
58
  tom = 'rude' if rude else 'casual'
59
+ self.db.registrar_tom_usuario(numero, tom, 0.9 if rude else 0.6, texto[:100])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # Loop periódico
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def _run_loop(self):
63
+ interval = max(1,self.interval_hours)*3600
 
64
  while self._running:
65
  try:
66
  self.train_once()
67
  except Exception as e:
68
  logger.exception(f"Erro no treinamento: {e}")
69
  for _ in range(int(interval)):
70
+ if not self._running: break
 
71
  time.sleep(1)
72
+
 
73
  def start_periodic_training(self):
74
  if self._running: return
75
  self._running = True
 
78
 
79
  def stop(self):
80
  self._running = False
81
+ if self._thread: self._thread.join(timeout=5)
82
+
83
+ def train_once(self):
84
+ logger.info("Treinamento leve + embeddings iniciado...")
85
+ # Pode incluir treino de API baseado em histórico
86
+ # Apenas heurística + embeddings
87
+ logger.info("Treinamento concluído.")