Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| AutoModelForSeq2SeqLM, | |
| AutoTokenizer, | |
| ) | |
| from IndicTransToolkit import IndicProcessor | |
| import os | |
| import subprocess | |
| # Function to clone the repository and set up the environment | |
| def setup_repo(): | |
| # Clone the repository | |
| repo_url = "https://github.com/AI4Bharat/IndicTrans2" | |
| repo_dir = "IndicTrans2" | |
| if not os.path.exists(repo_dir): | |
| subprocess.run(["git", "clone", repo_url]) | |
| # Navigate to the project directory and install dependencies | |
| os.chdir(os.path.join(repo_dir, "huggingface_interface")) | |
| subprocess.run(["source", "install.sh"], shell=True) | |
| # Function to process translation | |
| def translate(input_text, src_lang, tgt_lang): | |
| setup_repo() # Ensure the repo is set up | |
| model_name = "ai4bharat/indictrans2-indic-indic-1B" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) | |
| ip = IndicProcessor(inference=True) | |
| batch = ip.preprocess_batch([input_text], src_lang=src_lang, tgt_lang=tgt_lang) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| inputs = tokenizer( | |
| batch, | |
| truncation=True, | |
| padding="longest", | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| generated_tokens = model.generate( | |
| **inputs, | |
| use_cache=True, | |
| min_length=0, | |
| max_length=256, | |
| num_beams=5, | |
| num_return_sequences=1, | |
| ) | |
| with tokenizer.as_target_tokenizer(): | |
| translation = tokenizer.batch_decode( | |
| generated_tokens.detach().cpu().tolist(), | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True, | |
| )[0] | |
| return translation | |
| # List of languages with their code names | |
| languages = [ | |
| ("Assamese", "asm_Beng"), ("Kashmiri (Arabic)", "kas_Arab"), ("Punjabi", "pan_Guru"), | |
| ("Bengali", "ben_Beng"), ("Kashmiri (Devanagari)", "kas_Deva"), ("Sanskrit", "san_Deva"), | |
| ("Bodo", "brx_Deva"), ("Maithili", "mai_Deva"), ("Santali", "sat_Olck"), | |
| ("Dogri", "doi_Deva"), ("Malayalam", "mal_Mlym"), ("Sindhi (Arabic)", "snd_Arab"), | |
| ("English", "eng_Latn"), ("Marathi", "mar_Deva"), ("Sindhi (Devanagari)", "snd_Deva"), | |
| ("Konkani", "gom_Deva"), ("Manipuri (Bengali)", "mni_Beng"), ("Tamil", "tam_Taml"), | |
| ("Gujarati", "guj_Gujr"), ("Manipuri (Meitei)", "mni_Mtei"), ("Telugu", "tel_Telu"), | |
| ("Hindi", "hin_Deva"), ("Nepali", "npi_Deva"), ("Urdu", "urd_Arab"), | |
| ("Kannada", "kan_Knda"), ("Odia", "ory_Orya") | |
| ] | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# IndicTrans2 Translation") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox(label="Input Text") | |
| src_lang = gr.Dropdown(label="Source Language", choices=[lang[0] for lang in languages], type="value") | |
| tgt_lang = gr.Dropdown(label="Target Language", choices=[lang[0] for lang in languages], type="value") | |
| translate_button = gr.Button("Translate") | |
| output_text = gr.Textbox(label="Translated Output") | |
| # Call translate function when button is clicked | |
| translate_button.click(fn=translate, inputs=[input_text, src_lang, tgt_lang], outputs=output_text) | |
| demo.launch() | |