Spaces:
Running
Running
File size: 1,696 Bytes
acd7cf4 43d27f2 acd7cf4 43d27f2 acd7cf4 43d27f2 acd7cf4 8c66169 acd7cf4 43d27f2 acd7cf4 43d27f2 acd7cf4 43d27f2 acd7cf4 43d27f2 acd7cf4 43d27f2 acd7cf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import asyncio
from tqdm.asyncio import tqdm as tqdm_async
from graphgen.bases.datatypes import QAPair
from graphgen.utils import create_event_loop
class BaseEvaluator:
def __init__(self, max_concurrent: int = 100):
self.max_concurrent = max_concurrent
self.results: list[float] = None
def evaluate(self, pairs: list[QAPair]) -> list[float]:
"""
Evaluate the text and return a score.
"""
return create_event_loop().run_until_complete(self.async_evaluate(pairs))
async def async_evaluate(self, pairs: list[QAPair]) -> list[float]:
semaphore = asyncio.Semaphore(self.max_concurrent)
async def evaluate_with_semaphore(pair):
async with semaphore: # 获取Semaphore
return await self.evaluate_single(pair)
results = []
for result in tqdm_async(
asyncio.as_completed([evaluate_with_semaphore(pair) for pair in pairs]),
total=len(pairs),
):
results.append(await result)
return results
async def evaluate_single(self, pair: QAPair) -> float:
raise NotImplementedError()
def get_average_score(self, pairs: list[QAPair]) -> float:
"""
Get the average score of a batch of texts.
"""
results = self.evaluate(pairs)
self.results = results
return sum(self.results) / len(pairs)
def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]:
"""
Get the min and max score of a batch of texts.
"""
if self.results is None:
self.get_average_score(pairs)
return min(self.results), max(self.results)
|