MegaTronX commited on
Commit
aa3631e
·
verified ·
1 Parent(s): 51ae273

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -70
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import os
 
5
 
6
  # Set cache directory for Spaces
7
  os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache'
@@ -14,32 +15,45 @@ class HunyuanTranslator:
14
  self._load_model()
15
 
16
  def _load_model(self):
17
- """Load the pre-quantized FP8 model"""
18
- print("Loading pre-quantized Hunyuan-MT FP8 model...")
19
 
20
  try:
 
21
  self.tokenizer = AutoTokenizer.from_pretrained(
22
  self.model_name,
23
  cache_dir='/tmp/cache',
24
  trust_remote_code=True
25
  )
26
 
27
- # Load the pre-quantized FP8 model - let transformers handle the quantization automatically
28
- self.model = AutoModelForCausalLM.from_pretrained(
 
29
  self.model_name,
30
- device_map="auto",
31
- trust_remote_code=True, # Important for custom models
32
- cache_dir='/tmp/cache',
33
- torch_dtype=torch.float16, # Use fp16 as fallback, model will use its native fp8 where available
34
  )
35
 
36
- print("FP8 model loaded successfully!")
37
  print(f"Model device: {self.model.device}")
38
  print(f"Model dtype: {next(self.model.parameters()).dtype}")
39
 
40
  except Exception as e:
41
- print(f"Error loading model: {e}")
42
- raise
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def translate_ja_to_en(self, input_text: str) -> str:
45
  """Translate Japanese to English using FP8 model"""
@@ -47,63 +61,84 @@ class HunyuanTranslator:
47
  return "Please enter some Japanese text to translate."
48
 
49
  # Limit input length for Spaces
50
- if len(input_text) > 2000:
51
- return "Input too long. Please keep under 2000 characters for this demo."
52
 
53
  try:
54
- # Japanese to English specific prompt
55
- prompt = f"Translate the following Japanese text to English. Provide only the translation without additional explanations:\n\nJapanese: {input_text}\nEnglish:"
56
 
57
- messages = [{"role": "user", "content": prompt}]
 
 
 
 
 
58
 
59
- # Apply chat template
60
- tokenized_chat = self.tokenizer.apply_chat_template(
61
- messages,
62
- tokenize=True,
63
- add_generation_prompt=True,
64
- return_tensors="pt",
65
  )
66
 
67
- # Generate with FP8 model
 
 
 
68
  with torch.no_grad():
69
  outputs = self.model.generate(
70
- tokenized_chat.to(self.model.device),
71
  max_new_tokens=512,
72
  temperature=0.7,
73
  do_sample=True,
74
  top_p=0.9,
75
  repetition_penalty=1.1,
76
  pad_token_id=self.tokenizer.eos_token_id,
77
- eos_token_id=self.tokenizer.eos_token_id
 
78
  )
79
 
80
- # Decode output
81
- output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
82
 
83
- # Extract translation (remove prompt and get only the English part)
84
- if "English:" in output_text:
85
- output_text = output_text.split("English:")[-1].strip()
 
 
 
 
 
 
86
 
87
- # Clean up any remaining special tokens or markers
88
- output_text = output_text.replace("<|endoftext|>", "").strip()
89
- output_text = output_text.replace("</s>", "").strip()
90
 
91
- return output_text if output_text else "No translation generated. Please try again."
92
 
93
  except Exception as e:
94
  return f"Error during translation: {str(e)}"
95
 
96
  def create_translation_interface():
97
- """Create the Gradio interface optimized for Spaces"""
98
 
99
  # Initialize translator
100
- translator = HunyuanTranslator()
101
-
102
- def translate_function(input_text):
103
- """Wrapper function for Gradio"""
104
- return translator.translate_ja_to_en(input_text)
 
 
 
 
 
 
105
 
106
- # Custom CSS for better appearance on Spaces
107
  custom_css = """
108
  .gradio-container {
109
  max-width: 900px !important;
@@ -114,13 +149,16 @@ def create_translation_interface():
114
  margin: auto;
115
  padding: 20px;
116
  }
117
- .example-text {
118
- font-size: 0.9em;
119
- color: #666;
 
 
 
120
  }
121
  """
122
 
123
- # Create Gradio interface optimized for Spaces
124
  with gr.Blocks(
125
  title="Japanese to English Translation - Hunyuan-MT FP8",
126
  theme=gr.themes.Soft(),
@@ -130,22 +168,22 @@ def create_translation_interface():
130
  gr.Markdown(
131
  """
132
  # 🇯🇵 → 🇺🇸 Japanese to English Translation
133
- **Model:** `tencent/Hunyuan-MT-7B-fp8` (7B parameters, pre-quantized FP8)
134
- **Specialization:** High-quality Japanese → English translation
135
 
136
- *Enter Japanese text below and click Translate*
137
  """
138
  )
139
 
140
  with gr.Row(equal_height=False):
141
  with gr.Column(scale=1):
 
142
  input_text = gr.Textbox(
143
- label="Japanese Text Input",
144
- placeholder="日本語のテキストを入力してください... (Enter Japanese text here)",
145
- lines=5,
146
  max_lines=8,
147
  show_copy_button=True,
148
- elem_id="input-text"
149
  )
150
 
151
  with gr.Row():
@@ -163,17 +201,18 @@ def create_translation_interface():
163
  )
164
 
165
  with gr.Column(scale=1):
 
166
  output_text = gr.Textbox(
167
- label="English Translation",
168
  placeholder="Translation will appear here...",
169
- lines=5,
170
  max_lines=8,
171
  show_copy_button=True,
172
- elem_id="output-text"
173
  )
174
 
175
  # Examples section
176
- gr.Markdown("### 💡 Try these examples:")
177
  examples = gr.Examples(
178
  examples=[
179
  ["こんにちは、元気ですか?"],
@@ -181,7 +220,9 @@ def create_translation_interface():
181
  ["機械学習と人工知能は現代技術の重要な分野です。"],
182
  ["このレストランの料理はとても美味しいです。"],
183
  ["明日の会議は午後二時から始まります。"],
184
- ["日本の文化は非常に興味深いと思います。"]
 
 
185
  ],
186
  inputs=input_text,
187
  outputs=output_text,
@@ -211,22 +252,29 @@ def create_translation_interface():
211
  outputs=output_text
212
  )
213
 
214
- # Additional info
215
  gr.Markdown(
216
  """
217
  ---
218
- ### ℹ️ Usage Notes:
219
- - **Model**: tencent/Hunyuan-MT-7B-fp8 (7B parameters, FP8 quantized)
220
- - **Optimized** specifically for Japanese → English translation
221
- - **Max input length**: ~2000 characters
222
- - **Translation time**: Usually 10-30 seconds
223
- - **Memory efficient**: Uses FP8 quantization for faster inference
224
-
225
- ### 🛠️ Technical Details:
226
- - Pre-quantized to FP8 (8-bit floating point)
227
- - ~3-4GB memory footprint
228
- - Optimized for GPU inference
229
- - Supports long-form translation
 
 
 
 
 
 
 
230
  """
231
  )
232
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import os
5
+ from compressed_tensors import load_compressed_model
6
 
7
  # Set cache directory for Spaces
8
  os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache'
 
15
  self._load_model()
16
 
17
  def _load_model(self):
18
+ """Load the pre-quantized FP8 model using Compressed Tensors"""
19
+ print("Loading Hunyuan-MT FP8 model with Compressed Tensors...")
20
 
21
  try:
22
+ # Load tokenizer first
23
  self.tokenizer = AutoTokenizer.from_pretrained(
24
  self.model_name,
25
  cache_dir='/tmp/cache',
26
  trust_remote_code=True
27
  )
28
 
29
+ # Load model with Compressed Tensors
30
+ print("Loading model with compressed_tensors...")
31
+ self.model = load_compressed_model(
32
  self.model_name,
33
+ device="auto", # Automatically use GPU if available
34
+ torch_dtype=torch.float16,
35
+ trust_remote_code=True
 
36
  )
37
 
38
+ print("FP8 model loaded successfully with Compressed Tensors!")
39
  print(f"Model device: {self.model.device}")
40
  print(f"Model dtype: {next(self.model.parameters()).dtype}")
41
 
42
  except Exception as e:
43
+ print(f"Error loading model with Compressed Tensors: {e}")
44
+ # Fallback to standard loading without compression
45
+ try:
46
+ print("Trying standard loading as fallback...")
47
+ self.model = AutoModelForCausalLM.from_pretrained(
48
+ self.model_name,
49
+ device_map="auto",
50
+ torch_dtype=torch.float16,
51
+ trust_remote_code=True,
52
+ cache_dir='/tmp/cache'
53
+ )
54
+ print("Model loaded successfully with standard method!")
55
+ except Exception as e2:
56
+ raise Exception(f"Both Compressed Tensors and standard loading failed: {e2}")
57
 
58
  def translate_ja_to_en(self, input_text: str) -> str:
59
  """Translate Japanese to English using FP8 model"""
 
61
  return "Please enter some Japanese text to translate."
62
 
63
  # Limit input length for Spaces
64
+ if len(input_text) > 1500:
65
+ return "Input too long. Please keep under 1500 characters for this demo."
66
 
67
  try:
68
+ # Clean and prepare the input text
69
+ input_text = input_text.strip()
70
 
71
+ # Create a clear translation prompt
72
+ prompt = f"""Translate the following Japanese text to English. Provide only the translation without any additional explanations or notes.
73
+
74
+ Japanese: {input_text}
75
+
76
+ English:"""
77
 
78
+ # Tokenize the input
79
+ inputs = self.tokenizer(
80
+ prompt,
81
+ return_tensors="pt",
82
+ truncation=True,
83
+ max_length=1024
84
  )
85
 
86
+ # Move inputs to the same device as model
87
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
88
+
89
+ # Generate translation
90
  with torch.no_grad():
91
  outputs = self.model.generate(
92
+ **inputs,
93
  max_new_tokens=512,
94
  temperature=0.7,
95
  do_sample=True,
96
  top_p=0.9,
97
  repetition_penalty=1.1,
98
  pad_token_id=self.tokenizer.eos_token_id,
99
+ eos_token_id=self.tokenizer.eos_token_id,
100
+ num_return_sequences=1
101
  )
102
 
103
+ # Decode the output
104
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
105
 
106
+ # Extract just the translation part (remove the prompt)
107
+ if prompt in generated_text:
108
+ translation = generated_text.replace(prompt, "").strip()
109
+ else:
110
+ # If prompt isn't found, try to extract after "English:"
111
+ if "English:" in generated_text:
112
+ translation = generated_text.split("English:")[-1].strip()
113
+ else:
114
+ translation = generated_text.strip()
115
 
116
+ # Clean up the translation
117
+ translation = translation.split('\n')[0].strip() # Take first line only
118
+ translation = translation.replace('"', '').strip()
119
 
120
+ return translation if translation else "No translation generated. Please try again."
121
 
122
  except Exception as e:
123
  return f"Error during translation: {str(e)}"
124
 
125
  def create_translation_interface():
126
+ """Create the Gradio interface for Japanese to English translation"""
127
 
128
  # Initialize translator
129
+ try:
130
+ translator = HunyuanTranslator()
131
+
132
+ def translate_function(input_text):
133
+ return translator.translate_ja_to_en(input_text)
134
+
135
+ except Exception as e:
136
+ print(f"Failed to initialize translator: {e}")
137
+
138
+ def translate_function(input_text):
139
+ return f"Model initialization failed: {str(e)}\n\nPlease check that 'compressed-tensors' is installed and try again."
140
 
141
+ # Custom CSS for better appearance
142
  custom_css = """
143
  .gradio-container {
144
  max-width: 900px !important;
 
149
  margin: auto;
150
  padding: 20px;
151
  }
152
+ .japanese-text {
153
+ font-family: "Hiragino Sans", "Yu Gothic", "Meiryo", sans-serif;
154
+ }
155
+ .translation-box {
156
+ border-left: 3px solid #4CAF50;
157
+ padding-left: 15px;
158
  }
159
  """
160
 
161
+ # Create Gradio interface
162
  with gr.Blocks(
163
  title="Japanese to English Translation - Hunyuan-MT FP8",
164
  theme=gr.themes.Soft(),
 
168
  gr.Markdown(
169
  """
170
  # 🇯🇵 → 🇺🇸 Japanese to English Translation
171
+ **Model:** `tencent/Hunyuan-MT-7B-fp8` **Technology:** Compressed Tensors FP8 Quantization
 
172
 
173
+ *Fast, high-quality Japanese to English translation using optimized FP8 model*
174
  """
175
  )
176
 
177
  with gr.Row(equal_height=False):
178
  with gr.Column(scale=1):
179
+ gr.Markdown("### 📥 Japanese Input")
180
  input_text = gr.Textbox(
181
+ label="",
182
+ placeholder="日本語のテキストを入力してください...\n(Enter Japanese text here)",
183
+ lines=6,
184
  max_lines=8,
185
  show_copy_button=True,
186
+ elem_classes=["japanese-text"]
187
  )
188
 
189
  with gr.Row():
 
201
  )
202
 
203
  with gr.Column(scale=1):
204
+ gr.Markdown("### 📤 English Translation")
205
  output_text = gr.Textbox(
206
+ label="",
207
  placeholder="Translation will appear here...",
208
+ lines=6,
209
  max_lines=8,
210
  show_copy_button=True,
211
+ elem_classes=["translation-box"]
212
  )
213
 
214
  # Examples section
215
+ gr.Markdown("### 💡 Example Translations")
216
  examples = gr.Examples(
217
  examples=[
218
  ["こんにちは、元気ですか?"],
 
220
  ["機械学習と人工知能は現代技術の重要な分野です。"],
221
  ["このレストランの料理はとても美味しいです。"],
222
  ["明日の会議は午後二時から始まります。"],
223
+ ["日本の文化は非常に興味深いと思います。"],
224
+ ["新しいプロジェクトの提案書を作成しました。"],
225
+ ["電車の遅延により、到着が30分ほど遅れます。"]
226
  ],
227
  inputs=input_text,
228
  outputs=output_text,
 
252
  outputs=output_text
253
  )
254
 
255
+ # Technical details
256
  gr.Markdown(
257
  """
258
  ---
259
+ ### 🛠️ Technical Information
260
+
261
+ **Model Details:**
262
+ - **Base Model**: Hunyuan-MT 7B
263
+ - **Quantization**: FP8 (8-bit floating point) via Compressed Tensors
264
+ - **Memory Usage**: ~3-4GB
265
+ - **Specialization**: Japanese ↔ English translation
266
+
267
+ **Optimization Features:**
268
+ - FP8 quantization for faster inference
269
+ - Compressed Tensors for efficient storage
270
+ - GPU acceleration support
271
+ - ✅ Batch processing capable
272
+
273
+ **Usage Tips:**
274
+ - Keep inputs under 1500 characters for best results
275
+ - Translation takes 5-15 seconds typically
276
+ - Model works best with complete sentences
277
+ - Handles technical and casual Japanese well
278
  """
279
  )
280