| | """HuggingFace Hub API wrapper for model discovery and info retrieval.""" |
| |
|
| | import json |
| | import time |
| | from dataclasses import dataclass, field |
| | from typing import Optional |
| | from functools import lru_cache |
| |
|
| | import requests |
| |
|
| | HF_API = "https://huggingface.co/api" |
| | _session = requests.Session() |
| | _session.headers.update({"Accept": "application/json"}) |
| |
|
| | |
| | _cache: dict[str, tuple[float, any]] = {} |
| | CACHE_TTL = 300 |
| |
|
| |
|
| | def _cached_get(url: str, token: Optional[str] = None, ttl: int = CACHE_TTL) -> dict: |
| | """GET with caching and rate-limit handling.""" |
| | now = time.time() |
| | if url in _cache and (now - _cache[url][0]) < ttl: |
| | return _cache[url][1] |
| |
|
| | headers = {} |
| | if token: |
| | headers["Authorization"] = f"Bearer {token}" |
| |
|
| | resp = _session.get(url, headers=headers, timeout=15) |
| |
|
| | if resp.status_code == 429: |
| | retry = int(resp.headers.get("Retry-After", 5)) |
| | time.sleep(retry) |
| | resp = _session.get(url, headers=headers, timeout=15) |
| |
|
| | resp.raise_for_status() |
| | data = resp.json() |
| | _cache[url] = (now, data) |
| | return data |
| |
|
| |
|
| | @dataclass |
| | class ModelInfo: |
| | """Parsed model information from HF Hub.""" |
| | model_id: str |
| | model_type: str = "unknown" |
| | architectures: list[str] = field(default_factory=list) |
| | vocab_size: int = 0 |
| | hidden_size: int = 0 |
| | intermediate_size: int = 0 |
| | num_hidden_layers: int = 0 |
| | num_attention_heads: int = 0 |
| | num_key_value_heads: int = 0 |
| | max_position_embeddings: int = 0 |
| | torch_dtype: str = "unknown" |
| | pipeline_tag: str = "" |
| | tags: list[str] = field(default_factory=list) |
| | downloads: int = 0 |
| | likes: int = 0 |
| | size_bytes: int = 0 |
| | gated: bool = False |
| | private: bool = False |
| | trust_remote_code: bool = False |
| | error: Optional[str] = None |
| |
|
| | @property |
| | def param_estimate(self) -> str: |
| | """Rough parameter count estimate based on architecture.""" |
| | if self.size_bytes > 0: |
| | |
| | params = self.size_bytes / 2 |
| | if params > 1e9: |
| | return f"{params/1e9:.1f}B" |
| | elif params > 1e6: |
| | return f"{params/1e6:.0f}M" |
| | return "unknown" |
| |
|
| | @property |
| | def arch_signature(self) -> str: |
| | """Unique signature for architecture matching.""" |
| | return f"{self.model_type}|{self.hidden_size}|{self.intermediate_size}" |
| |
|
| | @property |
| | def display_name(self) -> str: |
| | """Short display name (without org prefix).""" |
| | return self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id |
| |
|
| | @property |
| | def ram_estimate_gb(self) -> float: |
| | """Estimated RAM needed for merging (roughly 2.5x model size for bf16 merge).""" |
| | if self.size_bytes > 0: |
| | return round(self.size_bytes * 2.5 / (1024**3), 1) |
| | return 0.0 |
| |
|
| | def to_dict(self) -> dict: |
| | return { |
| | "model_id": self.model_id, |
| | "model_type": self.model_type, |
| | "architectures": self.architectures, |
| | "vocab_size": self.vocab_size, |
| | "hidden_size": self.hidden_size, |
| | "intermediate_size": self.intermediate_size, |
| | "num_hidden_layers": self.num_hidden_layers, |
| | "num_attention_heads": self.num_attention_heads, |
| | "torch_dtype": self.torch_dtype, |
| | "pipeline_tag": self.pipeline_tag, |
| | "downloads": self.downloads, |
| | "likes": self.likes, |
| | "param_estimate": self.param_estimate, |
| | "ram_estimate_gb": self.ram_estimate_gb, |
| | "gated": self.gated, |
| | "private": self.private, |
| | } |
| |
|
| |
|
| | def fetch_model_info(model_id: str, token: Optional[str] = None) -> ModelInfo: |
| | """Fetch comprehensive model information from HF Hub. |
| | |
| | Args: |
| | model_id: Full model ID (e.g., "Qwen/Qwen2.5-Coder-7B-Instruct") |
| | token: Optional HF API token for gated/private models |
| | |
| | Returns: |
| | ModelInfo dataclass with all available information |
| | """ |
| | info = ModelInfo(model_id=model_id) |
| |
|
| | |
| | try: |
| | data = _cached_get(f"{HF_API}/models/{model_id}", token=token) |
| | except requests.exceptions.HTTPError as e: |
| | if e.response.status_code == 401: |
| | info.error = "Gated or private model — HF token required" |
| | info.gated = True |
| | elif e.response.status_code == 404: |
| | info.error = f"Model not found: {model_id}" |
| | else: |
| | info.error = f"API error: {e.response.status_code}" |
| | return info |
| | except Exception as e: |
| | info.error = f"Connection error: {str(e)}" |
| | return info |
| |
|
| | |
| | info.pipeline_tag = data.get("pipeline_tag", "") |
| | info.tags = data.get("tags", []) |
| | info.downloads = data.get("downloads", 0) |
| | info.likes = data.get("likes", 0) |
| | info.gated = data.get("gated", False) not in (False, None) |
| | info.private = data.get("private", False) |
| |
|
| | |
| | config = data.get("config", {}) |
| | if config: |
| | info.model_type = config.get("model_type", "unknown") |
| | info.architectures = config.get("architectures", []) |
| |
|
| | |
| | |
| | try: |
| | full_config = _cached_get( |
| | f"https://huggingface.co/{model_id}/resolve/main/config.json", |
| | token=token, |
| | ) |
| | info.model_type = full_config.get("model_type", info.model_type) |
| | info.architectures = full_config.get("architectures", info.architectures) |
| | info.vocab_size = full_config.get("vocab_size", 0) |
| | info.hidden_size = full_config.get("hidden_size", 0) |
| | info.intermediate_size = full_config.get("intermediate_size", 0) |
| | info.num_hidden_layers = full_config.get("num_hidden_layers", 0) |
| | info.num_attention_heads = full_config.get("num_attention_heads", 0) |
| | info.num_key_value_heads = full_config.get("num_key_value_heads", 0) |
| | info.max_position_embeddings = full_config.get("max_position_embeddings", 0) |
| | info.torch_dtype = full_config.get("torch_dtype", "unknown") |
| |
|
| | if "auto_map" in full_config: |
| | info.trust_remote_code = True |
| | except Exception: |
| | |
| | if config: |
| | info.vocab_size = config.get("vocab_size", 0) |
| | info.hidden_size = config.get("hidden_size", 0) |
| | else: |
| | info.error = "Could not fetch config.json — model may need trust_remote_code=True" |
| | info.trust_remote_code = True |
| |
|
| | |
| | siblings = data.get("siblings", []) |
| | total_size = 0 |
| | for f in siblings: |
| | fname = f.get("rfilename", "") |
| | size = f.get("size", 0) or 0 |
| | |
| | if any(fname.endswith(ext) for ext in |
| | [".safetensors", ".bin", ".pt", ".pth", ".gguf"]): |
| | total_size += size |
| | info.size_bytes = total_size |
| |
|
| | return info |
| |
|
| |
|
| | def search_models( |
| | query: str = "", |
| | author: str = "", |
| | architecture: str = "", |
| | limit: int = 20, |
| | sort: str = "downloads", |
| | token: Optional[str] = None, |
| | ) -> list[dict]: |
| | """Search HuggingFace Hub for models. |
| | |
| | Args: |
| | query: Search query string |
| | author: Filter by author/organization |
| | architecture: Filter by model_type (e.g., "llama", "qwen2") |
| | limit: Max results to return |
| | sort: Sort by "downloads", "likes", "created", "modified" |
| | token: Optional HF API token |
| | |
| | Returns: |
| | List of dicts with basic model info |
| | """ |
| | params = { |
| | "limit": min(limit, 100), |
| | "sort": sort, |
| | "direction": -1, |
| | "config": True, |
| | } |
| | if query: |
| | params["search"] = query |
| | if author: |
| | params["author"] = author |
| |
|
| | url = f"{HF_API}/models" |
| | try: |
| | data = _cached_get( |
| | f"{url}?{'&'.join(f'{k}={v}' for k, v in params.items())}", |
| | token=token, |
| | ttl=60, |
| | ) |
| | except Exception as e: |
| | return [{"error": str(e)}] |
| |
|
| | results = [] |
| | for m in data: |
| | config = m.get("config", {}) or {} |
| | model_type = config.get("model_type", "") |
| |
|
| | |
| | if architecture and model_type.lower() != architecture.lower(): |
| | continue |
| |
|
| | results.append({ |
| | "model_id": m.get("modelId", ""), |
| | "model_type": model_type, |
| | "pipeline_tag": m.get("pipeline_tag", ""), |
| | "downloads": m.get("downloads", 0), |
| | "likes": m.get("likes", 0), |
| | "tags": m.get("tags", [])[:5], |
| | }) |
| |
|
| | return results[:limit] |
| |
|
| |
|
| | def get_popular_base_models(architecture: str = "", token: Optional[str] = None) -> list[dict]: |
| | """Get popular base models for a given architecture type. |
| | |
| | Useful for suggesting base_model in merge configs. |
| | """ |
| | |
| | known_bases = { |
| | "llama": [ |
| | "meta-llama/Llama-3.1-8B-Instruct", |
| | "meta-llama/Llama-3.1-70B-Instruct", |
| | "meta-llama/Llama-2-7b-hf", |
| | ], |
| | "mistral": [ |
| | "mistralai/Mistral-7B-Instruct-v0.3", |
| | "mistralai/Mixtral-8x7B-Instruct-v0.1", |
| | ], |
| | "qwen2": [ |
| | "Qwen/Qwen2.5-7B-Instruct", |
| | "Qwen/Qwen2.5-14B-Instruct", |
| | "Qwen/Qwen2.5-3B-Instruct", |
| | "Qwen/Qwen2.5-72B-Instruct", |
| | ], |
| | "gemma2": [ |
| | "google/gemma-2-9b-it", |
| | "google/gemma-2-27b-it", |
| | ], |
| | "phi3": [ |
| | "microsoft/Phi-3-mini-4k-instruct", |
| | "microsoft/Phi-3-medium-4k-instruct", |
| | ], |
| | } |
| |
|
| | if architecture.lower() in known_bases: |
| | return [{"model_id": m} for m in known_bases[architecture.lower()]] |
| |
|
| | |
| | return search_models( |
| | query=f"{architecture} instruct", |
| | limit=5, |
| | sort="downloads", |
| | token=token, |
| | ) |
| |
|