Karan6933 commited on
Commit
538c943
·
verified ·
1 Parent(s): 2069ca1

Upload 7 files

Browse files
Files changed (7) hide show
  1. Dockerfile +33 -0
  2. app/main.py +89 -0
  3. app/model.py +199 -0
  4. app/prompt.py +23 -0
  5. app/schemas.py +41 -0
  6. requirements.txt +11 -0
  7. run.sh +12 -0
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile
2
+ FROM python:3.11-slim
3
+
4
+ # Set environment variables for Hugging Face cache optimization
5
+ ENV PYTHONUNBUFFERED=1 \
6
+ PYTHONDONTWRITEBYTECODE=1 \
7
+ HF_HOME=/tmp/.huggingface \
8
+ TRANSFORMERS_CACHE=/tmp/.cache/huggingface \
9
+ HF_HUB_CACHE=/tmp/.cache/huggingface/hub
10
+
11
+ # Install minimal system dependencies
12
+ RUN apt-get update && apt-get install -y --no-install-recommends \
13
+ git \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # Set working directory
17
+ WORKDIR /app
18
+
19
+ # Copy requirements first for layer caching
20
+ COPY requirements.txt .
21
+ RUN pip install --no-cache-dir -r requirements.txt
22
+
23
+ # Copy application code
24
+ COPY app/ ./app/
25
+
26
+ # Create cache directories
27
+ RUN mkdir -p /tmp/.cache/huggingface
28
+
29
+ # Expose Hugging Face Spaces default port
30
+ EXPOSE 7860
31
+
32
+ # Run the application
33
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
app/main.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/main.py
2
+ """
3
+ FastAPI application for serving Nanbeige4.1-3B model.
4
+ Optimized for Hugging Face Spaces (CPU, Docker).
5
+ """
6
+
7
+ import asyncio
8
+ from contextlib import asynccontextmanager
9
+
10
+ from fastapi import FastAPI
11
+ from fastapi.responses import StreamingResponse, JSONResponse
12
+
13
+ from app.model import load_model, generate_stream, generate
14
+ from app.prompt import build_prompt
15
+ from app.schemas import GenerationRequest, GenerationResponse
16
+
17
+
18
+ @asynccontextmanager
19
+ async def lifespan(app: FastAPI):
20
+ """
21
+ Lifespan context manager for startup/shutdown events.
22
+ Loads model on startup to ensure it's ready for requests.
23
+ """
24
+ # Startup: Load model
25
+ print("Loading model...")
26
+ load_model()
27
+ print("Model loaded successfully")
28
+ yield
29
+ # Shutdown: Cleanup (if needed)
30
+ print("Shutting down...")
31
+
32
+
33
+ app = FastAPI(
34
+ title="Nanbeige4.1-3B API",
35
+ description="FastAPI wrapper for Nanbeige4.1-3B with streaming support",
36
+ version="1.0.0",
37
+ lifespan=lifespan
38
+ )
39
+
40
+
41
+ @app.get("/")
42
+ async def health_check():
43
+ """Health check endpoint."""
44
+ return {"status": "ok", "model": "Nanbeige4.1-3B"}
45
+
46
+
47
+ @app.post("/generate")
48
+ async def generate_text(request: GenerationRequest):
49
+ """
50
+ Generate text from prompt.
51
+ Supports both streaming and non-streaming responses.
52
+ """
53
+ # Build final prompt with system instructions
54
+ final_prompt = build_prompt(request.prompt)
55
+
56
+ if request.stream:
57
+ # Streaming response
58
+ async def stream_generator():
59
+ # Run sync generator in thread pool to not block
60
+ loop = asyncio.get_event_loop()
61
+ sync_gen = generate_stream(
62
+ final_prompt,
63
+ temperature=request.temperature,
64
+ max_tokens=request.max_tokens
65
+ )
66
+
67
+ for chunk in sync_gen:
68
+ if chunk:
69
+ # SSE format
70
+ yield f"data: {chunk}\n\n"
71
+
72
+ yield "data: [DONE]\n\n"
73
+
74
+ return StreamingResponse(
75
+ stream_generator(),
76
+ media_type="text/event-stream",
77
+ headers={
78
+ "Cache-Control": "no-cache",
79
+ "Connection": "keep-alive",
80
+ }
81
+ )
82
+ else:
83
+ # Non-streaming response
84
+ result = generate(
85
+ final_prompt,
86
+ temperature=request.temperature,
87
+ max_tokens=request.max_tokens
88
+ )
89
+ return GenerationResponse(text=result)
app/model.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/model.py
2
+ """
3
+ Model loading and inference utilities for Nanbeige/Nanbeige4.1-3B.
4
+ Implements singleton pattern to ensure model loads only once.
5
+ """
6
+
7
+ import gc
8
+ import os
9
+ from typing import Generator, Optional
10
+
11
+ import torch
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
13
+
14
+ # Global singleton instances
15
+ _tokenizer: Optional[AutoTokenizer] = None
16
+ _model: Optional[AutoModelForCausalLM] = None
17
+
18
+
19
+ def get_quantization_config() -> Optional[BitsAndBytesConfig]:
20
+ """
21
+ Configure 4-bit quantization for CPU memory efficiency.
22
+ Returns None if bitsandbytes is not available or on CPU.
23
+ """
24
+ try:
25
+ # 4-bit quantization config for minimal memory footprint
26
+ return BitsAndBytesConfig(
27
+ load_in_4bit=True,
28
+ bnb_4bit_compute_dtype=torch.float16,
29
+ bnb_4bit_quant_type="nf4",
30
+ bnb_4bit_use_double_quant=True,
31
+ )
32
+ except Exception:
33
+ return None
34
+
35
+
36
+ def load_model() -> tuple[AutoTokenizer, AutoModelForCausalLM]:
37
+ """
38
+ Load tokenizer and model with singleton pattern.
39
+ Loads only on first call, returns cached instances thereafter.
40
+
41
+ Returns:
42
+ Tuple of (tokenizer, model)
43
+ """
44
+ global _tokenizer, _model
45
+
46
+ if _tokenizer is not None and _model is not None:
47
+ return _tokenizer, _model
48
+
49
+ model_name = "Nanbeige/Nanbeige4.1-3B"
50
+
51
+ # Load tokenizer
52
+ _tokenizer = AutoTokenizer.from_pretrained(
53
+ model_name,
54
+ use_fast=False,
55
+ trust_remote_code=True
56
+ )
57
+
58
+ # Configure model loading for CPU
59
+ # Use torch.float16 for memory efficiency on CPU
60
+ model_kwargs = {
61
+ "torch_dtype": torch.float16,
62
+ "trust_remote_code": True,
63
+ "low_cpu_mem_usage": True,
64
+ }
65
+
66
+ # Try to use quantization if available, otherwise use standard loading
67
+ quant_config = get_quantization_config()
68
+ if quant_config is not None:
69
+ model_kwargs["quantization_config"] = quant_config
70
+
71
+ # Load model
72
+ _model = AutoModelForCausalLM.from_pretrained(
73
+ model_name,
74
+ **model_kwargs
75
+ )
76
+
77
+ # Ensure model is in eval mode
78
+ _model.eval()
79
+
80
+ # Clear cache to free memory
81
+ gc.collect()
82
+ if torch.cuda.is_available():
83
+ torch.cuda.empty_cache()
84
+
85
+ return _tokenizer, _model
86
+
87
+
88
+ def generate_stream(
89
+ prompt: str,
90
+ temperature: float = 0.7,
91
+ max_tokens: int = 200
92
+ ) -> Generator[str, None, None]:
93
+ """
94
+ Generate text in streaming fashion.
95
+
96
+ Args:
97
+ prompt: Input prompt text
98
+ temperature: Sampling temperature
99
+ max_tokens: Maximum tokens to generate
100
+
101
+ Yields:
102
+ Text chunks as they are generated
103
+ """
104
+ tokenizer, model = load_model()
105
+
106
+ # Tokenize input
107
+ inputs = tokenizer(
108
+ prompt,
109
+ return_tensors="pt",
110
+ add_special_tokens=False
111
+ )
112
+
113
+ # Move to same device as model
114
+ input_ids = inputs.input_ids.to(model.device)
115
+
116
+ # Generation parameters optimized for Nanbeige
117
+ generation_kwargs = {
118
+ "input_ids": input_ids,
119
+ "max_new_tokens": max_tokens,
120
+ "temperature": temperature,
121
+ "top_p": 0.95,
122
+ "do_sample": True,
123
+ "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
124
+ "eos_token_id": tokenizer.eos_token_id,
125
+ }
126
+
127
+ # Stream generation using generate with streamer
128
+ from transformers import TextIteratorStreamer
129
+ from threading import Thread
130
+
131
+ streamer = TextIteratorStreamer(
132
+ tokenizer,
133
+ skip_prompt=True,
134
+ skip_special_tokens=True
135
+ )
136
+ generation_kwargs["streamer"] = streamer
137
+
138
+ # Run generation in separate thread to enable streaming
139
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
140
+ thread.start()
141
+
142
+ generated_text = ""
143
+ for text in streamer:
144
+ generated_text += text
145
+ yield text
146
+
147
+ thread.join()
148
+
149
+ # Cleanup
150
+ gc.collect()
151
+
152
+
153
+ def generate(
154
+ prompt: str,
155
+ temperature: float = 0.7,
156
+ max_tokens: int = 200
157
+ ) -> str:
158
+ """
159
+ Generate text non-streaming (full response).
160
+
161
+ Args:
162
+ prompt: Input prompt text
163
+ temperature: Sampling temperature
164
+ max_tokens: Maximum tokens to generate
165
+
166
+ Returns:
167
+ Complete generated text
168
+ """
169
+ tokenizer, model = load_model()
170
+
171
+ # Tokenize input
172
+ inputs = tokenizer(
173
+ prompt,
174
+ return_tensors="pt",
175
+ add_special_tokens=False
176
+ )
177
+
178
+ input_ids = inputs.input_ids.to(model.device)
179
+
180
+ # Generate
181
+ with torch.no_grad():
182
+ output_ids = model.generate(
183
+ input_ids,
184
+ max_new_tokens=max_tokens,
185
+ temperature=temperature,
186
+ top_p=0.95,
187
+ do_sample=True,
188
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
189
+ eos_token_id=tokenizer.eos_token_id,
190
+ )
191
+
192
+ # Decode only the new tokens
193
+ new_tokens = output_ids[0][len(input_ids[0]):]
194
+ response = tokenizer.decode(new_tokens, skip_special_tokens=True)
195
+
196
+ # Cleanup
197
+ gc.collect()
198
+
199
+ return response
app/prompt.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/prompt.py
2
+ """
3
+ Prompt building utilities for Nanbeige model.
4
+ """
5
+
6
+ SYSTEM_PROMPT = """Tu ek helpful assistant hai. Hamesha concise aur accurate jawab de.
7
+ - Agar kuch pata nahi, toh clearly bol "I don't know"
8
+ - Kabhi bhi mat bol ki tu ek AI model hai
9
+ - Sirf verified information share kar
10
+ - Hinglish mein baat kar"""
11
+
12
+
13
+ def build_prompt(user_input: str) -> str:
14
+ """
15
+ Build the final prompt by combining system prompt with user input.
16
+
17
+ Args:
18
+ user_input: Raw user query/input
19
+
20
+ Returns:
21
+ Formatted prompt string ready for model inference
22
+ """
23
+ return f"{SYSTEM_PROMPT}\n\nUser: {user_input}\nAssistant:"
app/schemas.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/schemas.py
2
+ """
3
+ Pydantic schemas for API request/response validation.
4
+ """
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ class GenerationRequest(BaseModel):
10
+ """Request schema for text generation endpoint."""
11
+
12
+ prompt: str = Field(
13
+ ...,
14
+ min_length=1,
15
+ description="Input prompt text"
16
+ )
17
+ temperature: float = Field(
18
+ default=0.7,
19
+ ge=0.0,
20
+ le=2.0,
21
+ description="Sampling temperature"
22
+ )
23
+ max_tokens: int = Field(
24
+ default=200,
25
+ ge=1,
26
+ le=2048,
27
+ description="Maximum tokens to generate"
28
+ )
29
+ stream: bool = Field(
30
+ default=True,
31
+ description="Whether to stream the response"
32
+ )
33
+
34
+
35
+ class GenerationResponse(BaseModel):
36
+ """Response schema for non-streaming generation."""
37
+
38
+ text: str = Field(
39
+ ...,
40
+ description="Generated text response"
41
+ )
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+ fastapi==0.115.0
3
+ uvicorn[standard]==0.32.0
4
+ pydantic==2.9.0
5
+ transformers==4.46.0
6
+ torch==2.5.0
7
+ accelerate==1.0.0
8
+ sentencepiece==0.2.0
9
+ bitsandbytes==0.44.0
10
+ huggingface-hub==0.26.0
11
+ python-multipart==0.0.12
run.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # run.sh
3
+ # Production startup script for uvicorn server
4
+
5
+ exec uvicorn app.main:app \
6
+ --host 0.0.0.0 \
7
+ --port 7860 \
8
+ --workers 1 \
9
+ --loop uvloop \
10
+ --http httptools \
11
+ --proxy-headers \
12
+ --forwarded-allow-ips '*'