QJerry commited on
Commit
e89c9e8
·
verified ·
1 Parent(s): 6cc7643

Update prompt_check.py

Browse files
Files changed (1) hide show
  1. prompt_check.py +15 -24
prompt_check.py CHANGED
@@ -1,35 +1,26 @@
1
- def is_unsafe_prompt(model, tokenizer, user_prompt):
2
- prompt = f"""You are a content safety filter. Your task is to analyze the following user input and determine if it contains unsafe content.
3
-
4
- Unsafe content includes:
5
- 1. **Sexually Explicit Material**: Descriptions of nudity or sexual acts.
6
- 2. **Sensitive Political Content**: References to real-world political figures (e.g., heads of state, government officials), political parties, sensitive geopolitical events, or propaganda.
7
-
8
- Analyze the user input below.
9
- If the input falls into EITHER of these categories, answer "yes".
10
- If the input is safe, answer "no".
11
-
12
- User Input: "{user_prompt}"
13
-
14
- Answer (only yes or no):"""
15
- messages = [
16
- {"role": "user", "content": prompt}
17
- ]
18
 
19
  text = tokenizer.apply_chat_template(
20
  messages,
21
  tokenize=False,
22
  add_generation_prompt=True,
23
- enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.
24
  )
25
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
26
 
27
- # conduct text completion
28
- generated_ids = model.generate(
29
- **model_inputs,
30
- max_new_tokens=10
31
- )
32
- output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
33
 
34
  content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
35
 
 
1
+ import re
2
+
3
+
4
+ def clean_model_output(text):
5
+ text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
6
+ text = re.sub(r"\n*(assistant|user)\n*", "", text)
7
+ text = re.sub(r"\n+", "\n", text).strip()
8
+ return text
9
+
10
+
11
+ def is_unsafe_prompt(model, tokenizer, system_prompt=None, user_prompt=None, max_new_token=10):
12
+ messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
 
 
 
 
 
13
 
14
  text = tokenizer.apply_chat_template(
15
  messages,
16
  tokenize=False,
17
  add_generation_prompt=True,
18
+ enable_thinking=False,
19
  )
20
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
21
 
22
+ generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_token)
23
+ output_ids = generated_ids[0][-max_new_token:].tolist()
 
 
 
 
24
 
25
  content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
26