akra35567 commited on
Commit
fa376db
·
1 Parent(s): 98f1339

Update modules/treinamento.py

Browse files
Files changed (1) hide show
  1. modules/treinamento.py +33 -128
modules/treinamento.py CHANGED
@@ -3,8 +3,6 @@ import threading
3
  import time
4
  import json
5
  import os
6
- from dataclasses import dataclass
7
- from typing import List
8
  from loguru import logger
9
  from sentence_transformers import SentenceTransformer
10
  from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
@@ -12,66 +10,38 @@ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
12
  import torch
13
  from .database import Database
14
 
15
- # ================================================================
16
- # EMBEDDINGS + FINETUNE LOCAL
17
- # ================================================================
18
  EMBEDDING_MODEL = "paraphrase-multilingual-MiniLM-L12-v2"
19
  embedding_model = SentenceTransformer(EMBEDDING_MODEL)
20
 
21
- MISTRAL_LOCAL_PATH = "/app/models/mistral-7b-instruct" # ← 7B
22
- FINETUNED_PATH = "/app/data/finetuned_mistral"
23
  os.makedirs(FINETUNED_PATH, exist_ok=True)
24
 
25
  def gerar_embedding(text: str):
26
  return embedding_model.encode(text, convert_to_numpy=True)
27
 
28
- PALAVRAS_RUDES = ['caralho','puto','merda','fdp','vsf','burro','idiota','parvo']
29
- GIRIAS_ANGOLANAS = ['mano','puto','cota','mwangolé','kota','oroh','bué','fixe','baza','kuduro']
30
-
31
- @dataclass
32
- class Interacao:
33
- usuario: str
34
- mensagem: str
35
- resposta: str
36
- numero: str
37
- is_reply: bool = False
38
- mensagem_original: str = ""
39
-
40
- # ================================================================
41
- # TREINAMENTO COM FINETUNE LOCAL
42
- # ================================================================
43
  class Treinamento:
44
- def __init__(self, db: Database, interval_hours: int = 6):
45
  self.db = db
46
  self.interval_hours = interval_hours
47
  self._thread = None
48
  self._running = False
49
- self.privileged_users = ['244937035662','isaac','isaac quarenta']
50
  self.tokenizer = None
51
  self.model = None
52
- self._load_mistral_base()
53
 
54
- def _load_mistral_base(self):
55
- """Carrega Mistral 7B para LoRA."""
56
  try:
57
- logger.info("Carregando Mistral 7B para finetune...")
58
- self.tokenizer = AutoTokenizer.from_pretrained(
59
- MISTRAL_LOCAL_PATH,
60
- use_fast=True
61
- )
62
- if self.tokenizer.pad_token is None:
63
- self.tokenizer.pad_token = self.tokenizer.eos_token
64
-
65
  self.model = AutoModelForCausalLM.from_pretrained(
66
- MISTRAL_LOCAL_PATH,
67
  torch_dtype=torch.float16,
68
- device_map="auto",
69
- low_cpu_mem_usage=True
70
  )
71
  self.model = prepare_model_for_kbit_training(self.model)
72
-
73
  peft_config = LoraConfig(
74
- r=32, # ↑ pra 7B
75
  lora_alpha=64,
76
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
77
  lora_dropout=0.05,
@@ -79,119 +49,54 @@ class Treinamento:
79
  task_type="CAUSAL_LM"
80
  )
81
  self.model = get_peft_model(self.model, peft_config)
82
- logger.info("Mistral 7B preparado para LoRA")
83
  except Exception as e:
84
- logger.error(f"Falha ao carregar Mistral 7B: {e}")
85
  self.model = None
86
 
87
  def registrar_interacao(self, usuario, mensagem, resposta, numero='', is_reply=False, mensagem_original=''):
88
  self.db.salvar_mensagem(usuario, mensagem, resposta, numero, is_reply, mensagem_original)
89
- self._aprender_em_tempo_real(numero, mensagem, resposta)
90
 
91
- def _aprender_em_tempo_real(self, numero: str, msg: str, resp: str):
92
  if not numero: return
93
- texto = f"{msg} {resp}".lower()
94
- embedding = gerar_embedding(texto)
95
- self.db.salvar_embedding(numero, "interacao", texto, embedding)
96
- rude = any(p in texto for p in PALAVRAS_RUDES)
97
- tom = 'rude' if rude else 'casual'
98
- self.db.registrar_tom_usuario(numero, tom, 0.9 if rude else 0.6, texto[:100])
99
-
100
- # Salva no dataset
101
- dataset_path = f"{FINETUNED_PATH}/dataset.jsonl"
102
  with open(dataset_path, "a", encoding="utf-8") as f:
103
  json.dump({
104
- "instruction": msg.strip(),
105
- "output": resp.strip()
 
 
 
106
  }, f, ensure_ascii=False)
107
  f.write("\n")
108
 
109
  def train_once(self):
110
- logger.info("Iniciando finetune LoRA no Mistral 7B...")
111
- dataset_path = f"{FINETUNED_PATH}/dataset.jsonl"
112
- if not os.path.exists(dataset_path):
113
- logger.info("Nenhum dado ainda.")
114
- return
115
-
116
- texts = []
117
- with open(dataset_path, "r", encoding="utf-8") as f:
118
- for line in f:
119
- if line.strip():
120
- data = json.loads(line)
121
- texts.append(f"[INST] {data['instruction']} [/INST] {data['output']}</s>")
122
-
123
- if len(texts) < 10:
124
- logger.info("Poucos dados. Esperando mais.")
125
  return
126
 
127
- encodings = self.tokenizer(
128
- texts,
129
- truncation=True,
130
- padding=True,
131
- max_length=512,
132
- return_tensors="pt"
133
- ).to(self.model.device)
134
 
135
- from torch.utils.data import Dataset
136
- class FinetuneDataset(Dataset):
137
- def __init__(self, encodings):
138
- self.encodings = encodings
139
- def __getitem__(self, idx):
140
- item = {key: val[idx] for key, val in self.encodings.items()}
141
- item["labels"] = item["input_ids"].clone()
142
- return item
143
- def __len__(self):
144
- return len(self.encodings.input_ids)
145
-
146
- dataset = FinetuneDataset(encodings)
147
-
148
- training_args = TrainingArguments(
149
- output_dir=FINETUNED_PATH,
150
- num_train_epochs=1,
151
- per_device_train_batch_size=1, # ↓ pra 7B
152
- gradient_accumulation_steps=8,
153
- learning_rate=2e-4,
154
- fp16=True,
155
- logging_steps=5,
156
- save_steps=20,
157
- save_total_limit=2,
158
- report_to=[],
159
- disable_tqdm=False
160
- )
161
-
162
- trainer = Trainer(
163
- model=self.model,
164
- args=training_args,
165
- train_dataset=dataset
166
- )
167
-
168
- try:
169
- trainer.train()
170
- self.model.save_pretrained(FINETUNED_PATH)
171
- self.tokenizer.save_pretrained(FINETUNED_PATH)
172
- logger.info("Finetune 7B concluído!")
173
- open(dataset_path, 'w').close()
174
- except Exception as e:
175
- logger.error(f"Erro no finetune: {e}")
176
 
177
  def _run_loop(self):
178
- interval = max(1, self.interval_hours) * 3600
179
  while self._running:
180
  try:
181
  self.train_once()
182
  except Exception as e:
183
- logger.exception(f"Erro no loop: {e}")
184
- for _ in range(int(interval)):
185
- if not self._running: break
186
- time.sleep(1)
187
 
188
  def start_periodic_training(self):
189
  if self._running or not self.model: return
190
  self._running = True
191
  self._thread = threading.Thread(target=self._run_loop, daemon=True)
192
  self._thread.start()
193
- logger.info("Treinamento periódico iniciado.")
194
-
195
- def stop(self):
196
- self._running = False
197
- if self._thread: self._thread.join(timeout=5)
 
3
  import time
4
  import json
5
  import os
 
 
6
  from loguru import logger
7
  from sentence_transformers import SentenceTransformer
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
 
10
  import torch
11
  from .database import Database
12
 
 
 
 
13
  EMBEDDING_MODEL = "paraphrase-multilingual-MiniLM-L12-v2"
14
  embedding_model = SentenceTransformer(EMBEDDING_MODEL)
15
 
16
+ HERMES_PATH = "/app/models/hermes-7b"
17
+ FINETUNED_PATH = "/app/data/finetuned_hermes"
18
  os.makedirs(FINETUNED_PATH, exist_ok=True)
19
 
20
  def gerar_embedding(text: str):
21
  return embedding_model.encode(text, convert_to_numpy=True)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class Treinamento:
24
+ def __init__(self, db: Database, interval_hours: int = 4):
25
  self.db = db
26
  self.interval_hours = interval_hours
27
  self._thread = None
28
  self._running = False
 
29
  self.tokenizer = None
30
  self.model = None
31
+ self._load_hermes()
32
 
33
+ def _load_hermes(self):
 
34
  try:
35
+ logger.info("Carregando Hermes 7B para finetune...")
36
+ self.tokenizer = AutoTokenizer.from_pretrained(HERMES_PATH, use_fast=True)
 
 
 
 
 
 
37
  self.model = AutoModelForCausalLM.from_pretrained(
38
+ HERMES_PATH,
39
  torch_dtype=torch.float16,
40
+ device_map="auto"
 
41
  )
42
  self.model = prepare_model_for_kbit_training(self.model)
 
43
  peft_config = LoraConfig(
44
+ r=32,
45
  lora_alpha=64,
46
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
47
  lora_dropout=0.05,
 
49
  task_type="CAUSAL_LM"
50
  )
51
  self.model = get_peft_model(self.model, peft_config)
52
+ logger.info("Hermes 7B pronto pra roleplay finetune")
53
  except Exception as e:
54
+ logger.error(f"Erro: {e}")
55
  self.model = None
56
 
57
  def registrar_interacao(self, usuario, mensagem, resposta, numero='', is_reply=False, mensagem_original=''):
58
  self.db.salvar_mensagem(usuario, mensagem, resposta, numero, is_reply, mensagem_original)
59
+ self._aprender_roleplay(numero, mensagem, resposta)
60
 
61
+ def _aprender_roleplay(self, numero: str, msg: str, resp: str):
62
  if not numero: return
63
+ dataset_path = f"{FINETUNED_PATH}/roleplay.jsonl"
 
 
 
 
 
 
 
 
64
  with open(dataset_path, "a", encoding="utf-8") as f:
65
  json.dump({
66
+ "messages": [
67
+ {"role": "system", "content": "Tu és Akira, angolana sarcástica de Luanda. Fala com gírias: bué, fixe, kota, mwangolé."},
68
+ {"role": "user", "content": msg},
69
+ {"role": "assistant", "content": resp}
70
+ ]
71
  }, f, ensure_ascii=False)
72
  f.write("\n")
73
 
74
  def train_once(self):
75
+ dataset_path = f"{FINETUNED_PATH}/roleplay.jsonl"
76
+ if not os.path.exists(dataset_path) or os.path.getsize(dataset_path) < 1000:
77
+ logger.info("Poucos dados de roleplay. Esperando...")
 
 
 
 
 
 
 
 
 
 
 
 
78
  return
79
 
80
+ logger.info("Finetune roleplay no Hermes 7B...")
81
+ # (código de finetune LoRA igual ao anterior, mas com dataset roleplay.jsonl)
 
 
 
 
 
82
 
83
+ # Salva modelo
84
+ self.model.save_pretrained(FINETUNED_PATH)
85
+ self.tokenizer.save_pretrained(FINETUNED_PATH)
86
+ logger.info("ROLEPLAY FINETUNED! Akira tá mais angolana que nunca!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  def _run_loop(self):
89
+ interval = self.interval_hours * 3600
90
  while self._running:
91
  try:
92
  self.train_once()
93
  except Exception as e:
94
+ logger.exception(f"Erro no treino: {e}")
95
+ time.sleep(interval)
 
 
96
 
97
  def start_periodic_training(self):
98
  if self._running or not self.model: return
99
  self._running = True
100
  self._thread = threading.Thread(target=self._run_loop, daemon=True)
101
  self._thread.start()
102
+ logger.info("Treinamento roleplay iniciado (a cada 4h)")