| | """Merge configuration YAML generator with presets and validation.""" |
| |
|
| | from dataclasses import dataclass, field |
| | from typing import Optional |
| | import yaml |
| |
|
| |
|
| | |
| |
|
| | MERGE_METHODS = { |
| | "dare_ties": { |
| | "name": "DARE-TIES", |
| | "description": "Drop And REscale with TIES — trims low-magnitude parameters and resolves sign conflicts. Best for combining 2+ specialist models.", |
| | "min_models": 2, |
| | "max_models": 10, |
| | "needs_base": True, |
| | "params": ["weight", "density"], |
| | "global_params": ["int8_mask", "normalize"], |
| | "supports_slices": True, |
| | }, |
| | "ties": { |
| | "name": "TIES", |
| | "description": "Trim, Elect Sign, Merge — resolves parameter interference between models. Similar to DARE-TIES but without the drop step.", |
| | "min_models": 2, |
| | "max_models": 10, |
| | "needs_base": True, |
| | "params": ["weight", "density"], |
| | "global_params": ["int8_mask", "normalize"], |
| | "supports_slices": True, |
| | }, |
| | "slerp": { |
| | "name": "SLERP", |
| | "description": "Spherical Linear Interpolation — smoothly blends two models along a curved path in weight space. Best for two-model merges.", |
| | "min_models": 2, |
| | "max_models": 2, |
| | "needs_base": False, |
| | "params": [], |
| | "global_params": ["t"], |
| | "supports_slices": True, |
| | }, |
| | "linear": { |
| | "name": "Linear", |
| | "description": "Simple weighted average of model parameters. Fast and predictable baseline.", |
| | "min_models": 2, |
| | "max_models": 10, |
| | "needs_base": False, |
| | "params": ["weight"], |
| | "global_params": ["normalize"], |
| | "supports_slices": True, |
| | }, |
| | "task_arithmetic": { |
| | "name": "Task Arithmetic", |
| | "description": "Add or subtract task vectors from a base model. Use negative weights to remove capabilities.", |
| | "min_models": 1, |
| | "max_models": 10, |
| | "needs_base": True, |
| | "params": ["weight"], |
| | "global_params": [], |
| | "supports_slices": False, |
| | }, |
| | "passthrough": { |
| | "name": "Passthrough (Frankenmerge)", |
| | "description": "Stack layers from different models. Can create larger models from smaller ones. Supports different layer counts.", |
| | "min_models": 1, |
| | "max_models": 10, |
| | "needs_base": False, |
| | "params": [], |
| | "global_params": [], |
| | "supports_slices": True, |
| | "requires_slices": True, |
| | }, |
| | } |
| |
|
| |
|
| | |
| |
|
| | @dataclass |
| | class MergePreset: |
| | name: str |
| | description: str |
| | method: str |
| | weight_strategy: str |
| |
|
| | def apply(self, model_ids: list[str]) -> tuple[list[float], list[float]]: |
| | """Generate weights and densities for given models.""" |
| | n = len(model_ids) |
| | if n == 0: |
| | return [], [] |
| |
|
| | if self.weight_strategy == "equal": |
| | weights = [round(1.0 / n, 3)] * n |
| | densities = [0.6] * n |
| |
|
| | elif self.weight_strategy == "first_dominant": |
| | weights = [0.6] + [round(0.4 / (n - 1), 3)] * (n - 1) if n > 1 else [1.0] |
| | densities = [0.7] + [0.5] * (n - 1) |
| |
|
| | elif self.weight_strategy == "last_dominant": |
| | weights = [round(0.4 / (n - 1), 3)] * (n - 1) + [0.6] if n > 1 else [1.0] |
| | densities = [0.5] * (n - 1) + [0.7] |
| |
|
| | elif self.weight_strategy == "auto_detect": |
| | weights, densities = _auto_detect_weights(model_ids) |
| |
|
| | else: |
| | weights = [round(1.0 / n, 3)] * n |
| | densities = [0.6] * n |
| |
|
| | return weights, densities |
| |
|
| |
|
| | def _auto_detect_weights(model_ids: list[str]) -> tuple[list[float], list[float]]: |
| | """Auto-detect optimal weights based on model names/tags.""" |
| | n = len(model_ids) |
| | weights = [] |
| | densities = [] |
| |
|
| | for mid in model_ids: |
| | name = mid.lower() |
| | if "code" in name or "coder" in name: |
| | weights.append(0.5) |
| | densities.append(0.7) |
| | elif "math" in name: |
| | weights.append(0.4) |
| | densities.append(0.6) |
| | elif "instruct" in name and "code" not in name: |
| | weights.append(0.3) |
| | densities.append(0.5) |
| | else: |
| | weights.append(0.3) |
| | densities.append(0.5) |
| |
|
| | |
| | total = sum(weights) |
| | if total > 0: |
| | weights = [round(w / total, 3) for w in weights] |
| |
|
| | return weights, densities |
| |
|
| |
|
| | PRESETS = { |
| | "equal": MergePreset("Equal", "Equal weights for all models", "dare_ties", "equal"), |
| | "first_dominant": MergePreset("First Model Dominant", "Prioritize the first model", "dare_ties", "first_dominant"), |
| | "last_dominant": MergePreset("Last Model Dominant", "Prioritize the last model", "dare_ties", "last_dominant"), |
| | "coding_focus": MergePreset("Coding Focus", "Higher weight for code-related models", "dare_ties", "auto_detect"), |
| | "balanced_slerp": MergePreset("Balanced SLERP", "50/50 interpolation between two models", "slerp", "equal"), |
| | } |
| |
|
| |
|
| | |
| |
|
| | @dataclass |
| | class MergeConfig: |
| | """Complete merge configuration.""" |
| | method: str = "dare_ties" |
| | models: list[str] = field(default_factory=list) |
| | base_model: str = "" |
| | weights: list[float] = field(default_factory=list) |
| | densities: list[float] = field(default_factory=list) |
| | tokenizer_source: str = "" |
| | dtype: str = "bfloat16" |
| |
|
| | |
| | slerp_t: float = 0.5 |
| | int8_mask: bool = True |
| | normalize: bool = True |
| |
|
| | |
| | slices: list[dict] = field(default_factory=list) |
| |
|
| | |
| | output_name: str = "" |
| |
|
| | def validate(self) -> list[str]: |
| | """Validate the configuration. Returns list of error messages.""" |
| | errors = [] |
| | method_info = MERGE_METHODS.get(self.method) |
| |
|
| | if not method_info: |
| | errors.append(f"Unknown merge method: {self.method}") |
| | return errors |
| |
|
| | n = len(self.models) |
| | if n < method_info["min_models"]: |
| | errors.append(f"{method_info['name']} requires at least {method_info['min_models']} models") |
| | if n > method_info["max_models"]: |
| | errors.append(f"{method_info['name']} supports at most {method_info['max_models']} models") |
| |
|
| | if method_info["needs_base"] and not self.base_model: |
| | errors.append(f"{method_info['name']} requires a base_model") |
| |
|
| | if "weight" in method_info["params"]: |
| | if self.weights and len(self.weights) != n: |
| | errors.append(f"Expected {n} weights, got {len(self.weights)}") |
| | if self.weights and any(w < -1 or w > 2 for w in self.weights): |
| | errors.append("Weights should be between -1 and 2") |
| |
|
| | if "density" in method_info["params"]: |
| | if self.densities and len(self.densities) != n: |
| | errors.append(f"Expected {n} densities, got {len(self.densities)}") |
| | if self.densities and any(d < 0 or d > 1 for d in self.densities): |
| | errors.append("Densities must be between 0 and 1") |
| |
|
| | if self.method == "slerp" and (self.slerp_t < 0 or self.slerp_t > 1): |
| | errors.append("SLERP t parameter must be between 0 and 1") |
| |
|
| | if method_info.get("requires_slices") and not self.slices: |
| | errors.append(f"{method_info['name']} requires slice definitions") |
| |
|
| | return errors |
| |
|
| |
|
| | def generate_yaml(config: MergeConfig) -> str: |
| | """Generate mergekit-compatible YAML configuration. |
| | |
| | Args: |
| | config: MergeConfig with all parameters |
| | |
| | Returns: |
| | YAML string ready for mergekit |
| | """ |
| | errors = config.validate() |
| | if errors: |
| | return f"# VALIDATION ERRORS:\n" + "\n".join(f"# - {e}" for e in errors) |
| |
|
| | method_info = MERGE_METHODS[config.method] |
| | doc = {} |
| |
|
| | |
| | if config.method == "passthrough": |
| | doc["slices"] = config.slices or _default_slices(config) |
| | doc["merge_method"] = config.method |
| | doc["dtype"] = config.dtype |
| | return yaml.dump(doc, default_flow_style=False, sort_keys=False) |
| |
|
| | |
| | doc["merge_method"] = config.method |
| |
|
| | if method_info["needs_base"]: |
| | doc["base_model"] = config.base_model |
| |
|
| | |
| | if config.method == "slerp": |
| | doc["models"] = [{"model": m} for m in config.models] |
| | doc["parameters"] = {"t": config.slerp_t} |
| | else: |
| | models_list = [] |
| | for i, model_id in enumerate(config.models): |
| | entry = {"model": model_id} |
| | params = {} |
| | if "weight" in method_info["params"] and config.weights: |
| | params["weight"] = config.weights[i] |
| | if "density" in method_info["params"] and config.densities: |
| | params["density"] = config.densities[i] |
| | if params: |
| | entry["parameters"] = params |
| | models_list.append(entry) |
| | doc["models"] = models_list |
| |
|
| | |
| | global_params = {} |
| | if "int8_mask" in method_info.get("global_params", []): |
| | global_params["int8_mask"] = config.int8_mask |
| | if "normalize" in method_info.get("global_params", []): |
| | global_params["normalize"] = config.normalize |
| |
|
| | if global_params: |
| | doc["parameters"] = global_params |
| |
|
| | doc["dtype"] = config.dtype |
| |
|
| | if config.tokenizer_source: |
| | doc["tokenizer_source"] = config.tokenizer_source |
| |
|
| | return yaml.dump(doc, default_flow_style=False, sort_keys=False) |
| |
|
| |
|
| | def _default_slices(config: MergeConfig) -> list[dict]: |
| | """Generate default slice config for passthrough merges.""" |
| | slices = [] |
| | for model_id in config.models: |
| | slices.append({ |
| | "sources": [{"model": model_id, "layer_range": [0, 32]}] |
| | }) |
| | return slices |
| |
|
| |
|
| | def generate_from_preset( |
| | preset_name: str, |
| | model_ids: list[str], |
| | base_model: str = "", |
| | tokenizer_source: str = "", |
| | dtype: str = "bfloat16", |
| | ) -> str: |
| | """Quick config generation from a preset name. |
| | |
| | Args: |
| | preset_name: Key from PRESETS dict |
| | model_ids: List of model IDs to merge |
| | base_model: Base model for methods that need one |
| | tokenizer_source: Which model's tokenizer to use |
| | dtype: Data type for merge |
| | |
| | Returns: |
| | YAML string |
| | """ |
| | preset = PRESETS.get(preset_name) |
| | if not preset: |
| | return f"# Unknown preset: {preset_name}\n# Available: {', '.join(PRESETS.keys())}" |
| |
|
| | weights, densities = preset.apply(model_ids) |
| |
|
| | config = MergeConfig( |
| | method=preset.method, |
| | models=model_ids, |
| | base_model=base_model or (model_ids[0] if model_ids else ""), |
| | weights=weights, |
| | densities=densities, |
| | tokenizer_source=tokenizer_source or base_model or (model_ids[0] if model_ids else ""), |
| | dtype=dtype, |
| | ) |
| |
|
| | return generate_yaml(config) |
| |
|
| |
|
| | def get_method_info(method: str) -> dict: |
| | """Get human-readable info about a merge method.""" |
| | return MERGE_METHODS.get(method, {"name": "Unknown", "description": "Unknown method"}) |
| |
|