selinazarzour's picture
Added the shopping automation logic in a new tab
dadf8b4
raw
history blame
14.7 kB
import gradio as gr
from PIL import Image
import torch
import io
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
CLIPProcessor,
CLIPModel,
AutoTokenizer,
AutoModelForCausalLM,
)
import os
import json
# Global state for models
blip_model = None
blip_processor = None
clip_model = None
clip_processor = None
llm_tokenizer = None
llm_model = None
def device():
return 'cuda' if torch.cuda.is_available() else 'cpu'
def load_models():
global blip_model, blip_processor, clip_model, clip_processor, llm_model, llm_tokenizer
if blip_model is None:
print('Loading BLIP image captioning model...')
blip_processor = BlipProcessor.from_pretrained('Salesforce/blip-image-captioning-base')
blip_model = BlipForConditionalGeneration.from_pretrained('Salesforce/blip-image-captioning-base').to(device())
blip_model.eval()
if clip_model is None:
print('Loading CLIP model...')
clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
clip_model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32').to(device())
clip_model.eval()
if os.environ.get('USE_TINY_LLAMA', '1') == '1' and llm_model is None:
try:
print('Loading TinyLlama chat model (CPU)...')
llm_tokenizer = AutoTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0')
llm_model = AutoModelForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').to('cpu')
llm_model.eval()
except Exception as e:
print(f'Failed to load TinyLlama: {e}')
def analyze_fashion_image(pil_image: Image.Image):
load_models()
# BLIP caption
inputs = blip_processor(pil_image, return_tensors='pt').to(device())
with torch.no_grad():
out = blip_model.generate(**inputs, max_length=64)
caption = blip_processor.decode(out[0], skip_special_tokens=True)
# Lightweight fashion prompts via BLIP conditional generation
fashion_prompts = [
'Describe the clothing type and key attributes.',
'List the dominant colors.',
'What style or vibe does this outfit convey?',
'Suggest an occasion where this outfit fits well.'
]
analysis = {}
for prompt in fashion_prompts:
try:
cond_inputs = blip_processor(pil_image, prompt, return_tensors='pt').to(device())
with torch.no_grad():
out2 = blip_model.generate(**cond_inputs, max_length=48)
resp = blip_processor.decode(out2[0], skip_special_tokens=True)
analysis[prompt] = resp
except Exception:
analysis[prompt] = ''
return {
'caption': caption,
'fashion_analysis': analysis,
}
def simple_rule_based_advice(text: str) -> str:
t = (text or '').lower()
advice = []
if 'black' in t:
advice.append('Black is versatile; add a contrasting accessory (silver, gold, or a bold color) to elevate the look.')
if 'white' in t:
advice.append('White gives a clean aesthetic; consider layering with textured fabrics or a light jacket for depth.')
if 'blue' in t:
advice.append('Blue pairs well with neutrals (tan, white, grey). Try denim-on-denim or navy with beige for a classic combo.')
if 'dress' in t:
advice.append('Consider heel height and bag size to match the dress formality. A belt can define the silhouette if needed.')
if 'shirt' in t or 'top' in t:
advice.append('Balance the silhouette: pair fitted tops with relaxed bottoms or vice versa. Tuck/half-tuck for shape.')
if not advice:
advice.append('Consider color harmony, silhouette balance, and occasion. Add one statement accessory to complete the look.')
return ' '.join(advice)
def llm_advise(system_prompt: str, user_prompt: str) -> str:
# Try TinyLlama, else fallback to rule-based
if llm_model is None or llm_tokenizer is None:
return simple_rule_based_advice(user_prompt)
prompt = (
f"<|system|>\n{system_prompt}\n<|user|>\n{user_prompt}\n<|assistant|>"
)
inputs = llm_tokenizer(prompt, return_tensors='pt')
with torch.no_grad():
out = llm_model.generate(
**inputs,
max_new_tokens=220,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=llm_tokenizer.eos_token_id,
)
text = llm_tokenizer.decode(out[0], skip_special_tokens=True)
# Extract assistant part
if '<|assistant|>' in text:
text = text.split('<|assistant|>')[-1].strip()
return text
def analyze_fashion(image):
if image is None:
return "Please upload an image first!"
analysis = analyze_fashion_image(image)
sys_prompt = (
'You are a professional fashion stylist. Provide constructive, specific, and friendly advice about fit, colors, style, and accessories.'
)
user_msg = f"Analyze and advise based on: {json.dumps(analysis)}"
advice = llm_advise(sys_prompt, user_msg)
return advice
# --- Try-On API Integration ---
import requests
def tryon_api(clothing_img, avatar_img):
clothing_bytes = io.BytesIO()
clothing_img.save(clothing_bytes, format='JPEG')
clothing_bytes.seek(0)
avatar_bytes = io.BytesIO()
avatar_img.save(avatar_bytes, format='JPEG')
avatar_bytes.seek(0)
files = {
'clothing_image': ('clothing.jpg', clothing_bytes, 'image/jpeg'),
'avatar_image': ('avatar.jpg', avatar_bytes, 'image/jpeg'),
}
headers = {
'X-RapidAPI-Key': '76cebf2c24msh8ea6e489fb837f0p171e40jsn2f57db29afd2',
'X-RapidAPI-Host': 'try-on-diffusion.p.rapidapi.com'
}
response = requests.post('https://try-on-diffusion.p.rapidapi.com/try-on-file', files=files, headers=headers)
if response.status_code == 200:
return Image.open(io.BytesIO(response.content))
else:
raise Exception('Try-on generation failed: ' + response.text)
import asyncio
import re
try:
from playwright.async_api import async_playwright
except ImportError:
async_playwright = None
FASHION_SITES = {
'zara': {
'url': 'https://www.zara.com/us/en/search',
'search_param': 'searchTerm',
'selectors': {
'products': '.product-item',
'title': '.product-link',
'price': '.price',
'image': '.media-image img',
'link': '.product-link'
}
},
'hm': {
'url': 'https://www2.hm.com/en_us/search-results.html',
'search_param': 'q',
'selectors': {
'products': '.item-link',
'title': '.item-heading',
'price': '.item-price',
'image': '.item-image img',
'link': '.item-link'
}
},
'asos': {
'url': 'https://www.asos.com/us/search/',
'search_param': 'q',
'selectors': {
'products': '[data-testid="product-tile"]',
'title': '[data-testid="product-title"]',
'price': '[data-testid="current-price"]',
'image': 'img',
'link': 'a'
}
}
}
async def search_fashion_items(query, max_results=3):
results = []
if async_playwright is None:
return results
try:
async with async_playwright() as p:
browser = await p.chromium.launch(headless=True)
context = await browser.new_context(
user_agent='Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
)
for site_name, site_config in FASHION_SITES.items():
try:
page = await context.new_page()
search_url = f"{site_config['url']}?{site_config['search_param']}={query}"
await page.goto(search_url, wait_until='networkidle', timeout=10000)
await page.wait_for_timeout(2000)
products = await page.query_selector_all(site_config['selectors']['products'])
site_results = []
for product in products[:max_results]:
try:
title_elem = await product.query_selector(site_config['selectors']['title'])
price_elem = await product.query_selector(site_config['selectors']['price'])
image_elem = await product.query_selector(site_config['selectors']['image'])
link_elem = await product.query_selector(site_config['selectors']['link'])
title = await title_elem.inner_text() if title_elem else 'N/A'
price = await price_elem.inner_text() if price_elem else 'N/A'
image_src = await image_elem.get_attribute('src') if image_elem else ''
link_href = await link_elem.get_attribute('href') if link_elem else ''
title = title.strip()[:100]
price = re.sub(r'[^\\d.,$€£]', '', price) if price != 'N/A' else 'N/A'
if link_href and not link_href.startswith('http'):
base_url = f"https://{site_name}.com" if site_name != 'hm' else 'https://www2.hm.com'
link_href = base_url + link_href
site_results.append({
'title': title,
'price': price,
'image': image_src,
'link': link_href,
'store': site_name.upper(),
'query': query
})
except Exception:
pass
results.extend(site_results)
await page.close()
except Exception:
pass
await browser.close()
except Exception:
pass
return results
def shopping_search_sync(query):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
results = loop.run_until_complete(search_fashion_items(query, max_results=3))
finally:
loop.close()
if not results:
return "No products found. Try a different keyword."
out = ""
for r in results:
out += f"\n**{r['title']}**\nStore: {r['store']}\nPrice: {r['price']}\n[View Product]({r['link']})\n\n"
return out
# --- Gradio UI: Try-On Only ---
with gr.Blocks() as demo:
gr.Markdown("# πŸ‘— AI Virtual Try-On\nUpload clothing and avatar images, generate a try-on result!")
with gr.Row():
clothing_input = gr.Image(type="pil", label="Clothing Image")
avatar_input = gr.Image(type="pil", label="Avatar Image")
tryon_btn = gr.Button("Generate Try-On & Analyze")
tryon_output = gr.Image(label="Try-On Result")
# analysis_output = gr.Textbox(label="AI Analysis")
advice_output = gr.Textbox(label="AI Advice")
# Store the latest analysis for chatbot context
latest_analysis = {"caption": "", "fashion_analysis": {}}
def tryon_workflow(cloth, avatar):
img = tryon_api(cloth, avatar)
analysis = analyze_fashion_image(img)
# Save for chatbot context
latest_analysis["caption"] = analysis["caption"]
latest_analysis["fashion_analysis"] = analysis["fashion_analysis"]
sys_prompt = 'You are a professional fashion stylist. Provide constructive, specific, and friendly advice about fit, colors, style, and accessories.'
user_msg = f"Analyze and advise based on: {json.dumps(analysis)}"
advice = llm_advise(sys_prompt, user_msg)
pretty = f"πŸ“ Caption: {analysis['caption']}\n"
for q, a in analysis['fashion_analysis'].items():
pretty += f"\n{q}\n➑️ {a}\n"
return img, pretty, advice
tryon_btn.click(tryon_workflow, inputs=[clothing_input, avatar_input], outputs=[tryon_output, advice_output])
# --- Recommend Similar Outfits ---
def recommend_similar(pil_image):
load_models()
inputs = clip_processor(images=pil_image, return_tensors='pt').to(device())
with torch.no_grad():
_ = clip_model.get_image_features(**inputs)
recs = [
{
'title': 'Similar silhouette',
'description': 'Find items with a matching cut and color palette.',
'stores': ['Zara', 'H&M', 'ASOS']
},
{
'title': 'Complementary accessories',
'description': 'Belts, bags, and shoes that elevate this outfit.',
'stores': ['Uniqlo', 'Mango', 'Amazon']
},
{
'title': 'Alternative colors',
'description': 'Same style in seasonal colorways to suit your palette.',
'stores': ['Nordstrom', "Macy's", 'Urban Outfitters']
}
]
rec_text = 'Here are some similar outfit ideas and recommendations:\n'
for idx, rec in enumerate(recs):
rec_text += f"\n{idx + 1}. {rec['title']}: {rec['description']} (Stores: {', '.join(rec['stores'])})"
return rec_text
# --- Chatbot ---
def chatbot(message, history=None):
sys_prompt = 'You are a helpful, up-to-date fashion stylist assistant.'
# Add latest analysis as context if available
context = ""
if latest_analysis["caption"] or latest_analysis["fashion_analysis"]:
context = f"\n\nCurrent try-on analysis: {json.dumps(latest_analysis)}"
return llm_advise(sys_prompt, message + context)
# --- Gradio UI: Tabs for all features ---
with gr.Tab("Chatbot"):
chatbox = gr.ChatInterface(chatbot, title="Fashion Chatbot (TinyLlama)")
with gr.Tab("Find Similar"):
sim_input = gr.Image(type="pil", label="Upload Try-On or Fashion Image")
sim_btn = gr.Button("Find Similar Outfits")
sim_output = gr.Textbox(label="Similar Outfits & Recommendations")
sim_btn.click(recommend_similar, inputs=sim_input, outputs=sim_output)
with gr.Tab("Shopping Automation"):
shop_query = gr.Textbox(label="Describe the item or keywords to search (e.g. 'black dress', 'summer shirt')")
shop_btn = gr.Button("Search Fashion Stores")
shop_results = gr.Markdown(label="Shopping Results")
shop_btn.click(shopping_search_sync, inputs=shop_query, outputs=shop_results)
demo.launch()