Jayashree Sridhar commited on
Commit
292f6f6
·
1 Parent(s): 005cc1a

refactore the code files to use TinyGPT2Model

Browse files
agents/tools/llm_tools.py CHANGED
@@ -1,12 +1,13 @@
1
  """
2
  Mistral LLM Tools for CrewAI (modular class version)
3
  """
4
- from models.mistral_model import MistralModel
 
5
 
6
  class LLMTools:
7
  def __init__(self, config=None):
8
  self.config = config
9
- self.model = MistralModel()
10
 
11
  def mistral_chat(self, prompt: str, context: dict = None) -> str:
12
  """Chat with Mistral AI for intelligent responses."""
 
1
  """
2
  Mistral LLM Tools for CrewAI (modular class version)
3
  """
4
+ #from models.mistral_model import MistralModel
5
+ from models.tinygpt2_model import TinyGPT2Model
6
 
7
  class LLMTools:
8
  def __init__(self, config=None):
9
  self.config = config
10
+ self.model =TinyGPT2Model()
11
 
12
  def mistral_chat(self, prompt: str, context: dict = None) -> str:
13
  """Chat with Mistral AI for intelligent responses."""
agents/tools/voice_tools.py CHANGED
@@ -5,7 +5,8 @@ from transformers import pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq
5
  import asyncio
6
  import soundfile as sf
7
  import tempfile # Added the import for tempfile!
8
- from models.mistral_model import MistralModel
 
9
 
10
  class MultilingualVoiceProcessor:
11
  def __init__(self, model_name="openai/whisper-base", device=None):
@@ -48,7 +49,7 @@ class VoiceTools:
48
  return {"text": text, "language": detected_lang}
49
 
50
  def detect_emotion(self, text: str) -> dict:
51
- model = MistralModel()
52
  prompt = f"""
53
  Analyze the emotional state in this text: "{text}"
54
  Identify:
 
5
  import asyncio
6
  import soundfile as sf
7
  import tempfile # Added the import for tempfile!
8
+ #from models.mistral_model import MistralModel
9
+ from models.tinygpt2_model import TinyGPT2Model
10
 
11
  class MultilingualVoiceProcessor:
12
  def __init__(self, model_name="openai/whisper-base", device=None):
 
49
  return {"text": text, "language": detected_lang}
50
 
51
  def detect_emotion(self, text: str) -> dict:
52
+ model = TinyGPT2Model()
53
  prompt = f"""
54
  Analyze the emotional state in this text: "{text}"
55
  Identify:
agents/tools/voice_tools_openaiwhisper.py CHANGED
@@ -3,7 +3,7 @@ Multilingual Voice Processing Tools - modular class version
3
  """
4
  import numpy as np
5
  import asyncio
6
- from models.mistral_model import MistralModel
7
  import whisper
8
  import numpy as np
9
  from gtts import gTTS
@@ -15,6 +15,7 @@ from typing import Tuple, Optional
15
  import speech_recognition as sr
16
  from transformers import pipeline
17
  import whisper
 
18
 
19
  # class MultilingualVoiceProcessor:
20
  # """Handles multilingual STT and TTS"""
@@ -135,7 +136,7 @@ class VoiceTools:
135
 
136
  def detect_emotion(self, text: str) -> dict:
137
  """Detect emotional state from text using LLM."""
138
- model = MistralModel()
139
  prompt = f"""
140
  Analyze the emotional state in this text: "{text}"
141
  Identify:
 
3
  """
4
  import numpy as np
5
  import asyncio
6
+ #from models.mistral_model import MistralModel
7
  import whisper
8
  import numpy as np
9
  from gtts import gTTS
 
15
  import speech_recognition as sr
16
  from transformers import pipeline
17
  import whisper
18
+ from models.tinygpt2_model import TinyGPT2Model
19
 
20
  # class MultilingualVoiceProcessor:
21
  # """Handles multilingual STT and TTS"""
 
136
 
137
  def detect_emotion(self, text: str) -> dict:
138
  """Detect emotional state from text using LLM."""
139
+ model = TinyGPT2Model()
140
  prompt = f"""
141
  Analyze the emotional state in this text: "{text}"
142
  Identify:
models/_init_.py CHANGED
@@ -12,6 +12,7 @@ __version__ = "1.0.0"
12
  # Lazy imports
13
  if TYPE_CHECKING:
14
  from .mistral_model import MistralModel, MistralConfig, MistralPromptFormatter
 
15
 
16
  # Public API
17
  __all__ = [
@@ -19,12 +20,13 @@ __all__ = [
19
  "MistralModel",
20
  "MistralConfig",
21
  "MistralPromptFormatter",
22
-
 
23
  # Model management
24
  "load_model",
25
  "get_model_info",
26
  "clear_model_cache",
27
-
28
  # Constants
29
  "AVAILABLE_MODELS",
30
  "MODEL_REQUIREMENTS",
@@ -46,6 +48,13 @@ AVAILABLE_MODELS = {
46
  "size": "7B",
47
  "context_length": 32768,
48
  "languages": ["multilingual"]
 
 
 
 
 
 
 
49
  }
50
  }
51
 
@@ -56,19 +65,25 @@ MODEL_REQUIREMENTS = {
56
  "vram": "8GB (GPU) or 16GB (CPU)",
57
  "disk": "15GB",
58
  "compute": "GPU recommended"
 
 
 
 
 
 
59
  }
60
  }
61
 
62
- # Default configuration
63
  DEFAULT_MODEL_CONFIG = {
64
- "max_length": 2048,
65
  "temperature": 0.7,
66
  "top_p": 0.95,
67
  "top_k": 50,
68
  "do_sample": True,
69
  "num_return_sequences": 1,
70
- "device": "cuda" if torch.cuda.is_available() else "cpu",
71
- "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
72
  "load_in_8bit": False,
73
  "cache_dir": ".cache/models"
74
  }
@@ -76,10 +91,10 @@ DEFAULT_MODEL_CONFIG = {
76
  # Model instance cache
77
  _model_cache: Dict[str, Any] = {}
78
 
79
- def load_model(model_name: str = "mistral-7b-instruct", config: Optional[Dict[str, Any]] = None):
80
  """
81
  Load a model with caching support
82
-
83
  Args:
84
  model_name: Name of the model to load
85
  config: Optional configuration override
@@ -91,41 +106,39 @@ def load_model(model_name: str = "mistral-7b-instruct", config: Optional[Dict[st
91
  cache_key = f"{model_name}_{str(config)}"
92
  if cache_key in _model_cache:
93
  return _model_cache[cache_key]
94
-
95
  # Import here to avoid circular imports
96
- from .mistral_model import MistralModel, MistralConfig
97
-
98
- # Get model info
99
- model_info = AVAILABLE_MODELS.get(model_name)
100
- if not model_info:
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  raise ValueError(f"Unknown model: {model_name}")
102
-
103
- # Merge configurations
104
- model_config = DEFAULT_MODEL_CONFIG.copy()
105
- if config:
106
- model_config.update(config)
107
-
108
- # Create config object
109
- mistral_config = MistralConfig(
110
- model_id=model_info["model_id"],
111
- **model_config
112
- )
113
-
114
- # Load model
115
- model = MistralModel(mistral_config)
116
-
117
  # Cache it
118
  _model_cache[cache_key] = model
119
-
120
  return model
121
 
122
  def get_model_info(model_name: str) -> Optional[Dict[str, Any]]:
123
  """
124
  Get information about a model
125
-
126
  Args:
127
  model_name: Name of the model
128
-
129
  Returns:
130
  Model information dictionary or None
131
  """
@@ -133,23 +146,24 @@ def get_model_info(model_name: str) -> Optional[Dict[str, Any]]:
133
  if info:
134
  # Add requirements
135
  requirements = MODEL_REQUIREMENTS.get(model_name, {})
 
136
  info["requirements"] = requirements
137
-
138
  # Add loading status
139
  cache_keys = [k for k in _model_cache.keys() if k.startswith(model_name)]
140
  info["is_loaded"] = len(cache_keys) > 0
141
-
142
  return info
143
 
144
  def clear_model_cache(model_name: Optional[str] = None):
145
  """
146
  Clear model cache to free memory
147
-
148
  Args:
149
  model_name: Specific model to clear, or None for all
150
  """
151
  global _model_cache
152
-
153
  if model_name:
154
  # Clear specific model
155
  keys_to_remove = [k for k in _model_cache.keys() if k.startswith(model_name)]
@@ -158,11 +172,11 @@ def clear_model_cache(model_name: Optional[str] = None):
158
  else:
159
  # Clear all
160
  _model_cache.clear()
161
-
162
  # Force garbage collection
163
  import gc
164
  gc.collect()
165
-
166
  # Clear GPU cache if using CUDA
167
  if torch.cuda.is_available():
168
  torch.cuda.empty_cache()
@@ -171,20 +185,25 @@ def clear_model_cache(model_name: Optional[str] = None):
171
  def estimate_memory_usage(model_name: str) -> Dict[str, Any]:
172
  """
173
  Estimate memory usage for a model
174
-
175
  Args:
176
  model_name: Name of the model
177
-
178
  Returns:
179
  Memory estimation dictionary
180
  """
181
  model_info = AVAILABLE_MODELS.get(model_name)
182
  if not model_info:
183
  return {}
184
-
185
  size = model_info.get("size", "7B")
186
- size_gb = float(size.replace("B", ""))
187
-
 
 
 
 
 
188
  estimates = {
189
  "model_size_gb": size_gb,
190
  "fp32_memory_gb": size_gb * 4, # 4 bytes per parameter
@@ -193,7 +212,7 @@ def estimate_memory_usage(model_name: str) -> Dict[str, Any]:
193
  "recommended_ram_gb": size_gb * 2.5,
194
  "recommended_vram_gb": size_gb * 1.5
195
  }
196
-
197
  return estimates
198
 
199
  def get_device_info() -> Dict[str, Any]:
@@ -204,14 +223,14 @@ def get_device_info() -> Dict[str, Any]:
204
  "current_device": torch.cuda.current_device() if torch.cuda.is_available() else None,
205
  "device_name": torch.cuda.get_device_name() if torch.cuda.is_available() else "CPU"
206
  }
207
-
208
  if torch.cuda.is_available():
209
  info["gpu_memory"] = {
210
  "allocated": torch.cuda.memory_allocated() / 1024**3, # GB
211
  "reserved": torch.cuda.memory_reserved() / 1024**3, # GB
212
  "total": torch.cuda.get_device_properties(0).total_memory / 1024**3 # GB
213
  }
214
-
215
  return info
216
 
217
  # Module initialization
 
12
  # Lazy imports
13
  if TYPE_CHECKING:
14
  from .mistral_model import MistralModel, MistralConfig, MistralPromptFormatter
15
+ from .tiny_gpt2_model import TinyGPT2Model
16
 
17
  # Public API
18
  __all__ = [
 
20
  "MistralModel",
21
  "MistralConfig",
22
  "MistralPromptFormatter",
23
+ "TinyGPT2Model",
24
+
25
  # Model management
26
  "load_model",
27
  "get_model_info",
28
  "clear_model_cache",
29
+
30
  # Constants
31
  "AVAILABLE_MODELS",
32
  "MODEL_REQUIREMENTS",
 
48
  "size": "7B",
49
  "context_length": 32768,
50
  "languages": ["multilingual"]
51
+ },
52
+ "tiny-gpt2": {
53
+ "model_id": "sshleifer/tiny-gpt2",
54
+ "type": "tiny",
55
+ "size": "small",
56
+ "context_length": 256,
57
+ "languages": ["en"]
58
  }
59
  }
60
 
 
65
  "vram": "8GB (GPU) or 16GB (CPU)",
66
  "disk": "15GB",
67
  "compute": "GPU recommended"
68
+ },
69
+ "tiny-gpt2": {
70
+ "ram": "≤1GB",
71
+ "vram": "CPU only",
72
+ "disk": "<1GB",
73
+ "compute": "CPU"
74
  }
75
  }
76
 
77
+ # Default configuration: Set to CPU/float32
78
  DEFAULT_MODEL_CONFIG = {
79
+ "max_length": 256,
80
  "temperature": 0.7,
81
  "top_p": 0.95,
82
  "top_k": 50,
83
  "do_sample": True,
84
  "num_return_sequences": 1,
85
+ "device": "cpu",
86
+ "torch_dtype": torch.float32,
87
  "load_in_8bit": False,
88
  "cache_dir": ".cache/models"
89
  }
 
91
  # Model instance cache
92
  _model_cache: Dict[str, Any] = {}
93
 
94
+ def load_model(model_name: str = "tiny-gpt2", config: Optional[Dict[str, Any]] = None):
95
  """
96
  Load a model with caching support
97
+
98
  Args:
99
  model_name: Name of the model to load
100
  config: Optional configuration override
 
106
  cache_key = f"{model_name}_{str(config)}"
107
  if cache_key in _model_cache:
108
  return _model_cache[cache_key]
109
+
110
  # Import here to avoid circular imports
111
+ if model_name == "tiny-gpt2":
112
+ from .tiny_gpt2_model import TinyGPT2Model
113
+ # No config needed for TinyGPT2, ignore config for now
114
+ model = TinyGPT2Model()
115
+ elif model_name in ["mistral-7b-instruct", "mistral-7b"]:
116
+ from .mistral_model import MistralModel, MistralConfig
117
+ model_info = AVAILABLE_MODELS.get(model_name)
118
+ if not model_info:
119
+ raise ValueError(f"Unknown model: {model_name}")
120
+ model_config = DEFAULT_MODEL_CONFIG.copy()
121
+ if config:
122
+ model_config.update(config)
123
+ mistral_config = MistralConfig(
124
+ model_id=model_info["model_id"],
125
+ **model_config
126
+ )
127
+ model = MistralModel(mistral_config)
128
+ else:
129
  raise ValueError(f"Unknown model: {model_name}")
130
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  # Cache it
132
  _model_cache[cache_key] = model
 
133
  return model
134
 
135
  def get_model_info(model_name: str) -> Optional[Dict[str, Any]]:
136
  """
137
  Get information about a model
138
+
139
  Args:
140
  model_name: Name of the model
141
+
142
  Returns:
143
  Model information dictionary or None
144
  """
 
146
  if info:
147
  # Add requirements
148
  requirements = MODEL_REQUIREMENTS.get(model_name, {})
149
+ info = info.copy() # avoid mutating global dict!
150
  info["requirements"] = requirements
151
+
152
  # Add loading status
153
  cache_keys = [k for k in _model_cache.keys() if k.startswith(model_name)]
154
  info["is_loaded"] = len(cache_keys) > 0
155
+
156
  return info
157
 
158
  def clear_model_cache(model_name: Optional[str] = None):
159
  """
160
  Clear model cache to free memory
161
+
162
  Args:
163
  model_name: Specific model to clear, or None for all
164
  """
165
  global _model_cache
166
+
167
  if model_name:
168
  # Clear specific model
169
  keys_to_remove = [k for k in _model_cache.keys() if k.startswith(model_name)]
 
172
  else:
173
  # Clear all
174
  _model_cache.clear()
175
+
176
  # Force garbage collection
177
  import gc
178
  gc.collect()
179
+
180
  # Clear GPU cache if using CUDA
181
  if torch.cuda.is_available():
182
  torch.cuda.empty_cache()
 
185
  def estimate_memory_usage(model_name: str) -> Dict[str, Any]:
186
  """
187
  Estimate memory usage for a model
188
+
189
  Args:
190
  model_name: Name of the model
191
+
192
  Returns:
193
  Memory estimation dictionary
194
  """
195
  model_info = AVAILABLE_MODELS.get(model_name)
196
  if not model_info:
197
  return {}
198
+
199
  size = model_info.get("size", "7B")
200
+ if size.endswith("B"):
201
+ size_gb = float(size.replace("B", "")) # e.g. "7B"
202
+ elif size == "small":
203
+ size_gb = 0.02 # Arbitrary tiny model size in GB
204
+ else:
205
+ size_gb = 0.1 # catchall
206
+
207
  estimates = {
208
  "model_size_gb": size_gb,
209
  "fp32_memory_gb": size_gb * 4, # 4 bytes per parameter
 
212
  "recommended_ram_gb": size_gb * 2.5,
213
  "recommended_vram_gb": size_gb * 1.5
214
  }
215
+
216
  return estimates
217
 
218
  def get_device_info() -> Dict[str, Any]:
 
223
  "current_device": torch.cuda.current_device() if torch.cuda.is_available() else None,
224
  "device_name": torch.cuda.get_device_name() if torch.cuda.is_available() else "CPU"
225
  }
226
+
227
  if torch.cuda.is_available():
228
  info["gpu_memory"] = {
229
  "allocated": torch.cuda.memory_allocated() / 1024**3, # GB
230
  "reserved": torch.cuda.memory_reserved() / 1024**3, # GB
231
  "total": torch.cuda.get_device_properties(0).total_memory / 1024**3 # GB
232
  }
233
+
234
  return info
235
 
236
  # Module initialization