File size: 6,388 Bytes
c985775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""This module contains functionalities for running inference on Gemma 2 model
finetuned for urgency detection using the HuggingFace library.
"""

# Standard Library
import ast

from textwrap import dedent
from typing import Any, Optional

# Third Party Library
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase


def _construct_prompt(*, rules_list: list[str]) -> str:
    """Construct the prompt for the finetuned model.

    Parameters
    ----------
    rules_list
        The list of urgency rules to match against the user message.

    Returns
    -------
    str
        The prompt for the finetuned model.
    """

    _prompt_base: str = dedent(
        """
        You are a highly sensitive urgency detector. Score if ANY part of the
        user message corresponds to any part of the urgency rules provided below.
        Ignore any part of the user message that does not correspond to the rules.
        Respond with (a) the rule that is most consistent with the user message,
        (b) the probability between 0 and 1 with increments of 0.1 that ANY part of
        the user message matches the rule, and (c) the reason for the probability.


        Respond in json string:

        {
           best_matching_rule: str
           probability: float
           reason: str
        }
        """
    ).strip()
    _prompt_rules: str = dedent(
        """
        Urgency Rules:
        {urgency_rules}
        """
    ).strip()
    urgency_rules_str = "\n".join(
        [f"{i}. {rule}" for i, rule in enumerate(rules_list, 1)]
    )
    prompt = (
        _prompt_base + "\n\n" + _prompt_rules.format(urgency_rules=urgency_rules_str)
    )
    return prompt


def get_completions(
    *,
    model,
    rules_list: list[str],
    skip_special_tokens_during_decode: bool = False,
    text_generation_params: Optional[dict[str, Any]] = None,
    tokenizer: PreTrainedTokenizerBase,
    user_message: str,
) -> dict[str, Any]:
    """Get completions from the model for the given data.

    Parameters
    ----------
    model
        The model for inference.
    rules_list
        The list of urgency rules to match against the user message.
    skip_special_tokens_during_decode
        Specifies whether to skip special tokens during the decoding process.
    text_generation_params
        Dictionary containing text generation parameters for the LLM model. If not
        specified, then default values will be used.
    tokenizer
        The tokenizer for the model.
    user_message
        The user message to match against the urgency rules.

    Returns
    -------
    dict[str, Any]
        The completion from the model. If the model output does not produce a valid
        JSON string, then the original output is returned in the "generated_json" key.
    """

    assert all(x for x in rules_list), "Rules must be non-empty strings!"
    text_generation_params = text_generation_params or {
        "do_sample": True,
        "eos_token_id": tokenizer.eos_token_id,
        "max_new_tokens": 1024,
        "num_return_sequences": 1,
        "repetition_penalty": 1.1,
        "temperature": 1e-6,
        "top_p": 0.9,
    }
    tokenizer.add_special_tokens = False  # Because we are using the chat template

    start_of_turn, end_of_turn = tokenizer.additional_special_tokens
    eos = tokenizer.eos_token
    start_of_turn_model = f"{start_of_turn}model"
    end_of_turn_model = f"{end_of_turn}{eos}"
    input_ = (
        _construct_prompt(rules_list=rules_list) + f"\n\nUser Message:\n{user_message}"
    )
    chat = [{"role": "user", "content": input_}]
    prompt = tokenizer.apply_chat_template(
        chat, add_generation_prompt=True, tokenize=False
    )
    inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
    outputs = model.generate(
        input_ids=inputs.to(model.device), **text_generation_params
    )
    decoded_output = tokenizer.decode(
        outputs[0], skip_special_tokens=skip_special_tokens_during_decode
    )
    completion_dict = {"user_message": user_message, "generated_json": decoded_output}
    try:
        start_of_turn_model_index = decoded_output.index(start_of_turn_model)
        end_of_turn_model_index = decoded_output.index(end_of_turn_model)
        generated_response = decoded_output[
            start_of_turn_model_index
            + len(start_of_turn_model) : end_of_turn_model_index
        ].strip()
        completion_dict["generated_json"] = ast.literal_eval(generated_response)
    except (SyntaxError, ValueError):
        pass
    return completion_dict


if __name__ == "__main__":
    DTYPE = torch.bfloat16
    MODEL_ID = "idinsight/gemma-2-2b-it-ud"

    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, add_eos_token=False)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, device_map="auto", return_dict=True, torch_dtype=DTYPE
    )

    text_generation_params = {
        "do_sample": True,
        "eos_token_id": tokenizer.eos_token_id,
        "max_new_tokens": 1024,
        "num_return_sequences": 1,
        "repetition_penalty": 1.1,
        "temperature": 1e-6,
        "top_p": 0.9,
    }

    response = get_completions(
        model=model,
        rules_list=[
            "NOT URGENT",
            "Bleeding from the vagina",
            "Bad tummy pain",
            "Bad headache that won’t go away",
            "Bad headache that won’t go away",
            "Changes to vision",
            "Trouble breathing",
            "Hot or very cold, and very weak",
            "Fits or uncontrolled shaking",
            "Baby moves less",
            "Fluid from the vagina",
            "Feeding problems",
            "Fits or uncontrolled shaking",
            "Fast, slow or difficult breathing",
            "Too hot or cold",
            "Baby’s colour changes",
            "Vomiting and watery poo",
            "Infected belly button",
            "Swollen or infected eyes",
            "Bulging or sunken soft spot",
        ],
        skip_special_tokens_during_decode=False,
        text_generation_params=text_generation_params,
        tokenizer=tokenizer,
        user_message="If my newborn can't able to breathe what can i do",
    )
    print(f"{response = }")