edbeeching commited on
Commit
8854100
·
1 Parent(s): b6d1901

bugfixes and improved validation

Browse files
Files changed (1) hide show
  1. app.py +53 -44
app.py CHANGED
@@ -8,6 +8,13 @@ from enum import Enum
8
  from datasets import get_dataset_infos
9
  from transformers import AutoConfig
10
 
 
 
 
 
 
 
 
11
 
12
  class GenerationStatus(Enum):
13
  PENDING = "PENDING"
@@ -27,6 +34,7 @@ class GenerationRequest:
27
  input_dataset_name: str
28
  input_dataset_config: str
29
  input_dataset_split: str
 
30
  prompt_column: str
31
  model_name_or_path: str
32
  model_revision: str
@@ -55,7 +63,7 @@ def validate_request(request: GenerationRequest):
55
  raise Exception(f"Dataset split {request.input_dataset_split} does not exist in dataset {request.input_dataset_name}. Available splits: {list(input_dataset_info.splits.keys())}")
56
 
57
  # check that the number of samples is less than MAX_SAMPLES
58
- if input_dataset_info.splits[request.input_dataset_split].num_samples > MAX_SAMPLES:
59
  raise Exception(f"Dataset split {request.input_dataset_split} in dataset {request.input_dataset_name} exceeds max sample limit of {MAX_SAMPLES}.")
60
 
61
  # check the prompt column exists in the dataset
@@ -67,6 +75,7 @@ def validate_request(request: GenerationRequest):
67
  try:
68
  model_config = AutoConfig.from_pretrained(request.model_name_or_path, revision=request.model_revision, token=request.model_token)
69
  except Exception as e:
 
70
  raise Exception(f"Model {request.model_name_or_path} revision {request.model_revision} does not exist or cannot be accessed with the provided token.")
71
 
72
  # check the model max position embeddings is greater than the requested max tokens and less than MAX_TOKENS
@@ -91,36 +100,42 @@ def validate_request(request: GenerationRequest):
91
  def add_request_to_db(request: GenerationRequest):
92
  url: str = os.getenv("SUPABASE_URL")
93
  key: str = os.getenv("SUPABASE_KEY")
94
- options: ClientOptions = {
95
- "schema": "public"
96
- }
97
- supabase: Client = create_client(url, key, options)
98
-
99
- data = {
100
- "status": request.status.value,
101
- "input_dataset_name": request.input_dataset_name,
102
- "input_dataset_config": request.input_dataset_config,
103
- "input_dataset_split": request.input_dataset_split,
104
- "prompt_column": request.prompt_column,
105
- "model_name_or_path": request.model_name_or_path,
106
- "model_revision": request.model_revision,
107
- "model_token": request.model_token,
108
- "system_prompt": request.system_prompt,
109
- "max_tokens": request.max_tokens,
110
- "temperature": request.temperature,
111
- "top_k": request.top_k,
112
- "top_p": request.top_p,
113
- "input_dataset_token": request.input_dataset_token,
114
- "output_dataset_token": request.output_dataset_token,
115
- "username": request.username,
116
- "email": request.email
117
- }
118
-
119
- response = supabase.table("generation-requests").insert(data).execute()
120
- if response.status_code != 201:
121
- raise Exception(f"Failed to add request to database: {response.data}")
122
 
123
- return response.data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
 
126
  def create_gradio_interface():
@@ -130,11 +145,7 @@ def create_gradio_interface():
130
  gr.Markdown("# Synthetic Data Generation Request")
131
  with gr.Row():
132
  gr.Markdown("""
133
- Welcome to the Synthetic Data Generation service! This tool allows you to generate synthetic data using large language models.
134
-
135
- Generation is FREE for Hugging Face PRO users and uses idle GPUs on the HF science cluster.\n
136
-
137
-
138
  """)
139
  with gr.Group():
140
  with gr.Row():
@@ -153,8 +164,6 @@ def create_gradio_interface():
153
  - Model must be accessible (public or with valid token)
154
  - Maximum 10,000 samples per dataset
155
  - Maximum of 32k generation tokens
156
-
157
- **Note:** Generation requests are processed asynchronously. You will be notified via email when your request is complete.
158
  """)
159
 
160
  with gr.Row():
@@ -184,7 +193,7 @@ def create_gradio_interface():
184
  temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, value=0.7, step=0.1)
185
  with gr.Row():
186
  top_k = gr.Slider(label="Top K", value=50, minimum=5, maximum=100, step=5)
187
- top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.1)
188
  with gr.Column():
189
  system_prompt = gr.Textbox(label="System Prompt (Optional)", lines=3, placeholder="Optional system prompt... e.g., You are a helpful assistant.")
190
 
@@ -202,15 +211,16 @@ def create_gradio_interface():
202
  submit_btn = gr.Button("Submit Generation Request", variant="primary")
203
  output_status = gr.Textbox(label="Status", interactive=False)
204
 
205
- def submit_request(input_ds, input_split, prompt_col, model_name, model_rev, model_token, sys_prompt,
206
- max_tok, temp, top_k_val, top_p_val, output_ds, user, email_addr, input_dataset_token, output_dataset_token):
207
  try:
208
  request = GenerationRequest(
209
  id="", # Will be generated when adding to the database
210
  status=GenerationStatus.PENDING,
211
- input_dataset_name=input_ds,
212
  input_dataset_split=input_split,
213
  input_dataset_config=input_dataset_config,
 
214
  prompt_column=prompt_col,
215
  model_name_or_path=model_name,
216
  model_revision=model_rev,
@@ -220,7 +230,6 @@ def create_gradio_interface():
220
  temperature=temp,
221
  top_k=int(top_k_val),
222
  top_p=top_p_val,
223
- output_dataset_name=output_ds,
224
  input_dataset_token=input_dataset_token if input_dataset_token else None,
225
  output_dataset_token=output_dataset_token,
226
  username=user,
@@ -237,9 +246,9 @@ def create_gradio_interface():
237
 
238
  submit_btn.click(
239
  submit_request,
240
- inputs=[input_dataset_name, input_dataset_split, prompt_column, model_name_or_path,
241
  model_revision, model_token, system_prompt, max_tokens, temperature, top_k, top_p,
242
- output_dataset_name, username, email, input_dataset_token, output_dataset_token],
243
  outputs=output_status
244
  )
245
 
 
8
  from datasets import get_dataset_infos
9
  from transformers import AutoConfig
10
 
11
+ """
12
+ Still TODO:
13
+ - validate the user is PRO
14
+ - check the output dataset token is valid
15
+
16
+ """
17
+
18
 
19
  class GenerationStatus(Enum):
20
  PENDING = "PENDING"
 
34
  input_dataset_name: str
35
  input_dataset_config: str
36
  input_dataset_split: str
37
+ output_dataset_name: str
38
  prompt_column: str
39
  model_name_or_path: str
40
  model_revision: str
 
63
  raise Exception(f"Dataset split {request.input_dataset_split} does not exist in dataset {request.input_dataset_name}. Available splits: {list(input_dataset_info.splits.keys())}")
64
 
65
  # check that the number of samples is less than MAX_SAMPLES
66
+ if input_dataset_info.splits[request.input_dataset_split].num_examples > MAX_SAMPLES:
67
  raise Exception(f"Dataset split {request.input_dataset_split} in dataset {request.input_dataset_name} exceeds max sample limit of {MAX_SAMPLES}.")
68
 
69
  # check the prompt column exists in the dataset
 
75
  try:
76
  model_config = AutoConfig.from_pretrained(request.model_name_or_path, revision=request.model_revision, token=request.model_token)
77
  except Exception as e:
78
+ print(e)
79
  raise Exception(f"Model {request.model_name_or_path} revision {request.model_revision} does not exist or cannot be accessed with the provided token.")
80
 
81
  # check the model max position embeddings is greater than the requested max tokens and less than MAX_TOKENS
 
100
  def add_request_to_db(request: GenerationRequest):
101
  url: str = os.getenv("SUPABASE_URL")
102
  key: str = os.getenv("SUPABASE_KEY")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ try:
105
+ supabase: Client = create_client(
106
+ url,
107
+ key,
108
+ options=ClientOptions(
109
+ postgrest_client_timeout=10,
110
+ storage_client_timeout=10,
111
+ schema="public",
112
+ )
113
+ )
114
+
115
+ data = {
116
+ "status": request.status.value,
117
+ "input_dataset_name": request.input_dataset_name,
118
+ "input_dataset_config": request.input_dataset_config,
119
+ "input_dataset_split": request.input_dataset_split,
120
+ "output_dataset_name": request.output_dataset_name,
121
+ "prompt_column": request.prompt_column,
122
+ "model_name_or_path": request.model_name_or_path,
123
+ "model_revision": request.model_revision,
124
+ "model_token": request.model_token,
125
+ "system_prompt": request.system_prompt,
126
+ "max_tokens": request.max_tokens,
127
+ "temperature": request.temperature,
128
+ "top_k": request.top_k,
129
+ "top_p": request.top_p,
130
+ "input_dataset_token": request.input_dataset_token,
131
+ "output_dataset_token": request.output_dataset_token,
132
+ "username": request.username,
133
+ "email": request.email
134
+ }
135
+
136
+ supabase.table("gen-requests").insert(data).execute()
137
+ except Exception as e:
138
+ raise Exception("Failed to add request to database")
139
 
140
 
141
  def create_gradio_interface():
 
145
  gr.Markdown("# Synthetic Data Generation Request")
146
  with gr.Row():
147
  gr.Markdown("""
148
+ Welcome to the Synthetic Data Generation service! This tool allows you to generate synthetic data using large language models. Generation is FREE for Hugging Face PRO users and uses idle GPUs on the HF science cluster.\n
 
 
 
 
149
  """)
150
  with gr.Group():
151
  with gr.Row():
 
164
  - Model must be accessible (public or with valid token)
165
  - Maximum 10,000 samples per dataset
166
  - Maximum of 32k generation tokens
 
 
167
  """)
168
 
169
  with gr.Row():
 
193
  temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, value=0.7, step=0.1)
194
  with gr.Row():
195
  top_k = gr.Slider(label="Top K", value=50, minimum=5, maximum=100, step=5)
196
+ top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.05)
197
  with gr.Column():
198
  system_prompt = gr.Textbox(label="System Prompt (Optional)", lines=3, placeholder="Optional system prompt... e.g., You are a helpful assistant.")
199
 
 
211
  submit_btn = gr.Button("Submit Generation Request", variant="primary")
212
  output_status = gr.Textbox(label="Status", interactive=False)
213
 
214
+ def submit_request(input_dataset_name, input_split, input_dataset_config, output_dataset_name, prompt_col, model_name, model_rev, model_token, sys_prompt,
215
+ max_tok, temp, top_k_val, top_p_val, user, email_addr, input_dataset_token, output_dataset_token):
216
  try:
217
  request = GenerationRequest(
218
  id="", # Will be generated when adding to the database
219
  status=GenerationStatus.PENDING,
220
+ input_dataset_name=input_dataset_name,
221
  input_dataset_split=input_split,
222
  input_dataset_config=input_dataset_config,
223
+ output_dataset_name=output_dataset_name,
224
  prompt_column=prompt_col,
225
  model_name_or_path=model_name,
226
  model_revision=model_rev,
 
230
  temperature=temp,
231
  top_k=int(top_k_val),
232
  top_p=top_p_val,
 
233
  input_dataset_token=input_dataset_token if input_dataset_token else None,
234
  output_dataset_token=output_dataset_token,
235
  username=user,
 
246
 
247
  submit_btn.click(
248
  submit_request,
249
+ inputs=[input_dataset_name, input_dataset_split, input_dataset_config, output_dataset_name, prompt_column, model_name_or_path,
250
  model_revision, model_token, system_prompt, max_tokens, temperature, top_k, top_p,
251
+ username, email, input_dataset_token, output_dataset_token],
252
  outputs=output_status
253
  )
254