Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
fd09229
1
Parent(s):
d958a06
fix the bug
Browse files- classifier.py +10 -32
- config.py +5 -6
classifier.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from transformers import AutoProcessor,
|
| 2 |
from PIL import Image
|
| 3 |
import torch
|
| 4 |
import logging
|
|
@@ -34,11 +34,11 @@ class GarbageClassifier:
|
|
| 34 |
)
|
| 35 |
|
| 36 |
# Load model
|
| 37 |
-
self.model =
|
| 38 |
self.config.MODEL_NAME,
|
| 39 |
torch_dtype=self.config.TORCH_DTYPE,
|
| 40 |
device_map=self.config.DEVICE_MAP,
|
| 41 |
-
)
|
| 42 |
|
| 43 |
self.logger.info("Model loaded successfully")
|
| 44 |
|
|
@@ -138,40 +138,18 @@ class GarbageClassifier:
|
|
| 138 |
tokenize=True,
|
| 139 |
return_dict=True,
|
| 140 |
return_tensors="pt",
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
# Move inputs to model device and set dtype
|
| 144 |
-
inputs = inputs.to(self.model.device, dtype=self.model.dtype)
|
| 145 |
input_len = inputs["input_ids"].shape[-1]
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
"disable_compile": True, # Important for stability
|
| 153 |
-
}
|
| 154 |
-
|
| 155 |
-
if self.config.DO_SAMPLE:
|
| 156 |
-
generation_kwargs.update(
|
| 157 |
-
{
|
| 158 |
-
"do_sample": True,
|
| 159 |
-
"temperature": self.config.TEMPERATURE,
|
| 160 |
-
"top_p": self.config.TOP_P,
|
| 161 |
-
"top_k": self.config.TOP_K,
|
| 162 |
-
}
|
| 163 |
-
)
|
| 164 |
-
else:
|
| 165 |
-
generation_kwargs["do_sample"] = False
|
| 166 |
-
|
| 167 |
-
outputs = self.model.generate(**inputs, **generation_kwargs)
|
| 168 |
-
|
| 169 |
-
# Decode response
|
| 170 |
response = self.processor.batch_decode(
|
| 171 |
outputs[:, input_len:],
|
| 172 |
skip_special_tokens=True,
|
| 173 |
-
|
| 174 |
-
)[0]
|
| 175 |
|
| 176 |
# Extract classification from response
|
| 177 |
classification = self._extract_classification(response)
|
|
|
|
| 1 |
+
from transformers import AutoProcessor, AutoModelForImageTextToText
|
| 2 |
from PIL import Image
|
| 3 |
import torch
|
| 4 |
import logging
|
|
|
|
| 34 |
)
|
| 35 |
|
| 36 |
# Load model
|
| 37 |
+
self.model = AutoModelForImageTextToText.from_pretrained(
|
| 38 |
self.config.MODEL_NAME,
|
| 39 |
torch_dtype=self.config.TORCH_DTYPE,
|
| 40 |
device_map=self.config.DEVICE_MAP,
|
| 41 |
+
)
|
| 42 |
|
| 43 |
self.logger.info("Model loaded successfully")
|
| 44 |
|
|
|
|
| 138 |
tokenize=True,
|
| 139 |
return_dict=True,
|
| 140 |
return_tensors="pt",
|
| 141 |
+
).to(self.model.device, dtype=self.model.dtype)
|
|
|
|
|
|
|
|
|
|
| 142 |
input_len = inputs["input_ids"].shape[-1]
|
| 143 |
|
| 144 |
+
outputs = self.model.generate(
|
| 145 |
+
**inputs,
|
| 146 |
+
max_new_tokens=self.config.MAX_NEW_TOKENS,
|
| 147 |
+
disable_compile=True,
|
| 148 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
response = self.processor.batch_decode(
|
| 150 |
outputs[:, input_len:],
|
| 151 |
skip_special_tokens=True,
|
| 152 |
+
)
|
|
|
|
| 153 |
|
| 154 |
# Extract classification from response
|
| 155 |
classification = self._extract_classification(response)
|
config.py
CHANGED
|
@@ -9,15 +9,14 @@ class Config:
|
|
| 9 |
MODEL_NAME: str = "google/gemma-3n-E2B-it"
|
| 10 |
|
| 11 |
# Generation parameters
|
| 12 |
-
MAX_NEW_TOKENS: int =
|
| 13 |
-
TEMPERATURE: float = 0.3
|
| 14 |
-
DO_SAMPLE: bool = True
|
| 15 |
-
TOP_P: float = 0.8
|
| 16 |
-
TOP_K: int = 40
|
| 17 |
|
| 18 |
# Device configuration
|
| 19 |
TORCH_DTYPE: str = torch.bfloat16
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# Image preprocessing
|
| 23 |
IMAGE_SIZE: int = 512
|
|
|
|
| 9 |
MODEL_NAME: str = "google/gemma-3n-E2B-it"
|
| 10 |
|
| 11 |
# Generation parameters
|
| 12 |
+
MAX_NEW_TOKENS: int = 512
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# Device configuration
|
| 15 |
TORCH_DTYPE: str = torch.bfloat16
|
| 16 |
+
if torch.cuda.is_available():
|
| 17 |
+
DEVICE_MAP: str = "cuda:0" # Use first GPU if available
|
| 18 |
+
else:
|
| 19 |
+
DEVICE_MAP: str = "cpu"
|
| 20 |
|
| 21 |
# Image preprocessing
|
| 22 |
IMAGE_SIZE: int = 512
|