Spaces:
Running
Running
hy
commited on
Commit
ยท
61d0a1d
1
Parent(s):
1ae484c
- aggro_model.py +82 -19
aggro_model.py
CHANGED
|
@@ -152,48 +152,111 @@ except Exception as e:
|
|
| 152 |
# 3. ๋ฉ์ธ ํจ์
|
| 153 |
# =============================================================================
|
| 154 |
def get_aggro_score(title: str) -> dict:
|
| 155 |
-
|
|
|
|
| 156 |
# 1. ๊ท์น ๊ธฐ๋ฐ ์ ์
|
| 157 |
rule_score = 0.0
|
| 158 |
rule_pattern = "๋ถ์ ๋ถ๊ฐ"
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
|
|
|
| 168 |
|
| 169 |
# 2. KoBERT ์ ์
|
| 170 |
bert_score = 0.0
|
| 171 |
if aggro_model and tokenizer:
|
| 172 |
try:
|
| 173 |
inputs = tokenizer(
|
| 174 |
-
title,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
)
|
| 176 |
input_ids = inputs['input_ids'].to(device)
|
| 177 |
mask = inputs['attention_mask'].to(device)
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
with torch.no_grad():
|
| 180 |
outputs = aggro_model(input_ids, mask)
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
bert_score = probs[0][1].item() * 100
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
bert_score = 50.0
|
| 185 |
|
| 186 |
-
|
| 187 |
if rule_score < 5:
|
| 188 |
-
|
|
|
|
| 189 |
elif rule_score < 20:
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
w_bert = 1.0
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
final_score = (rule_score * w_rule) + (bert_score * w_bert)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
# 4. ๊ฒฐ๊ณผ
|
| 199 |
normalized_score = min(final_score / 100.0, 1.0)
|
|
|
|
| 152 |
# 3. ๋ฉ์ธ ํจ์
|
| 153 |
# =============================================================================
|
| 154 |
def get_aggro_score(title: str) -> dict:
|
| 155 |
+
print(f"\n[DEBUG] ๋ถ์ํ ์ ๋ชฉ: {title}") # 1. ์ ๋ชฉ์ด ์ ๋ค์ด์๋ ํ์ธ
|
| 156 |
+
|
| 157 |
# 1. ๊ท์น ๊ธฐ๋ฐ ์ ์
|
| 158 |
rule_score = 0.0
|
| 159 |
rule_pattern = "๋ถ์ ๋ถ๊ฐ"
|
| 160 |
|
| 161 |
+
try:
|
| 162 |
+
res = rule_scorer.get_score(title)
|
| 163 |
+
rule_score = res['score'] # 0~100์
|
| 164 |
+
rule_pattern = res.get('pattern_name', '์ ์ ์์')
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f"[DEBUG] ๊ท์น ๊ณ์ฐ ์๋ฌ: {e}")
|
| 167 |
+
rule_score = 0.0
|
| 168 |
+
|
| 169 |
+
print(f"[DEBUG] 1. ๊ท์น ์ ์: {rule_score}") # 2. ๊ท์น ์ ์ ํ์ธ
|
| 170 |
|
| 171 |
# 2. KoBERT ์ ์
|
| 172 |
bert_score = 0.0
|
| 173 |
if aggro_model and tokenizer:
|
| 174 |
try:
|
| 175 |
inputs = tokenizer(
|
| 176 |
+
title,
|
| 177 |
+
return_tensors='pt',
|
| 178 |
+
padding="max_length",
|
| 179 |
+
truncation=True,
|
| 180 |
+
max_length=64
|
| 181 |
)
|
| 182 |
input_ids = inputs['input_ids'].to(device)
|
| 183 |
mask = inputs['attention_mask'].to(device)
|
| 184 |
+
|
| 185 |
+
# ํ ํฐํ ๊ฒฐ๊ณผ ํ์ธ (์ ๋๋ก ์๋ ธ๋์ง)
|
| 186 |
+
# print(f"[DEBUG] ํ ํฐํ ๊ฒฐ๊ณผ: {inputs['input_ids'][0][:10]}")
|
| 187 |
|
| 188 |
with torch.no_grad():
|
| 189 |
outputs = aggro_model(input_ids, mask)
|
| 190 |
+
# ๐จ ์๋ณธ ๋ก์ง (Logits ๊ฐ ํ์ธ)
|
| 191 |
+
print(f"[DEBUG] ๋ชจ๋ธ ์ถ๋ ฅ๊ฐ(Logits): {outputs}")
|
| 192 |
+
|
| 193 |
+
# Temperature Scaling ์ ์ฉ ์ /ํ ๋น๊ต
|
| 194 |
+
probs = F.softmax(outputs / 2.0, dim=1)
|
| 195 |
bert_score = probs[0][1].item() * 100
|
| 196 |
+
|
| 197 |
+
print(f"[DEBUG] 2. BERT ์๋ณธ ์ ์: {bert_score}") # 3. AI ์ ์ ํ์ธ
|
| 198 |
+
|
| 199 |
+
except Exception as e:
|
| 200 |
+
print(f"[Aggro] KoBERT ์์ธก ์ค๋ฅ: {e}")
|
| 201 |
bert_score = 50.0
|
| 202 |
|
| 203 |
+
# 3. Safety Net (์ ์ ๊น๊ธฐ)
|
| 204 |
if rule_score < 5:
|
| 205 |
+
print("[DEBUG] Safety Net ๋ฐ๋! (๊ท์น ์ ์ ๋ฏธ๋ฌ -> AI ์ ์ 70% ์ญ๊ฐ)")
|
| 206 |
+
bert_score *= 0.3
|
| 207 |
elif rule_score < 20:
|
| 208 |
+
print("[DEBUG] Safety Net ๋ฐ๋! (๊ท์น ์ ์ ๋ฎ์ -> AI ์ ์ 20% ์ญ๊ฐ)")
|
| 209 |
+
bert_score *= 0.8
|
| 210 |
+
|
| 211 |
+
print(f"[DEBUG] 3. ๋ณด์ ๋ BERT ์ ์: {bert_score}") # 4. ๊น์ธ ์ ์ ํ์ธ
|
|
|
|
| 212 |
|
| 213 |
+
# 4. ์ต์ข
ํฉ์ฐ
|
| 214 |
+
w_rule = 0.4
|
| 215 |
+
w_bert = 0.6
|
| 216 |
+
|
| 217 |
final_score = (rule_score * w_rule) + (bert_score * w_bert)
|
| 218 |
+
print(f"[DEBUG] 4. ์ต์ข
ํฉ์ฐ ์ ์: {final_score}")
|
| 219 |
+
# # 1. ๊ท์น ๊ธฐ๋ฐ ์ ์
|
| 220 |
+
# rule_score = 0.0
|
| 221 |
+
# rule_pattern = "๋ถ์ ๋ถ๊ฐ"
|
| 222 |
+
|
| 223 |
+
# if rule_scorer:
|
| 224 |
+
# try:
|
| 225 |
+
# res = rule_scorer.get_score(title)
|
| 226 |
+
# rule_score = res['score']
|
| 227 |
+
# rule_pattern = res.get('pattern_name', '์ ์ ์์')
|
| 228 |
+
# except Exception as e:
|
| 229 |
+
# print(f"๊ท์น ๊ณ์ฐ ์๋ฌ: {e}")
|
| 230 |
+
# rule_score = 50.0
|
| 231 |
+
|
| 232 |
+
# # 2. KoBERT ์ ์
|
| 233 |
+
# bert_score = 0.0
|
| 234 |
+
# if aggro_model and tokenizer:
|
| 235 |
+
# try:
|
| 236 |
+
# inputs = tokenizer(
|
| 237 |
+
# title, return_tensors='pt', padding="max_length", truncation=True, max_length=64
|
| 238 |
+
# )
|
| 239 |
+
# input_ids = inputs['input_ids'].to(device)
|
| 240 |
+
# mask = inputs['attention_mask'].to(device)
|
| 241 |
+
|
| 242 |
+
# with torch.no_grad():
|
| 243 |
+
# outputs = aggro_model(input_ids, mask)
|
| 244 |
+
# probs = F.softmax(outputs / 2.0, dim=1)
|
| 245 |
+
# bert_score = probs[0][1].item() * 100
|
| 246 |
+
# except:
|
| 247 |
+
# bert_score = 50.0
|
| 248 |
+
|
| 249 |
+
# # Safety Net ์ ์ฉ (๊ท์น ์ ์๊ฐ ๋ฎ์ผ๋ฉด AI ์ ์๋ ๊น์)
|
| 250 |
+
# if rule_score < 5:
|
| 251 |
+
# bert_score *= 0.3 # ๊ท์น ์ ์๊ฐ ๊ฑฐ์ ์์ผ๋ฉด AI ์ ์ 70% ์ญ๊ฐ
|
| 252 |
+
# elif rule_score < 20:
|
| 253 |
+
# bert_score *= 0.8 # ๊ท์น ์ ์๊ฐ ๋ฎ์ผ๋ฉด AI ์ ์ 20% ์ญ๊ฐ
|
| 254 |
+
|
| 255 |
+
# #3. ํฉ์ฐ
|
| 256 |
+
# w_rule = 0.0
|
| 257 |
+
# w_bert = 1.0
|
| 258 |
+
|
| 259 |
+
# final_score = (rule_score * w_rule) + (bert_score * w_bert)
|
| 260 |
|
| 261 |
# 4. ๊ฒฐ๊ณผ
|
| 262 |
normalized_score = min(final_score / 100.0, 1.0)
|