File size: 12,869 Bytes
6a3bd1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 |
import re
from typing import Dict, List, Tuple, Optional
from prompt_library_manager import PromptLibraryManager
class OutputProcessingManager:
"""
輸出驗證、格式化與智能標籤生成
整合 PromptLibraryManager 提供商業級標籤生成
"""
def __init__(self, prompt_library: PromptLibraryManager = None):
"""
Args:
prompt_library: PromptLibraryManager 實例(可選,會自動創建)
"""
self.profanity_filter = set([])
self.max_lengths = {
'instagram': 2200,
'tiktok': 100,
'xiaohongshu': 500
}
# 初始化或使用提供的 PromptLibraryManager
if prompt_library is None:
self.prompt_library = PromptLibraryManager()
else:
self.prompt_library = prompt_library
# 地標檢測關鍵字(用於簡單的地標識別)
self.landmark_keywords = self._init_landmark_keywords()
print("✓ OutputProcessingManager (with integrated PromptLibraryManager) initialized")
def _init_landmark_keywords(self) -> Dict[str, List[str]]:
"""
初始化地標檢測關鍵字映射
用於從檢測到的物體和場景中推測可能的地標
"""
return {
'Big Ben': ['clock tower', 'tower', 'bridge', 'palace', 'gothic'],
'Eiffel Tower': ['tower', 'iron', 'landmark', 'lattice'],
'Statue of Liberty': ['statue', 'monument', 'harbor', 'torch'],
'Golden Gate Bridge': ['bridge', 'suspension', 'orange', 'bay'],
'Sydney Opera House': ['opera', 'building', 'harbor', 'shell'],
'Taj Mahal': ['palace', 'dome', 'monument', 'marble'],
'Colosseum': ['arena', 'amphitheater', 'ruins', 'ancient'],
'Pyramids of Giza': ['pyramid', 'desert', 'ancient', 'monument'],
'Burj Khalifa': ['skyscraper', 'tower', 'building', 'tall'],
'Tokyo Tower': ['tower', 'lattice', 'red'],
'Taipei 101': ['skyscraper', 'tower', 'building'],
# 可以擴展更多
}
def detect_landmark(self, detections: List[Dict], scene_info: Dict) -> Optional[str]:
"""
從檢測結果中推測可能的地標
Args:
detections: YOLO 檢測結果
scene_info: 場景分析結果
Returns:
推測的地標名稱,若無法推測則返回 None
"""
detected_objects = [d.get('class_name', '').lower() for d in detections]
# 從場景資訊中提取更多線索
scene_keywords = []
urban_scene = scene_info.get('urban', {}).get('top', '')
if urban_scene:
scene_keywords.append(urban_scene.lower())
all_keywords = detected_objects + scene_keywords
# 計算每個地標的匹配分數
scores = {}
for landmark, keywords in self.landmark_keywords.items():
match_count = sum(1 for obj in all_keywords
if any(kw in obj for kw in keywords))
if match_count > 0:
scores[landmark] = match_count
# 返回得分最高的地標(至少需要 2 個匹配)
if scores:
best_landmark = max(scores.items(), key=lambda x: x[1])
if best_landmark[1] >= 2:
return best_landmark[0]
return None
def generate_smart_hashtags(self, detections: List[Dict], scene_info: Dict,
brands: List, platform: str, language: str) -> List[str]:
"""
智能標籤生成:整合品牌、地標、場景的標籤
Args:
detections: 檢測到的物體列表
scene_info: 場景分析結果
brands: 檢測到的品牌列表
platform: 平台名稱
language: 語言 ('zh', 'en', 或 'zh-en')
Returns:
智能生成的 hashtag 列表(最多 10 個)
"""
hashtags = []
# 1. 檢測地標(最高優先級)
detected_landmark = self.detect_landmark(detections, scene_info)
if detected_landmark:
landmark_tags = self.prompt_library.landmark_prompts.get_hashtags(
detected_landmark, language
)
hashtags.extend(landmark_tags[:5]) # 地標標籤限制 5 個
# 2. 品牌標籤(高優先級)
if brands:
for brand in brands[:3]: # 最多 3 個品牌
brand_name = brand[0] if isinstance(brand, tuple) else brand
brand_tags = self.prompt_library.brand_prompts.get_hashtags(
brand_name, language
)
hashtags.extend(brand_tags[:3]) # 每個品牌最多 3 個標籤
# 3. 場景標籤(中優先級)
scene_category = self._detect_scene_category(scene_info, detections)
if scene_category:
scene_tags = self.prompt_library.scene_prompts.get_hashtags(
scene_category, language
)
hashtags.extend(scene_tags[:4])
# 4. 構圖特定標籤
composition_tags = self._get_composition_hashtags(scene_info, language)
hashtags.extend(composition_tags)
# 5. 平台特定標籤
platform_tags = self._get_platform_hashtags(platform, language)
hashtags.extend(platform_tags)
# 去重並保持順序(地標 > 品牌 > 場景 > 構圖 > 平台)
seen = set()
unique_hashtags = []
for tag in hashtags:
if tag not in seen and tag: # 確保標籤不為空
seen.add(tag)
unique_hashtags.append(tag)
# 返回前 10 個
return unique_hashtags[:10]
def _detect_scene_category(self, scene_info: Dict, detections: List[Dict]) -> Optional[str]:
"""
檢測場景類別
Returns:
場景類別名稱 ('urban', 'nature', 'indoor', 'food', etc.)
"""
# 檢查物體類別來判斷場景
object_classes = [d.get('class_name', '').lower() for d in detections]
# 食物場景
food_keywords = ['sandwich', 'pizza', 'cake', 'food', 'plate', 'bowl', 'cup', 'bottle']
if any(kw in obj for kw in food_keywords for obj in object_classes):
return 'food'
# 自然場景
nature_keywords = ['tree', 'mountain', 'water', 'sky', 'beach', 'ocean']
if any(kw in obj for kw in nature_keywords for obj in object_classes):
return 'nature'
# 城市場景(默認)
urban_scene = scene_info.get('urban', {}).get('top', '')
if urban_scene and ('canyon' in urban_scene or 'street' in urban_scene or 'building' in urban_scene):
return 'urban'
# 室內場景
indoor_keywords = ['chair', 'table', 'couch', 'bed', 'desk']
if any(kw in obj for kw in indoor_keywords for obj in object_classes):
return 'indoor'
return 'urban' # 默認城市場景
def _get_composition_hashtags(self, scene_info: Dict, language: str) -> List[str]:
"""
根據構圖類型生成標籤
"""
hashtags = []
composition = scene_info.get('urban', {}).get('top', '')
# 城市峽谷
if 'canyon' in composition or 'skyscraper' in composition:
if language == 'zh':
hashtags.extend(['城市峽谷', '城市風景'])
elif language == 'en':
hashtags.extend(['UrbanCanyon', 'Cityscape'])
else: # bilingual
hashtags.extend(['城市峽谷', 'UrbanCanyon'])
# 攝影類型
if language == 'zh':
hashtags.append('攝影日常')
elif language == 'en':
hashtags.append('Photography')
else:
hashtags.extend(['攝影日常', 'Photography'])
return hashtags
def _get_platform_hashtags(self, platform: str, language: str) -> List[str]:
"""
根據平台生成特定標籤
"""
hashtags = []
if platform == 'instagram':
if language == 'zh':
hashtags.append('IG日常')
elif language == 'en':
hashtags.append('InstaDaily')
else:
hashtags.extend(['IG日常', 'InstaDaily'])
elif platform == 'tiktok':
if language == 'zh':
hashtags.append('抖音')
elif language == 'en':
hashtags.append('TikTok')
else:
hashtags.extend(['抖音', 'TikTok'])
elif platform == 'xiaohongshu':
hashtags.extend(['小紅書', '分享日常'])
return hashtags
def validate_output(self, output: Dict, platform: str,
detections: List[Dict] = None, scene_info: Dict = None,
brands: List = None, language: str = 'en') -> Tuple[bool, str]:
"""
驗證輸出格式和內容(含標籤自動補充)
Args:
output: 生成的標題字典
platform: 平台名稱
detections: 檢測結果(用於標籤補充)
scene_info: 場景資訊(用於標籤補充)
brands: 品牌列表(用於標籤補充)
language: 語言
Returns:
(是否通過驗證, 驗證訊息)
"""
# 1. 結構驗證
required_fields = ['caption', 'hashtags', 'tone', 'platform']
if not all(field in output for field in required_fields):
return False, "Missing required fields"
# 2. 長度驗證
max_length = self.max_lengths.get(platform, 2200)
if len(output['caption']) > max_length:
output['caption'] = output['caption'][:max_length-3] + '...'
# 3. 內容過濾
if self._contains_profanity(output['caption']):
return False, "Contains inappropriate content"
# 4. 標籤驗證
output['hashtags'] = self._validate_hashtags(output['hashtags'])
# 🆕 5. 標籤數量檢查與自動補充(商業級功能)
min_hashtags = 5 # 最低標籤數量要求
if len(output['hashtags']) < min_hashtags:
# 如果提供了檢測資訊,自動補充標籤
if detections is not None and scene_info is not None:
additional_tags = self.generate_smart_hashtags(
detections, scene_info, brands or [], platform, language
)
# 補充標籤(避免重複)
for tag in additional_tags:
if tag not in output['hashtags'] and len(output['hashtags']) < 10:
output['hashtags'].append(tag)
print(f" [AUTO-補充] 標籤數量不足 ({len(output['hashtags'])} < {min_hashtags}),已自動補充至 {len(output['hashtags'])} 個")
# 6. 確保標題中沒有 hashtag 符號
if '#' in output['caption']:
# 移除標題中的 hashtag
output['caption'] = re.sub(r'#\w+', '', output['caption']).strip()
return True, "Validation passed"
def _contains_profanity(self, text: str) -> bool:
"""檢查不當內容"""
text_lower = text.lower()
for word in self.profanity_filter:
if word in text_lower:
return True
return False
def _validate_hashtags(self, hashtags: List[str]) -> List[str]:
"""
驗證並清理 hashtags
Args:
hashtags: 原始 hashtag 列表
Returns:
清理後的 hashtag 列表
"""
cleaned = []
for tag in hashtags:
# 移除 # 符號
tag = tag.lstrip('#')
# 保留中文、英文、數字
tag = re.sub(r'[^\w\u4e00-\u9fff]', '', tag)
# 確保不為空且不重複
if tag and tag not in cleaned:
cleaned.append(tag)
return cleaned[:10] # 最多 10 個
def format_for_platform(self, caption: Dict, platform: str) -> str:
"""
根據平台格式化輸出
Args:
caption: 標題字典
platform: 平台名稱
Returns:
格式化的字串
"""
formatted = f"{caption['caption']}\n\n"
if platform == 'xiaohongshu':
# 小紅書:標籤直接接在標題後
formatted += ' '.join([f"#{tag}" for tag in caption['hashtags']])
else:
# Instagram/TikTok:標籤另起一行
formatted += '\n' + ' '.join([f"#{tag}" for tag in caption['hashtags']])
return formatted
print("✓ OutputProcessingManager (V3 with PromptLibraryManager integration) defined")
|