fereen5 commited on
Commit
5a9ae47
·
verified ·
1 Parent(s): 1d433ea

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +305 -0
app.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ import os
4
+ from datetime import datetime
5
+ from functools import lru_cache
6
+ import torch
7
+ import numpy as np
8
+
9
+ # Language mappings
10
+ LANGUAGE_CODES = {
11
+ "English": "eng_Latn",
12
+ "Korean": "kor_Hang",
13
+ "Japanese": "jpn_Jpan",
14
+ "Chinese": "zho_Hans",
15
+ "Spanish": "spa_Latn",
16
+ "French": "fra_Latn",
17
+ "German": "deu_Latn",
18
+ "Russian": "rus_Cyrl",
19
+ "Portuguese": "por_Latn",
20
+ "Italian": "ita_Latn",
21
+ "Burmese": "mya_Mymr",
22
+ "Thai": "tha_Thai"
23
+ }
24
+
25
+ class TranslationHistory:
26
+ def __init__(self):
27
+ self.history = []
28
+ self.max_entries = 100
29
+
30
+ def add(self, source, translated, src_lang, tgt_lang):
31
+ entry = {
32
+ "source": source,
33
+ "translated": translated,
34
+ "src_lang": src_lang,
35
+ "tgt_lang": tgt_lang,
36
+ "timestamp": datetime.now()
37
+ }
38
+ self.history.insert(0, entry)
39
+ if len(self.history) > self.max_entries:
40
+ self.history.pop()
41
+ return entry
42
+
43
+ def get_history(self):
44
+ return self.history
45
+
46
+ def clear(self):
47
+ self.history = []
48
+
49
+ # Initialize history
50
+ translation_history = TranslationHistory()
51
+
52
+ # Load model and tokenizer with error handling
53
+ try:
54
+ model_name = "facebook/nllb-200-distilled-600M"
55
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
56
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
57
+
58
+ if torch.cuda.is_available():
59
+ model = model.to("cuda")
60
+ device = "cuda"
61
+ else:
62
+ device = "cpu"
63
+
64
+ except Exception as e:
65
+ print(f"Error loading model: {str(e)}")
66
+ raise
67
+
68
+ @lru_cache(maxsize=1000)
69
+ def cached_translate(text, src_lang, tgt_lang, max_length=128, temperature=0.7):
70
+ try:
71
+ if not text.strip():
72
+ return ""
73
+
74
+ # Convert friendly names to codes
75
+ src_code = LANGUAGE_CODES.get(src_lang, src_lang)
76
+ tgt_code = LANGUAGE_CODES.get(tgt_lang, tgt_lang)
77
+
78
+ # Manually define language token mappings
79
+ LANGUAGE_TOKENS = {
80
+ "eng_Latn": tokenizer.convert_tokens_to_ids("eng_Latn"),
81
+ "kor_Hang": tokenizer.convert_tokens_to_ids("kor_Hang"),
82
+ "jpn_Jpan": tokenizer.convert_tokens_to_ids("jpn_Jpan"),
83
+ "zho_Hans": tokenizer.convert_tokens_to_ids("zho_Hans"),
84
+ "spa_Latn": tokenizer.convert_tokens_to_ids("spa_Latn"),
85
+ "fra_Latn": tokenizer.convert_tokens_to_ids("fra_Latn"),
86
+ "deu_Latn": tokenizer.convert_tokens_to_ids("deu_Latn"), # Replace with actual token id for 'deu_Latn'
87
+ "rus_Cyrl": tokenizer.convert_tokens_to_ids("rus_Cyrl"), # Replace with actual token id for 'rus_Cyrl'
88
+ "por_Latn": tokenizer.convert_tokens_to_ids("por_Latn"), # Replace with actual token id for 'por_Latn'
89
+ "ita_Latn": tokenizer.convert_tokens_to_ids("ita_Latn"), # Replace with actual token id for 'ita_Latn'
90
+ "mya_Mymr": tokenizer.convert_tokens_to_ids("mya_Mymr"), # Replace with actual token id for 'mya_Mymr'
91
+ "tha_Thai": tokenizer.convert_tokens_to_ids("tha_Thai") # Replace with actual token id for 'tha_Thai'
92
+ }
93
+
94
+ inputs = tokenizer(text, return_tensors="pt", padding=True)
95
+ if device == "cuda":
96
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
97
+
98
+ forced_bos_token_id = LANGUAGE_TOKENS.get(tgt_code, None) # Use the manual mapping
99
+
100
+ outputs = model.generate(
101
+ **inputs,
102
+ forced_bos_token_id=forced_bos_token_id,
103
+ max_length=max_length,
104
+ temperature=temperature,
105
+ num_beams=5,
106
+ length_penalty=0.6,
107
+ early_stopping=True
108
+ )
109
+
110
+ translated = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
111
+
112
+ # Add to history
113
+ translation_history.add(text, translated, src_lang, tgt_lang)
114
+
115
+ return translated
116
+
117
+ except Exception as e:
118
+ return f"Translation error: {str(e)}"
119
+
120
+ def translate_file_with_progress(file, src_lang, tgt_lang, max_length=128, temperature=0.7):
121
+ try:
122
+ # Ensure file is handled correctly
123
+ file_path = file.name # Gradio provides only the path, not a file object
124
+
125
+ # Open the file manually
126
+ with open(file_path, 'r', encoding='utf-8') as f:
127
+ content = f.read()
128
+
129
+ lines = content.split('\n')
130
+ translated_lines = []
131
+
132
+ progress = gr.Progress()
133
+ for i, line in enumerate(progress.tqdm(lines)):
134
+ if line.strip():
135
+ translated = cached_translate(
136
+ line, src_lang, tgt_lang,
137
+ max_length=max_length,
138
+ temperature=temperature
139
+ )
140
+ translated_lines.append(translated)
141
+ else:
142
+ translated_lines.append("")
143
+
144
+ output = '\n'.join(translated_lines)
145
+
146
+ # Save output
147
+ os.makedirs("translated", exist_ok=True)
148
+ output_path = os.path.join("translated", f"translated_{os.path.basename(file_path)}")
149
+ with open(output_path, 'w', encoding='utf-8') as f:
150
+ f.write(output)
151
+
152
+ return f"Translation saved to {output_path}", output
153
+
154
+ except Exception as e:
155
+ return f"File translation error: {str(e)}", None
156
+
157
+
158
+ def swap_languages(src, tgt):
159
+ return tgt, src
160
+
161
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
162
+ gr.Markdown("""
163
+ # Enhanced NLLB Translator
164
+ Translate text between multiple languages using Facebook's NLLB model
165
+ """)
166
+
167
+ with gr.Tab("Text Translation"):
168
+ with gr.Row():
169
+ src_lang = gr.Dropdown(
170
+ choices=sorted(LANGUAGE_CODES.keys()),
171
+ value="English",
172
+ label="Source Language"
173
+ )
174
+ swap_btn = gr.Button("⇄", scale=0.15)
175
+ tgt_lang = gr.Dropdown(
176
+ choices=sorted(LANGUAGE_CODES.keys()),
177
+ value="Korean",
178
+ label="Target Language"
179
+ )
180
+
181
+ with gr.Row():
182
+ with gr.Column():
183
+ input_text = gr.Textbox(
184
+ lines=5,
185
+ placeholder="Enter text to translate...",
186
+ label="Input Text"
187
+ )
188
+
189
+ with gr.Column():
190
+ output_text = gr.Textbox(
191
+ lines=5,
192
+ label="Translated Text",
193
+ interactive=False
194
+ )
195
+
196
+ with gr.Row():
197
+ translate_btn = gr.Button("Translate", variant="primary")
198
+ clear_btn = gr.Button("Clear")
199
+
200
+ with gr.Accordion("Advanced Options", open=False):
201
+ max_length = gr.Slider(
202
+ minimum=10,
203
+ maximum=512,
204
+ value=128,
205
+ step=1,
206
+ label="Maximum Length"
207
+ )
208
+ temperature = gr.Slider(
209
+ minimum=0.1,
210
+ maximum=2.0,
211
+ value=0.7,
212
+ step=0.1,
213
+ label="Temperature"
214
+ )
215
+
216
+ with gr.Accordion("Translation History", open=False):
217
+ history_list = gr.JSON(translation_history.get_history)
218
+ refresh_btn = gr.Button("Refresh History")
219
+ clear_history_btn = gr.Button("Clear History")
220
+
221
+ with gr.Tab("File Translation"):
222
+ with gr.Row():
223
+ file_input = gr.File(label="Upload file to translate")
224
+
225
+ with gr.Row():
226
+ file_src_lang = gr.Dropdown(
227
+ choices=sorted(LANGUAGE_CODES.keys()),
228
+ value="English",
229
+ label="Source Language"
230
+ )
231
+ file_tgt_lang = gr.Dropdown(
232
+ choices=sorted(LANGUAGE_CODES.keys()),
233
+ value="Korean",
234
+ label="Target Language"
235
+ )
236
+
237
+ with gr.Row():
238
+ file_output_status = gr.Textbox(label="Translation Status")
239
+ file_output_text = gr.Textbox(
240
+ label="Translated Text",
241
+ visible=False,
242
+ interactive=False
243
+ )
244
+
245
+ with gr.Accordion("Advanced Options", open=False):
246
+ file_max_length = gr.Slider(
247
+ minimum=10,
248
+ maximum=512,
249
+ value=128,
250
+ step=1,
251
+ label="Maximum Length"
252
+ )
253
+ file_temperature = gr.Slider(
254
+ minimum=0.1,
255
+ maximum=2.0,
256
+ value=0.7,
257
+ step=0.1,
258
+ label="Temperature"
259
+ )
260
+
261
+ file_translate_btn = gr.Button("Translate File", variant="primary")
262
+
263
+ # Event handlers
264
+ translate_btn.click(
265
+ fn=cached_translate,
266
+ inputs=[input_text, src_lang, tgt_lang, max_length, temperature],
267
+ outputs=output_text
268
+ )
269
+
270
+ file_translate_btn.click(
271
+ fn=translate_file_with_progress,
272
+ inputs=[file_input, file_src_lang, file_tgt_lang, file_max_length, file_temperature],
273
+ outputs=[file_output_status, file_output_text]
274
+ )
275
+
276
+ swap_btn.click(
277
+ fn=swap_languages,
278
+ inputs=[src_lang, tgt_lang],
279
+ outputs=[src_lang, tgt_lang]
280
+ )
281
+
282
+ clear_btn.click(
283
+ lambda: ["", ""],
284
+ outputs=[input_text, output_text]
285
+ )
286
+
287
+ refresh_btn.click(
288
+ fn=translation_history.get_history,
289
+ outputs=history_list
290
+ )
291
+
292
+ clear_history_btn.click(
293
+ fn=translation_history.clear,
294
+ outputs=history_list
295
+ )
296
+
297
+ gr.Markdown(f"""
298
+ ### Model Information
299
+ - Using {model_name}
300
+ - Running on {device}
301
+ - Cache size: 1000 entries
302
+ """)
303
+
304
+ if __name__ == "__main__":
305
+ demo.launch(share=True)