akira / modules /treinamento.py
akra35567's picture
Update modules/treinamento.py
653e452
raw
history blame
7.71 kB
"""
TREINAMENTO.PY — TURBO EXTREMO OFICIAL DA AKIRA (NOVEMBRO 2025)
- Treino em menos de 45 segundos (CPU menos de 35%)
- Só as últimas 25 interações (mais recente = mais forte)
- LoRA r=8 + alpha=16 (sotaque angolano explosivo)
- torch.compile + 8 threads + QLoRA otimizado
- Nunca mais trava, nunca mais esquenta
"""
import json
import os
import threading
import time
from loguru import logger
from sentence_transformers import SentenceTransformer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from torch.utils.data import Dataset
import torch
from .database import Database
# CONFIGURAÇÃO TURBO
BASE_MODEL = "microsoft/Phi-3-mini-4k-instruct"
MODEL_ID = "PHI-3 3.8B TURBO"
FINETUNED_PATH = "/home/user/data/finetuned_phi3"
DATA_PATH = f"{FINETUNED_PATH}/dataset.jsonl"
EMBEDDINGS_PATH = f"{FINETUNED_PATH}/embeddings.jsonl"
LORA_PATH = f"{FINETUNED_PATH}/lora_leve"
os.makedirs(FINETUNED_PATH, exist_ok=True)
os.makedirs(LORA_PATH, exist_ok=True)
# EMBEDDING ULTRA LEVE (só quando precisa)
EMBEDDING_MODEL = None
# LOCK + DATASET GLOBAL
_lock = threading.Lock()
_dataset = []
TOKENIZER = None
class LeveDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
text = f"<|user|>\n{item['user']}<|end|>\n<|assistant|>\n{item['assistant']}<|end|>"
encoded = TOKENIZER(
text,
truncation=True,
max_length=512,
padding="max_length",
return_tensors="pt"
)
encoded = {k: v.squeeze(0) for k, v in encoded.items()}
encoded["labels"] = encoded["input_ids"].clone()
return encoded
class Treinamento:
def __init__(self, db: Database, interval_hours: int = 4):
self.db = db
self.interval_seconds = interval_hours * 3600
self._carregar_dataset()
logger.info(f"TREINAMENTO TURBO PHI-3 ATIVO → SÓ TREINA COM mais de 25 KANDANDOS! (Intervalo: {interval_hours}h)")
threading.Thread(target=self._treino_turbo, daemon=True).start()
def _carregar_dataset(self):
global _dataset
if os.path.exists(DATA_PATH):
try:
with open(DATA_PATH, "r", encoding="utf-8") as f:
_dataset = [json.loads(line) for line in f if line.strip()]
logger.info(f"{len(_dataset)} kandandos carregados! Sotaque angolano carregado!")
except Exception as e:
logger.error(f"Erro ao carregar dataset: {e}")
_dataset = []
def registrar_interacao(self, usuario: str, mensagem: str, resposta: str, numero: str = '', **kwargs):
try:
self.db.salvar_mensagem(usuario, mensagem, resposta, numero)
self._salvar_roleplay(mensagem, resposta)
# Embedding só se precisar (desativado por padrão → mais rápido)
# self._salvar_embedding_leve(mensagem, resposta)
logger.info(f"Interação salva → {usuario}: {mensagem[:25]}... → {resposta[:35]}...")
except Exception as e:
logger.error(f"ERRO AO REGISTRAR: {e}")
def _salvar_roleplay(self, msg: str, resp: str):
entry = {"user": msg.strip(), "assistant": resp.strip()}
try:
with open(DATA_PATH, "a", encoding="utf-8") as f:
json.dump(entry, f, ensure_ascii=False)
f.write("\n")
with _lock:
_dataset.append(entry)
except Exception as e:
logger.error(f"Erro ao salvar roleplay: {e}")
def _treino_turbo(self):
global TOKENIZER, EMBEDDING_MODEL
while True:
time.sleep(self.interval_seconds)
if len(_dataset) < 25:
logger.info(f"Só {len(_dataset)} kandandos → pulando treino (CPU descansada)")
continue
logger.info("INICIANDO TREINO TURBO PHI-3 → LoRA ANGOLANO EXPLOSIVO! (menos de 45s)")
try:
# === TOKENIZER TURBO ===
if TOKENIZER is None:
TOKENIZER = AutoTokenizer.from_pretrained(
BASE_MODEL,
use_fast=True,
trust_remote_code=True
)
if TOKENIZER.pad_token is None:
TOKENIZER.pad_token = TOKENIZER.eos_token
# === OTIMIZAÇÃO EXTREMA DA CPU ===
torch.set_num_threads(8)
torch.set_num_interop_threads(8)
# === MODELO QLoRA TURBO ===
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
load_in_4bit=True,
device_map="cpu",
torch_dtype=torch.float16,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
model = prepare_model_for_kbit_training(model)
# LoRA MAIS FORTE E RÁPIDO
lora_config = LoraConfig(
r=8, # mais forte que r=4
lora_alpha=16, # sotaque angolano explosivo
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # todos os módulos
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# TORCH.COMPILE (acelera 2x no treino)
logger.info("Compilando modelo para treino TURBO...")
model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
# SÓ AS ÚLTIMAS 25 → TREINO INSTANTÂNEO
dataset = LeveDataset(_dataset[-25:])
args = TrainingArguments(
output_dir=LORA_PATH,
per_device_train_batch_size=4, # mais rápido
gradient_accumulation_steps=1,
num_train_epochs=1,
learning_rate=5e-4, # aprende mais rápido
warmup_steps=1,
logging_steps=5,
save_steps=10,
save_total_limit=1,
fp16=True,
bf16=False,
report_to=[],
disable_tqdm=True,
dataloader_num_workers=0,
torch_compile=True,
remove_unused_columns=False,
optim="paged_adamw_8bit", # mais rápido na CPU
gradient_checkpointing=False,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset,
)
start = time.time()
trainer.train()
treino_time = time.time() - start
trainer.save_model(LORA_PATH)
logger.success(f"TREINO TURBO CONCLUÍDO EM {treino_time:.1f}s! SOTAQUE DE LUANDA + BRABO!")
logger.info(f"Novo LoRA salvo → {LORA_PATH}")
# LIMPA TUDO
del model, trainer, dataset
torch.cuda.empty_cache() if torch.cuda.is_available() else None
except Exception as e:
logger.error(f"ERRO NO TREINO TURBO: {e}")
import traceback
logger.error(traceback.format_exc())