yuvraj-singh-9886 commited on
Commit
3b70c60
·
1 Parent(s): 6be35bd

Add StoryKimi ZeroGPU implementation

Browse files

- Add ZeroGPU-compatible app.py with @spaces.GPU decorator
- Copy all necessary model files (config.py, model.py, tokenizer.py, inference.py)
- Update requirements.txt with spaces and gradio dependencies
- Comprehensive README.md based on original StoryKimi with HF Spaces adaptations
- Add .gitignore to exclude checkpoints and temporary files while keeping main model
- Configure metadata for ZeroGPU hardware in README frontmatter

Files changed (8) hide show
  1. .gitignore +217 -0
  2. README.md +140 -5
  3. app.py +202 -0
  4. config.py +151 -0
  5. inference.py +46 -0
  6. model.py +589 -0
  7. requirements.txt +9 -0
  8. tokenizer.py +18 -0
.gitignore ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be added to the global gitignore or merged into this project gitignore. For a PyCharm
158
+ # project, uncomment the following line:
159
+ #.idea/
160
+
161
+ # Model checkpoints and weights (except the main one we want to keep)
162
+ checkpoints/
163
+ *.pt
164
+ *.pth
165
+ *.ckpt
166
+ *.safetensors
167
+ !checkpoint_2000.pt
168
+
169
+ # Wandb logs
170
+ wandb/
171
+ runs/
172
+
173
+ # Generated data
174
+ generated_data/
175
+ data/
176
+ datasets/
177
+
178
+ # Images (except for README)
179
+ images/
180
+ *.png
181
+ *.jpg
182
+ *.jpeg
183
+ *.gif
184
+ !images/image.png
185
+
186
+ # Gradio temporary files
187
+ gradio_cached_examples/
188
+ flagged/
189
+
190
+ # OS files
191
+ .DS_Store
192
+ .DS_Store?
193
+ ._*
194
+ .Spotlight-V100
195
+ .Trashes
196
+ ehthumbs.db
197
+ Thumbs.db
198
+
199
+ # IDE files
200
+ .vscode/
201
+ .idea/
202
+ *.swp
203
+ *.swo
204
+ *~
205
+
206
+ # Temporary files
207
+ *.tmp
208
+ *.temp
209
+ temp/
210
+
211
+ # Log files
212
+ *.log
213
+ logs/
214
+
215
+ # Test files
216
+ test_outputs/
217
+ test_results/
README.md CHANGED
@@ -1,13 +1,148 @@
1
  ---
2
  title: StoryKimi Zero
3
- emoji: 📈
4
- colorFrom: gray
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.42.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: StoryKimi Zero
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.42.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
+ hardware: zero-gpu
12
+ short_description: Generate stories with StoryKimi model using ZeroGPU
13
  ---
14
 
15
+ # StoryKimi Zero - DeepSeek V3 Inspired Model on ZeroGPU
16
+
17
+ A PyTorch implementation of a DeepSeek V3 inspired transformer model with Mixture of Experts (MoE), Latent Attention, and other advanced features, deployed on Hugging Face Spaces with ZeroGPU for efficient inference.
18
+
19
+ ![StoryKimi Model](https://huggingface.co/YuvrajSingh9886/StoryKimi/resolve/main/images/image.png)
20
+
21
+ ## 📊 Training Results & Model Weights
22
+
23
+ **📈 View Training Report**: [StoryKimi Training Results on WandB](https://wandb.ai/rentio/DSV-Training/reports/SmolKimi-A-smaller-Kimi-K2---VmlldzoxMzYwNDQ4Mg?accessToken=lfs6n1y7gn8q0f0dwilta8yuwzxel45ztzbbcavwbqp7jsyv1p7cz9elflycv9fg)
24
+
25
+ **💾 Pre-trained Weights**:
26
+ - **Hugging Face Model**: [YuvrajSingh9886/StoryKimi](https://huggingface.co/YuvrajSingh9886/StoryKimi)
27
+ - **WandB Checkpoints**: Check the WandB report above for additional trained model checkpoints
28
+
29
+ ## 🌟 Features
30
+
31
+ - **ZeroGPU Integration**: Dynamic GPU allocation with NVIDIA H200 slices (70GB VRAM)
32
+ - **Latent Attention**: Efficient attention mechanism with compressed key-value representations
33
+ - **Mixture of Experts (MoE)**: 8 experts with top-2 routing and shared expert support
34
+ - **SWiGLU Activation**: Advanced activation function in expert layers
35
+ - **Sinusoidal Positional Embeddings**: Position encoding for sequence understanding
36
+ - **Interactive Interface**: User-friendly Gradio interface with real-time generation
37
+ - **Multiple Sampling Methods**: Top-k sampling with temperature control
38
+ - **Real-time Generation**: Fast inference with automatic scaling
39
+
40
+ ## 🔧 Model Architecture
41
+
42
+ ### Default Configuration
43
+ - **Embedding Dimensions**: 384
44
+ - **Decoder Layers**: 6
45
+ - **Attention Heads**: 8
46
+ - **MoE Experts**: 8 (top-2 routing)
47
+ - **Block Size**: 128 tokens
48
+ - **Vocabulary Size**: Based on Llama-2-7b tokenizer (~32,000 tokens)
49
+ - **Latent Dimension**: 64 (for compressed attention)
50
+
51
+ ### ZeroGPU Configuration
52
+ - **GPU Type**: NVIDIA H200 slice
53
+ - **Available VRAM**: 70GB per workload
54
+ - **Max Duration**: 120 seconds per generation
55
+ - **Deployment**: Hugging Face Spaces with automatic scaling
56
+
57
+ ## 🎯 Usage
58
+
59
+ 1. **Enter your story prompt** in the text box
60
+ 2. **Select model checkpoint** (Checkpoint 2000 available)
61
+ 3. **Adjust generation parameters**:
62
+ - **Max Length**: 10-128 tokens
63
+ - **Temperature**: 0.1-2.0 (creativity vs coherence)
64
+ - **Top-k**: 1-100 (vocabulary filtering)
65
+ 4. **Click "Generate Text"** to create your AI-generated story
66
+ 5. **Enjoy your personalized story!**
67
+
68
+ ## 💡 Generation Tips
69
+
70
+ - **Lower temperature** (0.1-0.7) for more coherent and focused stories
71
+ - **Higher temperature** (0.8-2.0) for more creative and diverse outputs
72
+ - **Adjust top-k** to control vocabulary diversity and randomness
73
+ - **Use descriptive prompts** for better and more relevant results
74
+ - **Experiment with different lengths** to find your preferred story format
75
+
76
+ ## 🔄 ZeroGPU Benefits
77
+
78
+ - **Free GPU Access**: No cost for users to generate stories
79
+ - **Efficient Resource Usage**: GPU allocated only when needed for inference
80
+ - **Automatic Scaling**: Handles multiple concurrent users seamlessly
81
+ - **High Performance**: NVIDIA H200 acceleration for fast generation
82
+ - **No Setup Required**: Ready-to-use interface with pre-loaded model
83
+
84
+ ## 🏗️ Technical Implementation
85
+
86
+ ### Model Features
87
+ - **Latent Attention**: Compressed key-value representations for efficiency
88
+ - **Mixture of Experts**: 8 experts with intelligent routing
89
+ - **Advanced Activation**: SWiGLU for better performance
90
+ - **Positional Encoding**: Sinusoidal embeddings for sequence understanding
91
+
92
+ ### Deployment Features
93
+ - **ZeroGPU Decorator**: `@spaces.GPU(duration=120)` for dynamic allocation
94
+ - **Optimized Loading**: Efficient model loading and initialization
95
+ - **Error Handling**: Robust error management for better user experience
96
+ - **Real-time Feedback**: Live generation status and results
97
+
98
+ ## 🚀 Local Development
99
+
100
+ Want to run this locally or contribute? Check out the full repository:
101
+
102
+ **📁 Source Code**: [YuvrajSingh-mist/SmolHub/StoryKimi](https://github.com/YuvrajSingh-mist/SmolHub/tree/main/StoryKimi)
103
+
104
+ ### Quick Local Setup
105
+ ```bash
106
+ # Clone the repository
107
+ git clone https://github.com/YuvrajSingh-mist/SmolHub.git
108
+ cd SmolHub/StoryKimi
109
+
110
+ # Install dependencies
111
+ chmod +x install.sh
112
+ ./install.sh
113
+
114
+ # Run Gradio interface
115
+ cd gradio
116
+ python app.py
117
+ ```
118
+
119
+ ### Training Your Own Model
120
+ ```bash
121
+ # Set your HF token for Llama-2 tokenizer access
122
+ export HF_TOKEN="your_token_here"
123
+
124
+ # Basic training
125
+ python trainer.py
126
+
127
+ # Advanced training with custom parameters
128
+ python trainer.py --embeddings_dims 512 --experts 16 --epochs 5
129
+ ```
130
+
131
+ ## 📊 Model Performance
132
+
133
+ The model has been trained on diverse text data and shows strong performance in:
134
+ - **Story Generation**: Creative and coherent narrative creation
135
+ - **Text Continuation**: Natural extension of given prompts
136
+ - **Style Adaptation**: Adapting to different writing styles and genres
137
+ - **Character Development**: Creating consistent characters and dialogue
138
+
139
+ ## 🔗 Related Links
140
+
141
+ - **Full Project**: [SmolHub Repository](https://github.com/YuvrajSingh-mist/SmolHub)
142
+ - **Model Weights**: [HuggingFace Model](https://huggingface.co/YuvrajSingh9886/StoryKimi)
143
+ - **Training Report**: [WandB Results](https://wandb.ai/rentio/DSV-Training/reports/SmolKimi-A-smaller-Kimi-K2---VmlldzoxMzYwNDQ4Mg?accessToken=lfs6n1y7gn8q0f0dwilta8yuwzxel45ztzbbcavwbqp7jsyv1p7cz9elflycv9fg)
144
+ - **Other Models**: [SmolMixtral](https://github.com/YuvrajSingh-mist/SmolHub/tree/main/SmolMixtral), [SmolTransformer](https://github.com/YuvrajSingh-mist/SmolHub/tree/main/SmolTransformer)
145
+
146
+ ## 📝 License
147
+
148
+ MIT License - See LICENSE file for details
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces # HF Spaces ZeroGPU decorator - only available in HF Spaces environment
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import os
6
+ import sys
7
+
8
+ from config import ModelArgs, get_args
9
+ from model import DeepSeekV3, initialize_tokenizer
10
+ from tokenizer import Tokenizer
11
+ from inference import topk_sampling
12
+
13
+ # Global variables
14
+ tk = None
15
+ model = None
16
+ model_args = None
17
+
18
+ # Model paths - using the checkpoint in the HF Space
19
+ model_paths = {
20
+ "Checkpoint 2000": "./checkpoint_2000.pt",
21
+ }
22
+
23
+ def initialize_app():
24
+ """Initialize the app with tokenizer and model args"""
25
+ global tk, model_args
26
+
27
+ # Initialize model args
28
+ model_args = ModelArgs()
29
+
30
+ # Initialize tokenizer (no HF token needed for basic operation)
31
+ if tk is None:
32
+ tk = Tokenizer(hf_token=None)
33
+ tk = tk.ready_tokenizer()
34
+
35
+ # Initialize the global tokenizer in model.py
36
+ initialize_tokenizer(hf_token=None)
37
+
38
+ def load_model(model_path, device, model_args):
39
+ """Load model from checkpoint"""
40
+ model = DeepSeekV3(
41
+ embeddings_dims=model_args.embeddings_dims,
42
+ block_size=model_args.block_size,
43
+ vocab_size=model_args.vocab_size,
44
+ dropout=model_args.dropout,
45
+ device=device
46
+ )
47
+
48
+ if os.path.exists(model_path):
49
+ checkpoint = torch.load(model_path, map_location=device)
50
+ model.load_state_dict(checkpoint)
51
+ model.eval()
52
+ print(f"Model loaded from {model_path}")
53
+ else:
54
+ print(f"Checkpoint {model_path} not found. Using randomly initialized model.")
55
+
56
+ return model
57
+
58
+ @spaces.GPU(duration=120)
59
+ def generate_text(prompt, model_choice, max_length, temperature, top_k):
60
+ """Generate text using the selected model and top-k sampling"""
61
+ global tk, model_args
62
+
63
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64
+ print(f"Using device: {device}")
65
+
66
+ # Load the selected model
67
+ model_path = model_paths.get(model_choice, "./checkpoint_2000.pt")
68
+ model = load_model(model_path, device, model_args)
69
+ model = model.to(device)
70
+
71
+ try:
72
+ generated_text = topk_sampling(
73
+ model=model,
74
+ prompt=prompt,
75
+ device=device,
76
+ max_length=max_length,
77
+ top_k=top_k,
78
+ temperature=temperature,
79
+ tokenizer=tk
80
+ )
81
+
82
+ return generated_text
83
+
84
+ except Exception as e:
85
+ return f"Error generating text: {str(e)}"
86
+
87
+ def create_interface():
88
+ """Create the Gradio interface"""
89
+ global tk, model_args
90
+
91
+ # Initialize the app
92
+ initialize_app()
93
+
94
+ with gr.Blocks(title="StoryKimi Text Generator", theme=gr.themes.Soft()) as demo:
95
+ gr.Markdown("# 🚀 StoryKimi Text Generator")
96
+ gr.Markdown("Generate text using the Kimi K2 inspired StoryKimi model with ZeroGPU support.")
97
+ gr.Markdown("⚡ **Powered by ZeroGPU** - Dynamic GPU allocation for efficient inference")
98
+
99
+ with gr.Row():
100
+ with gr.Column(scale=2):
101
+ prompt_input = gr.Textbox(
102
+ label="Input Prompt",
103
+ placeholder="Enter your prompt here...",
104
+ lines=3,
105
+ value="Once upon a time there lived a baby deer named Bambi."
106
+ )
107
+
108
+ with gr.Row():
109
+ model_dropdown = gr.Dropdown(
110
+ choices=list(model_paths.keys()),
111
+ label="Model Checkpoint",
112
+ value="Checkpoint 2000"
113
+ )
114
+
115
+ with gr.Row():
116
+ max_length_slider = gr.Slider(
117
+ minimum=10,
118
+ maximum=128,
119
+ value=50,
120
+ step=10,
121
+ label="Max Length"
122
+ )
123
+
124
+ temperature_slider = gr.Slider(
125
+ minimum=0.1,
126
+ maximum=2.0,
127
+ value=0.9,
128
+ step=0.1,
129
+ label="Temperature"
130
+ )
131
+
132
+ with gr.Row():
133
+ top_k_slider = gr.Slider(
134
+ minimum=1,
135
+ maximum=100,
136
+ value=50,
137
+ step=1,
138
+ label="Top-k"
139
+ )
140
+
141
+ with gr.Row():
142
+ top_k_slider = gr.Slider(
143
+ minimum=1,
144
+ maximum=100,
145
+ value=50,
146
+ step=1,
147
+ label="Top-k"
148
+ )
149
+
150
+ generate_btn = gr.Button("🎯 Generate Text", variant="primary", size="lg")
151
+
152
+ with gr.Column(scale=3):
153
+ output_text = gr.Textbox(
154
+ label="Generated Text",
155
+ lines=15,
156
+ interactive=False
157
+ )
158
+
159
+ with gr.Row():
160
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
161
+
162
+ # Event handlers
163
+ generate_btn.click(
164
+ fn=generate_text,
165
+ inputs=[
166
+ prompt_input,
167
+ model_dropdown,
168
+ max_length_slider,
169
+ temperature_slider,
170
+ top_k_slider
171
+ ],
172
+ outputs=output_text
173
+ )
174
+
175
+ clear_btn.click(
176
+ fn=lambda: ("", ""),
177
+ outputs=[prompt_input, output_text]
178
+ )
179
+
180
+ # Model information
181
+ gr.Markdown("## ℹ️ Model Information")
182
+ gr.Markdown("""
183
+ - **Model Architecture**: Kimi K2 inspired (StoryKimi)
184
+ - **ZeroGPU**: Dynamic GPU allocation with H200 slice (70GB VRAM)
185
+ - **GPU Duration**: 120 seconds maximum per generation
186
+ - **Deployment**: Hugging Face Spaces with automatic scaling
187
+ """)
188
+
189
+ gr.Markdown("## 🚀 Features")
190
+ gr.Markdown("""
191
+ - **Top-k Sampling**: Control randomness with top-k token selection
192
+ - **Temperature Control**: Adjust creativity vs coherence
193
+ - **Variable Length**: Generate 10-128 tokens
194
+ - **Real-time Generation**: Powered by ZeroGPU infrastructure
195
+ """)
196
+
197
+ return demo
198
+
199
+ if __name__ == "__main__":
200
+ # Create and launch the interface
201
+ demo = create_interface()
202
+ demo.launch()
config.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import dataclass
3
+
4
+ def get_args():
5
+ parser = argparse.ArgumentParser(description='SmolKimi - DeepSeek V3 Inspired Model Training')
6
+
7
+ # Model Architecture
8
+ parser.add_argument('--block_size', type=int, default=128, help='Maximum sequence length')
9
+ parser.add_argument('--batch_size', type=int, default=256, help='Training batch size')
10
+ parser.add_argument('--embeddings_dims', type=int, default=384, help='Model embedding dimensions')
11
+ parser.add_argument('--no_of_heads', type=int, default=8, help='Number of attention heads')
12
+ parser.add_argument('--no_of_decoder_layers', type=int, default=6, help='Number of decoder layers')
13
+ parser.add_argument('--latent_dim', type=int, default=64, help='Latent dimension for attention')
14
+
15
+ # MoE Configuration
16
+ parser.add_argument('--experts', type=int, default=8, help='Number of MoE experts')
17
+ parser.add_argument('--top_experts', type=int, default=2, help='Number of experts to route to (top-k)')
18
+ parser.add_argument('--use_shared_expert', action='store_true', default=True, help='Enable shared expert in MoE')
19
+ parser.add_argument('--noisy_topk', action='store_true', default=False, help='Use noisy top-k routing')
20
+ parser.add_argument('--useauxFreeLoadBalancingLoss', action='store_true', default=True, help='Use auxiliary-free load balancing loss')
21
+ parser.add_argument('--aux_free_bias_update_rate', type=float, default=0.001, help='Bias update rate for load balancing')
22
+ parser.add_argument('--loss_scale', type=float, default=0.3, help='Loss scaling factor')
23
+
24
+ # Training Hyperparameters
25
+ parser.add_argument('--epochs', type=int, default=1, help='Number of training epochs')
26
+ parser.add_argument('--max_lr', type=float, default=6e-4, help='Maximum learning rate')
27
+ parser.add_argument('--weight_decay_optim', type=float, default=0.1, help='Weight decay for optimizer')
28
+ parser.add_argument('--beta_1', type=float, default=0.9, help='Beta1 for optimizer')
29
+ parser.add_argument('--beta_2', type=float, default=0.95, help='Beta2 for optimizer')
30
+ parser.add_argument('--eps', type=float, default=1e-8, help='Epsilon for optimizer')
31
+ parser.add_argument('--clip', type=float, default=1.0, help='Gradient clipping value')
32
+
33
+ # Regularization
34
+ parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate')
35
+ parser.add_argument('--attn_dropout', type=float, default=0.1, help='Attention dropout rate')
36
+
37
+ # System Configuration
38
+ parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda/cpu)')
39
+ parser.add_argument('--use_checkpointing', action='store_true', default=False, help='Use gradient checkpointing')
40
+ parser.add_argument('--use_liger', action='store_true', default=True, help='Use Liger kernels for optimization')
41
+ parser.add_argument('--ignore_pad_token_in_loss', action='store_true', default=True, help='Ignore padding tokens in loss calculation')
42
+
43
+ # Data Configuration
44
+ parser.add_argument('--vocab_size', type=int, default=32000 + 1 , help='Vocabulary size (updated based on tokenizer)')
45
+ parser.add_argument('--base_freq', type=int, default=100000, help='Base frequency for positional encoding')
46
+ parser.add_argument('--hf_token', type=str, default=None, help='Hugging Face token for accessing gated models like Llama-2')
47
+
48
+ # Dataset Selection
49
+ parser.add_argument('--dataset', type=str, default='tinystories', choices=['tinystories', 'fineweb', 'tinyshakespeare'], help='Dataset to use for training')
50
+
51
+ # Generation Parameters
52
+ parser.add_argument('--generation_max_length', type=int, default=50, help='Maximum length for text generation')
53
+ parser.add_argument('--generation_top_k', type=int, default=50, help='Top-k value for sampling during generation')
54
+ parser.add_argument('--generation_temperature', type=float, default=1.0, help='Temperature for sampling during generation')
55
+
56
+ # Logging and Checkpointing
57
+ parser.add_argument('--log_interval', type=int, default=100, help='Steps between logging')
58
+ parser.add_argument('--save_interval', type=int, default=2000, help='Steps between saving checkpoints')
59
+ parser.add_argument('--eval_interval', type=int, default=400, help='Steps between evaluation')
60
+ parser.add_argument('--eval_iters', type=int, default=400, help='Number of iterations for evaluation')
61
+ parser.add_argument('--warmup_iters', type=int, default=400, help='Number of warmup iterations')
62
+ parser.add_argument('--total_iters', type=int, default=10000, help='Total training iterations')
63
+ parser.add_argument('--lr_decay_iters', type=int, default=10000, help='Learning rate decay iterations')
64
+ parser.add_argument('--wandb_project', type=str, default='smolkimi', help='Wandb project name')
65
+ parser.add_argument('--wandb_run_name', type=str, default=None, help='Wandb run name')
66
+
67
+ # Batch Size Configuration
68
+ parser.add_argument('--total_batch_size', type=int, default=524288, help='Total batch size for gradient accumulation')
69
+ parser.add_argument('--micro_batch_size', type=int, default=None, help='Micro batch size (defaults to batch_size)')
70
+
71
+ # Distributed Training
72
+ parser.add_argument('--use_ddp', action='store_true', default=False, help='Use distributed data parallel')
73
+
74
+ return parser.parse_args()
75
+
76
+ @dataclass
77
+ class ModelArgs:
78
+ def __init__(self, args=None):
79
+ if args is None:
80
+ args = get_args()
81
+
82
+ # Model Architecture
83
+ self.block_size = args.block_size
84
+ self.batch_size = args.batch_size
85
+ self.embeddings_dims = args.embeddings_dims
86
+ self.no_of_heads = args.no_of_heads
87
+ self.no_of_decoder_layers = args.no_of_decoder_layers
88
+ self.latent_dim = args.latent_dim
89
+
90
+ # MoE Configuration
91
+ self.experts = args.experts
92
+ self.top_experts = args.top_experts
93
+ self.use_shared_expert = args.use_shared_expert
94
+ self.noisy_topk = args.noisy_topk
95
+ self.useauxFreeLoadBalancingLoss = args.useauxFreeLoadBalancingLoss
96
+ self.aux_free_bias_update_rate = args.aux_free_bias_update_rate
97
+ self.loss_scale = args.loss_scale
98
+
99
+ # Training Hyperparameters
100
+ self.epochs = args.epochs
101
+ self.max_lr = args.max_lr
102
+ self.weight_decay_optim = args.weight_decay_optim
103
+ self.beta_1 = args.beta_1
104
+ self.beta_2 = args.beta_2
105
+ self.eps = args.eps
106
+ self.clip = args.clip
107
+
108
+ # Regularization
109
+ self.dropout = args.dropout
110
+ self.attn_dropout = args.attn_dropout
111
+
112
+ # System Configuration
113
+ self.device = args.device
114
+ self.use_checkpointing = args.use_checkpointing
115
+ self.use_liger = args.use_liger
116
+ self.ignore_pad_token_in_loss = args.ignore_pad_token_in_loss
117
+
118
+ # Data Configuration
119
+ self.vocab_size = args.vocab_size
120
+ self.base_freq = args.base_freq
121
+ self.hf_token = args.hf_token
122
+ self.dataset = args.dataset
123
+
124
+ # Generation Parameters
125
+ self.generation_max_length = args.generation_max_length
126
+ self.generation_top_k = args.generation_top_k
127
+ self.generation_temperature = args.generation_temperature
128
+
129
+ # Logging and Checkpointing
130
+ self.log_interval = args.log_interval
131
+ self.save_interval = args.save_interval
132
+ self.eval_interval = args.eval_interval
133
+ self.eval_iters = args.eval_iters
134
+ self.warmup_iters = args.warmup_iters
135
+ self.total_iters = args.total_iters
136
+ self.lr_decay_iters = args.lr_decay_iters
137
+ self.wandb_project = args.wandb_project
138
+ self.wandb_run_name = args.wandb_run_name
139
+
140
+ # Batch Size Configuration
141
+ self.total_batch_size = args.total_batch_size
142
+ self.micro_batch_size = args.micro_batch_size if args.micro_batch_size else args.batch_size
143
+ self.gradient_accumulation_steps = self.total_batch_size // (self.micro_batch_size * self.block_size)
144
+
145
+ # Calculated parameters
146
+ self.min_lr = 0.1 * self.max_lr
147
+ self.save_checkpoint_iter = self.save_interval
148
+ self.eval_check = self.eval_interval
149
+
150
+ # Distributed Training
151
+ self.use_ddp = args.use_ddp
inference.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from config import ModelArgs
4
+ from model import DeepSeekV3
5
+ from tokenizer import Tokenizer
6
+
7
+ def topk_sampling(model, prompt, device, max_length=50, top_k=50, temperature=1.0, tokenizer=None, hf_token=None):
8
+ if tokenizer is None:
9
+ # Use default tokenizer if none provided
10
+ tokenizer_instance = Tokenizer(hf_token=hf_token)
11
+ tokenizer = tokenizer_instance.ready_tokenizer()
12
+
13
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
14
+ generated_tokens = []
15
+
16
+ if(len(input_ids[0]) < max_length):
17
+ max_length -= len(input_ids[0]) # If the input is longer than max_length, set max_length to the length of the input
18
+ else:
19
+ max_length = len(input_ids[0]) - max_length
20
+ for _ in range(max_length):
21
+ with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
22
+ # Pass inference=True to use the inference path in the model
23
+ outputs = model(input_ids, inference=True)
24
+ logits = outputs[:, -1, :]
25
+ logits = logits / temperature
26
+ probs = F.softmax(logits, dim=-1)
27
+
28
+ # Top-k filtering
29
+ top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
30
+
31
+ # Sample from top-k
32
+ next_token = torch.multinomial(top_k_probs, num_samples=1)
33
+
34
+ xcol = torch.gather(top_k_indices, -1, next_token)
35
+ input_ids = torch.cat([input_ids, xcol], dim=1) #1 because is it the dimension of the sequence
36
+
37
+ if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id and xcol.item() == tokenizer.eos_token_id:
38
+ break
39
+
40
+
41
+ return tokenizer.decode(input_ids[0])
42
+
43
+
44
+ def save_text(file_path, step, text):
45
+ with open(file_path, 'w') as f:
46
+ f.write(f"Step {step}: {text}\n")
model.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from torch.nn import RMSNorm
7
+ from config import ModelArgs
8
+ from tokenizer import Tokenizer
9
+
10
+ # Initialize tokenizer globally as None - will be set later
11
+ tokenizer = None
12
+ model_args = ModelArgs()
13
+
14
+
15
+ def initialize_tokenizer(hf_token=None):
16
+ """Initialize the global tokenizer with the provided HF token"""
17
+ global tokenizer
18
+ if tokenizer is None:
19
+ tokenizer_instance = Tokenizer(hf_token=hf_token)
20
+ tokenizer = tokenizer_instance.ready_tokenizer()
21
+ return tokenizer
22
+
23
+ class Normalization(nn.Module):
24
+ def __init__(
25
+ self,
26
+ embeddings_dims: int = model_args.embeddings_dims
27
+ ):
28
+ super().__init__()
29
+ self.rmsnorm_layer = RMSNorm(embeddings_dims)
30
+
31
+
32
+ def forward(self, x):
33
+
34
+ x = self.rmsnorm_layer(x)
35
+ return x
36
+
37
+
38
+
39
+ class Swish(nn.Module):
40
+ def __init__(
41
+ self,
42
+ block_size: int = model_args.block_size,
43
+ embeddings_dims: int = model_args.embeddings_dims,
44
+ device = model_args.device
45
+ ):
46
+ super().__init__()
47
+
48
+ self.sig = torch.nn.Sigmoid()
49
+
50
+
51
+ def forward(self, x):
52
+ swish = x * self.sig(x)
53
+
54
+ return swish
55
+
56
+
57
+
58
+ class SWiGLUExpertMoE(nn.Module):
59
+ def __init__(
60
+ self,
61
+ block_size: int = model_args.block_size,
62
+ embeddings_dims: int = model_args.embeddings_dims,
63
+ device = model_args.device
64
+ ):
65
+ super().__init__()
66
+
67
+ self.hidden_dims = (embeddings_dims * 2)
68
+ self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device)
69
+ self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, device = device)
70
+ self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, device = device)
71
+ self.linear_layer3 = nn.Linear(in_features=self.hidden_dims, out_features=embeddings_dims, bias=False, device = device)
72
+
73
+
74
+
75
+
76
+ def forward(self, x):
77
+ swish_res = self.swish(self.linear_layer1(x))
78
+ x_V = self.linear_layer2(x)
79
+ res = torch.mul(swish_res, x_V)
80
+ out = self.linear_layer3(res)
81
+ return out
82
+
83
+
84
+
85
+ class MoeLayer(nn.Module):
86
+ def __init__(
87
+ self,
88
+ dropout = model_args.dropout,
89
+ embeddings_size = model_args.embeddings_dims,
90
+ device = model_args.device,
91
+ # inner_dimensional_states: int = 3072
92
+ ):
93
+ super().__init__()
94
+
95
+ self.heads = nn.ModuleList([SWiGLUExpertMoE() for _ in range(model_args.experts)])
96
+ self.gate = nn.Linear(in_features=embeddings_size, out_features=model_args.experts, device=device, bias=False)
97
+
98
+ # Only create shared expert if enabled
99
+ if model_args.use_shared_expert:
100
+ self.shared_expert = SWiGLUExpertMoE()
101
+ else:
102
+ self.shared_expert = None
103
+
104
+ if(model_args.noisy_topk is True and model_args.use_checkpointing == False):
105
+ self.noise = nn.Linear(in_features=embeddings_size, out_features=model_args.experts, device=device, bias=False)
106
+ self.noisy_router = None
107
+ # self.outputs = torch.zeros((batch_size,block_size, embeddings_size), device=device) #batch size needs to be defined because we are accessing it explicitly
108
+ self.device = device
109
+ # self.shared_expert_out = torch.zeros((model_args.batch_size, model_args.embeddings_dims), device=device)
110
+ # self.b = torch.zeros((model_args.batch_size, model_args.block_size, model_args.experts), device=device)
111
+
112
+ if model_args.useauxFreeLoadBalancingLoss:
113
+ self.register_buffer('routing_bias', torch.zeros(model_args.experts, device=self.device))
114
+ # self.routing_bias = torch.zeros(model_args.experts, device=self.device)
115
+ self.bias_update_speed = model_args.aux_free_bias_update_rate
116
+
117
+
118
+ def forward(self, x):
119
+ # mlp_weights_init = self.mlp.apply(weights_init)
120
+ self.gate_out = self.gate(x) #[bz, seq, num_experts]
121
+
122
+
123
+ if(model_args.noisy_topk == True and model_args.use_checkpointing == False):
124
+ noise = self.noise(x)
125
+ gaussian_noise = torch.normal(0, 1, size=self.gate_out.shape, device=self.device)
126
+ self.noisy_router = F.softplus(noise) * gaussian_noise
127
+ self.gate_out += self.noisy_router
128
+
129
+
130
+
131
+ shared_output = 0
132
+ out = 0
133
+
134
+
135
+
136
+ if model_args.useauxFreeLoadBalancingLoss:
137
+
138
+ self.gate_out += self.routing_bias
139
+
140
+
141
+
142
+
143
+ # Adjust top_k based on whether shared expert is used
144
+ top_k = model_args.top_experts
145
+ top_k_values, top_k_indices = torch.topk(self.gate_out, k=top_k) #[bs, seq len, top k]
146
+ # topkmask = torch.ones_like(top_k_values, device=self.device) # [bs, seq len, experts]
147
+ # indices = torch.arange(top_k_values.size(0), device=self.device).unsqueeze(1).unsqueeze(2) # [bs, 1, 1]
148
+ # topkvaluesMasked = top_k_values.masked_fill(indices != top_k_indices, float('-inf')) # Mask out negative values
149
+ masked = torch.full_like(self.gate_out, float('-1e20'), device=self.device)
150
+ masked_values = masked.scatter_(-1, top_k_indices, top_k_values)
151
+ probs = torch.nn.functional.softmax(masked_values, dim=-1) #[bs, seq len, top k]
152
+
153
+ out = torch.zeros_like(x)
154
+ if model_args.use_shared_expert and self.shared_expert is not None:
155
+ shared_output += self.shared_expert(x)
156
+
157
+ flat_x = x.view(-1, x.size(-1)) # Flatten the input for easier processing
158
+
159
+ for i in range(model_args.experts): # Iterate through each expert index (0 to num_experts-1)
160
+ # Determine which tokens routed to this expert 'i'
161
+ # top_k_indices is [bs, seq_len, self.top_k]
162
+ # We want a mask of shape [bs, seq_len] where True if expert 'i' is in the top_k for that token
163
+ expert_i_is_chosen_mask = (top_k_indices == i).any(dim=-1) # Check along the top_k dimension
164
+ # expert_i_is_chosen_mask has shape [bs, seq_len]
165
+
166
+ if not expert_i_is_chosen_mask.any(): # If expert 'i' was not chosen by any token
167
+ continue
168
+
169
+ # Flatten the mask to apply to flat_x
170
+ flat_expert_i_is_chosen_mask = expert_i_is_chosen_mask.reshape(-1) # Shape: [bs * seq_len]
171
+
172
+ # Select input tokens for this expert
173
+ selected_input_tokens = flat_x[flat_expert_i_is_chosen_mask] # Shape: [num_active_for_expert_i, embed_dim]
174
+
175
+ if selected_input_tokens.numel() == 0: # Should be caught by .any() above, but good check
176
+ continue
177
+
178
+ # Process through the expert
179
+ expert_output_for_selected = self.heads[i](selected_input_tokens)
180
+
181
+ # Get the routing probabilities for these chosen tokens specifically for expert 'i'
182
+ # routing_probs is [bs, seq_len, num_experts]
183
+ # expert_i_probs_original_shape = routing_probs[:, :, i] # Probabilities for expert 'i', shape [bs, seq_len]
184
+ # flat_expert_i_probs = expert_i_probs_original_shape.reshape(-1) # Shape [bs * seq_len]
185
+ # active_token_weights = flat_expert_i_probs[flat_expert_i_is_chosen_mask] # Shape: [num_active_for_expert_i]
186
+
187
+ # Alternative way to get weights directly using the mask on routing_probs for expert i:
188
+ # Get the [bs, seq_len] slice of probabilities for the current expert 'i'
189
+ probs_for_expert_i = probs[:, :, i] # Shape: [bs, seq_len]
190
+ # Now use the expert_i_is_chosen_mask (which is also [bs, seq_len]) to select the relevant weights
191
+ active_token_weights = probs_for_expert_i[expert_i_is_chosen_mask] # Shape: [num_active_for_expert_i]
192
+
193
+
194
+ weighted_expert_output = expert_output_for_selected * active_token_weights.unsqueeze(-1)
195
+
196
+ # Add this expert's contribution
197
+ temp_contribution_for_expert_i = torch.zeros_like(x) # Initialize with zeros
198
+ temp_contribution_for_expert_i.masked_scatter_(
199
+ expert_i_is_chosen_mask.unsqueeze(-1).expand_as(x), # Use the original 2D mask, expanded
200
+ weighted_expert_output
201
+ )
202
+ out = out + temp_contribution_for_expert_i
203
+
204
+
205
+ # for expert_idx in range(model_args.experts):
206
+ # # Create mask for current expert across all top_k positions
207
+ # expert_mask = (top_k_indices == expert_idx)
208
+
209
+ # # Sum probabilities for current expert
210
+ # expert_weights = (probs * expert_mask).sum(dim=-1) # [batch, seq_len]
211
+
212
+ # # Get inputs where expert is used
213
+ # selected = expert_weights > 0
214
+ # if not selected.any():
215
+ # continue
216
+ # # print(expert_weights.shape)
217
+ # # print(x[selected].shape)
218
+
219
+ # # Process all selected inputs through expert
220
+ # expert_out = self.heads[expert_idx](x[selected])
221
+
222
+
223
+
224
+ # # Weight and accumulate outputs
225
+ # out[selected] += expert_out * expert_weights[selected].unsqueeze(-1)
226
+
227
+ out = out + shared_output # Add shared expert output if enabled
228
+
229
+ if model_args.useauxFreeLoadBalancingLoss and self.training:
230
+
231
+ with torch.no_grad():
232
+ ci = probs.sum(dim=(0,1)) # Su of tokens for each expert
233
+ ci_avg = ci.mean()
234
+
235
+
236
+ error_i = ci_avg - ci
237
+
238
+ self.update = self.bias_update_speed * torch.sign(error_i) # Update routing bias
239
+ self.routing_bias.add_(self.update)
240
+ # self.routing_bias = self.routing_bias + self.update
241
+
242
+ return out
243
+
244
+
245
+ # import numpy as np
246
+ class SinusoidalPositionalEmbeddings(nn.Module):
247
+ def __init__(
248
+ self,
249
+ device,
250
+ embeddings_dims: int = model_args.embeddings_dims,
251
+ block_size: int = model_args.block_size,
252
+ batch_size: int = model_args.batch_size,
253
+ ):
254
+ super().__init__()
255
+
256
+ self.embeddings_dims = embeddings_dims
257
+ self.block_size = block_size
258
+ self.batch_size = batch_size
259
+ self.device = device
260
+
261
+ # Create positional encoding matrix
262
+ pe = torch.zeros(block_size, embeddings_dims)
263
+ position = torch.arange(0, block_size, dtype=torch.float).unsqueeze(1)
264
+ div_term = torch.exp(torch.arange(0, embeddings_dims, 2).float() * (-math.log(10000.0) / embeddings_dims))
265
+
266
+ pe[:, 0::2] = torch.sin(position * div_term)
267
+ pe[:, 1::2] = torch.cos(position * div_term)
268
+
269
+ # Register as buffer so it's not a parameter but moves with the model
270
+ self.register_buffer('pe', pe.unsqueeze(0)) # Shape: [1, block_size, embeddings_dims]
271
+
272
+ def forward(self, x):
273
+ # x shape: [batch_size, seq_len, embeddings_dims]
274
+ batch_size, seq_len, _ = x.shape
275
+
276
+ # Add positional embeddings
277
+ # pe[:, :seq_len] ensures we only use the positional embeddings up to the sequence length
278
+ pos_emb = self.pe[:, :seq_len].to(x.device)
279
+ return pos_emb
280
+
281
+
282
+
283
+ class LatentAttention(nn.Module):
284
+ def __init__(
285
+ self,
286
+ attn_dropout = model_args.attn_dropout,
287
+ embeddings_dims = model_args.embeddings_dims,
288
+ no_of_heads = model_args.no_of_heads,
289
+ device = model_args.device
290
+ ):
291
+ super().__init__()
292
+ self.head_size = embeddings_dims // no_of_heads
293
+ self.no_of_heads = no_of_heads
294
+ # if(model_args.use_flash_attention==False):
295
+ self.latent_dim = model_args.latent_dim
296
+ self.W_k = nn.Linear(in_features=self.latent_dim, out_features=self.head_size, device=device, bias=False)
297
+ self.W_v = nn.Linear(in_features=self.latent_dim, out_features=self.head_size, device=device, bias=False)
298
+ self.W_dkv = nn.Linear(in_features=model_args.embeddings_dims, out_features=self.latent_dim, device=device, bias=False) # 3 for query, key and value
299
+ self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=model_args.device, bias=False)
300
+ # self.keys = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,device=model_args.device, bias=False)
301
+ # self.values = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=model_args.device,bias=False)
302
+ # self.dropout = nn.Dropout(p = attn_dropout)
303
+
304
+
305
+ self.dropout = nn.Dropout(p = attn_dropout)
306
+ self.device = device
307
+
308
+ # Use sinusoidal positional embeddings instead of rotary
309
+ self.pos_embeddings = SinusoidalPositionalEmbeddings(embeddings_dims=self.head_size, device=device)
310
+ # self.register_buffer('absorbed_q', None)
311
+ # self.absorbed_q = None
312
+
313
+ def forward(self, x, kv_cache=None, mask=None):
314
+ batch_size, block_size, embd_dims = x.shape
315
+
316
+ # k = self.keys(x)
317
+ # q = self.query(x)
318
+ # v = self.values(x)
319
+
320
+ self.latent_matrix = self.W_dkv(x)
321
+
322
+ # print("q shape: ", q.shape)
323
+
324
+ # print("Shape of latent mat: ", self.query.weight.shape)
325
+ # print("Shape of compressed_k: ", self.W_k.weight.shape)
326
+
327
+ # if(self.absorbed_q is None):
328
+ self.absorbed_q = torch.matmul(self.query.weight.T , self.W_k.weight)
329
+
330
+
331
+ # weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5)
332
+
333
+ # if kv_cache is None:
334
+ if kv_cache is None:
335
+ kv_cache = self.latent_matrix
336
+ else:
337
+ # print(kv_cache)
338
+ # print("Shape of latent matrix: ", self.latent_matrix.shape)
339
+ # print("Shape of kv_cache: ", kv_cache.shape)
340
+ kv_cache = torch.cat([kv_cache, self.latent_matrix], dim=1)
341
+
342
+ self.compressed_k = self.W_k(kv_cache)
343
+ self.compressed_v = self.W_v(kv_cache)
344
+
345
+ q_res = torch.matmul(x , self.absorbed_q)
346
+ weights = q_res @ torch.transpose(kv_cache, dim0=-2, dim1=-1) * (self.head_size ** -0.5) # [batch_size, block_size, block_size]
347
+ # print("Shape of weights: ", weights.shape)
348
+ # print("Shape of kv_cache: ", kv_cache.shape)
349
+ if(mask is not None):
350
+ weights = weights.masked_fill(mask == 0, float('-1e20')) #Masking the attention weights
351
+
352
+ masked_table = torch.tril(torch.ones(q_res.shape[1], kv_cache.shape[1], device=model_args.device))
353
+
354
+ masked_values = weights.masked_fill(masked_table[: q_res.shape[1], : kv_cache.shape[1]] == 0, float('-1e20'))
355
+ weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
356
+ weights_normalized = self.dropout(weights_normalized)
357
+
358
+ # print("Shape of weights_normalized: ", weights_normalized.shape)
359
+ # Apply positional embeddings to the output
360
+
361
+
362
+
363
+
364
+ # print("Shape of compressed_v: ", self.compressed_v.shape)
365
+ out = weights_normalized @ self.compressed_v
366
+
367
+ # out = self.pos_embeddings(out)
368
+ return out, kv_cache
369
+
370
+ # MHA
371
+
372
+
373
+ class MHLA(nn.Module):
374
+ def __init__(
375
+ self,
376
+ device,
377
+ attn_dropout = model_args.attn_dropout,
378
+ embeddings_dims = model_args.embeddings_dims,
379
+ no_of_heads = model_args.no_of_heads,
380
+ ):
381
+ super().__init__()
382
+ self.heads = nn.ModuleList([LatentAttention(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads) for _ in range(no_of_heads)])
383
+ self.dropout = nn.Dropout(p = attn_dropout)
384
+ self.linear = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=device, bias=False) # 12 (no of heads) * (batch_size) 64 = 768 -> gives out the text embeddings
385
+
386
+ def forward(self, x, kv_cache=None, mask=None):
387
+ # concat = torch.cat([head(x, kv_cache=kv_cache, mask=mask) for head in self.heads], dim=-1)
388
+ res = []
389
+ for head in self.heads:
390
+ head_out, kv_cache = head(x, kv_cache=kv_cache, mask=mask)
391
+ res.append(head_out)
392
+ concat = torch.cat(res, dim=-1) # Concatenate along the last dimension
393
+ linear_layer = self.linear(concat)
394
+ out = self.dropout(linear_layer)
395
+ return out, kv_cache
396
+
397
+ class FFN(nn.Module):
398
+ def __init__(self,
399
+ device,
400
+ embeddings_dims: int = model_args.embeddings_dims,
401
+ block_size: int = model_args.block_size,
402
+ vocab_size: int = model_args.vocab_size,
403
+ dropout = model_args.dropout
404
+
405
+ ):
406
+ super().__init__()
407
+
408
+ self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, device = device)
409
+ self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, device = device)
410
+
411
+ self.dropout = nn.Dropout(p = dropout) # Uncommenting the dropout line
412
+ def forward(self, x):
413
+
414
+ x = self.linear_layer(x)
415
+ x = F.gelu(x)
416
+ x = self.linear_layer2(x)
417
+ x = F.gelu(x)
418
+ # x = self.dropout(x) # Uncommenting the dropout line
419
+ return x
420
+
421
+
422
+
423
+
424
+
425
+
426
+
427
+ class DecoderLayer(nn.Module):
428
+ def __init__(self,
429
+ device,
430
+ attn_dropout: float = model_args.attn_dropout,
431
+ no_of_heads: int = model_args.no_of_heads,
432
+ embeddings_dims: int = model_args.embeddings_dims,
433
+ dropout = model_args.dropout,
434
+ block_size: int = model_args.block_size,
435
+ vocab_size: int = model_args.vocab_size,
436
+
437
+ ) :
438
+ super().__init__()
439
+
440
+ # self.base_freq = model_args.base_freq
441
+ # self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, device = device)
442
+ self.mha = MHLA(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads, device=device)
443
+ self.layer_norm1 = Normalization(embeddings_dims=embeddings_dims)
444
+ self.layer_norm2 = Normalization(embeddings_dims=embeddings_dims)
445
+ # self.layer_norm3 = Normalization(embeddings_dims=embeddings_dims)
446
+ self.dropout = nn.Dropout(p = dropout)
447
+
448
+ self.moe_block = MoeLayer(dropout=dropout, embeddings_size=embeddings_dims)
449
+
450
+ def forward(self, x, kv_cache=None, ffn=None, mask=None):
451
+
452
+ out, kv_cache = self.mha(self.layer_norm1(x), kv_cache=kv_cache, mask=mask) #Very important step -> Layer Norm on input and then passes it to the subsequent blocks
453
+ x = x + out # Fixed: removed in-place operation
454
+ x = x + self.moe_block(self.layer_norm2(x)) #Very important step
455
+
456
+ return x, kv_cache
457
+
458
+
459
+ class Block(nn.Module):
460
+ def __init__(self,
461
+ device,
462
+ embeddings_dims: int = model_args.embeddings_dims,
463
+ no_of_decoder_layers: int = model_args.no_of_decoder_layers,
464
+ block_size: int = model_args.block_size,
465
+ vocab_size: int = model_args.vocab_size,
466
+ dropout = model_args.dropout
467
+
468
+ ) :
469
+ super().__init__()
470
+ self.base_freq = model_args.base_freq
471
+ # self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, dtype=torch.float32, device = device)
472
+ self.decoder = nn.ModuleList(DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout, device = device) for _ in range(no_of_decoder_layers))
473
+ # self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, dtype=torch.float32, device = device)
474
+ self.dropout = nn.Dropout(p = dropout)
475
+ self.norm = Normalization(embeddings_dims)
476
+
477
+ #weight tying
478
+ # self.embeddings.weight = self.linear_layer.weight
479
+
480
+ self.apply(self._init_weights)
481
+
482
+ def _init_weights(self, module):
483
+ if isinstance(module, nn.Linear):
484
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
485
+
486
+ if module.bias is not None:
487
+ nn.init.zeros_(module.bias)
488
+ elif isinstance(module, nn.Embedding):
489
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
490
+
491
+
492
+
493
+ def forward(self, x, mask=None, actual_labels = None, inference=False):
494
+ index = 0
495
+ no_of_layers = 0
496
+ # x = self.embeddings(x)
497
+ # # x = self.dropout(x)
498
+ # if(mask is not None):
499
+ kv_cache = None
500
+ # x = x * mask
501
+ # # mask = mask.unsqueeze(-1)
502
+ # x = self.decoder(x)
503
+ for layer in self.decoder:
504
+ # if no_of_layers % 2 == 0:
505
+ # if no_of_layers % 4 == 0:
506
+ # # print("x shape: ", x.shape)
507
+ # x = layer(x, rope=False, ffn=True, mask=mask)
508
+ # x = layer(x, rope=True, ffn=True, mask=mask)
509
+
510
+ # # print("x shape: ", x.shape)
511
+ # else:
512
+ # # print("x shape local: ", x.shape)
513
+ # if no_of_layers % 4 == 0:
514
+ # # print("x shape: ", x.shape)
515
+ # x = layer(x, rope=False, ffn=False, mask=mask)
516
+ x, kv_cache = layer(x, kv_cache=kv_cache, ffn=None, mask=mask)
517
+ # print("x shape local: ", x.shape)
518
+ # no_of_layers += 1
519
+ # print(x.shape)
520
+ x = self.dropout(x)
521
+ x = 2 * ((model_args.no_of_decoder_layers) ** -0.5) * x
522
+ x = self.norm(x)
523
+
524
+ # if(inference):
525
+ # out = self.linear_layer(x)
526
+ # return out
527
+ # if(model_args.use_liger):
528
+ # # print("yo")
529
+ # y = x.contiguous().view(-1, model_args.embeddings_dims)
530
+ # if(actual_labels is not None):
531
+ # labels = actual_labels.contiguous().view(-1)
532
+
533
+ # # Pass linear layer weights FIRST as required [2][5]
534
+ # # ignore_index is already set during initialization
535
+ # loss = self.le_loss(self.linear_layer.weight, y, labels)
536
+ # return loss
537
+ # else:
538
+ # # print("Hi")
539
+ # out = self.linear_layer(x)
540
+ # return out
541
+
542
+ return x
543
+
544
+
545
+
546
+ class DeepSeekV3(nn.Module):
547
+ def __init__(self,
548
+ device,
549
+ embeddings_dims: int = model_args.embeddings_dims,
550
+ block_size: int = model_args.block_size,
551
+ vocab_size: int = model_args.vocab_size,
552
+ dropout = model_args.dropout
553
+ ):
554
+ super().__init__()
555
+ self.decoder = Block(device=device, embeddings_dims=embeddings_dims, no_of_decoder_layers=model_args.no_of_decoder_layers, block_size=block_size, vocab_size=vocab_size, dropout=dropout)
556
+ self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, dtype=torch.float32, device=device)
557
+ self.pos_embeddings = SinusoidalPositionalEmbeddings(embeddings_dims=embeddings_dims, device=device)
558
+ self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, dtype=torch.float32, device=device, bias=False)
559
+ # Weight tying - tie embedding and output projection weights
560
+ self.embedding.weight = self.linear_layer.weight
561
+
562
+ # Initialize the LigerFusedLinearCrossEntropyLoss for optimized training
563
+ if model_args.use_liger:
564
+ from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss
565
+ # Initialize with ignore_index for padding tokens if enabled
566
+ if model_args.ignore_pad_token_in_loss:
567
+ self.le_loss = LigerFusedLinearCrossEntropyLoss(
568
+ ignore_index=tokenizer.pad_token_id
569
+ )
570
+ else:
571
+ self.le_loss = LigerFusedLinearCrossEntropyLoss()
572
+
573
+ def forward(self, x, inference=False, mask=None):
574
+ if(mask is not None):
575
+ x = x * mask
576
+
577
+ x = self.embedding(x)
578
+ x = x + self.pos_embeddings(x) # Add positional embeddings
579
+ B, T, C = x.shape
580
+
581
+ if inference:
582
+ # For inference, we only need the last token prediction
583
+ decoder_out = self.decoder(x, mask=mask)
584
+ logits = self.linear_layer(decoder_out)
585
+ return logits
586
+ else:
587
+ decoder_out = self.decoder(x, mask=mask)
588
+ logits = self.linear_layer(decoder_out)
589
+ return logits
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ spaces
2
+ torch>=2.1.2
3
+ transformers>=4.36.0
4
+ datasets
5
+ tqdm
6
+ huggingface_hub
7
+ gradio
8
+ numpy
9
+ safetensors
tokenizer.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+
3
+ class Tokenizer:
4
+
5
+ def __init__(self, hf_token=None) -> None:
6
+ # Try to get token from environment if not provided
7
+
8
+ if hf_token:
9
+ print(f"[INFO] Using HF token for model access")
10
+ else:
11
+ print("[INFO] No HF token provided - using public models only")
12
+
13
+ self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", token=hf_token)
14
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
15
+
16
+ def ready_tokenizer(self):
17
+
18
+ return self.tokenizer