Spaces:
Runtime error
Runtime error
| #this is version two with flagging features | |
| import gradio as gr | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| # initialize the environment | |
| model_name = 'anugrahap/gpt2-indo-textgen' | |
| HF_TOKEN = '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' | |
| hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "gpt2-output") | |
| # define the tokenization method | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, | |
| model_max_length=1e30, | |
| padding_side='right', | |
| return_tensors='pt') | |
| # add the EOS token as PAD token to avoid warnings | |
| model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id) | |
| generator = pipeline('text-generation', model=model, tokenizer=tokenizer) | |
| # create the decoder parameter to generate the text | |
| def single_generation(text,min_length,max_length,temperature,top_k,top_p,num_beams,repetition_penalty,do_sample): | |
| # create local variable for error parameter | |
| error_rep=gr.Error(f"ERROR: repetition penalty cannot be lower than one! Given rep penalty = {repetition_penalty}") | |
| error_temp=gr.Error(f"ERROR: temperature cannot be zero or lower! Given temperature = {temperature}") | |
| error_minmax=gr.Error(f"ERROR: min length must be lower than or equal to max length! Given min length = {min_length}") | |
| error_numbeams_type=gr.Error(f"ERROR: number of beams must be an integer not {type(num_beams)}") | |
| error_topk_type=gr.Error(f"ERROR: top k must be an integer not {type(top_k)}") | |
| error_minmax_type=gr.Error(f"ERROR: min length and max length must be an integer not {type(min_length)} and {type(max_length)}") | |
| error_empty_temprep=gr.Error("ERROR: temperature and repetition penalty cannot be empty!") | |
| error_empty_text=gr.Error("ERROR: Input Text cannot be empty!") | |
| error_unknown=gr.Error("Unknown Error.") | |
| if text != '': | |
| if type(min_length) == int and type(max_length) == int: | |
| if type(top_k) == int: | |
| if type(num_beams) == int: | |
| if min_length <= max_length: | |
| if temperature > 0: | |
| if repetition_penalty >= 1: | |
| if temperature and repetition_penalty is not None: | |
| result = generator(text, | |
| min_length=min_length, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| num_beams=num_beams, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=do_sample, | |
| no_repeat_ngram_size=2, | |
| num_return_sequences=1) | |
| return result[0]["generated_text"] | |
| elif temperature or repetition_penalty is None: | |
| raise error_empty_temprep | |
| elif repetition_penalty < 1: | |
| raise error_rep | |
| elif temperature <= 0: | |
| raise error_temp | |
| elif min_length > max_length: | |
| raise error_minmax | |
| elif type(num_beams) != int: | |
| raise error_numbeams_type | |
| elif type(top_k) != int: | |
| raise error_topk_type | |
| elif type(min_length) != int or type(max_length) != int: | |
| raise error_minmax_type | |
| elif text == '': | |
| raise error_empty_text | |
| else: | |
| raise error_unknown | |
| # create the variable needed for the gradio app | |
| forinput=[gr.Textbox(lines=5, label="Input Text"), | |
| gr.Slider(label="Min Length", minimum=10, maximum=50, value=10, step=5), | |
| gr.Slider(label="Max Length", minimum=10, maximum=100, value=30, step=10), | |
| gr.Number(label="Temperature Sampling", value=1.5), | |
| gr.Slider(label="Top K Sampling", minimum=0, maximum=100, value=30, step=5), | |
| gr.Slider(label="Top P Sampling", minimum=0.01, maximum=1, value=0.93), | |
| gr.Slider(label="Number of Beams", minimum=1, maximum=10, value=5, step=1), | |
| gr.Number(label="Repetition Penalty", value=2.0), | |
| gr.Dropdown(label="Do Sample?", choices=[True,False], value=True, multiselect=False)] | |
| foroutput=gr.Textbox(lines=5, max_lines=50, label="Generated Text with Greedy/Beam Search Decoding") | |
| examples = [ | |
| ["Indonesia adalah negara kepulauan", 10, 30, 1.0, 25, 0.92, 5, 2.0, True], | |
| ["Indonesia adalah negara kepulauan", 10, 30, 1.0, 25, 0.92, 5, 1.0, False], | |
| ["Skripsi merupakan tugas akhir mahasiswa", 20, 40, 1.0, 50, 0.92, 1, 2.0, True], | |
| ["Skripsi merupakan tugas akhir mahasiswa", 20, 40, 1.0, 50, 0.92, 1, 1.0, False], | |
| ["Pemandangan di pantai kuta Bali sangatlah indah.", 30, 50, 0.5, 40, 0.98, 10, 1.0, True], | |
| ["Pemandangan di pantai kuta Bali sangatlah indah.", 10, 30, 1.5, 30, 0.93, 5, 2.0, True]] | |
| title = """ | |
| <style> | |
| .center { | |
| display: block; | |
| margin-top: 20px; | |
| margin-down: 0px; | |
| margin-left: auto; | |
| margin-right: auto; | |
| } | |
| </style> | |
| <style> | |
| h1 { | |
| text-align: center; | |
| margin-top: 0px; | |
| } | |
| </style> | |
| <img src="https://i.postimg.cc/cHPVPSfH/Q-GEN-logo.png" | |
| alt="Q-GEN Logo" | |
| border="0" | |
| class="center" | |
| style="height: 100px; width: 100px;"/> | |
| <h1>GPT-2 Indonesian Text Generation Playground</h1>""" | |
| description = "<p><i>This project is a part of thesis requirement of Anugrah Akbar Praramadhan</i></p>" | |
| article = """<p style='text-align: center'> | |
| <a href='https://huggingface.co/anugrahap/gpt2-indo-textgen' target='_blank'>Link to the Trained Model<b> |</b></a> | |
| <a href='https://huggingface.co/spaces/anugrahap/gpt2-indo-text-gen/tree/main' target='_blank'>Link to the Project Repository<b> |</b></a> | |
| <a href='https://huggingface.co/datasets/anugrahap/gpt2-output/' target='_blank'>Link to the Autosaved Generated Output<b> |</b></a> | |
| <a href='https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf' target='_blank'>Original Paper</a><br></p> | |
| <p style='text-align: center'> Trained on Indo4B Benchmark Dataset of Indonesian language Wikipedia with a Causal Language Modeling (CLM) objective<br></p> | |
| <p style='text-align: center'>Copyright Anugrah Akbar Praramadhan 2023</p> | |
| """ | |
| # using gradio interfaces | |
| app = gr.Interface( | |
| fn=single_generation, | |
| inputs=forinput, | |
| outputs=foroutput, | |
| examples=examples, | |
| title=title, | |
| description=description, | |
| article=article, | |
| allow_flagging='manual', | |
| flagging_options=['Well Performed', 'Inappropriate Word Selection', 'Wordy', 'Strange Word', 'Others'], | |
| flagging_callback=hf_writer) | |
| if __name__=='__main__': | |
| app.launch() |