codemichaeld commited on
Commit
9310eed
·
verified ·
1 Parent(s): 6fb7520

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -268,10 +268,10 @@ def find_sharded_safetensors_files(repo_id, subfolder=None, hf_token=None, max_s
268
 
269
  # Filter for safetensors files in the subfolder
270
  if subfolder:
271
- pattern = f"{subfolder}/"
272
  safetensors_files = [f for f in repo_files if f.endswith('.safetensors') and f.startswith(pattern)]
273
  # Remove subfolder prefix
274
- safetensors_files = [f[len(pattern):] for f in safetensors_files]
275
  else:
276
  safetensors_files = [f for f in repo_files if f.endswith('.safetensors')]
277
 
@@ -280,21 +280,28 @@ def find_sharded_safetensors_files(repo_id, subfolder=None, hf_token=None, max_s
280
  single_files = []
281
 
282
  for f in safetensors_files:
283
- if "-of-" in f:
 
 
284
  sharded_files.append(f)
285
  else:
286
  single_files.append(f)
287
 
288
- # Return sharded files if found, otherwise single files
289
  if sharded_files:
290
  # Sort by shard number for consistent ordering
291
- sharded_files.sort(key=lambda x: int(re.search(r'-(\d+)-of-', x).group(1)))
 
 
 
 
292
  # Limit number of shards to prevent accidental downloads of huge models
293
  if len(sharded_files) > max_shards:
294
  raise ValueError(f"Too many shards found ({len(sharded_files)}). Maximum allowed is {max_shards}. "
295
  f"Please specify a more specific pattern.")
296
  return sharded_files
297
  elif single_files:
 
298
  return single_files
299
  else:
300
  return []
@@ -372,7 +379,7 @@ def download_model_files(source_type, repo_url, filename_pattern, model_format,
372
 
373
  return model_files, temp_dir
374
  else:
375
- # Single file
376
  progress(0.2, desc=f"Downloading {filename_pattern}...")
377
  model_path = hf_hub_download(
378
  repo_id=repo_id,
 
268
 
269
  # Filter for safetensors files in the subfolder
270
  if subfolder:
271
+ pattern = f"{subfolder}/" if not subfolder.endswith("/") else subfolder
272
  safetensors_files = [f for f in repo_files if f.endswith('.safetensors') and f.startswith(pattern)]
273
  # Remove subfolder prefix
274
+ safetensors_files = [f[len(pattern):] for f in safetensors_files if len(f) > len(pattern)]
275
  else:
276
  safetensors_files = [f for f in repo_files if f.endswith('.safetensors')]
277
 
 
280
  single_files = []
281
 
282
  for f in safetensors_files:
283
+ # Check for sharding pattern: model-XXXXX-of-YYYYY.safetensors
284
+ match = re.search(r'-\d{5}-of-\d{5}\.safetensors$', f)
285
+ if match:
286
  sharded_files.append(f)
287
  else:
288
  single_files.append(f)
289
 
290
+ # If we have sharded files, return them sorted by shard number
291
  if sharded_files:
292
  # Sort by shard number for consistent ordering
293
+ def extract_shard_num(filename):
294
+ match = re.search(r'-(\d{5})-of-\d{5}\.safetensors$', filename)
295
+ return int(match.group(1)) if match else 0
296
+
297
+ sharded_files.sort(key=extract_shard_num)
298
  # Limit number of shards to prevent accidental downloads of huge models
299
  if len(sharded_files) > max_shards:
300
  raise ValueError(f"Too many shards found ({len(sharded_files)}). Maximum allowed is {max_shards}. "
301
  f"Please specify a more specific pattern.")
302
  return sharded_files
303
  elif single_files:
304
+ # Return single files (non-sharded)
305
  return single_files
306
  else:
307
  return []
 
379
 
380
  return model_files, temp_dir
381
  else:
382
+ # SINGLE FILE SAFETENSORS - separate from shard discovery
383
  progress(0.2, desc=f"Downloading {filename_pattern}...")
384
  model_path = hf_hub_download(
385
  repo_id=repo_id,