sudanl commited on
Commit
e3c211e
·
1 Parent(s): 9709b91

feat: Add OSS modules to SAGE-Bench for HuggingFace Space compatibility

Browse files
requirements.txt CHANGED
@@ -5,4 +5,7 @@ numpy
5
  pandas
6
  python-dateutil
7
  openai>=1.0.0
8
- aiohttp
 
 
 
 
5
  pandas
6
  python-dateutil
7
  openai>=1.0.0
8
+ aiohttp
9
+ oss2
10
+ loguru
11
+ tqdm
src/oss/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # OSS module for SAGE-Bench
src/oss/oss_file_manager.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import datetime as dt
3
+ import os
4
+ import re
5
+ from logging import Logger
6
+ from multiprocessing import Pool
7
+ from multiprocessing.pool import ThreadPool
8
+ from pathlib import Path
9
+ from typing import List, Union
10
+
11
+ import oss2
12
+ from loguru import logger
13
+ from oss2.credentials import EnvironmentVariableCredentialsProvider
14
+ from tqdm import tqdm
15
+
16
+ from compassflow.constants import DATADIR
17
+ from compassflow.oss.oss import OssBucket
18
+ from compassflow.utils import starstarmap
19
+
20
+
21
+ class OSSFileManager:
22
+ def __init__(
23
+ self,
24
+ oss_access_key_id: str = None,
25
+ oss_access_key_secret: str = None,
26
+ region: str = "http://oss-cn-shanghai.aliyuncs.com",
27
+ bucket_name: str = "opencompass",
28
+ oss_block_name: str = None,
29
+ logger: Logger = logger,
30
+ ) -> None:
31
+ """OSS File Manager
32
+
33
+ Args:
34
+ oss_access_key_id (str, optional): _description_. Defaults to None.
35
+ oss_access_key_secret (str, optional): _description_. Defaults to None.
36
+ region (_type_, optional): _description_. Defaults to 'http://oss-cn-shanghai.aliyuncs.com'.
37
+ bucket_name (str, optional): _description_. Defaults to 'opencompass'.
38
+ oss_block_name: oss_block_name which is defined in the prefect
39
+ logger (Logger, optional): _description_. Defaults to logger.
40
+ """
41
+ self.logger = logger
42
+ if oss_block_name is not None:
43
+ oss_bucket = OssBucket.load(oss_block_name)
44
+ self._bucket = oss_bucket._get_bucket()
45
+ return
46
+
47
+ # 阿里云账号AccessKey拥有所有API的访问权限,风险很高。强烈建议您创建并使用RAM账号进行API访问或日常运维,请登录RAM控制台创建RAM账号。
48
+ if oss_access_key_id is not None and oss_access_key_secret is not None:
49
+ os.environ["OSS_ACCESS_KEY_ID"] = oss_access_key_id
50
+ os.environ["OSS_ACCESS_KEY_SECRET"] = oss_access_key_secret
51
+
52
+ if (
53
+ os.getenv("OSS_ACCESS_KEY_ID") is None
54
+ or os.getenv("OSS_ACCESS_KEY_SECRET") is None
55
+ ):
56
+ raise ValueError("Access Key ID and Access Key Secret cannot be empty.")
57
+
58
+ auth = oss2.ProviderAuth(EnvironmentVariableCredentialsProvider())
59
+
60
+ # Endpoint以杭州为例,其它Region请按实际情况填写。
61
+ # 填写Bucket名称,例如examplebucket。
62
+ bucket = oss2.Bucket(
63
+ auth=auth,
64
+ endpoint=region,
65
+ bucket_name=bucket_name,
66
+ )
67
+
68
+ self._bucket = bucket
69
+
70
+ def list_latest_files_by_date(
71
+ self,
72
+ object_dir: str = "",
73
+ delimiter: str = "/",
74
+ start_date: Union[str, dt.date, dt.datetime] = None,
75
+ end_date: Union[str, dt.date, dt.datetime] = None,
76
+ max_num_files: int = 10,
77
+ date_pattern: str = r"([0-9]{4}-[0-9]{2}-[0-9]{2})",
78
+ file_date_format: str = "%Y-%m-%d",
79
+ suffix: str = "",
80
+ ) -> List[Union[str, Path]]:
81
+ """List the latest files by date in an OSS bucket directory
82
+
83
+ Args:
84
+ object_dir (str, optional): _description_. Defaults to ''.
85
+ delimiter (str, optional): _description_. Defaults to '/'.
86
+ start_date (Union[str, dt.date], optional): _description_. Defaults to None.
87
+ end_date (Union[str, dt.date], optional): _description_. Defaults to None.
88
+ max_num_files (int, optional): _description_. Defaults to 10.
89
+ date_pattern (str, optional): _description_. Defaults to r'^([0-9]{4}-[0-9]{2}-[0-9]{2})'.
90
+ suffix (str, optional): _description_. Defaults to ''.
91
+
92
+ Returns:
93
+ List[Union[str, Path]]: _description_
94
+ """
95
+ if start_date is not None:
96
+ if isinstance(start_date, str):
97
+ start_date = dt.datetime.strptime(start_date.replace("-", ""), "%Y%m%d")
98
+
99
+ if isinstance(start_date, dt.date):
100
+ start_date = dt.datetime(
101
+ start_date.year, start_date.month, start_date.day, 0, 0, 0
102
+ )
103
+
104
+ if end_date is not None:
105
+ if isinstance(end_date, str):
106
+ end_date = dt.datetime.strptime(end_date.replace("-", ""), "%Y%m%d")
107
+
108
+ if isinstance(end_date, dt.date):
109
+ end_date = dt.datetime(
110
+ end_date.year, end_date.month, end_date.day, 0, 0, 0
111
+ )
112
+
113
+ object_iter = oss2.ObjectIterator(
114
+ bucket=self._bucket, prefix=object_dir, delimiter=delimiter
115
+ )
116
+
117
+ root_dir = Path(object_dir)
118
+ filenames = []
119
+ for filename in object_iter:
120
+ filename = filename.key.replace(object_dir, "")
121
+ # print(filename)
122
+
123
+ if filename.endswith(suffix):
124
+ # Match date pattern in filename
125
+ date_search = re.search(date_pattern, filename)
126
+ if date_search:
127
+ file_date = dt.datetime.strptime(
128
+ date_search.group(1), file_date_format
129
+ )
130
+ else:
131
+ self.logger.warning(
132
+ f"date pattern doesn't match, skipping file {filename}"
133
+ )
134
+ continue
135
+
136
+ # Check if file date within specified range
137
+ if start_date is not None:
138
+ if start_date > file_date:
139
+ continue
140
+
141
+ if end_date is not None:
142
+ if end_date < file_date:
143
+ continue
144
+
145
+ filepath = root_dir / filename
146
+ # name_tstamp_tuple = (filepath, os.path.getmtime(filepath))
147
+ # filenames.append(name_tstamp_tuple)
148
+ filenames.append(str(filepath))
149
+ # sort by tstamp
150
+ # filenames = sorted(filenames, key=lambda x: x[1])
151
+ filenames = sorted(filenames)
152
+ # filenames = [x[0] for x in filenames]
153
+
154
+ max_num_files = max_num_files or len(filenames)
155
+ filenames = filenames[-max_num_files:]
156
+
157
+ self.logger.info(f"{filenames=}")
158
+
159
+ return filenames
160
+
161
+ def download_object_to_file(
162
+ self,
163
+ oss_file_path: str | Path,
164
+ local_file_path: str | Path,
165
+ replace: bool = False,
166
+ make_dir: bool = False,
167
+ print_logs: bool = True,
168
+ ):
169
+ """Download a single OSS object to local file.
170
+
171
+ Args:
172
+ oss_file_path (str): _description_
173
+ local_file_path (str): _description_
174
+ replace (bool, optional): _description_. Defaults to False.
175
+ make_dir (bool, optional): Whether to create intermediate dirs if they don't exist. Defaults to False.
176
+ print_logs: bool, optional): Whether to print logs. Defaults to True.
177
+ """
178
+ if isinstance(local_file_path, str):
179
+ local_file_path = Path(local_file_path)
180
+
181
+ if not replace:
182
+ if local_file_path.exists():
183
+ if print_logs:
184
+ err_msg = f"{local_file_path} already exists, skipping file..."
185
+ self.logger.info(err_msg)
186
+
187
+ return
188
+
189
+ if print_logs:
190
+ if local_file_path.exists():
191
+ err_msg = f"{local_file_path} already exists, replacing file..."
192
+ self.logger.info(err_msg)
193
+
194
+ if make_dir:
195
+ os.makedirs(local_file_path.parent, exist_ok=True)
196
+
197
+ self._bucket.get_object_to_file(
198
+ key=str(oss_file_path),
199
+ filename=local_file_path,
200
+ )
201
+
202
+ def download_objects_to_files(
203
+ self,
204
+ file_download_mapping: list[tuple[str | Path, str | Path]],
205
+ oss_base_dir: str = None,
206
+ local_base_dir: str = None,
207
+ replace: bool = True,
208
+ num_threads: int = 1,
209
+ **kwargs,
210
+ ) -> None:
211
+ """Download objects from OSS to local storage.
212
+
213
+ Args:
214
+ file_download_mapping (list[tuple[str | Path, str | Path]]): A list of file path pairs that maps the OSS file path (to download) to the local file path (download location).
215
+ oss_base_dir (str): OSS directory path to be prepended to all OSS file paths.
216
+ local_base_dir (str, optional): Base directory path to be prepended to all local file paths.
217
+ replace (bool, optional): _description_. Defaults to True.
218
+ num_threads (int, optional): _description_. Defaults to 1.
219
+ **kwargs: Additional keyword arguments passed to `list_latest_files_by_date`
220
+ """
221
+ if isinstance(oss_base_dir, str):
222
+ oss_base_dir = Path(oss_base_dir)
223
+
224
+ if isinstance(local_base_dir, str):
225
+ local_base_dir = Path(local_base_dir)
226
+
227
+ if not isinstance(file_download_mapping, list):
228
+ raise TypeError("file_download_mapping must be a list of 2 value tuples.")
229
+
230
+ for item in file_download_mapping:
231
+ if not isinstance(item, tuple):
232
+ raise TypeError(
233
+ "Each item in the file_download_mapping list must be a 2 value tuple."
234
+ )
235
+
236
+ if len(item) != 2:
237
+ raise ValueError(
238
+ "Each tuple in the file_download_mapping list must be length 2."
239
+ )
240
+
241
+ if num_threads == 1:
242
+ for local_file_path, oss_file_path in file_download_mapping:
243
+ if local_base_dir is not None:
244
+ local_file_path = local_base_dir / local_file_path
245
+
246
+ if oss_base_dir is not None:
247
+ oss_file_path = oss_base_dir / oss_file_path
248
+
249
+ self.download_object_to_file(
250
+ oss_file_path=oss_file_path,
251
+ local_file_path=local_file_path,
252
+ replace=replace,
253
+ )
254
+
255
+ return
256
+
257
+ # Start multithreaded process if num_threads > 1
258
+ with ThreadPool(num_threads) as p:
259
+ pool_args_list = []
260
+ for local_file_path, oss_file_path in file_download_mapping:
261
+ if local_base_dir is not None:
262
+ local_file_path = local_base_dir / local_file_path
263
+
264
+ if oss_base_dir is not None:
265
+ oss_file_path = oss_base_dir / oss_file_path
266
+
267
+ args_dict = dict(
268
+ oss_file_path=oss_file_path,
269
+ local_file_path=local_file_path,
270
+ replace=replace,
271
+ print_logs=True,
272
+ )
273
+ pool_args_list.append(args_dict)
274
+
275
+ ret_all = list(
276
+ tqdm(
277
+ starstarmap(
278
+ pool=p,
279
+ fn=self.download_object_to_file,
280
+ kwargs_iter=pool_args_list,
281
+ ),
282
+ total=len(file_download_mapping),
283
+ )
284
+ )
285
+
286
+ def download_latest_objects_to_dir(
287
+ self,
288
+ oss_object_dir: str,
289
+ local_dir: str,
290
+ start_date: Union[str, dt.date, dt.datetime] = None,
291
+ end_date: Union[str, dt.date, dt.datetime] = None,
292
+ date_pattern: str = r"([0-9]{4}-[0-9]{2}-[0-9]{2})",
293
+ file_date_format: str = "%Y-%m-%d",
294
+ max_num_files: int = 5,
295
+ replace: bool = True,
296
+ make_dir: bool = True,
297
+ delimiter: str = "/",
298
+ num_threads: int = 1,
299
+ suffix: str = "",
300
+ **kwargs,
301
+ ) -> None:
302
+ """Download the latest objects from oss to local directory
303
+
304
+ Args:
305
+ oss_object_dir (str): _description_
306
+ local_dir (str): _description_
307
+ start_date (Union[str, dt.date, dt.datetime], optional): _description_. Defaults to None.
308
+ end_date (Union[str, dt.date, dt.datetime], optional): _description_. Defaults to None.
309
+ max_num_files (int, optional): _description_. Defaults to 5.
310
+ replace (bool, optional): _description_. Defaults to True.
311
+ delimiter (str, optional): _description_. Defaults to '/'.
312
+ num_threads (int, optional): _description_. Defaults to 1.
313
+ suffix (str, optional): _description_. Defaults to ''.
314
+ **kwargs: Additional keyword arguments passed to `list_latest_files_by_date`
315
+ """
316
+ if isinstance(local_dir, str):
317
+ local_dir = Path(local_dir)
318
+
319
+ oss_file_list = self.list_latest_files_by_date(
320
+ object_dir=oss_object_dir,
321
+ delimiter=delimiter,
322
+ start_date=start_date,
323
+ end_date=end_date,
324
+ date_pattern=date_pattern,
325
+ file_date_format=file_date_format,
326
+ max_num_files=max_num_files,
327
+ suffix=suffix,
328
+ **kwargs,
329
+ )
330
+
331
+ if num_threads == 1:
332
+ for oss_file_path in oss_file_list:
333
+ file_name = Path(oss_file_path).name
334
+ local_file_path = local_dir / file_name
335
+
336
+ self.download_object_to_file(
337
+ oss_file_path=oss_file_path,
338
+ local_file_path=local_file_path,
339
+ replace=replace,
340
+ make_dir=make_dir,
341
+ )
342
+
343
+ return
344
+
345
+ # Start multithreaded process if num_threads > 1
346
+ with ThreadPool(num_threads) as p:
347
+ pool_args_list = []
348
+ for oss_file_path in oss_file_list:
349
+ file_name = Path(oss_file_path).name
350
+ local_file_path = local_dir / file_name
351
+
352
+ args_dict = dict(
353
+ oss_file_path=str(oss_file_path),
354
+ local_file_path=str(local_file_path),
355
+ replace=replace,
356
+ make_dir=make_dir,
357
+ print_logs=True,
358
+ )
359
+ pool_args_list.append(args_dict)
360
+
361
+ ret_all = list(
362
+ tqdm(
363
+ starstarmap(
364
+ pool=p,
365
+ fn=self.download_object_to_file,
366
+ kwargs_iter=pool_args_list,
367
+ ),
368
+ total=len(oss_file_list),
369
+ )
370
+ )
371
+
372
+ def upload_file_to_object(
373
+ self,
374
+ local_file_path: str,
375
+ oss_file_path: str | Path,
376
+ replace: bool = False,
377
+ print_logs: bool = True,
378
+ ):
379
+ """Upload a single local file to OSS
380
+
381
+ Args:
382
+ oss_file_path (str): _description_
383
+ local_file_path (str): _description_
384
+ replace (bool, optional): _description_. Defaults to False.
385
+ """
386
+ if isinstance(local_file_path, Path):
387
+ local_file_path = str(local_file_path)
388
+
389
+ if isinstance(oss_file_path, Path):
390
+ oss_file_path = str(oss_file_path)
391
+
392
+ # Check if file already exists
393
+ is_file_exists = self._bucket.object_exists(
394
+ key=oss_file_path,
395
+ )
396
+
397
+ if is_file_exists:
398
+ if replace:
399
+ if print_logs:
400
+ err_msg = f"{oss_file_path} already exists, replacing file..."
401
+ self.logger.info(err_msg)
402
+
403
+ self._bucket.put_object_from_file(
404
+ key=str(oss_file_path),
405
+ filename=local_file_path,
406
+ )
407
+
408
+ else:
409
+ if print_logs:
410
+ err_msg = f"{oss_file_path} already exists, skipping file..."
411
+ self.logger.info(err_msg)
412
+
413
+ return
414
+
415
+ self._bucket.put_object_from_file(
416
+ key=oss_file_path,
417
+ filename=local_file_path,
418
+ )
419
+
420
+ def upload_files_to_objects(
421
+ self,
422
+ file_upload_mapping: list[tuple[str | Path, str | Path]],
423
+ local_base_dir: str = None,
424
+ oss_base_dir: str = None,
425
+ replace: bool = True,
426
+ num_threads: int = 1,
427
+ **kwargs,
428
+ ) -> None:
429
+ """Upload files from local storage to OSS.
430
+
431
+ Args:
432
+ file_upload_mapping (list[tuple[str | Path, str | Path]]): A list of file path pairs that maps the local file path (to upload) to the OSS file path (upload location).
433
+ oss_base_dir (str): OSS directory path to be prepended to all OSS file paths.
434
+ local_base_dir (str, optional): Base directory path to be prepended to all local file paths.
435
+ replace (bool, optional): _description_. Defaults to True.
436
+ num_threads (int, optional): _description_. Defaults to 1.
437
+ **kwargs: Additional keyword arguments passed to `list_latest_files_by_date`
438
+ """
439
+ if isinstance(oss_base_dir, str):
440
+ oss_base_dir = Path(oss_base_dir)
441
+
442
+ if isinstance(local_base_dir, str):
443
+ local_base_dir = Path(local_base_dir)
444
+
445
+ if not isinstance(file_upload_mapping, list):
446
+ raise TypeError("file_upload_mapping must be a list of 2 value tuples.")
447
+
448
+ for item in file_upload_mapping:
449
+ if not isinstance(item, tuple):
450
+ raise TypeError(
451
+ "Each item in the file_upload_mapping list must be a 2 value tuple."
452
+ )
453
+
454
+ if len(item) != 2:
455
+ raise ValueError(
456
+ "Each tuple in the file_upload_mapping list must be length 2."
457
+ )
458
+
459
+ if num_threads == 1:
460
+ for local_file_path, oss_file_path in file_upload_mapping:
461
+ if local_base_dir is not None:
462
+ local_file_path = local_base_dir / local_file_path
463
+
464
+ if oss_base_dir is not None:
465
+ oss_file_path = oss_base_dir / oss_file_path
466
+
467
+ self.upload_file_to_object(
468
+ oss_file_path=oss_file_path,
469
+ local_file_path=local_file_path,
470
+ replace=replace,
471
+ )
472
+
473
+ return
474
+
475
+ # Start multithreaded process if num_threads > 1
476
+ with ThreadPool(num_threads) as p:
477
+ pool_args_list = []
478
+ for local_file_path, oss_file_path in file_upload_mapping:
479
+ if local_base_dir is not None:
480
+ local_file_path = local_base_dir / local_file_path
481
+
482
+ if oss_base_dir is not None:
483
+ oss_file_path = oss_base_dir / oss_file_path
484
+
485
+ args_dict = dict(
486
+ oss_file_path=oss_file_path,
487
+ local_file_path=local_file_path,
488
+ replace=replace,
489
+ print_logs=True,
490
+ )
491
+ pool_args_list.append(args_dict)
492
+
493
+ ret_all = list(
494
+ tqdm(
495
+ starstarmap(
496
+ pool=p,
497
+ fn=self.upload_file_to_object,
498
+ kwargs_iter=pool_args_list,
499
+ ),
500
+ total=len(file_upload_mapping),
501
+ )
502
+ )
503
+
504
+ @property
505
+ def bucket(self) -> oss2.Bucket:
506
+ return self._bucket
507
+
508
+
509
+ # %%
510
+ if __name__ == "__main__":
511
+ # %% Initialize
512
+ oss_file_manager = OSSFileManager(logger=logger)
513
+
514
+ # %% List the latest files by date (based on file suffix) in an OSS directory
515
+ oss_file_manager.list_latest_files_by_date(
516
+ "compass-arena/dev/data/conversations/", max_num_files=5
517
+ )
src/oss/oss_submission_handler.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ OSS提交处理器 - 替换原有的git/http提交方式
4
+ 在HuggingFace Spaces中直接将提交文件上传到OSS
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import json
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Dict, Any
13
+
14
+ # 添加上级目录到路径以导入oss_file_manager
15
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'spaces'))
16
+ from oss_file_manager import OSSFileManager
17
+
18
+ class OSSSubmissionHandler:
19
+ """OSS提交处理器 - 将用户提交直接上传到OSS"""
20
+
21
+ def __init__(self, oss_submission_path: str = "atlas_eval/submissions/"):
22
+ """
23
+ 初始化OSS提交处理器
24
+
25
+ Args:
26
+ oss_submission_path: OSS中存储提交文件的路径
27
+ """
28
+ self.oss_path = oss_submission_path
29
+ self.oss_manager = OSSFileManager()
30
+
31
+ print(f"📁 OSS提交路径: oss://opencompass/{oss_submission_path}")
32
+
33
+ def format_error(self, msg: str) -> str:
34
+ """格式化错误消息"""
35
+ return f"<p style='color: red; font-size: 16px;'>{msg}</p>"
36
+
37
+ def format_success(self, msg: str) -> str:
38
+ """格式化成功消息"""
39
+ return f"<p style='color: green; font-size: 16px;'>{msg}</p>"
40
+
41
+ def format_warning(self, msg: str) -> str:
42
+ """格式化警告消息"""
43
+ return f"<p style='color: orange; font-size: 16px;'>{msg}</p>"
44
+
45
+ def validate_sage_submission(self, submission_data: Dict[str, Any]) -> tuple[bool, str]:
46
+ """验证SAGE基准提交格式"""
47
+
48
+ # 检查必需的顶级字段
49
+ required_fields = ["submission_org", "submission_email", "predictions"]
50
+ for field in required_fields:
51
+ if field not in submission_data:
52
+ return False, f"缺少必需字段: {field}"
53
+
54
+ # 验证邮箱格式(基本验证)
55
+ email = submission_data["submission_email"]
56
+ if "@" not in email or "." not in email:
57
+ return False, "邮箱格式无效"
58
+
59
+ # 验证predictions
60
+ predictions = submission_data["predictions"]
61
+ if not isinstance(predictions, list) or len(predictions) == 0:
62
+ return False, "predictions必须是非空列表"
63
+
64
+ for i, prediction in enumerate(predictions):
65
+ # 检查必需的prediction字段
66
+ pred_required_fields = ["original_question_id", "content", "reasoning_content"]
67
+ for field in pred_required_fields:
68
+ if field not in prediction:
69
+ return False, f"预测{i}中缺少字段: {field}"
70
+
71
+ # 验证content数组
72
+ content = prediction["content"]
73
+ reasoning_content = prediction["reasoning_content"]
74
+
75
+ if not isinstance(content, list) or len(content) != 4:
76
+ return False, f"预测{i}的content必须是包含4个项目的列表"
77
+
78
+ if not isinstance(reasoning_content, list) or len(reasoning_content) != 4:
79
+ return False, f"预测{i}的reasoning_content必须是包含4个项目的列表"
80
+
81
+ # 验证question ID
82
+ if not isinstance(prediction["original_question_id"], int):
83
+ return False, f"预测{i}的question ID必须是整数"
84
+
85
+ return True, "提交格式有效"
86
+
87
+ def generate_submission_filename(self, submission_data: Dict[str, Any]) -> str:
88
+ """生成提交文件名"""
89
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
90
+ org_name = submission_data["submission_org"].replace(" ", "_").replace("/", "_").replace("\\", "_")
91
+ return f"submission_{org_name}_{timestamp}.json"
92
+
93
+ def upload_to_oss(self, submission_data: Dict[str, Any], filename: str) -> tuple[bool, str]:
94
+ """上传提交文件到OSS"""
95
+ try:
96
+ # 创建临时本地文件
97
+ temp_file = f"/tmp/{filename}"
98
+ with open(temp_file, 'w', encoding='utf-8') as f:
99
+ json.dump(submission_data, f, indent=2, ensure_ascii=False)
100
+
101
+ # 上传到OSS
102
+ oss_file_path = f"{self.oss_path}{filename}"
103
+
104
+ print(f"⬆️ 上传到OSS: {oss_file_path}")
105
+ self.oss_manager.upload_file_to_object(
106
+ local_file_path=temp_file,
107
+ oss_file_path=oss_file_path,
108
+ replace=True
109
+ )
110
+
111
+ # 清理临时文件
112
+ os.remove(temp_file)
113
+
114
+ print(f"✅ OSS上传成功: {oss_file_path}")
115
+ return True, f"oss://opencompass/{oss_file_path}"
116
+
117
+ except Exception as e:
118
+ print(f"❌ OSS上传失败: {e}")
119
+ return False, str(e)
120
+
121
+ def process_sage_submission(self, submission_file_or_data, org_name=None, email=None) -> str:
122
+ """
123
+ 处理SAGE基准提交文件 - OSS模式
124
+ 替换原有的git/http方式,直接上传到OSS
125
+ """
126
+
127
+ try:
128
+ # 处理输入参数 - 可能是文件路径或者已经的数据
129
+ if submission_file_or_data is None:
130
+ return self.format_error("❌ 没有提供提交数据。")
131
+
132
+ # 如果是字符串,认为是文件路径
133
+ if isinstance(submission_file_or_data, str):
134
+ try:
135
+ with open(submission_file_or_data, 'r', encoding='utf-8') as f:
136
+ content = f.read()
137
+ # 解析JSON
138
+ submission_data = json.loads(content)
139
+ except Exception as e:
140
+ return self.format_error(f"❌ 读取文件时出错: {str(e)}")
141
+ # 如果是字典,直接使用
142
+ elif isinstance(submission_file_or_data, dict):
143
+ submission_data = submission_file_or_data
144
+ else:
145
+ return self.format_error("❌ 无效的提交数据格式。")
146
+
147
+ # 如果表单提供了组织名和邮箱,使用表单数据
148
+ if org_name and email:
149
+ submission_data["submission_org"] = org_name.strip()
150
+ submission_data["submission_email"] = email.strip()
151
+
152
+ # 验证提交格式
153
+ is_valid, message = self.validate_sage_submission(submission_data)
154
+ if not is_valid:
155
+ return self.format_error(f"❌ 提交验证失败: {message}")
156
+
157
+ # 生成文件名
158
+ filename = self.generate_submission_filename(submission_data)
159
+
160
+ # 上传到OSS
161
+ success, result = self.upload_to_oss(submission_data, filename)
162
+
163
+ if not success:
164
+ return self.format_error(f"❌ 上传到OSS失败: {result}")
165
+
166
+ # 生成成功消息
167
+ org = submission_data["submission_org"]
168
+ email_addr = submission_data["submission_email"]
169
+ num_predictions = len(submission_data["predictions"])
170
+
171
+ success_msg = self.format_success(f"""
172
+ 🎉 <strong>提交成功!</strong><br><br>
173
+ 📋 <strong>提交信息:</strong><br>
174
+ • 组织: {org}<br>
175
+ • 邮箱: {email_addr}<br>
176
+ • 预测数量: {num_predictions} 个问题<br>
177
+ • 文件名: {filename}<br><br>
178
+ 🚀 <strong>存储位置:</strong><br>
179
+ {result}<br><br>
180
+ ⚡ <strong>评测状态:</strong><br>
181
+ 您的提交已成功上传到云存储,自动评测系统将在5-15分钟内开始处理。<br><br>
182
+ ⏳ <strong>评测流程:</strong><br>
183
+ 1. 🔍 系统自动检测到新提交<br>
184
+ 2. ⬇️ 下载并验证提交格式<br>
185
+ 3. 🔬 使用LLM-as-Judge进行全面评估<br>
186
+ 4. 📊 计算各科目及总体准确率<br>
187
+ 5. 🏆 自动更新到排行榜<br><br>
188
+ 🕐 <strong>预计时间:</strong><br>
189
+ 评测完成时间约5-15分钟,取决于当前队列长度。<br>
190
+ 请稍后刷新排行榜查看结果。<br><br>
191
+ 🧪 感谢您参与SAGE科学推理基准测试!
192
+ """)
193
+
194
+ return success_msg
195
+
196
+ except Exception as e:
197
+ return self.format_error(f"❌ 提交处理失败: {str(e)}")
198
+
199
+ # 兼容性函数 - 保持与原有代码的接口一致
200
+ def process_sage_submission_simple(submission_file, org_name=None, email=None) -> str:
201
+ """
202
+ 处理SAGE基准提交文件 - OSS模式
203
+ 这是一个兼容性函数,保持与原有simple_submit.py的接口一致
204
+ """
205
+ handler = OSSSubmissionHandler()
206
+ return handler.process_sage_submission(submission_file, org_name, email)
207
+
208
+ def format_error(msg):
209
+ return f"<p style='color: red; font-size: 16px;'>{msg}</p>"
210
+
211
+ def format_success(msg):
212
+ return f"<p style='color: green; font-size: 16px;'>{msg}</p>"
213
+
214
+ def format_warning(msg):
215
+ return f"<p style='color: orange; font-size: 16px;'>{msg}</p>"
216
+
217
+ if __name__ == "__main__":
218
+ # 测试代码
219
+ print("🧪 测试OSS提交处理器")
220
+
221
+ # 检查环境变量
222
+ required_env_vars = ["OSS_ACCESS_KEY_ID", "OSS_ACCESS_KEY_SECRET"]
223
+ missing_vars = [var for var in required_env_vars if not os.getenv(var)]
224
+
225
+ if missing_vars:
226
+ print(f"❌ 缺少必需的环境变量: {missing_vars}")
227
+ exit(1)
228
+
229
+ handler = OSSSubmissionHandler()
230
+ print("✅ OSS提交处理器初始化成功")
src/submission/simple_submit.py CHANGED
@@ -12,9 +12,8 @@ from typing import Dict, Any
12
  from pathlib import Path
13
 
14
  # 导入OSS提交处理器
15
- sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', '..', 'oss_sage_evaluator'))
16
  try:
17
- from oss_submission_handler import OSSSubmissionHandler
18
  OSS_AVAILABLE = True
19
  except ImportError as e:
20
  print(f"⚠️ OSS模块不可用,将使用备用模式: {e}")
 
12
  from pathlib import Path
13
 
14
  # 导入OSS提交处理器
 
15
  try:
16
+ from src.oss.oss_submission_handler import OSSSubmissionHandler
17
  OSS_AVAILABLE = True
18
  except ImportError as e:
19
  print(f"⚠️ OSS模块不可用,将使用备用模式: {e}")