AIencoder commited on
Commit
964120b
·
verified ·
1 Parent(s): 816c4d5

Create forgekit/ai_advisor.py

Browse files
Files changed (1) hide show
  1. forgekit/ai_advisor.py +224 -0
forgekit/ai_advisor.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AI-powered merge advisor using HuggingFace Inference API."""
2
+
3
+ import json
4
+ import requests
5
+ from typing import Optional
6
+
7
+ HF_INFERENCE_URL = "https://api-inference.huggingface.co/models"
8
+ DEFAULT_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
9
+
10
+
11
+ def _query_llm(
12
+ prompt: str,
13
+ system: str = "",
14
+ model: str = DEFAULT_MODEL,
15
+ token: Optional[str] = None,
16
+ max_tokens: int = 800,
17
+ ) -> str:
18
+ """Query an LLM via HF Inference API.
19
+
20
+ Args:
21
+ prompt: User message
22
+ system: System prompt
23
+ model: HF model ID for inference
24
+ token: HF API token (recommended for higher rate limits)
25
+ max_tokens: Max response length
26
+
27
+ Returns:
28
+ Generated text response
29
+ """
30
+ headers = {"Content-Type": "application/json"}
31
+ if token:
32
+ headers["Authorization"] = f"Bearer {token}"
33
+
34
+ # Format as chat messages
35
+ messages = []
36
+ if system:
37
+ messages.append({"role": "system", "content": system})
38
+ messages.append({"role": "user", "content": prompt})
39
+
40
+ payload = {
41
+ "inputs": _format_chat(messages, model),
42
+ "parameters": {
43
+ "max_new_tokens": max_tokens,
44
+ "temperature": 0.7,
45
+ "do_sample": True,
46
+ "return_full_text": False,
47
+ },
48
+ }
49
+
50
+ try:
51
+ resp = requests.post(
52
+ f"{HF_INFERENCE_URL}/{model}",
53
+ headers=headers,
54
+ json=payload,
55
+ timeout=60,
56
+ )
57
+
58
+ if resp.status_code == 503:
59
+ # Model loading
60
+ return "⏳ The AI model is loading (this can take 1-2 minutes on first use). Please try again shortly."
61
+
62
+ if resp.status_code == 429:
63
+ return "⚠️ Rate limited — please wait a moment and try again, or add your HF token for higher limits."
64
+
65
+ if resp.status_code != 200:
66
+ return f"⚠️ AI service returned status {resp.status_code}. Try again or add an HF token."
67
+
68
+ data = resp.json()
69
+ if isinstance(data, list) and len(data) > 0:
70
+ text = data[0].get("generated_text", "")
71
+ # Clean up any leftover template tokens
72
+ for tag in ["</s>", "<|im_end|>", "<|eot_id|>", "[/INST]"]:
73
+ text = text.replace(tag, "")
74
+ return text.strip()
75
+
76
+ return "⚠️ No response generated. The model may be overloaded — try again."
77
+
78
+ except requests.exceptions.Timeout:
79
+ return "⚠️ Request timed out. The model may be loading — try again in a minute."
80
+ except Exception as e:
81
+ return f"⚠️ Error: {str(e)}"
82
+
83
+
84
+ def _format_chat(messages: list[dict], model: str) -> str:
85
+ """Format messages into the model's expected chat template."""
86
+ # Mistral Instruct format
87
+ if "mistral" in model.lower() or "mixtral" in model.lower():
88
+ parts = []
89
+ for msg in messages:
90
+ if msg["role"] == "system":
91
+ parts.append(f"[INST] {msg['content']}\n")
92
+ elif msg["role"] == "user":
93
+ if parts:
94
+ parts.append(f"{msg['content']} [/INST]")
95
+ else:
96
+ parts.append(f"[INST] {msg['content']} [/INST]")
97
+ return "".join(parts)
98
+
99
+ # Generic ChatML fallback
100
+ parts = []
101
+ for msg in messages:
102
+ parts.append(f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>")
103
+ parts.append("<|im_start|>assistant\n")
104
+ return "\n".join(parts)
105
+
106
+
107
+ # ===== AI FEATURES =====
108
+
109
+ ADVISOR_SYSTEM = """You are ForgeKit AI, an expert assistant for merging large language models. You have deep knowledge of mergekit, model architectures, merge methods (DARE-TIES, TIES, SLERP, Linear, Task Arithmetic, Passthrough), and best practices for creating high-quality merged models.
110
+
111
+ Be concise, practical, and specific. Give actionable recommendations with concrete numbers (weights, densities). Format your response with clear sections using markdown."""
112
+
113
+
114
+ def merge_advisor(
115
+ models_text: str,
116
+ goal: str = "",
117
+ token: Optional[str] = None,
118
+ ) -> str:
119
+ """AI recommends the best merge method, weights, and configuration.
120
+
121
+ Args:
122
+ models_text: Newline-separated model IDs
123
+ goal: What the user wants the merged model to do
124
+ token: HF API token
125
+
126
+ Returns:
127
+ AI recommendation as markdown
128
+ """
129
+ models = [m.strip() for m in models_text.strip().split("\n") if m.strip()]
130
+ if len(models) < 2:
131
+ return "⚠️ Add at least 2 models to get a recommendation."
132
+
133
+ models_str = "\n".join(f"- {m}" for m in models)
134
+ goal_str = f"\n\nUser's goal: {goal}" if goal.strip() else ""
135
+
136
+ prompt = f"""I want to merge these models:
137
+ {models_str}
138
+ {goal_str}
139
+
140
+ Recommend:
141
+ 1. **Best merge method** and why (DARE-TIES, SLERP, Linear, TIES, Task Arithmetic, or Passthrough)
142
+ 2. **Optimal weights** for each model (with reasoning)
143
+ 3. **Density values** if applicable
144
+ 4. **Which model to use as base** and why
145
+ 5. **Which tokenizer** to keep
146
+ 6. **Any warnings** or tips specific to these models
147
+
148
+ Be specific with numbers and keep it practical."""
149
+
150
+ return _query_llm(prompt, system=ADVISOR_SYSTEM, token=token)
151
+
152
+
153
+ def model_describer(
154
+ models_text: str,
155
+ method: str = "",
156
+ weights_text: str = "",
157
+ token: Optional[str] = None,
158
+ ) -> str:
159
+ """AI explains what the merged model will be good at.
160
+
161
+ Args:
162
+ models_text: Newline-separated model IDs
163
+ method: Merge method being used
164
+ weights_text: Comma-separated weights
165
+ token: HF API token
166
+
167
+ Returns:
168
+ AI description of expected capabilities
169
+ """
170
+ models = [m.strip() for m in models_text.strip().split("\n") if m.strip()]
171
+ if not models:
172
+ return "⚠️ Add models first."
173
+
174
+ models_str = "\n".join(f"- {m}" for m in models)
175
+ method_str = f" using {method}" if method else ""
176
+ weights_str = f"\nWeights: {weights_text}" if weights_text.strip() else ""
177
+
178
+ prompt = f"""I'm merging these models{method_str}:
179
+ {models_str}{weights_str}
180
+
181
+ Based on what each source model is known for, describe:
182
+ 1. **What the merged model will excel at** (specific tasks/benchmarks)
183
+ 2. **What it might struggle with** compared to the source models
184
+ 3. **Ideal use cases** for this merge
185
+ 4. **Expected quality** compared to each individual model
186
+ 5. **A creative name suggestion** for this merge
187
+
188
+ Keep it concise and practical."""
189
+
190
+ return _query_llm(prompt, system=ADVISOR_SYSTEM, token=token)
191
+
192
+
193
+ def config_explainer(
194
+ yaml_config: str,
195
+ token: Optional[str] = None,
196
+ ) -> str:
197
+ """AI explains a YAML merge config in plain English.
198
+
199
+ Args:
200
+ yaml_config: The YAML configuration string
201
+ token: HF API token
202
+
203
+ Returns:
204
+ Plain English explanation
205
+ """
206
+ if not yaml_config.strip() or yaml_config.startswith("# Add"):
207
+ return "⚠️ Generate a YAML config first."
208
+
209
+ prompt = f"""Explain this mergekit YAML configuration in plain English. Break it down so someone new to model merging can understand exactly what will happen:
210
+
211
+ ```yaml
212
+ {yaml_config}
213
+ ```
214
+
215
+ Explain:
216
+ 1. **What this config does** in simple terms
217
+ 2. **Why these specific settings** were chosen (method, weights, density)
218
+ 3. **What the output model will be like**
219
+ 4. **Any potential issues** to watch out for
220
+ 5. **Estimated resource requirements** (RAM, time)
221
+
222
+ Be clear and beginner-friendly."""
223
+
224
+ return _query_llm(prompt, system=ADVISOR_SYSTEM, token=token)