File size: 1,868 Bytes
ce11ffc
 
 
f20ab91
119215e
 
26647e2
119215e
2e1c80b
119215e
d6ff8e0
119215e
9c931ea
119215e
744da58
119215e
26647e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce11ffc
 
119215e
 
744da58
 
ea1ca1e
119215e
 
 
ce11ffc
 
119215e
 
ae0e08b
 
ce11ffc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import random

from dataset_viber import AnnotatorInterFace
from datasets import load_dataset
from huggingface_hub import InferenceClient
import time

# https://huggingface.co/models?inference=warm&pipeline_tag=text-generation&sort=trending
MODEL_IDS = [
    "microsoft/Phi-3-mini-4k-instruct"
]
CLIENTS = [InferenceClient(model_id, token=os.environ["HF_TOKEN"]) for model_id in MODEL_IDS]

dataset = load_dataset("argilla/distilabel-capybara-dpo-7k-binarized", split="train")


def get_response(messages):
    max_retries = 3
    retry_delay = 3

    for attempt in range(max_retries):
        try:
            client = random.choice(CLIENTS)
            message = client.chat_completion(
                messages=messages,
                stream=False,
                max_tokens=2000
            )
            return message.choices[0].message.content
        except Exception as e:
            if attempt < max_retries - 1:
                print(f"An error occurred: {e}. Retrying in {retry_delay} seconds...")
                time.sleep(retry_delay)
            else:
                print(f"Max retries reached. Last error: {e}")
                raise

    return None  # This line will only be reached if all retries fail

def next_input(_prompt, _completion_a, _completion_b):
    new_dataset = dataset.shuffle()
    row = new_dataset[0]
    messages = row["chosen"][:-1]
    completions = [row["chosen"][-1]["content"]]
    completions.append(get_response(messages))
    random.shuffle(completions)
    return messages, completions.pop(), completions.pop()


if __name__ == "__main__":
    interface = AnnotatorInterFace.for_chat_generation_preference(
        fn_next_input=next_input,
        interactive=[False, True, True],
        dataset_name="dataset-viber-chat-generation-preference-inference-endpoints-battle",
    )
    interface.launch()