Spaces:
Sleeping
Sleeping
edbeeching
commited on
Commit
·
8854100
1
Parent(s):
b6d1901
bugfixes and improved validation
Browse files
app.py
CHANGED
|
@@ -8,6 +8,13 @@ from enum import Enum
|
|
| 8 |
from datasets import get_dataset_infos
|
| 9 |
from transformers import AutoConfig
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
class GenerationStatus(Enum):
|
| 13 |
PENDING = "PENDING"
|
|
@@ -27,6 +34,7 @@ class GenerationRequest:
|
|
| 27 |
input_dataset_name: str
|
| 28 |
input_dataset_config: str
|
| 29 |
input_dataset_split: str
|
|
|
|
| 30 |
prompt_column: str
|
| 31 |
model_name_or_path: str
|
| 32 |
model_revision: str
|
|
@@ -55,7 +63,7 @@ def validate_request(request: GenerationRequest):
|
|
| 55 |
raise Exception(f"Dataset split {request.input_dataset_split} does not exist in dataset {request.input_dataset_name}. Available splits: {list(input_dataset_info.splits.keys())}")
|
| 56 |
|
| 57 |
# check that the number of samples is less than MAX_SAMPLES
|
| 58 |
-
if input_dataset_info.splits[request.input_dataset_split].
|
| 59 |
raise Exception(f"Dataset split {request.input_dataset_split} in dataset {request.input_dataset_name} exceeds max sample limit of {MAX_SAMPLES}.")
|
| 60 |
|
| 61 |
# check the prompt column exists in the dataset
|
|
@@ -67,6 +75,7 @@ def validate_request(request: GenerationRequest):
|
|
| 67 |
try:
|
| 68 |
model_config = AutoConfig.from_pretrained(request.model_name_or_path, revision=request.model_revision, token=request.model_token)
|
| 69 |
except Exception as e:
|
|
|
|
| 70 |
raise Exception(f"Model {request.model_name_or_path} revision {request.model_revision} does not exist or cannot be accessed with the provided token.")
|
| 71 |
|
| 72 |
# check the model max position embeddings is greater than the requested max tokens and less than MAX_TOKENS
|
|
@@ -91,36 +100,42 @@ def validate_request(request: GenerationRequest):
|
|
| 91 |
def add_request_to_db(request: GenerationRequest):
|
| 92 |
url: str = os.getenv("SUPABASE_URL")
|
| 93 |
key: str = os.getenv("SUPABASE_KEY")
|
| 94 |
-
options: ClientOptions = {
|
| 95 |
-
"schema": "public"
|
| 96 |
-
}
|
| 97 |
-
supabase: Client = create_client(url, key, options)
|
| 98 |
-
|
| 99 |
-
data = {
|
| 100 |
-
"status": request.status.value,
|
| 101 |
-
"input_dataset_name": request.input_dataset_name,
|
| 102 |
-
"input_dataset_config": request.input_dataset_config,
|
| 103 |
-
"input_dataset_split": request.input_dataset_split,
|
| 104 |
-
"prompt_column": request.prompt_column,
|
| 105 |
-
"model_name_or_path": request.model_name_or_path,
|
| 106 |
-
"model_revision": request.model_revision,
|
| 107 |
-
"model_token": request.model_token,
|
| 108 |
-
"system_prompt": request.system_prompt,
|
| 109 |
-
"max_tokens": request.max_tokens,
|
| 110 |
-
"temperature": request.temperature,
|
| 111 |
-
"top_k": request.top_k,
|
| 112 |
-
"top_p": request.top_p,
|
| 113 |
-
"input_dataset_token": request.input_dataset_token,
|
| 114 |
-
"output_dataset_token": request.output_dataset_token,
|
| 115 |
-
"username": request.username,
|
| 116 |
-
"email": request.email
|
| 117 |
-
}
|
| 118 |
-
|
| 119 |
-
response = supabase.table("generation-requests").insert(data).execute()
|
| 120 |
-
if response.status_code != 201:
|
| 121 |
-
raise Exception(f"Failed to add request to database: {response.data}")
|
| 122 |
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
|
| 126 |
def create_gradio_interface():
|
|
@@ -130,11 +145,7 @@ def create_gradio_interface():
|
|
| 130 |
gr.Markdown("# Synthetic Data Generation Request")
|
| 131 |
with gr.Row():
|
| 132 |
gr.Markdown("""
|
| 133 |
-
Welcome to the Synthetic Data Generation service! This tool allows you to generate synthetic data using large language models.
|
| 134 |
-
|
| 135 |
-
Generation is FREE for Hugging Face PRO users and uses idle GPUs on the HF science cluster.\n
|
| 136 |
-
|
| 137 |
-
|
| 138 |
""")
|
| 139 |
with gr.Group():
|
| 140 |
with gr.Row():
|
|
@@ -153,8 +164,6 @@ def create_gradio_interface():
|
|
| 153 |
- Model must be accessible (public or with valid token)
|
| 154 |
- Maximum 10,000 samples per dataset
|
| 155 |
- Maximum of 32k generation tokens
|
| 156 |
-
|
| 157 |
-
**Note:** Generation requests are processed asynchronously. You will be notified via email when your request is complete.
|
| 158 |
""")
|
| 159 |
|
| 160 |
with gr.Row():
|
|
@@ -184,7 +193,7 @@ def create_gradio_interface():
|
|
| 184 |
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, value=0.7, step=0.1)
|
| 185 |
with gr.Row():
|
| 186 |
top_k = gr.Slider(label="Top K", value=50, minimum=5, maximum=100, step=5)
|
| 187 |
-
top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.
|
| 188 |
with gr.Column():
|
| 189 |
system_prompt = gr.Textbox(label="System Prompt (Optional)", lines=3, placeholder="Optional system prompt... e.g., You are a helpful assistant.")
|
| 190 |
|
|
@@ -202,15 +211,16 @@ def create_gradio_interface():
|
|
| 202 |
submit_btn = gr.Button("Submit Generation Request", variant="primary")
|
| 203 |
output_status = gr.Textbox(label="Status", interactive=False)
|
| 204 |
|
| 205 |
-
def submit_request(
|
| 206 |
-
max_tok, temp, top_k_val, top_p_val,
|
| 207 |
try:
|
| 208 |
request = GenerationRequest(
|
| 209 |
id="", # Will be generated when adding to the database
|
| 210 |
status=GenerationStatus.PENDING,
|
| 211 |
-
input_dataset_name=
|
| 212 |
input_dataset_split=input_split,
|
| 213 |
input_dataset_config=input_dataset_config,
|
|
|
|
| 214 |
prompt_column=prompt_col,
|
| 215 |
model_name_or_path=model_name,
|
| 216 |
model_revision=model_rev,
|
|
@@ -220,7 +230,6 @@ def create_gradio_interface():
|
|
| 220 |
temperature=temp,
|
| 221 |
top_k=int(top_k_val),
|
| 222 |
top_p=top_p_val,
|
| 223 |
-
output_dataset_name=output_ds,
|
| 224 |
input_dataset_token=input_dataset_token if input_dataset_token else None,
|
| 225 |
output_dataset_token=output_dataset_token,
|
| 226 |
username=user,
|
|
@@ -237,9 +246,9 @@ def create_gradio_interface():
|
|
| 237 |
|
| 238 |
submit_btn.click(
|
| 239 |
submit_request,
|
| 240 |
-
inputs=[input_dataset_name, input_dataset_split, prompt_column, model_name_or_path,
|
| 241 |
model_revision, model_token, system_prompt, max_tokens, temperature, top_k, top_p,
|
| 242 |
-
|
| 243 |
outputs=output_status
|
| 244 |
)
|
| 245 |
|
|
|
|
| 8 |
from datasets import get_dataset_infos
|
| 9 |
from transformers import AutoConfig
|
| 10 |
|
| 11 |
+
"""
|
| 12 |
+
Still TODO:
|
| 13 |
+
- validate the user is PRO
|
| 14 |
+
- check the output dataset token is valid
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
|
| 19 |
class GenerationStatus(Enum):
|
| 20 |
PENDING = "PENDING"
|
|
|
|
| 34 |
input_dataset_name: str
|
| 35 |
input_dataset_config: str
|
| 36 |
input_dataset_split: str
|
| 37 |
+
output_dataset_name: str
|
| 38 |
prompt_column: str
|
| 39 |
model_name_or_path: str
|
| 40 |
model_revision: str
|
|
|
|
| 63 |
raise Exception(f"Dataset split {request.input_dataset_split} does not exist in dataset {request.input_dataset_name}. Available splits: {list(input_dataset_info.splits.keys())}")
|
| 64 |
|
| 65 |
# check that the number of samples is less than MAX_SAMPLES
|
| 66 |
+
if input_dataset_info.splits[request.input_dataset_split].num_examples > MAX_SAMPLES:
|
| 67 |
raise Exception(f"Dataset split {request.input_dataset_split} in dataset {request.input_dataset_name} exceeds max sample limit of {MAX_SAMPLES}.")
|
| 68 |
|
| 69 |
# check the prompt column exists in the dataset
|
|
|
|
| 75 |
try:
|
| 76 |
model_config = AutoConfig.from_pretrained(request.model_name_or_path, revision=request.model_revision, token=request.model_token)
|
| 77 |
except Exception as e:
|
| 78 |
+
print(e)
|
| 79 |
raise Exception(f"Model {request.model_name_or_path} revision {request.model_revision} does not exist or cannot be accessed with the provided token.")
|
| 80 |
|
| 81 |
# check the model max position embeddings is greater than the requested max tokens and less than MAX_TOKENS
|
|
|
|
| 100 |
def add_request_to_db(request: GenerationRequest):
|
| 101 |
url: str = os.getenv("SUPABASE_URL")
|
| 102 |
key: str = os.getenv("SUPABASE_KEY")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
+
try:
|
| 105 |
+
supabase: Client = create_client(
|
| 106 |
+
url,
|
| 107 |
+
key,
|
| 108 |
+
options=ClientOptions(
|
| 109 |
+
postgrest_client_timeout=10,
|
| 110 |
+
storage_client_timeout=10,
|
| 111 |
+
schema="public",
|
| 112 |
+
)
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
data = {
|
| 116 |
+
"status": request.status.value,
|
| 117 |
+
"input_dataset_name": request.input_dataset_name,
|
| 118 |
+
"input_dataset_config": request.input_dataset_config,
|
| 119 |
+
"input_dataset_split": request.input_dataset_split,
|
| 120 |
+
"output_dataset_name": request.output_dataset_name,
|
| 121 |
+
"prompt_column": request.prompt_column,
|
| 122 |
+
"model_name_or_path": request.model_name_or_path,
|
| 123 |
+
"model_revision": request.model_revision,
|
| 124 |
+
"model_token": request.model_token,
|
| 125 |
+
"system_prompt": request.system_prompt,
|
| 126 |
+
"max_tokens": request.max_tokens,
|
| 127 |
+
"temperature": request.temperature,
|
| 128 |
+
"top_k": request.top_k,
|
| 129 |
+
"top_p": request.top_p,
|
| 130 |
+
"input_dataset_token": request.input_dataset_token,
|
| 131 |
+
"output_dataset_token": request.output_dataset_token,
|
| 132 |
+
"username": request.username,
|
| 133 |
+
"email": request.email
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
supabase.table("gen-requests").insert(data).execute()
|
| 137 |
+
except Exception as e:
|
| 138 |
+
raise Exception("Failed to add request to database")
|
| 139 |
|
| 140 |
|
| 141 |
def create_gradio_interface():
|
|
|
|
| 145 |
gr.Markdown("# Synthetic Data Generation Request")
|
| 146 |
with gr.Row():
|
| 147 |
gr.Markdown("""
|
| 148 |
+
Welcome to the Synthetic Data Generation service! This tool allows you to generate synthetic data using large language models. Generation is FREE for Hugging Face PRO users and uses idle GPUs on the HF science cluster.\n
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
""")
|
| 150 |
with gr.Group():
|
| 151 |
with gr.Row():
|
|
|
|
| 164 |
- Model must be accessible (public or with valid token)
|
| 165 |
- Maximum 10,000 samples per dataset
|
| 166 |
- Maximum of 32k generation tokens
|
|
|
|
|
|
|
| 167 |
""")
|
| 168 |
|
| 169 |
with gr.Row():
|
|
|
|
| 193 |
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, value=0.7, step=0.1)
|
| 194 |
with gr.Row():
|
| 195 |
top_k = gr.Slider(label="Top K", value=50, minimum=5, maximum=100, step=5)
|
| 196 |
+
top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.05)
|
| 197 |
with gr.Column():
|
| 198 |
system_prompt = gr.Textbox(label="System Prompt (Optional)", lines=3, placeholder="Optional system prompt... e.g., You are a helpful assistant.")
|
| 199 |
|
|
|
|
| 211 |
submit_btn = gr.Button("Submit Generation Request", variant="primary")
|
| 212 |
output_status = gr.Textbox(label="Status", interactive=False)
|
| 213 |
|
| 214 |
+
def submit_request(input_dataset_name, input_split, input_dataset_config, output_dataset_name, prompt_col, model_name, model_rev, model_token, sys_prompt,
|
| 215 |
+
max_tok, temp, top_k_val, top_p_val, user, email_addr, input_dataset_token, output_dataset_token):
|
| 216 |
try:
|
| 217 |
request = GenerationRequest(
|
| 218 |
id="", # Will be generated when adding to the database
|
| 219 |
status=GenerationStatus.PENDING,
|
| 220 |
+
input_dataset_name=input_dataset_name,
|
| 221 |
input_dataset_split=input_split,
|
| 222 |
input_dataset_config=input_dataset_config,
|
| 223 |
+
output_dataset_name=output_dataset_name,
|
| 224 |
prompt_column=prompt_col,
|
| 225 |
model_name_or_path=model_name,
|
| 226 |
model_revision=model_rev,
|
|
|
|
| 230 |
temperature=temp,
|
| 231 |
top_k=int(top_k_val),
|
| 232 |
top_p=top_p_val,
|
|
|
|
| 233 |
input_dataset_token=input_dataset_token if input_dataset_token else None,
|
| 234 |
output_dataset_token=output_dataset_token,
|
| 235 |
username=user,
|
|
|
|
| 246 |
|
| 247 |
submit_btn.click(
|
| 248 |
submit_request,
|
| 249 |
+
inputs=[input_dataset_name, input_dataset_split, input_dataset_config, output_dataset_name, prompt_column, model_name_or_path,
|
| 250 |
model_revision, model_token, system_prompt, max_tokens, temperature, top_k, top_p,
|
| 251 |
+
username, email, input_dataset_token, output_dataset_token],
|
| 252 |
outputs=output_status
|
| 253 |
)
|
| 254 |
|