Spaces:
Sleeping
Sleeping
| import asyncio | |
| import gradio as gr | |
| import numpy as np | |
| import time | |
| import json | |
| import os | |
| import tempfile | |
| import requests | |
| import logging | |
| from aiohttp import ClientSession | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from datasets import Dataset, load_dataset | |
| from tqdm import tqdm | |
| from tqdm.asyncio import tqdm_asyncio | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| SEMAPHORE_BOUND = os.getenv("SEMAPHORE_BOUND", "5") | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class Chunker: | |
| def __init__(self, strategy, split_seq=".", chunk_len=512): | |
| self.split_seq = split_seq | |
| self.chunk_len = chunk_len | |
| if strategy == "recursive": | |
| # https://huggingface.co/spaces/m-ric/chunk_visualizer | |
| self.split = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_len, | |
| separators=[split_seq] | |
| ).split_text | |
| if strategy == "sequence": | |
| self.split = self.seq_splitter | |
| if strategy == "constant": | |
| self.split = self.const_splitter | |
| def seq_splitter(self, text): | |
| return text.split(self.split_seq) | |
| def const_splitter(self, text): | |
| return [ | |
| text[i * self.chunk_len:(i + 1) * self.chunk_len] | |
| for i in range(int(np.ceil(len(text) / self.chunk_len))) | |
| ] | |
| def generator(input_ds, input_text_col, chunker): | |
| for i in tqdm(range(len(input_ds))): | |
| chunks = chunker.split(input_ds[i][input_text_col]) | |
| for chunk in chunks: | |
| if chunk: | |
| yield {input_text_col: chunk} | |
| async def embed_sent(sentence, embed_in_text_col, semaphore, tei_url, tmp_file): | |
| async with semaphore: | |
| payload = { | |
| "inputs": sentence, | |
| "truncate": True | |
| } | |
| async with ClientSession( | |
| headers={ | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {HF_TOKEN}" | |
| } | |
| ) as session: | |
| async with session.post(tei_url, json=payload) as resp: | |
| if resp.status != 200: | |
| raise RuntimeError(await resp.text()) | |
| result = await resp.json() | |
| tmp_file.write( | |
| json.dumps({"vector": result[0], embed_in_text_col: sentence}) + "\n" | |
| ) | |
| async def embed_ds(input_ds, tei_url, embed_in_text_col, temp_file): | |
| semaphore = asyncio.BoundedSemaphore(int(SEMAPHORE_BOUND)) | |
| jobs = [ | |
| asyncio.create_task(embed_sent(row[embed_in_text_col], embed_in_text_col, semaphore, tei_url, temp_file)) | |
| for row in input_ds if row[embed_in_text_col].strip() | |
| ] | |
| logger.info(f"num chunks to embed: {len(jobs)}") | |
| tic = time.time() | |
| await tqdm_asyncio.gather(*jobs) | |
| logger.info(f"embed time: {time.time() - tic}") | |
| def wake_up_endpoint(url): | |
| logger.info("Starting up TEI endpoint") | |
| n_loop = 0 | |
| while requests.get( | |
| url=url, | |
| headers={"Authorization": f"Bearer {HF_TOKEN}"} | |
| ).status_code != 200: | |
| time.sleep(2) | |
| n_loop += 1 | |
| if n_loop > 40: | |
| raise gr.Error("TEI endpoint is unavailable") | |
| logger.info("TEI endpoint is up") | |
| def chunk_embed(input_ds, input_splits, input_text_col, chunk_out_ds, | |
| strategy, split_seq, chunk_len, embed_out_ds, tei_url, private): | |
| gr.Info("Started chunking") | |
| try: | |
| input_splits = [spl.strip() for spl in input_splits.split(",") if spl] | |
| input_ds = load_dataset(input_ds, split="+".join(input_splits), token=HF_TOKEN) | |
| chunker = Chunker(strategy, split_seq, chunk_len) | |
| except Exception as e: | |
| raise gr.Error(str(e)) | |
| gen_kwargs = { | |
| "input_ds": input_ds, | |
| "input_text_col": input_text_col, | |
| "chunker": chunker | |
| } | |
| chunked_ds = Dataset.from_generator(generator, gen_kwargs=gen_kwargs) | |
| chunked_ds.push_to_hub( | |
| chunk_out_ds, | |
| private=private, | |
| token=HF_TOKEN | |
| ) | |
| gr.Info("Done chunking") | |
| logger.info("Done chunking") | |
| try: | |
| wake_up_endpoint(tei_url) | |
| with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file: | |
| asyncio.run(embed_ds(chunked_ds, tei_url, input_text_col, temp_file)) | |
| embedded_ds = Dataset.from_json(temp_file.name) | |
| embedded_ds.push_to_hub( | |
| embed_out_ds, | |
| private=private, | |
| token=HF_TOKEN | |
| ) | |
| except Exception as e: | |
| raise gr.Error(str(e)) | |
| gr.Info("Done embedding") | |
| logger.info("Done embedding") | |
| def change_dropdown(choice): | |
| if choice == "recursive": | |
| return [ | |
| gr.Textbox(visible=True), | |
| gr.Textbox(visible=True) | |
| ] | |
| elif choice == "sequence": | |
| return [ | |
| gr.Textbox(visible=True), | |
| gr.Textbox(visible=False) | |
| ] | |
| else: | |
| return [ | |
| gr.Textbox(visible=False), | |
| gr.Textbox(visible=True) | |
| ] | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| ## Chunk and embed | |
| """ | |
| ) | |
| input_ds = gr.Textbox(lines=1, label="Input dataset name") | |
| with gr.Row(): | |
| input_splits = gr.Textbox(lines=1, label="Input dataset splits", placeholder="train, test") | |
| input_text_col = gr.Textbox(lines=1, label="Input text column name", placeholder="text") | |
| chunk_out_ds = gr.Textbox(lines=1, label="Chunked dataset name") | |
| with gr.Row(): | |
| dropdown = gr.Dropdown( | |
| ["recursive", "sequence", "constant"], label="Chunking strategy", | |
| info="'recursive' uses a Langchain recursive tokenizer, 'sequence' splits texts by a chosen sequence, " | |
| "'constant' makes chunks of the constant size", | |
| scale=2 | |
| ) | |
| split_seq = gr.Textbox( | |
| lines=1, | |
| interactive=True, | |
| visible=False, | |
| label="Sequence", | |
| info="A text sequence to split on", | |
| placeholder="\n\n" | |
| ) | |
| chunk_len = gr.Textbox( | |
| lines=1, | |
| interactive=True, | |
| visible=False, | |
| label="Length", | |
| info="The length of chunks to split into in characters", | |
| placeholder="512" | |
| ) | |
| dropdown.change(fn=change_dropdown, inputs=dropdown, outputs=[split_seq, chunk_len]) | |
| embed_out_ds = gr.Textbox(lines=1, label="Embedded dataset name") | |
| private = gr.Checkbox(label="Make output datasets private") | |
| tei_url = gr.Textbox(lines=1, label="TEI endpoint url") | |
| with gr.Row(): | |
| clear = gr.ClearButton( | |
| components=[input_ds, input_splits, input_text_col, chunk_out_ds, | |
| dropdown, split_seq, chunk_len, embed_out_ds, tei_url, private] | |
| ) | |
| embed_btn = gr.Button("Submit") | |
| embed_btn.click( | |
| fn=chunk_embed, | |
| inputs=[input_ds, input_splits, input_text_col, chunk_out_ds, | |
| dropdown, split_seq, chunk_len, embed_out_ds, tei_url, private] | |
| ) | |
| demo.queue() | |
| demo.launch(debug=True) |