mlbench123 commited on
Commit
492772b
Β·
verified Β·
1 Parent(s): de778ac

Upload 9 files

Browse files
Files changed (9) hide show
  1. DEPLOYMENT.md +250 -0
  2. Dockerfile +28 -0
  3. README.md +135 -12
  4. app.py +400 -0
  5. binary_segmentation.py +398 -0
  6. client_examples.py +396 -0
  7. index.html +505 -0
  8. requirements.txt +13 -0
  9. test_api.py +225 -0
DEPLOYMENT.md ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deployment Guide - Hugging Face Spaces
2
+
3
+ ## Quick Deployment to Hugging Face
4
+
5
+ ### Step 1: Prepare Files
6
+
7
+ Ensure you have these files:
8
+ ```
9
+ your-repo/
10
+ β”œβ”€β”€ app.py # FastAPI application
11
+ β”œβ”€β”€ binary_segmentation.py # Core segmentation module
12
+ β”œβ”€β”€ requirements.txt # Python dependencies
13
+ β”œβ”€β”€ Dockerfile # Docker configuration
14
+ β”œβ”€β”€ README.md # This becomes your Space README
15
+ β”œβ”€β”€ static/
16
+ β”‚ └── index.html # Web interface
17
+ └── .model_cache/
18
+ └── u2netp.pth # Model weights (IMPORTANT!)
19
+ ```
20
+
21
+ ### Step 2: Download U2NETP Weights
22
+
23
+ **CRITICAL**: You must download the U2NETP model weights:
24
+
25
+ 1. Visit: https://github.com/xuebinqin/U-2-Net/tree/master/saved_models
26
+ 2. Download: `u2netp.pth` (4.7 MB)
27
+ 3. Place in: `.model_cache/u2netp.pth`
28
+
29
+ **OR** use this direct link:
30
+ ```bash
31
+ mkdir -p .model_cache
32
+ wget https://github.com/xuebinqin/U-2-Net/raw/master/saved_models/u2netp/u2netp.pth -O .model_cache/u2netp.pth
33
+ ```
34
+
35
+ ### Step 3: Create Hugging Face Space
36
+
37
+ 1. Go to https://huggingface.co/new-space
38
+ 2. Fill in:
39
+ - **Space name**: `background-removal` (or your choice)
40
+ - **License**: Apache 2.0
41
+ - **SDK**: Docker
42
+ - **Hardware**: CPU Basic (free tier works!)
43
+
44
+ 3. Click "Create Space"
45
+
46
+ ### Step 4: Upload Files
47
+
48
+ #### Option A: Using Git (Recommended)
49
+
50
+ ```bash
51
+ # Clone your new space
52
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
53
+ cd YOUR_SPACE_NAME
54
+
55
+ # Copy all files
56
+ cp /path/to/app.py .
57
+ cp /path/to/binary_segmentation.py .
58
+ cp /path/to/requirements.txt .
59
+ cp /path/to/Dockerfile .
60
+ cp /path/to/README_HF.md ./README.md
61
+ cp -r /path/to/static .
62
+ cp -r /path/to/.model_cache .
63
+
64
+ # Commit and push
65
+ git add .
66
+ git commit -m "Initial commit"
67
+ git push
68
+ ```
69
+
70
+ #### Option B: Using Web Interface
71
+
72
+ 1. Click "Files" β†’ "Add file"
73
+ 2. Upload each file individually
74
+ 3. **Important**: Upload `.model_cache/u2netp.pth` (it's large, ~4.7MB)
75
+
76
+ ### Step 5: Wait for Build
77
+
78
+ - Space will build automatically (takes 3-5 minutes)
79
+ - Watch the "Logs" tab for build progress
80
+ - Once complete, your Space will be live!
81
+
82
+ ### Step 6: Test Your Space
83
+
84
+ Visit your Space URL and try:
85
+ 1. Upload an image
86
+ 2. Click "Process Image"
87
+ 3. Download the result
88
+
89
+ ## Configuration Options
90
+
91
+ ### Use Different Models
92
+
93
+ To enable BiRefNet or RMBG models, edit `requirements.txt`:
94
+
95
+ ```txt
96
+ # Uncomment these lines:
97
+ transformers>=4.30.0
98
+ huggingface-hub>=0.16.0
99
+ ```
100
+
101
+ **Note**: These models are larger and may require upgraded hardware (GPU).
102
+
103
+ ### Custom Port
104
+
105
+ Default port is 7860 (Hugging Face standard). To change:
106
+
107
+ In `Dockerfile`:
108
+ ```dockerfile
109
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
110
+ ```
111
+
112
+ ### Environment Variables
113
+
114
+ Add secrets in Space Settings:
115
+ ```python
116
+ import os
117
+ API_KEY = os.environ.get("API_KEY", "default")
118
+ ```
119
+
120
+ ## Hardware Requirements
121
+
122
+ ### CPU Basic (Free)
123
+ - βœ… U2NETP model
124
+ - βœ… Small to medium images (<5MP)
125
+ - ⏱️ ~2-5 seconds per image
126
+
127
+ ### CPU Upgrade
128
+ - βœ… U2NETP model
129
+ - βœ… Large images
130
+ - ⏱️ ~1-3 seconds per image
131
+
132
+ ### GPU T4
133
+ - βœ… All models (U2NETP, BiRefNet, RMBG)
134
+ - βœ… Any image size
135
+ - ⏱️ <1 second per image
136
+
137
+ ## Troubleshooting
138
+
139
+ ### Build Fails
140
+
141
+ **Issue**: "No module named 'binary_segmentation'"
142
+ - **Fix**: Ensure `binary_segmentation.py` is in root directory
143
+
144
+ **Issue**: "Model weights not found"
145
+ - **Fix**: Upload `u2netp.pth` to `.model_cache/u2netp.pth`
146
+
147
+ **Issue**: "OpenCV error"
148
+ - **Fix**: Check Dockerfile has `libgl1-mesa-glx` installed
149
+
150
+ ### Runtime Errors
151
+
152
+ **Issue**: "Out of memory"
153
+ - **Fix**: Upgrade to GPU hardware OR reduce image size
154
+
155
+ **Issue**: "Slow processing"
156
+ - **Fix**: Use CPU Upgrade or GPU hardware
157
+
158
+ **Issue**: "Model not loading"
159
+ - **Fix**: Check logs, ensure model file is in correct location
160
+
161
+ ### API Not Working
162
+
163
+ **Issue**: 404 errors
164
+ - **Fix**: Check that FastAPI routes are correct
165
+ - **Fix**: Ensure `app:app` in CMD matches `app = FastAPI()` in code
166
+
167
+ **Issue**: CORS errors
168
+ - **Fix**: CORS is enabled by default; check browser console
169
+
170
+ ## File Structure Verification
171
+
172
+ Before deploying, verify:
173
+
174
+ ```bash
175
+ # Check all files exist
176
+ ls -la
177
+
178
+ # Should see:
179
+ # app.py
180
+ # binary_segmentation.py
181
+ # requirements.txt
182
+ # Dockerfile
183
+ # README.md
184
+ # static/index.html
185
+ # .model_cache/u2netp.pth
186
+
187
+ # Check model file size (should be ~4.7MB)
188
+ ls -lh .model_cache/u2netp.pth
189
+ ```
190
+
191
+ ## Alternative: Deploy Without Docker
192
+
193
+ If you prefer not to use Docker, create `.spacesdk` file:
194
+
195
+ ```
196
+ sdk: gradio
197
+ sdk_version: 4.0.0
198
+ ```
199
+
200
+ Then modify to use Gradio instead of FastAPI. But Docker is recommended for FastAPI.
201
+
202
+ ## Post-Deployment
203
+
204
+ ### Monitor Usage
205
+ - Check "Analytics" tab for usage stats
206
+ - Monitor "Logs" for errors
207
+
208
+ ### Update Your Space
209
+ ```bash
210
+ git pull
211
+ # Make changes
212
+ git add .
213
+ git commit -m "Update"
214
+ git push
215
+ ```
216
+
217
+ ### Share Your Space
218
+ - Get shareable link from Space page
219
+ - Embed in website using iframe
220
+ - Use API endpoint in your apps
221
+
222
+ ## Example API Usage from External Apps
223
+
224
+ Once deployed, use your Space API:
225
+
226
+ ```python
227
+ import requests
228
+
229
+ SPACE_URL = "https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME"
230
+
231
+ with open('image.jpg', 'rb') as f:
232
+ response = requests.post(
233
+ f"{SPACE_URL}/segment",
234
+ files={'file': f},
235
+ data={'model': 'u2netp', 'threshold': 0.5}
236
+ )
237
+
238
+ with open('result.png', 'wb') as out:
239
+ out.write(response.content)
240
+ ```
241
+
242
+ ## Need Help?
243
+
244
+ - Hugging Face Docs: https://huggingface.co/docs/hub/spaces
245
+ - Community Forum: https://discuss.huggingface.co/
246
+ - Discord: https://discord.gg/hugging-face
247
+
248
+ ---
249
+
250
+ **Pro Tip**: Start with CPU Basic (free), test your Space, then upgrade to GPU if needed!
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Set working directory
4
+ WORKDIR /app
5
+
6
+ # Install system dependencies
7
+ RUN apt-get update && apt-get install -y \
8
+ libgl1-mesa-glx \
9
+ libglib2.0-0 \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Copy requirements
13
+ COPY requirements.txt .
14
+
15
+ # Install Python dependencies
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy application files
19
+ COPY . .
20
+
21
+ # Create necessary directories
22
+ RUN mkdir -p .model_cache static
23
+
24
+ # Expose port
25
+ EXPOSE 7860
26
+
27
+ # Run the application
28
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,12 +1,135 @@
1
- ---
2
- title: Inspectech Segmentation
3
- emoji: 🐨
4
- colorFrom: purple
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 6.5.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Binary Image Segmentation - FastAPI Service
2
+
3
+ Professional background removal service with web interface and REST API, ready for Hugging Face Spaces deployment.
4
+
5
+ ## πŸš€ Quick Start
6
+
7
+ ### Local Development
8
+
9
+ ```bash
10
+ # 1. Install dependencies
11
+ pip install -r requirements.txt
12
+
13
+ # 2. Download U2NETP model weights
14
+ mkdir -p .model_cache
15
+ wget https://github.com/xuebinqin/U-2-Net/raw/master/saved_models/u2netp/u2netp.pth -O .model_cache/u2netp.pth
16
+
17
+ # 3. Run the server
18
+ uvicorn app:app --host 0.0.0.0 --port 7860
19
+
20
+ # 4. Open browser
21
+ # Visit: http://localhost:7860
22
+ ```
23
+
24
+ ### Test the API
25
+
26
+ ```bash
27
+ python test_api.py
28
+ ```
29
+
30
+ ## πŸ“ Project Structure
31
+
32
+ ```
33
+ .
34
+ β”œβ”€β”€ app.py # FastAPI application (main entry point)
35
+ β”œβ”€β”€ binary_segmentation.py # Core segmentation module
36
+ β”œβ”€β”€ requirements.txt # Python dependencies
37
+ β”œβ”€β”€ Dockerfile # Docker configuration for deployment
38
+ β”œβ”€β”€ README_HF.md # Hugging Face Space README
39
+ β”œβ”€β”€ DEPLOYMENT.md # Detailed deployment guide
40
+ β”œβ”€β”€ client_examples.py # API usage examples (Python, JS, curl)
41
+ β”œβ”€β”€ test_api.py # Test script
42
+ β”œβ”€β”€ .gitignore # Git ignore file
43
+ └── static/
44
+ └── index.html # Web interface
45
+ ```
46
+
47
+ ## 🎨 Features
48
+
49
+ ### Web Interface
50
+ - Drag & drop image upload
51
+ - 3 AI model options (U2NETP, BiRefNet, RMBG)
52
+ - Adjustable threshold
53
+ - Multiple output formats (transparent PNG, binary mask, or both)
54
+ - Real-time preview
55
+ - Download results
56
+
57
+ ### REST API
58
+ - **POST /segment** - Segment image β†’ transparent PNG
59
+ - **POST /segment/mask** - Get binary mask only
60
+ - **POST /segment/base64** - Get base64 encoded results
61
+ - **POST /segment/batch** - Process multiple images
62
+ - **GET /models** - List available models
63
+ - **GET /health** - Health check
64
+
65
+ ### Supported Models
66
+
67
+ | Model | Speed | Accuracy | Size | Best For |
68
+ |-------|-------|----------|------|----------|
69
+ | **U2NETP** | ⚑⚑⚑ | ⭐⭐ | 4.7 MB | Speed, simple objects |
70
+ | **BiRefNet** | ⚑ | ⭐⭐⭐ | ~400 MB | Best quality |
71
+ | **RMBG** | ⚑⚑ | ⭐⭐⭐ | ~200 MB | Balanced |
72
+
73
+ ## πŸ”§ API Usage Examples
74
+
75
+ ### Python
76
+
77
+ ```python
78
+ import requests
79
+
80
+ # Segment image
81
+ with open('input.jpg', 'rb') as f:
82
+ response = requests.post(
83
+ 'http://localhost:7860/segment',
84
+ files={'file': f},
85
+ data={'model': 'u2netp', 'threshold': 0.5}
86
+ )
87
+
88
+ # Save result
89
+ with open('output.png', 'wb') as out:
90
+ out.write(response.content)
91
+ ```
92
+
93
+ ### JavaScript
94
+
95
+ ```javascript
96
+ async function removeBackground(file) {
97
+ const formData = new FormData();
98
+ formData.append('file', file);
99
+ formData.append('model', 'u2netp');
100
+ formData.append('threshold', '0.5');
101
+
102
+ const response = await fetch('/segment', {
103
+ method: 'POST',
104
+ body: formData
105
+ });
106
+
107
+ const blob = await response.blob();
108
+ return URL.createObjectURL(blob);
109
+ }
110
+ ```
111
+
112
+ ### cURL
113
+
114
+ ```bash
115
+ curl -X POST "http://localhost:7860/segment" \
116
+ -F "file=@input.jpg" \
117
+ -F "model=u2netp" \
118
+ -F "threshold=0.5" \
119
+ --output result.png
120
+ ```
121
+
122
+ See `client_examples.py` for more!
123
+
124
+ ## 🌐 Deploy to Hugging Face Spaces
125
+
126
+ See `DEPLOYMENT.md` for complete guide!
127
+
128
+ ## πŸ“ License
129
+
130
+ Apache 2.0
131
+
132
+ ## πŸ™ Credits
133
+
134
+ - U2-Net, BiRefNet, RMBG models
135
+ - FastAPI framework
app.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI Binary Segmentation Service
3
+ Hugging Face Space compatible
4
+ """
5
+
6
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
7
+ from fastapi.responses import Response, JSONResponse, FileResponse
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.staticfiles import StaticFiles
10
+ import cv2
11
+ import numpy as np
12
+ from PIL import Image
13
+ import io
14
+ import logging
15
+ from typing import Literal, Optional
16
+ import base64
17
+ import os
18
+
19
+ from binary_segmentation import BinarySegmenter
20
+
21
+ # Configure logging
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format='%(asctime)s - %(levelname)s - %(message)s'
25
+ )
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # Initialize FastAPI app
29
+ app = FastAPI(
30
+ title="Binary Segmentation API",
31
+ description="Remove background from images using AI models",
32
+ version="1.0.0"
33
+ )
34
+
35
+ # Add CORS middleware
36
+ app.add_middleware(
37
+ CORSMiddleware,
38
+ allow_origins=["*"],
39
+ allow_credentials=True,
40
+ allow_methods=["*"],
41
+ allow_headers=["*"],
42
+ )
43
+
44
+ # Mount static files
45
+ if os.path.exists("static"):
46
+ app.mount("/static", StaticFiles(directory="static"), name="static")
47
+
48
+ # Global model instance (lazy loading)
49
+ segmenter_cache = {}
50
+
51
+ def get_segmenter(model_type: str = "u2netp") -> BinarySegmenter:
52
+ """Get or create segmenter instance"""
53
+ if model_type not in segmenter_cache:
54
+ logger.info(f"Loading {model_type} model...")
55
+ segmenter_cache[model_type] = BinarySegmenter(model_type=model_type)
56
+ logger.info(f"{model_type} model loaded successfully")
57
+ return segmenter_cache[model_type]
58
+
59
+
60
+ @app.get("/")
61
+ async def root():
62
+ """Serve the web interface"""
63
+ if os.path.exists("static/index.html"):
64
+ return FileResponse("static/index.html")
65
+
66
+ # Fallback to API info
67
+ return {
68
+ "name": "Binary Segmentation API",
69
+ "version": "1.0.0",
70
+ "endpoints": {
71
+ "/segment": "POST - Segment image and return PNG with transparency",
72
+ "/segment/mask": "POST - Return binary mask only",
73
+ "/segment/base64": "POST - Return base64 encoded results",
74
+ "/health": "GET - Health check",
75
+ "/models": "GET - List available models"
76
+ }
77
+ }
78
+
79
+
80
+ @app.get("/health")
81
+ async def health_check():
82
+ """Health check endpoint"""
83
+ return {
84
+ "status": "healthy",
85
+ "models_loaded": list(segmenter_cache.keys())
86
+ }
87
+
88
+
89
+ @app.get("/models")
90
+ async def list_models():
91
+ """List available segmentation models"""
92
+ return {
93
+ "models": [
94
+ {
95
+ "name": "u2netp",
96
+ "description": "Lightweight, fast model (1.1M params)",
97
+ "speed": "⚑⚑⚑",
98
+ "accuracy": "⭐⭐",
99
+ "size": "4.7 MB"
100
+ },
101
+ {
102
+ "name": "birefnet",
103
+ "description": "High accuracy model",
104
+ "speed": "⚑",
105
+ "accuracy": "⭐⭐⭐",
106
+ "size": "~400 MB",
107
+ "requires": "transformers package"
108
+ },
109
+ {
110
+ "name": "rmbg",
111
+ "description": "Balanced model",
112
+ "speed": "⚑⚑",
113
+ "accuracy": "⭐⭐⭐",
114
+ "size": "~200 MB",
115
+ "requires": "transformers package"
116
+ }
117
+ ],
118
+ "default": "u2netp"
119
+ }
120
+
121
+
122
+ @app.post("/segment")
123
+ async def segment_image(
124
+ file: UploadFile = File(..., description="Image file to segment"),
125
+ model: str = Form("u2netp", description="Model to use: u2netp, birefnet, or rmbg"),
126
+ threshold: float = Form(0.5, description="Segmentation threshold (0.0-1.0)", ge=0.0, le=1.0)
127
+ ):
128
+ """
129
+ Segment image and return PNG with transparent background.
130
+
131
+ Returns: PNG image with transparency
132
+ """
133
+ try:
134
+ # Validate model
135
+ if model not in ["u2netp", "birefnet", "rmbg"]:
136
+ raise HTTPException(
137
+ status_code=400,
138
+ detail=f"Invalid model: {model}. Choose from: u2netp, birefnet, rmbg"
139
+ )
140
+
141
+ # Read image
142
+ contents = await file.read()
143
+ nparr = np.frombuffer(contents, np.uint8)
144
+ image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
145
+
146
+ if image is None:
147
+ raise HTTPException(status_code=400, detail="Invalid image file")
148
+
149
+ # Get segmenter
150
+ segmenter = get_segmenter(model)
151
+
152
+ # Segment image
153
+ logger.info(f"Segmenting with model={model}, threshold={threshold}")
154
+ _, rgba = segmenter.segment(image, threshold=threshold, return_type="rgba")
155
+
156
+ if rgba is None:
157
+ raise HTTPException(status_code=500, detail="Segmentation failed")
158
+
159
+ # Convert to bytes
160
+ img_byte_arr = io.BytesIO()
161
+ rgba.save(img_byte_arr, format='PNG')
162
+ img_byte_arr.seek(0)
163
+
164
+ logger.info("Segmentation successful")
165
+ return Response(
166
+ content=img_byte_arr.getvalue(),
167
+ media_type="image/png",
168
+ headers={
169
+ "Content-Disposition": f"attachment; filename=segmented_{file.filename}"
170
+ }
171
+ )
172
+
173
+ except HTTPException:
174
+ raise
175
+ except Exception as e:
176
+ logger.error(f"Error in segmentation: {e}")
177
+ raise HTTPException(status_code=500, detail=str(e))
178
+
179
+
180
+ @app.post("/segment/mask")
181
+ async def segment_mask(
182
+ file: UploadFile = File(..., description="Image file to segment"),
183
+ model: str = Form("u2netp", description="Model to use"),
184
+ threshold: float = Form(0.5, description="Segmentation threshold (0.0-1.0)", ge=0.0, le=1.0)
185
+ ):
186
+ """
187
+ Segment image and return binary mask only.
188
+
189
+ Returns: PNG image (binary mask - black and white)
190
+ """
191
+ try:
192
+ # Validate model
193
+ if model not in ["u2netp", "birefnet", "rmbg"]:
194
+ raise HTTPException(
195
+ status_code=400,
196
+ detail=f"Invalid model: {model}. Choose from: u2netp, birefnet, rmbg"
197
+ )
198
+
199
+ # Read image
200
+ contents = await file.read()
201
+ nparr = np.frombuffer(contents, np.uint8)
202
+ image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
203
+
204
+ if image is None:
205
+ raise HTTPException(status_code=400, detail="Invalid image file")
206
+
207
+ # Get segmenter
208
+ segmenter = get_segmenter(model)
209
+
210
+ # Segment image
211
+ logger.info(f"Generating mask with model={model}, threshold={threshold}")
212
+ mask, _ = segmenter.segment(image, threshold=threshold, return_type="mask")
213
+
214
+ if mask is None:
215
+ raise HTTPException(status_code=500, detail="Segmentation failed")
216
+
217
+ # Convert to PNG
218
+ _, buffer = cv2.imencode('.png', mask)
219
+
220
+ logger.info("Mask generation successful")
221
+ return Response(
222
+ content=buffer.tobytes(),
223
+ media_type="image/png",
224
+ headers={
225
+ "Content-Disposition": f"attachment; filename=mask_{file.filename}"
226
+ }
227
+ )
228
+
229
+ except HTTPException:
230
+ raise
231
+ except Exception as e:
232
+ logger.error(f"Error in mask generation: {e}")
233
+ raise HTTPException(status_code=500, detail=str(e))
234
+
235
+
236
+ @app.post("/segment/base64")
237
+ async def segment_base64(
238
+ file: UploadFile = File(..., description="Image file to segment"),
239
+ model: str = Form("u2netp", description="Model to use"),
240
+ threshold: float = Form(0.5, description="Segmentation threshold (0.0-1.0)", ge=0.0, le=1.0),
241
+ return_type: str = Form("rgba", description="Return type: rgba, mask, or both")
242
+ ):
243
+ """
244
+ Segment image and return base64 encoded results.
245
+
246
+ Returns: JSON with base64 encoded images
247
+ """
248
+ try:
249
+ # Validate inputs
250
+ if model not in ["u2netp", "birefnet", "rmbg"]:
251
+ raise HTTPException(
252
+ status_code=400,
253
+ detail=f"Invalid model: {model}. Choose from: u2netp, birefnet, rmbg"
254
+ )
255
+
256
+ if return_type not in ["rgba", "mask", "both"]:
257
+ raise HTTPException(
258
+ status_code=400,
259
+ detail=f"Invalid return_type: {return_type}. Choose from: rgba, mask, both"
260
+ )
261
+
262
+ # Read image
263
+ contents = await file.read()
264
+ nparr = np.frombuffer(contents, np.uint8)
265
+ image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
266
+
267
+ if image is None:
268
+ raise HTTPException(status_code=400, detail="Invalid image file")
269
+
270
+ # Get segmenter
271
+ segmenter = get_segmenter(model)
272
+
273
+ # Segment image
274
+ logger.info(f"Segmenting (base64) with model={model}, threshold={threshold}, return_type={return_type}")
275
+ mask, rgba = segmenter.segment(image, threshold=threshold, return_type=return_type)
276
+
277
+ # Prepare response
278
+ response = {
279
+ "success": True,
280
+ "model": model,
281
+ "threshold": threshold
282
+ }
283
+
284
+ # Encode mask if requested
285
+ if return_type in ["mask", "both"] and mask is not None:
286
+ _, buffer = cv2.imencode('.png', mask)
287
+ mask_base64 = base64.b64encode(buffer).decode('utf-8')
288
+ response["mask"] = f"data:image/png;base64,{mask_base64}"
289
+
290
+ # Encode RGBA if requested
291
+ if return_type in ["rgba", "both"] and rgba is not None:
292
+ img_byte_arr = io.BytesIO()
293
+ rgba.save(img_byte_arr, format='PNG')
294
+ rgba_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
295
+ response["rgba"] = f"data:image/png;base64,{rgba_base64}"
296
+
297
+ logger.info("Base64 encoding successful")
298
+ return JSONResponse(content=response)
299
+
300
+ except HTTPException:
301
+ raise
302
+ except Exception as e:
303
+ logger.error(f"Error in base64 encoding: {e}")
304
+ raise HTTPException(status_code=500, detail=str(e))
305
+
306
+
307
+ @app.post("/segment/batch")
308
+ async def segment_batch(
309
+ files: list[UploadFile] = File(..., description="Multiple image files"),
310
+ model: str = Form("u2netp", description="Model to use"),
311
+ threshold: float = Form(0.5, description="Segmentation threshold (0.0-1.0)", ge=0.0, le=1.0)
312
+ ):
313
+ """
314
+ Segment multiple images and return base64 encoded results.
315
+
316
+ Returns: JSON with array of base64 encoded images
317
+ """
318
+ try:
319
+ # Validate model
320
+ if model not in ["u2netp", "birefnet", "rmbg"]:
321
+ raise HTTPException(
322
+ status_code=400,
323
+ detail=f"Invalid model: {model}. Choose from: u2netp, birefnet, rmbg"
324
+ )
325
+
326
+ # Limit batch size
327
+ if len(files) > 10:
328
+ raise HTTPException(
329
+ status_code=400,
330
+ detail="Maximum batch size is 10 images"
331
+ )
332
+
333
+ # Get segmenter
334
+ segmenter = get_segmenter(model)
335
+
336
+ results = []
337
+
338
+ for idx, file in enumerate(files):
339
+ try:
340
+ # Read image
341
+ contents = await file.read()
342
+ nparr = np.frombuffer(contents, np.uint8)
343
+ image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
344
+
345
+ if image is None:
346
+ results.append({
347
+ "filename": file.filename,
348
+ "success": False,
349
+ "error": "Invalid image file"
350
+ })
351
+ continue
352
+
353
+ # Segment
354
+ logger.info(f"Processing batch image {idx+1}/{len(files)}: {file.filename}")
355
+ _, rgba = segmenter.segment(image, threshold=threshold, return_type="rgba")
356
+
357
+ # Encode to base64
358
+ img_byte_arr = io.BytesIO()
359
+ rgba.save(img_byte_arr, format='PNG')
360
+ rgba_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
361
+
362
+ results.append({
363
+ "filename": file.filename,
364
+ "success": True,
365
+ "rgba": f"data:image/png;base64,{rgba_base64}"
366
+ })
367
+
368
+ except Exception as e:
369
+ logger.error(f"Error processing {file.filename}: {e}")
370
+ results.append({
371
+ "filename": file.filename,
372
+ "success": False,
373
+ "error": str(e)
374
+ })
375
+
376
+ logger.info(f"Batch processing complete: {len(results)} images")
377
+ return JSONResponse(content={
378
+ "total": len(files),
379
+ "results": results,
380
+ "model": model,
381
+ "threshold": threshold
382
+ })
383
+
384
+ except HTTPException:
385
+ raise
386
+ except Exception as e:
387
+ logger.error(f"Error in batch processing: {e}")
388
+ raise HTTPException(status_code=500, detail=str(e))
389
+
390
+
391
+ if __name__ == "__main__":
392
+ import uvicorn
393
+
394
+ # For local development
395
+ uvicorn.run(
396
+ "app:app",
397
+ host="0.0.0.0",
398
+ port=7860,
399
+ reload=True
400
+ )
binary_segmentation.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Binary Image Segmentation Tool
3
+ A lightweight, professional implementation for foreground object segmentation.
4
+
5
+ Supports multiple models:
6
+ - U2NETP (fastest, 1.1M params)
7
+ - BiRefNet (best accuracy, larger model)
8
+ - RMBG (good balance)
9
+ """
10
+
11
+ import os
12
+ import logging
13
+ from pathlib import Path
14
+ from typing import Literal, Tuple, Optional
15
+ import numpy as np
16
+ import torch
17
+ from PIL import Image
18
+ from torchvision import transforms
19
+ import cv2
20
+
21
+ # Configure logging
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format='%(asctime)s - %(levelname)s - %(message)s'
25
+ )
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # Device configuration
29
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
+ logger.info(f"Using device: {DEVICE}")
31
+
32
+
33
+ class U2NETP(torch.nn.Module):
34
+ """U2-Net Portrait (U2NETP) - Lightweight segmentation model"""
35
+
36
+ def __init__(self, in_ch=3, out_ch=1):
37
+ super(U2NETP, self).__init__()
38
+
39
+ # Encoder
40
+ self.stage1 = self._make_stage(in_ch, 16, 64)
41
+ self.pool12 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
42
+
43
+ self.stage2 = self._make_stage(64, 16, 64)
44
+ self.pool23 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
45
+
46
+ self.stage3 = self._make_stage(64, 16, 64)
47
+ self.pool34 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
48
+
49
+ self.stage4 = self._make_stage(64, 16, 64)
50
+
51
+ # Bridge
52
+ self.stage5 = self._make_stage(64, 16, 64)
53
+
54
+ # Decoder
55
+ self.stage4d = self._make_stage(128, 16, 64)
56
+ self.stage3d = self._make_stage(128, 16, 64)
57
+ self.stage2d = self._make_stage(128, 16, 64)
58
+ self.stage1d = self._make_stage(128, 16, 64)
59
+
60
+ # Side outputs
61
+ self.side1 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
62
+ self.side2 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
63
+ self.side3 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
64
+ self.side4 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
65
+ self.side5 = torch.nn.Conv2d(64, out_ch, 3, padding=1)
66
+
67
+ # Output fusion
68
+ self.outconv = torch.nn.Conv2d(5 * out_ch, out_ch, 1)
69
+
70
+ def _make_stage(self, in_ch, mid_ch, out_ch):
71
+ return torch.nn.Sequential(
72
+ torch.nn.Conv2d(in_ch, mid_ch, 3, padding=1),
73
+ torch.nn.ReLU(inplace=True),
74
+ torch.nn.Conv2d(mid_ch, mid_ch, 3, padding=1),
75
+ torch.nn.ReLU(inplace=True),
76
+ torch.nn.Conv2d(mid_ch, out_ch, 3, padding=1),
77
+ torch.nn.ReLU(inplace=True)
78
+ )
79
+
80
+ def forward(self, x):
81
+ hx = x
82
+
83
+ # Encoder
84
+ hx1 = self.stage1(hx)
85
+ hx = self.pool12(hx1)
86
+
87
+ hx2 = self.stage2(hx)
88
+ hx = self.pool23(hx2)
89
+
90
+ hx3 = self.stage3(hx)
91
+ hx = self.pool34(hx3)
92
+
93
+ hx4 = self.stage4(hx)
94
+ hx5 = self.stage5(hx4)
95
+
96
+ # Decoder
97
+ hx4d = self.stage4d(torch.cat((hx5, hx4), 1))
98
+ hx4dup = torch.nn.functional.interpolate(hx4d, scale_factor=2, mode='bilinear', align_corners=True)
99
+
100
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
101
+ hx3dup = torch.nn.functional.interpolate(hx3d, scale_factor=2, mode='bilinear', align_corners=True)
102
+
103
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
104
+ hx2dup = torch.nn.functional.interpolate(hx2d, scale_factor=2, mode='bilinear', align_corners=True)
105
+
106
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
107
+
108
+ # Side outputs
109
+ d1 = self.side1(hx1d)
110
+ d2 = torch.nn.functional.interpolate(self.side2(hx2d), size=d1.shape[2:], mode='bilinear', align_corners=True)
111
+ d3 = torch.nn.functional.interpolate(self.side3(hx3d), size=d1.shape[2:], mode='bilinear', align_corners=True)
112
+ d4 = torch.nn.functional.interpolate(self.side4(hx4d), size=d1.shape[2:], mode='bilinear', align_corners=True)
113
+ d5 = torch.nn.functional.interpolate(self.side5(hx5), size=d1.shape[2:], mode='bilinear', align_corners=True)
114
+
115
+ # Fusion
116
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5), 1))
117
+
118
+ return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5)
119
+
120
+
121
+ class BinarySegmenter:
122
+ """
123
+ Professional binary segmentation tool with multiple model backends.
124
+
125
+ Args:
126
+ model_type: Choice of segmentation model
127
+ cache_dir: Directory to cache downloaded models
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ model_type: Literal["u2netp", "birefnet", "rmbg"] = "u2netp",
133
+ cache_dir: str = "./.model_cache"
134
+ ):
135
+ self.model_type = model_type
136
+ self.cache_dir = Path(cache_dir)
137
+ self.cache_dir.mkdir(exist_ok=True)
138
+
139
+ self.model = None
140
+ self.transform = None
141
+ self._load_model()
142
+
143
+ def _load_model(self):
144
+ """Load the specified segmentation model"""
145
+ logger.info(f"Loading {self.model_type} model...")
146
+
147
+ if self.model_type == "u2netp":
148
+ self._load_u2netp()
149
+ elif self.model_type == "birefnet":
150
+ self._load_birefnet()
151
+ elif self.model_type == "rmbg":
152
+ self._load_rmbg()
153
+ else:
154
+ raise ValueError(f"Unknown model type: {self.model_type}")
155
+
156
+ self.model.to(DEVICE)
157
+ self.model.eval()
158
+ logger.info(f"{self.model_type} loaded successfully")
159
+
160
+ def _load_u2netp(self):
161
+ """Load U2NETP model (1.1M parameters, fastest)"""
162
+ self.model = U2NETP(3, 1)
163
+
164
+ # Try to load pretrained weights
165
+ model_path = self.cache_dir / "u2netp.pth"
166
+
167
+ if model_path.exists():
168
+ logger.info(f"Loading weights from {model_path}")
169
+ self.model.load_state_dict(
170
+ torch.load(model_path, map_location=DEVICE)
171
+ )
172
+ else:
173
+ logger.warning(f"No pretrained weights found at {model_path}")
174
+ logger.warning("Download from: https://github.com/xuebinqin/U-2-Net")
175
+
176
+ # Standard ImageNet normalization
177
+ self.transform = transforms.Compose([
178
+ transforms.Resize((320, 320)),
179
+ transforms.ToTensor(),
180
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
181
+ ])
182
+
183
+ def _load_birefnet(self):
184
+ """Load BiRefNet model (best accuracy, larger)"""
185
+ try:
186
+ from transformers import AutoModelForImageSegmentation
187
+
188
+ self.model = AutoModelForImageSegmentation.from_pretrained(
189
+ 'ZhengPeng7/BiRefNet',
190
+ trust_remote_code=True,
191
+ cache_dir=str(self.cache_dir)
192
+ )
193
+
194
+ self.transform = transforms.Compose([
195
+ transforms.Resize((1024, 1024)),
196
+ transforms.ToTensor(),
197
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
198
+ ])
199
+ except ImportError:
200
+ raise ImportError("BiRefNet requires: pip install transformers")
201
+
202
+ def _load_rmbg(self):
203
+ """Load RMBG model (good balance)"""
204
+ try:
205
+ from transformers import AutoModelForImageSegmentation
206
+
207
+ self.model = AutoModelForImageSegmentation.from_pretrained(
208
+ 'briaai/RMBG-1.4',
209
+ trust_remote_code=True,
210
+ cache_dir=str(self.cache_dir)
211
+ )
212
+
213
+ self.transform = transforms.Compose([
214
+ transforms.Resize((1024, 1024)),
215
+ transforms.ToTensor(),
216
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
217
+ ])
218
+ except ImportError:
219
+ raise ImportError("RMBG requires: pip install transformers")
220
+
221
+ def segment(
222
+ self,
223
+ image: np.ndarray,
224
+ threshold: float = 0.5,
225
+ return_type: Literal["mask", "rgba", "both"] = "mask"
226
+ ) -> Tuple[Optional[np.ndarray], Optional[Image.Image]]:
227
+ """
228
+ Segment foreground object from image.
229
+
230
+ Args:
231
+ image: Input image as numpy array (H, W, 3) in RGB or BGR
232
+ threshold: Threshold for binary mask (0-1)
233
+ return_type: What to return - "mask", "rgba", or "both"
234
+
235
+ Returns:
236
+ Tuple of (binary_mask, rgba_image) based on return_type
237
+ """
238
+ # Convert BGR to RGB if needed
239
+ if len(image.shape) == 3 and image.shape[2] == 3:
240
+ if image[0, 0, 0] != image[0, 0, 2]: # Simple heuristic
241
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
242
+ else:
243
+ image_rgb = image
244
+ else:
245
+ raise ValueError("Input must be a color image (H, W, 3)")
246
+
247
+ # Convert to PIL
248
+ image_pil = Image.fromarray(image_rgb)
249
+ original_size = image_pil.size
250
+
251
+ # Transform
252
+ input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE)
253
+
254
+ # Inference
255
+ with torch.no_grad():
256
+ if self.model_type == "u2netp":
257
+ outputs = self.model(input_tensor)
258
+ pred = outputs[0] # Main output
259
+ else: # birefnet or rmbg
260
+ pred = self.model(input_tensor)[-1].sigmoid()
261
+
262
+ # Post-process
263
+ pred = pred.squeeze().cpu().numpy()
264
+
265
+ # Resize to original
266
+ pred_resized = cv2.resize(pred, original_size, interpolation=cv2.INTER_LINEAR)
267
+
268
+ # Normalize to 0-255
269
+ pred_normalized = ((pred_resized - pred_resized.min()) /
270
+ (pred_resized.max() - pred_resized.min() + 1e-8) * 255)
271
+
272
+ # Create binary mask
273
+ binary_mask = (pred_normalized > (threshold * 255)).astype(np.uint8) * 255
274
+
275
+ # Optional: Morphological operations for cleaner mask
276
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
277
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
278
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
279
+
280
+ # Create RGBA if needed
281
+ rgba_image = None
282
+ if return_type in ["rgba", "both"]:
283
+ # Create 4-channel image
284
+ rgba = np.dstack([image_rgb, binary_mask])
285
+ rgba_image = Image.fromarray(rgba, mode='RGBA')
286
+
287
+ # Return based on type
288
+ if return_type == "mask":
289
+ return binary_mask, None
290
+ elif return_type == "rgba":
291
+ return None, rgba_image
292
+ else: # both
293
+ return binary_mask, rgba_image
294
+
295
+ def batch_segment(
296
+ self,
297
+ images: list[np.ndarray],
298
+ threshold: float = 0.5,
299
+ return_type: Literal["mask", "rgba", "both"] = "mask"
300
+ ) -> list:
301
+ """
302
+ Segment multiple images in batch.
303
+
304
+ Args:
305
+ images: List of input images
306
+ threshold: Threshold for binary masks
307
+ return_type: What to return for each image
308
+
309
+ Returns:
310
+ List of segmentation results
311
+ """
312
+ results = []
313
+ for i, img in enumerate(images):
314
+ logger.info(f"Processing image {i+1}/{len(images)}")
315
+ result = self.segment(img, threshold, return_type)
316
+ results.append(result)
317
+ return results
318
+
319
+
320
+ def segment_image_file(
321
+ input_path: str,
322
+ output_path: str,
323
+ model_type: str = "u2netp",
324
+ threshold: float = 0.5,
325
+ save_rgba: bool = True
326
+ ):
327
+ """
328
+ Convenience function to segment an image file.
329
+
330
+ Args:
331
+ input_path: Path to input image
332
+ output_path: Path to save output (mask or RGBA)
333
+ model_type: Model to use
334
+ threshold: Segmentation threshold
335
+ save_rgba: If True, save RGBA; if False, save binary mask
336
+ """
337
+ # Load image
338
+ image = cv2.imread(input_path)
339
+ if image is None:
340
+ raise FileNotFoundError(f"Could not load image: {input_path}")
341
+
342
+ # Create segmenter
343
+ segmenter = BinarySegmenter(model_type=model_type)
344
+
345
+ # Segment
346
+ return_type = "rgba" if save_rgba else "mask"
347
+ mask, rgba = segmenter.segment(image, threshold, return_type)
348
+
349
+ # Save
350
+ output_path = Path(output_path)
351
+ output_path.parent.mkdir(parents=True, exist_ok=True)
352
+
353
+ if save_rgba and rgba is not None:
354
+ rgba.save(output_path)
355
+ logger.info(f"Saved RGBA to: {output_path}")
356
+ elif mask is not None:
357
+ cv2.imwrite(str(output_path), mask)
358
+ logger.info(f"Saved mask to: {output_path}")
359
+
360
+ return str(output_path)
361
+
362
+
363
+ # Example usage
364
+ if __name__ == "__main__":
365
+ import argparse
366
+
367
+ parser = argparse.ArgumentParser(description="Binary image segmentation")
368
+ parser.add_argument("input", help="Input image path")
369
+ parser.add_argument("output", help="Output path")
370
+ parser.add_argument(
371
+ "--model",
372
+ choices=["u2netp", "birefnet", "rmbg"],
373
+ default="u2netp",
374
+ help="Segmentation model"
375
+ )
376
+ parser.add_argument(
377
+ "--threshold",
378
+ type=float,
379
+ default=0.5,
380
+ help="Segmentation threshold (0-1)"
381
+ )
382
+ parser.add_argument(
383
+ "--format",
384
+ choices=["mask", "rgba"],
385
+ default="rgba",
386
+ help="Output format"
387
+ )
388
+
389
+ args = parser.parse_args()
390
+
391
+ # Process
392
+ segment_image_file(
393
+ args.input,
394
+ args.output,
395
+ model_type=args.model,
396
+ threshold=args.threshold,
397
+ save_rgba=(args.format == "rgba")
398
+ )
client_examples.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API Client Examples for Binary Segmentation Service
3
+
4
+ These examples show how to interact with the FastAPI service
5
+ from Python, JavaScript, and curl.
6
+ """
7
+
8
+ import requests
9
+ import base64
10
+ import json
11
+ from pathlib import Path
12
+
13
+
14
+ # =============================================================================
15
+ # Python Client Examples
16
+ # =============================================================================
17
+
18
+ class SegmentationClient:
19
+ """Python client for segmentation API"""
20
+
21
+ def __init__(self, base_url: str = "http://localhost:7860"):
22
+ self.base_url = base_url.rstrip('/')
23
+
24
+ def segment_image(
25
+ self,
26
+ image_path: str,
27
+ output_path: str,
28
+ model: str = "u2netp",
29
+ threshold: float = 0.5
30
+ ):
31
+ """
32
+ Segment image and save as PNG with transparency
33
+
34
+ Args:
35
+ image_path: Path to input image
36
+ output_path: Path to save output PNG
37
+ model: Model to use (u2netp, birefnet, rmbg)
38
+ threshold: Segmentation threshold (0.0-1.0)
39
+ """
40
+ with open(image_path, 'rb') as f:
41
+ files = {'file': f}
42
+ data = {
43
+ 'model': model,
44
+ 'threshold': threshold
45
+ }
46
+
47
+ response = requests.post(
48
+ f"{self.base_url}/segment",
49
+ files=files,
50
+ data=data
51
+ )
52
+
53
+ response.raise_for_status()
54
+
55
+ with open(output_path, 'wb') as out:
56
+ out.write(response.content)
57
+
58
+ print(f"βœ“ Saved to: {output_path}")
59
+
60
+ def get_mask(
61
+ self,
62
+ image_path: str,
63
+ output_path: str,
64
+ model: str = "u2netp",
65
+ threshold: float = 0.5
66
+ ):
67
+ """Get binary mask only"""
68
+ with open(image_path, 'rb') as f:
69
+ files = {'file': f}
70
+ data = {
71
+ 'model': model,
72
+ 'threshold': threshold
73
+ }
74
+
75
+ response = requests.post(
76
+ f"{self.base_url}/segment/mask",
77
+ files=files,
78
+ data=data
79
+ )
80
+
81
+ response.raise_for_status()
82
+
83
+ with open(output_path, 'wb') as out:
84
+ out.write(response.content)
85
+
86
+ print(f"βœ“ Mask saved to: {output_path}")
87
+
88
+ def segment_base64(
89
+ self,
90
+ image_path: str,
91
+ model: str = "u2netp",
92
+ threshold: float = 0.5,
93
+ return_type: str = "both"
94
+ ):
95
+ """
96
+ Get segmentation results as base64
97
+
98
+ Returns:
99
+ dict with 'mask' and/or 'rgba' as base64 strings
100
+ """
101
+ with open(image_path, 'rb') as f:
102
+ files = {'file': f}
103
+ data = {
104
+ 'model': model,
105
+ 'threshold': threshold,
106
+ 'return_type': return_type
107
+ }
108
+
109
+ response = requests.post(
110
+ f"{self.base_url}/segment/base64",
111
+ files=files,
112
+ data=data
113
+ )
114
+
115
+ response.raise_for_status()
116
+ return response.json()
117
+
118
+ def batch_segment(
119
+ self,
120
+ image_paths: list[str],
121
+ model: str = "u2netp",
122
+ threshold: float = 0.5
123
+ ):
124
+ """
125
+ Segment multiple images
126
+
127
+ Args:
128
+ image_paths: List of paths to images (max 10)
129
+
130
+ Returns:
131
+ dict with results for each image
132
+ """
133
+ files = [
134
+ ('files', open(path, 'rb'))
135
+ for path in image_paths
136
+ ]
137
+
138
+ data = {
139
+ 'model': model,
140
+ 'threshold': threshold
141
+ }
142
+
143
+ try:
144
+ response = requests.post(
145
+ f"{self.base_url}/segment/batch",
146
+ files=files,
147
+ data=data
148
+ )
149
+
150
+ response.raise_for_status()
151
+ return response.json()
152
+ finally:
153
+ # Close all file handles
154
+ for _, f in files:
155
+ f.close()
156
+
157
+ def list_models(self):
158
+ """List available models"""
159
+ response = requests.get(f"{self.base_url}/models")
160
+ response.raise_for_status()
161
+ return response.json()
162
+
163
+ def health_check(self):
164
+ """Check service health"""
165
+ response = requests.get(f"{self.base_url}/health")
166
+ response.raise_for_status()
167
+ return response.json()
168
+
169
+
170
+ # =============================================================================
171
+ # Usage Examples
172
+ # =============================================================================
173
+
174
+ def example_basic():
175
+ """Basic usage"""
176
+ client = SegmentationClient("http://localhost:7860")
177
+
178
+ # Segment image
179
+ client.segment_image(
180
+ image_path="input.jpg",
181
+ output_path="output.png",
182
+ model="u2netp",
183
+ threshold=0.5
184
+ )
185
+
186
+
187
+ def example_mask():
188
+ """Get binary mask"""
189
+ client = SegmentationClient("http://localhost:7860")
190
+
191
+ client.get_mask(
192
+ image_path="input.jpg",
193
+ output_path="mask.png",
194
+ model="u2netp",
195
+ threshold=0.5
196
+ )
197
+
198
+
199
+ def example_base64():
200
+ """Get base64 results"""
201
+ client = SegmentationClient("http://localhost:7860")
202
+
203
+ result = client.segment_base64(
204
+ image_path="input.jpg",
205
+ return_type="both"
206
+ )
207
+
208
+ # Save base64 images
209
+ if 'rgba' in result:
210
+ # Remove data URL prefix
211
+ rgba_data = result['rgba'].split(',')[1]
212
+ with open('output_rgba.png', 'wb') as f:
213
+ f.write(base64.b64decode(rgba_data))
214
+
215
+ if 'mask' in result:
216
+ mask_data = result['mask'].split(',')[1]
217
+ with open('output_mask.png', 'wb') as f:
218
+ f.write(base64.b64decode(mask_data))
219
+
220
+
221
+ def example_batch():
222
+ """Process multiple images"""
223
+ client = SegmentationClient("http://localhost:7860")
224
+
225
+ results = client.batch_segment(
226
+ image_paths=["image1.jpg", "image2.jpg", "image3.jpg"],
227
+ model="u2netp",
228
+ threshold=0.5
229
+ )
230
+
231
+ # Save results
232
+ for i, result in enumerate(results['results']):
233
+ if result['success']:
234
+ rgba_data = result['rgba'].split(',')[1]
235
+ with open(f'output_{i}.png', 'wb') as f:
236
+ f.write(base64.b64decode(rgba_data))
237
+
238
+
239
+ def example_models():
240
+ """List available models"""
241
+ client = SegmentationClient("http://localhost:7860")
242
+
243
+ models = client.list_models()
244
+ print(json.dumps(models, indent=2))
245
+
246
+
247
+ # =============================================================================
248
+ # JavaScript Examples (for frontend)
249
+ # =============================================================================
250
+
251
+ JAVASCRIPT_EXAMPLES = """
252
+ // Example 1: Basic fetch
253
+ async function segmentImage(file) {
254
+ const formData = new FormData();
255
+ formData.append('file', file);
256
+ formData.append('model', 'u2netp');
257
+ formData.append('threshold', '0.5');
258
+
259
+ const response = await fetch('/segment', {
260
+ method: 'POST',
261
+ body: formData
262
+ });
263
+
264
+ const blob = await response.blob();
265
+ return URL.createObjectURL(blob);
266
+ }
267
+
268
+ // Example 2: Get base64
269
+ async function segmentBase64(file) {
270
+ const formData = new FormData();
271
+ formData.append('file', file);
272
+ formData.append('model', 'u2netp');
273
+ formData.append('threshold', '0.5');
274
+ formData.append('return_type', 'rgba');
275
+
276
+ const response = await fetch('/segment/base64', {
277
+ method: 'POST',
278
+ body: formData
279
+ });
280
+
281
+ const data = await response.json();
282
+ return data.rgba; // data:image/png;base64,...
283
+ }
284
+
285
+ // Example 3: Batch processing
286
+ async function segmentBatch(files) {
287
+ const formData = new FormData();
288
+
289
+ for (const file of files) {
290
+ formData.append('files', file);
291
+ }
292
+ formData.append('model', 'u2netp');
293
+ formData.append('threshold', '0.5');
294
+
295
+ const response = await fetch('/segment/batch', {
296
+ method: 'POST',
297
+ body: formData
298
+ });
299
+
300
+ return await response.json();
301
+ }
302
+
303
+ // Example 4: With progress
304
+ async function segmentWithProgress(file, onProgress) {
305
+ const formData = new FormData();
306
+ formData.append('file', file);
307
+ formData.append('model', 'u2netp');
308
+ formData.append('threshold', '0.5');
309
+
310
+ const xhr = new XMLHttpRequest();
311
+
312
+ return new Promise((resolve, reject) => {
313
+ xhr.upload.addEventListener('progress', (e) => {
314
+ if (e.lengthComputable) {
315
+ onProgress(e.loaded / e.total);
316
+ }
317
+ });
318
+
319
+ xhr.addEventListener('load', () => {
320
+ if (xhr.status === 200) {
321
+ const blob = xhr.response;
322
+ resolve(URL.createObjectURL(blob));
323
+ } else {
324
+ reject(new Error('Upload failed'));
325
+ }
326
+ });
327
+
328
+ xhr.addEventListener('error', () => reject(new Error('Upload failed')));
329
+
330
+ xhr.open('POST', '/segment');
331
+ xhr.responseType = 'blob';
332
+ xhr.send(formData);
333
+ });
334
+ }
335
+ """
336
+
337
+
338
+ # =============================================================================
339
+ # cURL Examples
340
+ # =============================================================================
341
+
342
+ CURL_EXAMPLES = """
343
+ # Example 1: Basic segmentation
344
+ curl -X POST "http://localhost:7860/segment" \\
345
+ -F "file=@input.jpg" \\
346
+ -F "model=u2netp" \\
347
+ -F "threshold=0.5" \\
348
+ --output result.png
349
+
350
+ # Example 2: Get mask
351
+ curl -X POST "http://localhost:7860/segment/mask" \\
352
+ -F "file=@input.jpg" \\
353
+ -F "model=u2netp" \\
354
+ -F "threshold=0.5" \\
355
+ --output mask.png
356
+
357
+ # Example 3: Get base64 JSON
358
+ curl -X POST "http://localhost:7860/segment/base64" \\
359
+ -F "file=@input.jpg" \\
360
+ -F "model=u2netp" \\
361
+ -F "threshold=0.5" \\
362
+ -F "return_type=both"
363
+
364
+ # Example 4: Batch processing
365
+ curl -X POST "http://localhost:7860/segment/batch" \\
366
+ -F "files=@image1.jpg" \\
367
+ -F "files=@image2.jpg" \\
368
+ -F "files=@image3.jpg" \\
369
+ -F "model=u2netp" \\
370
+ -F "threshold=0.5"
371
+
372
+ # Example 5: List models
373
+ curl -X GET "http://localhost:7860/models"
374
+
375
+ # Example 6: Health check
376
+ curl -X GET "http://localhost:7860/health"
377
+ """
378
+
379
+
380
+ if __name__ == "__main__":
381
+ print("API Client Examples")
382
+ print("=" * 50)
383
+ print("\nPython Examples:")
384
+ print(" example_basic() - Basic segmentation")
385
+ print(" example_mask() - Get binary mask")
386
+ print(" example_base64() - Get base64 results")
387
+ print(" example_batch() - Batch processing")
388
+ print(" example_models() - List models")
389
+ print("\nUncomment the example you want to run!")
390
+
391
+ # Uncomment to run:
392
+ # example_basic()
393
+ # example_mask()
394
+ # example_base64()
395
+ # example_batch()
396
+ # example_models()
index.html ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Background Removal - AI Segmentation</title>
7
+ <style>
8
+ * {
9
+ margin: 0;
10
+ padding: 0;
11
+ box-sizing: border-box;
12
+ }
13
+
14
+ body {
15
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
16
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
17
+ min-height: 100vh;
18
+ padding: 20px;
19
+ }
20
+
21
+ .container {
22
+ max-width: 1200px;
23
+ margin: 0 auto;
24
+ background: white;
25
+ border-radius: 20px;
26
+ box-shadow: 0 20px 60px rgba(0,0,0,0.3);
27
+ overflow: hidden;
28
+ }
29
+
30
+ header {
31
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
32
+ color: white;
33
+ padding: 30px;
34
+ text-align: center;
35
+ }
36
+
37
+ header h1 {
38
+ font-size: 2.5em;
39
+ margin-bottom: 10px;
40
+ }
41
+
42
+ header p {
43
+ font-size: 1.1em;
44
+ opacity: 0.9;
45
+ }
46
+
47
+ .content {
48
+ padding: 40px;
49
+ }
50
+
51
+ .upload-section {
52
+ text-align: center;
53
+ margin-bottom: 40px;
54
+ }
55
+
56
+ .upload-zone {
57
+ border: 3px dashed #667eea;
58
+ border-radius: 15px;
59
+ padding: 60px 40px;
60
+ background: #f8f9ff;
61
+ cursor: pointer;
62
+ transition: all 0.3s;
63
+ position: relative;
64
+ }
65
+
66
+ .upload-zone:hover {
67
+ border-color: #764ba2;
68
+ background: #f0f2ff;
69
+ }
70
+
71
+ .upload-zone.dragover {
72
+ border-color: #764ba2;
73
+ background: #e8ebff;
74
+ transform: scale(1.02);
75
+ }
76
+
77
+ .upload-icon {
78
+ font-size: 4em;
79
+ color: #667eea;
80
+ margin-bottom: 20px;
81
+ }
82
+
83
+ .upload-text {
84
+ font-size: 1.2em;
85
+ color: #333;
86
+ margin-bottom: 10px;
87
+ }
88
+
89
+ .upload-hint {
90
+ color: #666;
91
+ font-size: 0.9em;
92
+ }
93
+
94
+ input[type="file"] {
95
+ display: none;
96
+ }
97
+
98
+ .controls {
99
+ display: grid;
100
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
101
+ gap: 20px;
102
+ margin-bottom: 30px;
103
+ }
104
+
105
+ .control-group {
106
+ display: flex;
107
+ flex-direction: column;
108
+ }
109
+
110
+ .control-group label {
111
+ font-weight: 600;
112
+ margin-bottom: 8px;
113
+ color: #333;
114
+ }
115
+
116
+ select, input[type="range"] {
117
+ padding: 10px;
118
+ border: 2px solid #ddd;
119
+ border-radius: 8px;
120
+ font-size: 1em;
121
+ transition: border-color 0.3s;
122
+ }
123
+
124
+ select:focus, input[type="range"]:focus {
125
+ outline: none;
126
+ border-color: #667eea;
127
+ }
128
+
129
+ .threshold-value {
130
+ display: inline-block;
131
+ background: #667eea;
132
+ color: white;
133
+ padding: 4px 12px;
134
+ border-radius: 20px;
135
+ font-size: 0.9em;
136
+ margin-left: 10px;
137
+ }
138
+
139
+ .btn {
140
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
141
+ color: white;
142
+ border: none;
143
+ padding: 15px 40px;
144
+ font-size: 1.1em;
145
+ font-weight: 600;
146
+ border-radius: 10px;
147
+ cursor: pointer;
148
+ transition: all 0.3s;
149
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4);
150
+ }
151
+
152
+ .btn:hover {
153
+ transform: translateY(-2px);
154
+ box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6);
155
+ }
156
+
157
+ .btn:active {
158
+ transform: translateY(0);
159
+ }
160
+
161
+ .btn:disabled {
162
+ opacity: 0.5;
163
+ cursor: not-allowed;
164
+ }
165
+
166
+ .results {
167
+ display: grid;
168
+ grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
169
+ gap: 30px;
170
+ margin-top: 40px;
171
+ }
172
+
173
+ .result-card {
174
+ background: #f8f9ff;
175
+ border-radius: 15px;
176
+ padding: 20px;
177
+ box-shadow: 0 4px 10px rgba(0,0,0,0.1);
178
+ }
179
+
180
+ .result-card h3 {
181
+ color: #333;
182
+ margin-bottom: 15px;
183
+ font-size: 1.2em;
184
+ }
185
+
186
+ .result-card img {
187
+ width: 100%;
188
+ border-radius: 10px;
189
+ box-shadow: 0 2px 8px rgba(0,0,0,0.1);
190
+ }
191
+
192
+ .download-btn {
193
+ display: block;
194
+ width: 100%;
195
+ margin-top: 15px;
196
+ background: #10b981;
197
+ color: white;
198
+ padding: 10px;
199
+ text-align: center;
200
+ border-radius: 8px;
201
+ text-decoration: none;
202
+ font-weight: 600;
203
+ transition: background 0.3s;
204
+ }
205
+
206
+ .download-btn:hover {
207
+ background: #059669;
208
+ }
209
+
210
+ .loading {
211
+ text-align: center;
212
+ padding: 40px;
213
+ display: none;
214
+ }
215
+
216
+ .loading.active {
217
+ display: block;
218
+ }
219
+
220
+ .spinner {
221
+ border: 4px solid #f3f4f6;
222
+ border-top: 4px solid #667eea;
223
+ border-radius: 50%;
224
+ width: 50px;
225
+ height: 50px;
226
+ animation: spin 1s linear infinite;
227
+ margin: 0 auto 20px;
228
+ }
229
+
230
+ @keyframes spin {
231
+ 0% { transform: rotate(0deg); }
232
+ 100% { transform: rotate(360deg); }
233
+ }
234
+
235
+ .error {
236
+ background: #fee;
237
+ color: #c33;
238
+ padding: 15px;
239
+ border-radius: 8px;
240
+ margin-top: 20px;
241
+ display: none;
242
+ }
243
+
244
+ .error.active {
245
+ display: block;
246
+ }
247
+
248
+ .model-info {
249
+ background: #e8f4f8;
250
+ padding: 15px;
251
+ border-radius: 8px;
252
+ margin-top: 10px;
253
+ font-size: 0.9em;
254
+ color: #555;
255
+ }
256
+ </style>
257
+ </head>
258
+ <body>
259
+ <div class="container">
260
+ <header>
261
+ <h1>🎨 AI Background Removal</h1>
262
+ <p>Remove backgrounds from images using advanced AI models</p>
263
+ </header>
264
+
265
+ <div class="content">
266
+ <div class="upload-section">
267
+ <div class="upload-zone" id="uploadZone">
268
+ <div class="upload-icon">πŸ“</div>
269
+ <div class="upload-text">Click to upload or drag & drop</div>
270
+ <div class="upload-hint">Supports: JPG, PNG, WEBP (Max 10MB)</div>
271
+ <input type="file" id="fileInput" accept="image/*">
272
+ </div>
273
+ </div>
274
+
275
+ <div class="controls">
276
+ <div class="control-group">
277
+ <label for="modelSelect">AI Model</label>
278
+ <select id="modelSelect">
279
+ <option value="u2netp" selected>U2NETP (Fast & Lightweight)</option>
280
+ <option value="birefnet">BiRefNet (Best Quality)</option>
281
+ <option value="rmbg">RMBG (Balanced)</option>
282
+ </select>
283
+ <div class="model-info" id="modelInfo">
284
+ ⚑⚑⚑ Speed | ⭐⭐ Quality | 4.7 MB
285
+ </div>
286
+ </div>
287
+
288
+ <div class="control-group">
289
+ <label for="thresholdRange">
290
+ Threshold <span class="threshold-value" id="thresholdValue">0.5</span>
291
+ </label>
292
+ <input type="range" id="thresholdRange" min="0" max="1" step="0.1" value="0.5">
293
+ </div>
294
+
295
+ <div class="control-group">
296
+ <label for="outputType">Output Type</label>
297
+ <select id="outputType">
298
+ <option value="rgba" selected>Transparent PNG</option>
299
+ <option value="mask">Binary Mask</option>
300
+ <option value="both">Both</option>
301
+ </select>
302
+ </div>
303
+ </div>
304
+
305
+ <button class="btn" id="processBtn" disabled>Process Image</button>
306
+
307
+ <div class="loading" id="loading">
308
+ <div class="spinner"></div>
309
+ <p>Processing your image...</p>
310
+ </div>
311
+
312
+ <div class="error" id="error"></div>
313
+
314
+ <div class="results" id="results"></div>
315
+ </div>
316
+ </div>
317
+
318
+ <script>
319
+ const uploadZone = document.getElementById('uploadZone');
320
+ const fileInput = document.getElementById('fileInput');
321
+ const processBtn = document.getElementById('processBtn');
322
+ const loading = document.getElementById('loading');
323
+ const error = document.getElementById('error');
324
+ const results = document.getElementById('results');
325
+ const modelSelect = document.getElementById('modelSelect');
326
+ const modelInfo = document.getElementById('modelInfo');
327
+ const thresholdRange = document.getElementById('thresholdRange');
328
+ const thresholdValue = document.getElementById('thresholdValue');
329
+ const outputType = document.getElementById('outputType');
330
+
331
+ let selectedFile = null;
332
+
333
+ // Model information
334
+ const modelData = {
335
+ u2netp: { speed: '⚑⚑⚑', quality: '⭐⭐', size: '4.7 MB' },
336
+ birefnet: { speed: '⚑', quality: '⭐⭐⭐', size: '~400 MB' },
337
+ rmbg: { speed: '⚑⚑', quality: '⭐⭐⭐', size: '~200 MB' }
338
+ };
339
+
340
+ // Update model info
341
+ modelSelect.addEventListener('change', () => {
342
+ const model = modelData[modelSelect.value];
343
+ modelInfo.textContent = `${model.speed} Speed | ${model.quality} Quality | ${model.size}`;
344
+ });
345
+
346
+ // Update threshold value
347
+ thresholdRange.addEventListener('input', () => {
348
+ thresholdValue.textContent = thresholdRange.value;
349
+ });
350
+
351
+ // Upload zone click
352
+ uploadZone.addEventListener('click', () => {
353
+ fileInput.click();
354
+ });
355
+
356
+ // Drag and drop
357
+ uploadZone.addEventListener('dragover', (e) => {
358
+ e.preventDefault();
359
+ uploadZone.classList.add('dragover');
360
+ });
361
+
362
+ uploadZone.addEventListener('dragleave', () => {
363
+ uploadZone.classList.remove('dragover');
364
+ });
365
+
366
+ uploadZone.addEventListener('drop', (e) => {
367
+ e.preventDefault();
368
+ uploadZone.classList.remove('dragover');
369
+
370
+ if (e.dataTransfer.files.length > 0) {
371
+ handleFile(e.dataTransfer.files[0]);
372
+ }
373
+ });
374
+
375
+ // File input change
376
+ fileInput.addEventListener('change', (e) => {
377
+ if (e.target.files.length > 0) {
378
+ handleFile(e.target.files[0]);
379
+ }
380
+ });
381
+
382
+ function handleFile(file) {
383
+ if (!file.type.startsWith('image/')) {
384
+ showError('Please select an image file');
385
+ return;
386
+ }
387
+
388
+ if (file.size > 10 * 1024 * 1024) {
389
+ showError('File size must be less than 10MB');
390
+ return;
391
+ }
392
+
393
+ selectedFile = file;
394
+ processBtn.disabled = false;
395
+ uploadZone.querySelector('.upload-text').textContent = `Selected: ${file.name}`;
396
+ uploadZone.querySelector('.upload-icon').textContent = 'βœ…';
397
+ hideError();
398
+ }
399
+
400
+ // Process button
401
+ processBtn.addEventListener('click', async () => {
402
+ if (!selectedFile) return;
403
+
404
+ const formData = new FormData();
405
+ formData.append('file', selectedFile);
406
+ formData.append('model', modelSelect.value);
407
+ formData.append('threshold', thresholdRange.value);
408
+
409
+ processBtn.disabled = true;
410
+ loading.classList.add('active');
411
+ results.innerHTML = '';
412
+ hideError();
413
+
414
+ try {
415
+ let response;
416
+
417
+ if (outputType.value === 'both') {
418
+ // Use base64 endpoint for both outputs
419
+ response = await fetch('/segment/base64', {
420
+ method: 'POST',
421
+ body: formData
422
+ });
423
+
424
+ const data = await response.json();
425
+
426
+ if (!response.ok) {
427
+ throw new Error(data.detail || 'Processing failed');
428
+ }
429
+
430
+ // Display results
431
+ results.innerHTML = '';
432
+
433
+ if (data.rgba) {
434
+ results.innerHTML += `
435
+ <div class="result-card">
436
+ <h3>Transparent PNG</h3>
437
+ <img src="${data.rgba}" alt="Transparent result">
438
+ <a href="${data.rgba}" download="transparent.png" class="download-btn">
439
+ Download PNG
440
+ </a>
441
+ </div>
442
+ `;
443
+ }
444
+
445
+ if (data.mask) {
446
+ results.innerHTML += `
447
+ <div class="result-card">
448
+ <h3>Binary Mask</h3>
449
+ <img src="${data.mask}" alt="Mask result">
450
+ <a href="${data.mask}" download="mask.png" class="download-btn">
451
+ Download Mask
452
+ </a>
453
+ </div>
454
+ `;
455
+ }
456
+
457
+ } else {
458
+ // Use appropriate endpoint
459
+ const endpoint = outputType.value === 'mask' ? '/segment/mask' : '/segment';
460
+ response = await fetch(endpoint, {
461
+ method: 'POST',
462
+ body: formData
463
+ });
464
+
465
+ if (!response.ok) {
466
+ const errorData = await response.json();
467
+ throw new Error(errorData.detail || 'Processing failed');
468
+ }
469
+
470
+ // Get blob
471
+ const blob = await response.blob();
472
+ const url = URL.createObjectURL(blob);
473
+
474
+ // Display result
475
+ const title = outputType.value === 'mask' ? 'Binary Mask' : 'Transparent PNG';
476
+ results.innerHTML = `
477
+ <div class="result-card">
478
+ <h3>${title}</h3>
479
+ <img src="${url}" alt="Result">
480
+ <a href="${url}" download="result.png" class="download-btn">
481
+ Download Image
482
+ </a>
483
+ </div>
484
+ `;
485
+ }
486
+
487
+ } catch (err) {
488
+ showError(err.message);
489
+ } finally {
490
+ loading.classList.remove('active');
491
+ processBtn.disabled = false;
492
+ }
493
+ });
494
+
495
+ function showError(message) {
496
+ error.textContent = message;
497
+ error.classList.add('active');
498
+ }
499
+
500
+ function hideError() {
501
+ error.classList.remove('active');
502
+ }
503
+ </script>
504
+ </body>
505
+ </html>
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.109.0
2
+ uvicorn[standard]==0.27.0
3
+ python-multipart==0.0.6
4
+ torch>=2.0.0
5
+ torchvision>=0.15.0
6
+ numpy>=1.24.0
7
+ opencv-python-headless>=4.8.0
8
+ Pillow>=10.0.0
9
+
10
+ # Optional: For BiRefNet and RMBG models
11
+ # Uncomment if you want to use these models
12
+ # transformers>=4.30.0
13
+ # huggingface-hub>=0.16.0
test_api.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test script for Binary Segmentation API
3
+
4
+ Run this to verify the API is working correctly.
5
+ """
6
+
7
+ import requests
8
+ import sys
9
+ import time
10
+ from pathlib import Path
11
+
12
+
13
+ def test_api(base_url: str = "http://localhost:7860"):
14
+ """Run basic API tests"""
15
+
16
+ print("=" * 60)
17
+ print("Binary Segmentation API - Test Suite")
18
+ print("=" * 60)
19
+ print(f"\nTesting API at: {base_url}\n")
20
+
21
+ # Test 1: Health Check
22
+ print("Test 1: Health Check")
23
+ try:
24
+ response = requests.get(f"{base_url}/health", timeout=5)
25
+ if response.status_code == 200:
26
+ print("βœ“ Health check passed")
27
+ print(f" Response: {response.json()}")
28
+ else:
29
+ print(f"βœ— Health check failed: {response.status_code}")
30
+ return False
31
+ except Exception as e:
32
+ print(f"βœ— Health check failed: {e}")
33
+ print("\n Make sure the API is running:")
34
+ print(" python app.py")
35
+ print(" or")
36
+ print(" uvicorn app:app --host 0.0.0.0 --port 7860")
37
+ return False
38
+
39
+ print()
40
+
41
+ # Test 2: List Models
42
+ print("Test 2: List Models")
43
+ try:
44
+ response = requests.get(f"{base_url}/models", timeout=5)
45
+ if response.status_code == 200:
46
+ print("βœ“ Models endpoint working")
47
+ data = response.json()
48
+ print(f" Available models: {len(data.get('models', []))}")
49
+ for model in data.get('models', []):
50
+ print(f" - {model['name']}: {model['description']}")
51
+ else:
52
+ print(f"βœ— Models endpoint failed: {response.status_code}")
53
+ except Exception as e:
54
+ print(f"βœ— Models endpoint failed: {e}")
55
+
56
+ print()
57
+
58
+ # Test 3: Create test image
59
+ print("Test 3: Create Test Image")
60
+ try:
61
+ import numpy as np
62
+ from PIL import Image
63
+
64
+ # Create a simple test image (100x100 red square on white background)
65
+ img = np.ones((200, 200, 3), dtype=np.uint8) * 255
66
+ img[50:150, 50:150] = [255, 0, 0] # Red square
67
+
68
+ test_img = Image.fromarray(img)
69
+ test_path = Path("test_image.jpg")
70
+ test_img.save(test_path)
71
+
72
+ print(f"βœ“ Test image created: {test_path}")
73
+ except Exception as e:
74
+ print(f"βœ— Failed to create test image: {e}")
75
+ return False
76
+
77
+ print()
78
+
79
+ # Test 4: Segmentation (if test image exists)
80
+ if test_path.exists():
81
+ print("Test 4: Image Segmentation")
82
+ try:
83
+ with open(test_path, 'rb') as f:
84
+ files = {'file': f}
85
+ data = {
86
+ 'model': 'u2netp',
87
+ 'threshold': '0.5'
88
+ }
89
+
90
+ start_time = time.time()
91
+ response = requests.post(
92
+ f"{base_url}/segment",
93
+ files=files,
94
+ data=data,
95
+ timeout=30
96
+ )
97
+ elapsed = time.time() - start_time
98
+
99
+ if response.status_code == 200:
100
+ output_path = Path("test_output.png")
101
+ with open(output_path, 'wb') as out:
102
+ out.write(response.content)
103
+
104
+ print(f"βœ“ Segmentation successful ({elapsed:.2f}s)")
105
+ print(f" Output saved to: {output_path}")
106
+ print(f" Output size: {len(response.content)} bytes")
107
+ else:
108
+ print(f"βœ— Segmentation failed: {response.status_code}")
109
+ print(f" Response: {response.text}")
110
+ except Exception as e:
111
+ print(f"βœ— Segmentation failed: {e}")
112
+
113
+ print()
114
+
115
+ # Test 5: Mask endpoint
116
+ if test_path.exists():
117
+ print("Test 5: Binary Mask")
118
+ try:
119
+ with open(test_path, 'rb') as f:
120
+ files = {'file': f}
121
+ data = {
122
+ 'model': 'u2netp',
123
+ 'threshold': '0.5'
124
+ }
125
+
126
+ response = requests.post(
127
+ f"{base_url}/segment/mask",
128
+ files=files,
129
+ data=data,
130
+ timeout=30
131
+ )
132
+
133
+ if response.status_code == 200:
134
+ mask_path = Path("test_mask.png")
135
+ with open(mask_path, 'wb') as out:
136
+ out.write(response.content)
137
+
138
+ print(f"βœ“ Mask generation successful")
139
+ print(f" Mask saved to: {mask_path}")
140
+ else:
141
+ print(f"βœ— Mask generation failed: {response.status_code}")
142
+ except Exception as e:
143
+ print(f"βœ— Mask generation failed: {e}")
144
+
145
+ print()
146
+
147
+ # Test 6: Base64 endpoint
148
+ if test_path.exists():
149
+ print("Test 6: Base64 Output")
150
+ try:
151
+ with open(test_path, 'rb') as f:
152
+ files = {'file': f}
153
+ data = {
154
+ 'model': 'u2netp',
155
+ 'threshold': '0.5',
156
+ 'return_type': 'both'
157
+ }
158
+
159
+ response = requests.post(
160
+ f"{base_url}/segment/base64",
161
+ files=files,
162
+ data=data,
163
+ timeout=30
164
+ )
165
+
166
+ if response.status_code == 200:
167
+ result = response.json()
168
+ print(f"βœ“ Base64 output successful")
169
+ print(f" Has RGBA: {'rgba' in result}")
170
+ print(f" Has Mask: {'mask' in result}")
171
+ else:
172
+ print(f"βœ— Base64 output failed: {response.status_code}")
173
+ except Exception as e:
174
+ print(f"βœ— Base64 output failed: {e}")
175
+
176
+ print()
177
+
178
+ # Cleanup
179
+ print("Cleanup:")
180
+ try:
181
+ if test_path.exists():
182
+ test_path.unlink()
183
+ print(f" Removed: {test_path}")
184
+
185
+ output_path = Path("test_output.png")
186
+ if output_path.exists():
187
+ output_path.unlink()
188
+ print(f" Removed: {output_path}")
189
+
190
+ mask_path = Path("test_mask.png")
191
+ if mask_path.exists():
192
+ mask_path.unlink()
193
+ print(f" Removed: {mask_path}")
194
+ except Exception as e:
195
+ print(f" Warning: Cleanup failed: {e}")
196
+
197
+ print()
198
+ print("=" * 60)
199
+ print("Test Suite Complete!")
200
+ print("=" * 60)
201
+
202
+ return True
203
+
204
+
205
+ if __name__ == "__main__":
206
+ # Get base URL from command line or use default
207
+ base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:7860"
208
+
209
+ success = test_api(base_url)
210
+
211
+ if success:
212
+ print("\nβœ“ All critical tests passed!")
213
+ print("\nNext steps:")
214
+ print("1. Open http://localhost:7860 in your browser")
215
+ print("2. Upload an image and test the web interface")
216
+ print("3. Deploy to Hugging Face Spaces (see DEPLOYMENT.md)")
217
+ sys.exit(0)
218
+ else:
219
+ print("\nβœ— Some tests failed!")
220
+ print("\nTroubleshooting:")
221
+ print("1. Make sure the server is running:")
222
+ print(" uvicorn app:app --host 0.0.0.0 --port 7860")
223
+ print("2. Check that u2netp.pth is in .model_cache/")
224
+ print("3. Check logs for errors")
225
+ sys.exit(1)