stmasson commited on
Commit
a779a89
·
verified ·
1 Parent(s): 70f98a4

Upload scripts/eval_n8n_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/eval_n8n_model.py +409 -0
scripts/eval_n8n_model.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "transformers>=4.45.0",
5
+ # "datasets>=3.0.0",
6
+ # "accelerate>=1.0.0",
7
+ # "huggingface_hub>=0.26.0",
8
+ # "torch>=2.4.0",
9
+ # "tqdm>=4.66.0",
10
+ # "pandas>=2.0.0",
11
+ # ]
12
+ # [tool.uv]
13
+ # extra-index-url = ["https://download.pytorch.org/whl/cu124"]
14
+ # ///
15
+ """
16
+ Script d'évaluation pour le modèle n8n Expert.
17
+
18
+ Métriques:
19
+ 1. JSON Validity - Le output est-il du JSON valide?
20
+ 2. Schema Compliance - Le workflow suit-il le schéma n8n?
21
+ 3. Node Accuracy - Les types de nodes sont-ils corrects?
22
+ 4. Connection Logic - Les connexions sont-elles cohérentes?
23
+ 5. Thinking Quality - Le raisonnement est-il présent et structuré?
24
+
25
+ Usage:
26
+ python eval_n8n_model.py --model stmasson/n8n-expert-14b --samples 100
27
+ """
28
+
29
+ import os
30
+ import json
31
+ import argparse
32
+ import re
33
+ from typing import Dict, List, Any, Tuple
34
+ from dataclasses import dataclass
35
+ from tqdm import tqdm
36
+ import pandas as pd
37
+ import torch
38
+ from datasets import load_dataset
39
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
40
+ from huggingface_hub import login
41
+
42
+ # ============================================================================
43
+ # CONFIGURATION
44
+ # ============================================================================
45
+
46
+ # Types de nodes n8n valides (liste partielle)
47
+ VALID_NODE_TYPES = {
48
+ # Triggers
49
+ "n8n-nodes-base.webhookTrigger",
50
+ "n8n-nodes-base.scheduleTrigger",
51
+ "n8n-nodes-base.manualTrigger",
52
+ "n8n-nodes-base.emailTrigger",
53
+ # Actions
54
+ "n8n-nodes-base.httpRequest",
55
+ "n8n-nodes-base.set",
56
+ "n8n-nodes-base.if",
57
+ "n8n-nodes-base.switch",
58
+ "n8n-nodes-base.merge",
59
+ "n8n-nodes-base.splitInBatches",
60
+ "n8n-nodes-base.function",
61
+ "n8n-nodes-base.code",
62
+ "n8n-nodes-base.noOp",
63
+ # Intégrations
64
+ "n8n-nodes-base.slack",
65
+ "n8n-nodes-base.gmail",
66
+ "n8n-nodes-base.googleSheets",
67
+ "n8n-nodes-base.airtable",
68
+ "n8n-nodes-base.notion",
69
+ "n8n-nodes-base.discord",
70
+ "n8n-nodes-base.telegram",
71
+ "n8n-nodes-base.openAi",
72
+ "n8n-nodes-base.postgres",
73
+ "n8n-nodes-base.mysql",
74
+ "n8n-nodes-base.mongodb",
75
+ # AI
76
+ "@n8n/n8n-nodes-langchain.agent",
77
+ "@n8n/n8n-nodes-langchain.chainLlm",
78
+ }
79
+
80
+ # ============================================================================
81
+ # MÉTRIQUES
82
+ # ============================================================================
83
+
84
+ @dataclass
85
+ class EvalResult:
86
+ """Résultat d'évaluation pour un exemple"""
87
+ task_type: str
88
+ valid_json: bool
89
+ has_nodes: bool
90
+ has_connections: bool
91
+ nodes_valid: bool
92
+ has_thinking: bool
93
+ thinking_structured: bool
94
+ error: str = ""
95
+
96
+ @property
97
+ def score(self) -> float:
98
+ """Score global 0-1"""
99
+ scores = [
100
+ self.valid_json,
101
+ self.has_nodes,
102
+ self.has_connections,
103
+ self.nodes_valid,
104
+ self.has_thinking,
105
+ self.thinking_structured,
106
+ ]
107
+ return sum(scores) / len(scores)
108
+
109
+
110
+ def extract_workflow_json(text: str) -> Tuple[str, str]:
111
+ """
112
+ Extrait le JSON du workflow et le thinking de la réponse.
113
+ Retourne (thinking, workflow_json)
114
+ """
115
+ thinking = ""
116
+ workflow_json = ""
117
+
118
+ # Extraire le thinking
119
+ thinking_match = re.search(r'<thinking>(.*?)</thinking>', text, re.DOTALL)
120
+ if thinking_match:
121
+ thinking = thinking_match.group(1).strip()
122
+
123
+ # Extraire le JSON (après le thinking ou dans un bloc code)
124
+ # Méthode 1: Bloc code JSON
125
+ json_block = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
126
+ if json_block:
127
+ workflow_json = json_block.group(1).strip()
128
+ else:
129
+ # Méthode 2: JSON brut après le thinking
130
+ after_thinking = text
131
+ if thinking_match:
132
+ after_thinking = text[thinking_match.end():]
133
+
134
+ # Chercher un objet JSON
135
+ json_match = re.search(r'\{[\s\S]*\}', after_thinking)
136
+ if json_match:
137
+ workflow_json = json_match.group(0).strip()
138
+
139
+ return thinking, workflow_json
140
+
141
+
142
+ def validate_workflow(workflow_json: str) -> Dict[str, Any]:
143
+ """Valide un workflow n8n"""
144
+ result = {
145
+ "valid_json": False,
146
+ "has_nodes": False,
147
+ "has_connections": False,
148
+ "nodes_valid": False,
149
+ "node_count": 0,
150
+ "connection_count": 0,
151
+ "invalid_nodes": [],
152
+ }
153
+
154
+ # Test JSON valide
155
+ try:
156
+ wf = json.loads(workflow_json)
157
+ result["valid_json"] = True
158
+ except json.JSONDecodeError as e:
159
+ result["error"] = str(e)
160
+ return result
161
+
162
+ # Test nodes présents
163
+ nodes = wf.get("nodes", [])
164
+ result["has_nodes"] = len(nodes) > 0
165
+ result["node_count"] = len(nodes)
166
+
167
+ # Test connexions présentes
168
+ connections = wf.get("connections", {})
169
+ result["has_connections"] = len(connections) > 0
170
+ result["connection_count"] = sum(len(v) for v in connections.values())
171
+
172
+ # Test types de nodes valides
173
+ invalid_nodes = []
174
+ for node in nodes:
175
+ node_type = node.get("type", "")
176
+ if node_type and node_type not in VALID_NODE_TYPES:
177
+ # Accepter les types qui ressemblent à des nodes n8n
178
+ if not (node_type.startswith("n8n-nodes-base.") or
179
+ node_type.startswith("@n8n/")):
180
+ invalid_nodes.append(node_type)
181
+
182
+ result["invalid_nodes"] = invalid_nodes
183
+ result["nodes_valid"] = len(invalid_nodes) == 0
184
+
185
+ return result
186
+
187
+
188
+ def validate_thinking(thinking: str) -> Dict[str, bool]:
189
+ """Valide la qualité du thinking"""
190
+ result = {
191
+ "has_thinking": len(thinking) > 50, # Au moins 50 caractères
192
+ "thinking_structured": False,
193
+ }
194
+
195
+ # Vérifier si le thinking est structuré (contient des points numérotés ou tirets)
196
+ if thinking:
197
+ has_structure = (
198
+ re.search(r'\d+\.', thinking) is not None or # Points numérotés
199
+ re.search(r'^-\s', thinking, re.MULTILINE) is not None or # Tirets
200
+ re.search(r'^\*\s', thinking, re.MULTILINE) is not None or # Étoiles
201
+ "étape" in thinking.lower() or
202
+ "step" in thinking.lower()
203
+ )
204
+ result["thinking_structured"] = has_structure
205
+
206
+ return result
207
+
208
+
209
+ def evaluate_example(
210
+ model_output: str,
211
+ task_type: str,
212
+ ) -> EvalResult:
213
+ """Évalue un exemple généré par le modèle"""
214
+ # Extraire thinking et JSON
215
+ thinking, workflow_json = extract_workflow_json(model_output)
216
+
217
+ # Valider le workflow
218
+ wf_validation = validate_workflow(workflow_json)
219
+
220
+ # Valider le thinking
221
+ thinking_validation = validate_thinking(thinking)
222
+
223
+ return EvalResult(
224
+ task_type=task_type,
225
+ valid_json=wf_validation["valid_json"],
226
+ has_nodes=wf_validation["has_nodes"],
227
+ has_connections=wf_validation["has_connections"],
228
+ nodes_valid=wf_validation["nodes_valid"],
229
+ has_thinking=thinking_validation["has_thinking"],
230
+ thinking_structured=thinking_validation["thinking_structured"],
231
+ error=wf_validation.get("error", ""),
232
+ )
233
+
234
+
235
+ # ============================================================================
236
+ # ÉVALUATION
237
+ # ============================================================================
238
+
239
+ def run_evaluation(
240
+ model_path: str,
241
+ dataset_repo: str = "stmasson/n8n-agentic-multitask",
242
+ data_file: str = "data/multitask_large/val.jsonl",
243
+ num_samples: int = 100,
244
+ output_file: str = "eval_results.json",
245
+ ):
246
+ """Lance l'évaluation complète du modèle"""
247
+
248
+ print("=" * 60)
249
+ print("ÉVALUATION DU MODÈLE N8N EXPERT")
250
+ print("=" * 60)
251
+
252
+ # Auth
253
+ hf_token = os.environ.get("HF_TOKEN")
254
+ if hf_token:
255
+ login(token=hf_token)
256
+
257
+ # Charger le modèle
258
+ print(f"\nChargement du modèle: {model_path}")
259
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
260
+ model = AutoModelForCausalLM.from_pretrained(
261
+ model_path,
262
+ torch_dtype=torch.bfloat16,
263
+ device_map="auto",
264
+ trust_remote_code=True,
265
+ )
266
+
267
+ pipe = pipeline(
268
+ "text-generation",
269
+ model=model,
270
+ tokenizer=tokenizer,
271
+ device_map="auto",
272
+ )
273
+
274
+ # Charger le dataset
275
+ print(f"\nChargement du dataset: {dataset_repo}")
276
+ dataset = load_dataset(
277
+ dataset_repo,
278
+ data_files={"validation": data_file},
279
+ split="validation"
280
+ )
281
+
282
+ # Échantillonner
283
+ if num_samples < len(dataset):
284
+ dataset = dataset.shuffle(seed=42).select(range(num_samples))
285
+
286
+ print(f"Évaluation sur {len(dataset)} exemples")
287
+
288
+ # Évaluer
289
+ results = []
290
+ task_counts = {}
291
+
292
+ for example in tqdm(dataset, desc="Évaluation"):
293
+ messages = example["messages"]
294
+
295
+ # Déterminer le type de tâche
296
+ system_msg = messages[0]["content"] if messages else ""
297
+ if "génère" in system_msg.lower() or "generate" in system_msg.lower():
298
+ task_type = "generate"
299
+ elif "édite" in system_msg.lower() or "edit" in system_msg.lower():
300
+ task_type = "edit"
301
+ elif "corrige" in system_msg.lower() or "fix" in system_msg.lower():
302
+ task_type = "fix"
303
+ elif "améliore" in system_msg.lower() or "improve" in system_msg.lower():
304
+ task_type = "improve"
305
+ elif "explique" in system_msg.lower() or "explain" in system_msg.lower():
306
+ task_type = "explain"
307
+ elif "débogue" in system_msg.lower() or "debug" in system_msg.lower():
308
+ task_type = "debug"
309
+ else:
310
+ task_type = "unknown"
311
+
312
+ task_counts[task_type] = task_counts.get(task_type, 0) + 1
313
+
314
+ # Construire le prompt
315
+ prompt = tokenizer.apply_chat_template(
316
+ messages[:-1], # Exclure la réponse attendue
317
+ tokenize=False,
318
+ add_generation_prompt=True,
319
+ )
320
+
321
+ # Générer
322
+ try:
323
+ output = pipe(
324
+ prompt,
325
+ max_new_tokens=4096,
326
+ do_sample=False,
327
+ temperature=None,
328
+ top_p=None,
329
+ return_full_text=False,
330
+ )
331
+ generated = output[0]["generated_text"]
332
+ except Exception as e:
333
+ generated = f"ERROR: {str(e)}"
334
+
335
+ # Évaluer
336
+ eval_result = evaluate_example(generated, task_type)
337
+ results.append(eval_result)
338
+
339
+ # Calculer les statistiques
340
+ print("\n" + "=" * 60)
341
+ print("RÉSULTATS")
342
+ print("=" * 60)
343
+
344
+ total = len(results)
345
+
346
+ # Métriques globales
347
+ metrics = {
348
+ "valid_json": sum(r.valid_json for r in results) / total,
349
+ "has_nodes": sum(r.has_nodes for r in results) / total,
350
+ "has_connections": sum(r.has_connections for r in results) / total,
351
+ "nodes_valid": sum(r.nodes_valid for r in results) / total,
352
+ "has_thinking": sum(r.has_thinking for r in results) / total,
353
+ "thinking_structured": sum(r.thinking_structured for r in results) / total,
354
+ "overall_score": sum(r.score for r in results) / total,
355
+ }
356
+
357
+ print("\nMétriques globales:")
358
+ for metric, value in metrics.items():
359
+ print(f" {metric}: {value:.1%}")
360
+
361
+ # Métriques par tâche
362
+ print("\nMétriques par tâche:")
363
+ for task_type in sorted(task_counts.keys()):
364
+ task_results = [r for r in results if r.task_type == task_type]
365
+ if task_results:
366
+ task_score = sum(r.score for r in task_results) / len(task_results)
367
+ task_json = sum(r.valid_json for r in task_results) / len(task_results)
368
+ print(f" {task_type}: score={task_score:.1%}, json={task_json:.1%} (n={len(task_results)})")
369
+
370
+ # Sauvegarder les résultats
371
+ output = {
372
+ "model": model_path,
373
+ "num_samples": total,
374
+ "metrics": metrics,
375
+ "by_task": {
376
+ task: {
377
+ "count": len([r for r in results if r.task_type == task]),
378
+ "score": sum(r.score for r in results if r.task_type == task) /
379
+ max(1, len([r for r in results if r.task_type == task])),
380
+ }
381
+ for task in task_counts.keys()
382
+ },
383
+ }
384
+
385
+ with open(output_file, "w") as f:
386
+ json.dump(output, f, indent=2)
387
+
388
+ print(f"\nRésultats sauvegardés dans: {output_file}")
389
+
390
+ return metrics
391
+
392
+
393
+ # ============================================================================
394
+ # MAIN
395
+ # ============================================================================
396
+
397
+ if __name__ == "__main__":
398
+ parser = argparse.ArgumentParser(description="Évaluation du modèle n8n Expert")
399
+ parser.add_argument("--model", type=str, required=True, help="Chemin du modèle à évaluer")
400
+ parser.add_argument("--samples", type=int, default=100, help="Nombre d'exemples à évaluer")
401
+ parser.add_argument("--output", type=str, default="eval_results.json", help="Fichier de sortie")
402
+
403
+ args = parser.parse_args()
404
+
405
+ run_evaluation(
406
+ model_path=args.model,
407
+ num_samples=args.samples,
408
+ output_file=args.output,
409
+ )