druvx13 commited on
Commit
2fbabfb
·
verified ·
1 Parent(s): 7d80cba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -18
app.py CHANGED
@@ -1,44 +1,44 @@
1
  import os
2
  import json
3
  import torch
4
- import random
5
  from transformers import pipeline, set_seed
6
  import gradio as gr
7
 
8
  # Model setup
9
  CACHE_DIR = "./model_cache"
10
  os.makedirs(CACHE_DIR, exist_ok=True)
 
11
  generator = pipeline(
12
  "text-generation",
13
  model="openai-community/openai-gpt",
14
  cache_dir=CACHE_DIR,
15
- device_map="cpu",
16
- torch_dtype=torch.float32
17
  )
18
 
19
- # State for chat history
20
- def init_history():
21
  return []
22
 
23
- def generate_and_record(prompt, max_length, temperature, top_k, top_p, repetition_penalty, seed, num_return_sequences, history):
 
 
24
  if seed is not None:
25
- set_seed(seed)
26
  outputs = generator(
27
  prompt,
28
- max_length=max_length,
29
- temperature=temperature,
30
- top_k=top_k,
31
- top_p=top_p,
32
- repetition_penalty=repetition_penalty,
33
- num_return_sequences=num_return_sequences,
34
- do_sample=True
35
  )
36
  texts = [out["generated_text"] for out in outputs]
37
- entry = {"prompt": prompt, "results": texts}
38
- history.append(entry)
39
  return "\n\n---\n\n".join(texts), history
40
 
41
- def export_history(history):
42
  path = "chat_history.json"
43
  with open(path, "w", encoding="utf-8") as f:
44
  json.dump(history, f, ensure_ascii=False, indent=2)
@@ -71,7 +71,7 @@ with gr.Blocks(title="GPT Text Generation") as demo:
71
  inputs=[prompt_input, max_length, temperature, top_k, top_p, repetition_penalty, seed_input, num_seq, history_state],
72
  outputs=[output_text, history_state]
73
  )
74
- clear_btn.click(lambda _: [], inputs=[history_state], outputs=[history_state, output_text], _js="() => {document.querySelectorAll('textarea')[1].value='';}")
75
  export_btn.click(fn=export_history, inputs=[history_state], outputs=[])
76
 
77
  demo.queue().launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
2
  import json
3
  import torch
 
4
  from transformers import pipeline, set_seed
5
  import gradio as gr
6
 
7
  # Model setup
8
  CACHE_DIR = "./model_cache"
9
  os.makedirs(CACHE_DIR, exist_ok=True)
10
+
11
  generator = pipeline(
12
  "text-generation",
13
  model="openai-community/openai-gpt",
14
  cache_dir=CACHE_DIR,
15
+ device=-1, # CPU
 
16
  )
17
 
18
+ # Chat history state
19
+ def init_history():
20
  return []
21
 
22
+ def generate_and_record(
23
+ prompt, max_length, temperature, top_k, top_p, repetition_penalty, seed, num_return_sequences, history
24
+ ):
25
  if seed is not None:
26
+ set_seed(int(seed))
27
  outputs = generator(
28
  prompt,
29
+ max_length=int(max_length),
30
+ temperature=float(temperature),
31
+ top_k=int(top_k),
32
+ top_p=float(top_p),
33
+ repetition_penalty=float(repetition_penalty),
34
+ num_return_sequences=int(num_return_sequences),
35
+ do_sample=True,
36
  )
37
  texts = [out["generated_text"] for out in outputs]
38
+ history.append({"prompt": prompt, "results": texts})
 
39
  return "\n\n---\n\n".join(texts), history
40
 
41
+ def export_history(history):
42
  path = "chat_history.json"
43
  with open(path, "w", encoding="utf-8") as f:
44
  json.dump(history, f, ensure_ascii=False, indent=2)
 
71
  inputs=[prompt_input, max_length, temperature, top_k, top_p, repetition_penalty, seed_input, num_seq, history_state],
72
  outputs=[output_text, history_state]
73
  )
74
+ clear_btn.click(lambda _: ([], ""), inputs=[history_state], outputs=[history_state, output_text])
75
  export_btn.click(fn=export_history, inputs=[history_state], outputs=[])
76
 
77
  demo.queue().launch(server_name="0.0.0.0", server_port=7860)