File size: 6,388 Bytes
c985775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
"""This module contains functionalities for running inference on Gemma 2 model
finetuned for urgency detection using the HuggingFace library.
"""
# Standard Library
import ast
from textwrap import dedent
from typing import Any, Optional
# Third Party Library
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
def _construct_prompt(*, rules_list: list[str]) -> str:
"""Construct the prompt for the finetuned model.
Parameters
----------
rules_list
The list of urgency rules to match against the user message.
Returns
-------
str
The prompt for the finetuned model.
"""
_prompt_base: str = dedent(
"""
You are a highly sensitive urgency detector. Score if ANY part of the
user message corresponds to any part of the urgency rules provided below.
Ignore any part of the user message that does not correspond to the rules.
Respond with (a) the rule that is most consistent with the user message,
(b) the probability between 0 and 1 with increments of 0.1 that ANY part of
the user message matches the rule, and (c) the reason for the probability.
Respond in json string:
{
best_matching_rule: str
probability: float
reason: str
}
"""
).strip()
_prompt_rules: str = dedent(
"""
Urgency Rules:
{urgency_rules}
"""
).strip()
urgency_rules_str = "\n".join(
[f"{i}. {rule}" for i, rule in enumerate(rules_list, 1)]
)
prompt = (
_prompt_base + "\n\n" + _prompt_rules.format(urgency_rules=urgency_rules_str)
)
return prompt
def get_completions(
*,
model,
rules_list: list[str],
skip_special_tokens_during_decode: bool = False,
text_generation_params: Optional[dict[str, Any]] = None,
tokenizer: PreTrainedTokenizerBase,
user_message: str,
) -> dict[str, Any]:
"""Get completions from the model for the given data.
Parameters
----------
model
The model for inference.
rules_list
The list of urgency rules to match against the user message.
skip_special_tokens_during_decode
Specifies whether to skip special tokens during the decoding process.
text_generation_params
Dictionary containing text generation parameters for the LLM model. If not
specified, then default values will be used.
tokenizer
The tokenizer for the model.
user_message
The user message to match against the urgency rules.
Returns
-------
dict[str, Any]
The completion from the model. If the model output does not produce a valid
JSON string, then the original output is returned in the "generated_json" key.
"""
assert all(x for x in rules_list), "Rules must be non-empty strings!"
text_generation_params = text_generation_params or {
"do_sample": True,
"eos_token_id": tokenizer.eos_token_id,
"max_new_tokens": 1024,
"num_return_sequences": 1,
"repetition_penalty": 1.1,
"temperature": 1e-6,
"top_p": 0.9,
}
tokenizer.add_special_tokens = False # Because we are using the chat template
start_of_turn, end_of_turn = tokenizer.additional_special_tokens
eos = tokenizer.eos_token
start_of_turn_model = f"{start_of_turn}model"
end_of_turn_model = f"{end_of_turn}{eos}"
input_ = (
_construct_prompt(rules_list=rules_list) + f"\n\nUser Message:\n{user_message}"
)
chat = [{"role": "user", "content": input_}]
prompt = tokenizer.apply_chat_template(
chat, add_generation_prompt=True, tokenize=False
)
inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
outputs = model.generate(
input_ids=inputs.to(model.device), **text_generation_params
)
decoded_output = tokenizer.decode(
outputs[0], skip_special_tokens=skip_special_tokens_during_decode
)
completion_dict = {"user_message": user_message, "generated_json": decoded_output}
try:
start_of_turn_model_index = decoded_output.index(start_of_turn_model)
end_of_turn_model_index = decoded_output.index(end_of_turn_model)
generated_response = decoded_output[
start_of_turn_model_index
+ len(start_of_turn_model) : end_of_turn_model_index
].strip()
completion_dict["generated_json"] = ast.literal_eval(generated_response)
except (SyntaxError, ValueError):
pass
return completion_dict
if __name__ == "__main__":
DTYPE = torch.bfloat16
MODEL_ID = "idinsight/gemma-2-2b-it-ud"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, add_eos_token=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", return_dict=True, torch_dtype=DTYPE
)
text_generation_params = {
"do_sample": True,
"eos_token_id": tokenizer.eos_token_id,
"max_new_tokens": 1024,
"num_return_sequences": 1,
"repetition_penalty": 1.1,
"temperature": 1e-6,
"top_p": 0.9,
}
response = get_completions(
model=model,
rules_list=[
"NOT URGENT",
"Bleeding from the vagina",
"Bad tummy pain",
"Bad headache that won’t go away",
"Bad headache that won’t go away",
"Changes to vision",
"Trouble breathing",
"Hot or very cold, and very weak",
"Fits or uncontrolled shaking",
"Baby moves less",
"Fluid from the vagina",
"Feeding problems",
"Fits or uncontrolled shaking",
"Fast, slow or difficult breathing",
"Too hot or cold",
"Baby’s colour changes",
"Vomiting and watery poo",
"Infected belly button",
"Swollen or infected eyes",
"Bulging or sunken soft spot",
],
skip_special_tokens_during_decode=False,
text_generation_params=text_generation_params,
tokenizer=tokenizer,
user_message="If my newborn can't able to breathe what can i do",
)
print(f"{response = }")
|