Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| import io | |
| from transformers import ( | |
| BlipProcessor, | |
| BlipForConditionalGeneration, | |
| CLIPProcessor, | |
| CLIPModel, | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| ) | |
| import os | |
| import json | |
| # Global state for models | |
| blip_model = None | |
| blip_processor = None | |
| clip_model = None | |
| clip_processor = None | |
| llm_tokenizer = None | |
| llm_model = None | |
| def device(): | |
| return 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def load_models(): | |
| global blip_model, blip_processor, clip_model, clip_processor, llm_model, llm_tokenizer | |
| if blip_model is None: | |
| print('Loading BLIP image captioning model...') | |
| blip_processor = BlipProcessor.from_pretrained('Salesforce/blip-image-captioning-base') | |
| blip_model = BlipForConditionalGeneration.from_pretrained('Salesforce/blip-image-captioning-base').to(device()) | |
| blip_model.eval() | |
| if clip_model is None: | |
| print('Loading CLIP model...') | |
| clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32') | |
| clip_model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32').to(device()) | |
| clip_model.eval() | |
| if os.environ.get('USE_TINY_LLAMA', '1') == '1' and llm_model is None: | |
| try: | |
| print('Loading TinyLlama chat model (CPU)...') | |
| llm_tokenizer = AutoTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0') | |
| llm_model = AutoModelForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').to('cpu') | |
| llm_model.eval() | |
| except Exception as e: | |
| print(f'Failed to load TinyLlama: {e}') | |
| def analyze_fashion_image(pil_image: Image.Image): | |
| load_models() | |
| # BLIP caption | |
| inputs = blip_processor(pil_image, return_tensors='pt').to(device()) | |
| with torch.no_grad(): | |
| out = blip_model.generate(**inputs, max_length=64) | |
| caption = blip_processor.decode(out[0], skip_special_tokens=True) | |
| # Lightweight fashion prompts via BLIP conditional generation | |
| fashion_prompts = [ | |
| 'Describe the clothing type and key attributes.', | |
| 'List the dominant colors.', | |
| 'What style or vibe does this outfit convey?', | |
| 'Suggest an occasion where this outfit fits well.' | |
| ] | |
| analysis = {} | |
| for prompt in fashion_prompts: | |
| try: | |
| cond_inputs = blip_processor(pil_image, prompt, return_tensors='pt').to(device()) | |
| with torch.no_grad(): | |
| out2 = blip_model.generate(**cond_inputs, max_length=48) | |
| resp = blip_processor.decode(out2[0], skip_special_tokens=True) | |
| analysis[prompt] = resp | |
| except Exception: | |
| analysis[prompt] = '' | |
| return { | |
| 'caption': caption, | |
| 'fashion_analysis': analysis, | |
| } | |
| def simple_rule_based_advice(text: str) -> str: | |
| t = (text or '').lower() | |
| advice = [] | |
| if 'black' in t: | |
| advice.append('Black is versatile; add a contrasting accessory (silver, gold, or a bold color) to elevate the look.') | |
| if 'white' in t: | |
| advice.append('White gives a clean aesthetic; consider layering with textured fabrics or a light jacket for depth.') | |
| if 'blue' in t: | |
| advice.append('Blue pairs well with neutrals (tan, white, grey). Try denim-on-denim or navy with beige for a classic combo.') | |
| if 'dress' in t: | |
| advice.append('Consider heel height and bag size to match the dress formality. A belt can define the silhouette if needed.') | |
| if 'shirt' in t or 'top' in t: | |
| advice.append('Balance the silhouette: pair fitted tops with relaxed bottoms or vice versa. Tuck/half-tuck for shape.') | |
| if not advice: | |
| advice.append('Consider color harmony, silhouette balance, and occasion. Add one statement accessory to complete the look.') | |
| return ' '.join(advice) | |
| def llm_advise(system_prompt: str, user_prompt: str) -> str: | |
| # Try TinyLlama, else fallback to rule-based | |
| if llm_model is None or llm_tokenizer is None: | |
| return simple_rule_based_advice(user_prompt) | |
| prompt = ( | |
| f"<|system|>\n{system_prompt}\n<|user|>\n{user_prompt}\n<|assistant|>" | |
| ) | |
| inputs = llm_tokenizer(prompt, return_tensors='pt') | |
| with torch.no_grad(): | |
| out = llm_model.generate( | |
| **inputs, | |
| max_new_tokens=220, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| eos_token_id=llm_tokenizer.eos_token_id, | |
| ) | |
| text = llm_tokenizer.decode(out[0], skip_special_tokens=True) | |
| # Extract assistant part | |
| if '<|assistant|>' in text: | |
| text = text.split('<|assistant|>')[-1].strip() | |
| return text | |
| def analyze_fashion(image): | |
| if image is None: | |
| return "Please upload an image first!" | |
| analysis = analyze_fashion_image(image) | |
| sys_prompt = ( | |
| 'You are a professional fashion stylist. Provide constructive, specific, and friendly advice about fit, colors, style, and accessories.' | |
| ) | |
| user_msg = f"Analyze and advise based on: {json.dumps(analysis)}" | |
| advice = llm_advise(sys_prompt, user_msg) | |
| return advice | |
| # --- Try-On API Integration --- | |
| import requests | |
| def tryon_api(clothing_img, avatar_img): | |
| clothing_bytes = io.BytesIO() | |
| clothing_img.save(clothing_bytes, format='JPEG') | |
| clothing_bytes.seek(0) | |
| avatar_bytes = io.BytesIO() | |
| avatar_img.save(avatar_bytes, format='JPEG') | |
| avatar_bytes.seek(0) | |
| files = { | |
| 'clothing_image': ('clothing.jpg', clothing_bytes, 'image/jpeg'), | |
| 'avatar_image': ('avatar.jpg', avatar_bytes, 'image/jpeg'), | |
| } | |
| headers = { | |
| 'X-RapidAPI-Key': '76cebf2c24msh8ea6e489fb837f0p171e40jsn2f57db29afd2', | |
| 'X-RapidAPI-Host': 'try-on-diffusion.p.rapidapi.com' | |
| } | |
| response = requests.post('https://try-on-diffusion.p.rapidapi.com/try-on-file', files=files, headers=headers) | |
| if response.status_code == 200: | |
| return Image.open(io.BytesIO(response.content)) | |
| else: | |
| raise Exception('Try-on generation failed: ' + response.text) | |
| import asyncio | |
| import re | |
| try: | |
| from playwright.async_api import async_playwright | |
| except ImportError: | |
| async_playwright = None | |
| FASHION_SITES = { | |
| 'zara': { | |
| 'url': 'https://www.zara.com/us/en/search', | |
| 'search_param': 'searchTerm', | |
| 'selectors': { | |
| 'products': '.product-item', | |
| 'title': '.product-link', | |
| 'price': '.price', | |
| 'image': '.media-image img', | |
| 'link': '.product-link' | |
| } | |
| }, | |
| 'hm': { | |
| 'url': 'https://www2.hm.com/en_us/search-results.html', | |
| 'search_param': 'q', | |
| 'selectors': { | |
| 'products': '.item-link', | |
| 'title': '.item-heading', | |
| 'price': '.item-price', | |
| 'image': '.item-image img', | |
| 'link': '.item-link' | |
| } | |
| }, | |
| 'asos': { | |
| 'url': 'https://www.asos.com/us/search/', | |
| 'search_param': 'q', | |
| 'selectors': { | |
| 'products': '[data-testid="product-tile"]', | |
| 'title': '[data-testid="product-title"]', | |
| 'price': '[data-testid="current-price"]', | |
| 'image': 'img', | |
| 'link': 'a' | |
| } | |
| } | |
| } | |
| async def search_fashion_items(query, max_results=3): | |
| results = [] | |
| if async_playwright is None: | |
| return results | |
| try: | |
| async with async_playwright() as p: | |
| browser = await p.chromium.launch(headless=True) | |
| context = await browser.new_context( | |
| user_agent='Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
| ) | |
| for site_name, site_config in FASHION_SITES.items(): | |
| try: | |
| page = await context.new_page() | |
| search_url = f"{site_config['url']}?{site_config['search_param']}={query}" | |
| await page.goto(search_url, wait_until='networkidle', timeout=10000) | |
| await page.wait_for_timeout(2000) | |
| products = await page.query_selector_all(site_config['selectors']['products']) | |
| site_results = [] | |
| for product in products[:max_results]: | |
| try: | |
| title_elem = await product.query_selector(site_config['selectors']['title']) | |
| price_elem = await product.query_selector(site_config['selectors']['price']) | |
| image_elem = await product.query_selector(site_config['selectors']['image']) | |
| link_elem = await product.query_selector(site_config['selectors']['link']) | |
| title = await title_elem.inner_text() if title_elem else 'N/A' | |
| price = await price_elem.inner_text() if price_elem else 'N/A' | |
| image_src = await image_elem.get_attribute('src') if image_elem else '' | |
| link_href = await link_elem.get_attribute('href') if link_elem else '' | |
| title = title.strip()[:100] | |
| price = re.sub(r'[^\\d.,$β¬Β£]', '', price) if price != 'N/A' else 'N/A' | |
| if link_href and not link_href.startswith('http'): | |
| base_url = f"https://{site_name}.com" if site_name != 'hm' else 'https://www2.hm.com' | |
| link_href = base_url + link_href | |
| site_results.append({ | |
| 'title': title, | |
| 'price': price, | |
| 'image': image_src, | |
| 'link': link_href, | |
| 'store': site_name.upper(), | |
| 'query': query | |
| }) | |
| except Exception: | |
| pass | |
| results.extend(site_results) | |
| await page.close() | |
| except Exception: | |
| pass | |
| await browser.close() | |
| except Exception: | |
| pass | |
| return results | |
| def shopping_search_sync(query): | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| try: | |
| results = loop.run_until_complete(search_fashion_items(query, max_results=3)) | |
| finally: | |
| loop.close() | |
| if not results: | |
| return "No products found. Try a different keyword." | |
| out = "" | |
| for r in results: | |
| out += f"\n**{r['title']}**\nStore: {r['store']}\nPrice: {r['price']}\n[View Product]({r['link']})\n\n" | |
| return out | |
| # --- Gradio UI: Try-On Only --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π AI Virtual Try-On\nUpload clothing and avatar images, generate a try-on result!") | |
| with gr.Row(): | |
| clothing_input = gr.Image(type="pil", label="Clothing Image") | |
| avatar_input = gr.Image(type="pil", label="Avatar Image") | |
| tryon_btn = gr.Button("Generate Try-On & Analyze") | |
| tryon_output = gr.Image(label="Try-On Result") | |
| # analysis_output = gr.Textbox(label="AI Analysis") | |
| advice_output = gr.Textbox(label="AI Advice") | |
| # Store the latest analysis for chatbot context | |
| latest_analysis = {"caption": "", "fashion_analysis": {}} | |
| def tryon_workflow(cloth, avatar): | |
| img = tryon_api(cloth, avatar) | |
| analysis = analyze_fashion_image(img) | |
| # Save for chatbot context | |
| latest_analysis["caption"] = analysis["caption"] | |
| latest_analysis["fashion_analysis"] = analysis["fashion_analysis"] | |
| sys_prompt = 'You are a professional fashion stylist. Provide constructive, specific, and friendly advice about fit, colors, style, and accessories.' | |
| user_msg = f"Analyze and advise based on: {json.dumps(analysis)}" | |
| advice = llm_advise(sys_prompt, user_msg) | |
| pretty = f"π Caption: {analysis['caption']}\n" | |
| for q, a in analysis['fashion_analysis'].items(): | |
| pretty += f"\n{q}\nβ‘οΈ {a}\n" | |
| return img, pretty, advice | |
| tryon_btn.click(tryon_workflow, inputs=[clothing_input, avatar_input], outputs=[tryon_output, advice_output]) | |
| # --- Recommend Similar Outfits --- | |
| def recommend_similar(pil_image): | |
| load_models() | |
| inputs = clip_processor(images=pil_image, return_tensors='pt').to(device()) | |
| with torch.no_grad(): | |
| _ = clip_model.get_image_features(**inputs) | |
| recs = [ | |
| { | |
| 'title': 'Similar silhouette', | |
| 'description': 'Find items with a matching cut and color palette.', | |
| 'stores': ['Zara', 'H&M', 'ASOS'] | |
| }, | |
| { | |
| 'title': 'Complementary accessories', | |
| 'description': 'Belts, bags, and shoes that elevate this outfit.', | |
| 'stores': ['Uniqlo', 'Mango', 'Amazon'] | |
| }, | |
| { | |
| 'title': 'Alternative colors', | |
| 'description': 'Same style in seasonal colorways to suit your palette.', | |
| 'stores': ['Nordstrom', "Macy's", 'Urban Outfitters'] | |
| } | |
| ] | |
| rec_text = 'Here are some similar outfit ideas and recommendations:\n' | |
| for idx, rec in enumerate(recs): | |
| rec_text += f"\n{idx + 1}. {rec['title']}: {rec['description']} (Stores: {', '.join(rec['stores'])})" | |
| return rec_text | |
| # --- Chatbot --- | |
| def chatbot(message, history=None): | |
| sys_prompt = 'You are a helpful, up-to-date fashion stylist assistant.' | |
| # Add latest analysis as context if available | |
| context = "" | |
| if latest_analysis["caption"] or latest_analysis["fashion_analysis"]: | |
| context = f"\n\nCurrent try-on analysis: {json.dumps(latest_analysis)}" | |
| return llm_advise(sys_prompt, message + context) | |
| # --- Gradio UI: Tabs for all features --- | |
| with gr.Tab("Chatbot"): | |
| chatbox = gr.ChatInterface(chatbot, title="Fashion Chatbot (TinyLlama)") | |
| with gr.Tab("Find Similar"): | |
| sim_input = gr.Image(type="pil", label="Upload Try-On or Fashion Image") | |
| sim_btn = gr.Button("Find Similar Outfits") | |
| sim_output = gr.Textbox(label="Similar Outfits & Recommendations") | |
| sim_btn.click(recommend_similar, inputs=sim_input, outputs=sim_output) | |
| with gr.Tab("Shopping Automation"): | |
| shop_query = gr.Textbox(label="Describe the item or keywords to search (e.g. 'black dress', 'summer shirt')") | |
| shop_btn = gr.Button("Search Fashion Stores") | |
| shop_results = gr.Markdown(label="Shopping Results") | |
| shop_btn.click(shopping_search_sync, inputs=shop_query, outputs=shop_results) | |
| demo.launch() |