druvx13 commited on
Commit
08437dc
·
verified ·
1 Parent(s): d10d67d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import random
5
+ from transformers import pipeline, set_seed
6
+ import gradio as gr
7
+
8
+ # Model setup
9
+ CACHE_DIR = "./model_cache"
10
+ os.makedirs(CACHE_DIR, exist_ok=True)
11
+ generator = pipeline(
12
+ "text-generation",
13
+ model="openai-community/openai-gpt",
14
+ cache_dir=CACHE_DIR,
15
+ device_map="cpu",
16
+ torch_dtype=torch.float32
17
+ )
18
+
19
+ # State for chat history
20
+ def init_history():
21
+ return []
22
+
23
+ def generate_and_record(prompt, max_length, temperature, top_k, top_p, repetition_penalty, seed, num_return_sequences, history):
24
+ if seed is not None:
25
+ set_seed(seed)
26
+ outputs = generator(
27
+ prompt,
28
+ max_length=max_length,
29
+ temperature=temperature,
30
+ top_k=top_k,
31
+ top_p=top_p,
32
+ repetition_penalty=repetition_penalty,
33
+ num_return_sequences=num_return_sequences,
34
+ do_sample=True
35
+ )
36
+ texts = [out["generated_text"] for out in outputs]
37
+ entry = {"prompt": prompt, "results": texts}
38
+ history.append(entry)
39
+ return "\n\n---\n\n".join(texts), history
40
+
41
+ def export_history(history):
42
+ path = "chat_history.json"
43
+ with open(path, "w", encoding="utf-8") as f:
44
+ json.dump(history, f, ensure_ascii=False, indent=2)
45
+ return path
46
+
47
+ with gr.Blocks(title="GPT Text Generation") as demo:
48
+ gr.Markdown("## Text Generation with openai-community/openai-gpt (CPU)")
49
+
50
+ with gr.Row():
51
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=2)
52
+ max_length = gr.Slider(32, 1024, value=128, step=32, label="Max Length")
53
+ with gr.Row():
54
+ temperature = gr.Slider(0.1, 1.5, value=1.0, step=0.1, label="Temperature")
55
+ top_k = gr.Slider(0, 100, value=50, step=1, label="Top-K Sampling")
56
+ with gr.Row():
57
+ top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-P (Nucleus) Sampling")
58
+ repetition_penalty = gr.Slider(0.5, 2.0, value=1.1, step=0.1, label="Repetition Penalty")
59
+ seed_input = gr.Number(value=None, precision=0, label="Seed (optional)")
60
+ num_seq = gr.Dropdown(choices=[1, 2, 3, 5], value=1, label="Number of Generations")
61
+
62
+ generate_btn = gr.Button("Generate")
63
+ clear_btn = gr.Button("Clear History")
64
+ export_btn = gr.Button("Export History")
65
+
66
+ output_text = gr.TextArea(label="Generated Text", interactive=False, lines=10)
67
+ history_state = gr.State(init_history())
68
+
69
+ generate_btn.click(
70
+ fn=generate_and_record,
71
+ inputs=[prompt_input, max_length, temperature, top_k, top_p, repetition_penalty, seed_input, num_seq, history_state],
72
+ outputs=[output_text, history_state]
73
+ )
74
+ clear_btn.click(lambda _: [], inputs=[history_state], outputs=[history_state, output_text], _js="() => {document.querySelectorAll('textarea')[1].value='';}")
75
+ export_btn.click(fn=export_history, inputs=[history_state], outputs=[])
76
+
77
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)