yichuan-huang commited on
Commit
fd09229
·
1 Parent(s): d958a06

fix the bug

Browse files
Files changed (2) hide show
  1. classifier.py +10 -32
  2. config.py +5 -6
classifier.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoProcessor, Gemma3nForConditionalGeneration
2
  from PIL import Image
3
  import torch
4
  import logging
@@ -34,11 +34,11 @@ class GarbageClassifier:
34
  )
35
 
36
  # Load model
37
- self.model = Gemma3nForConditionalGeneration.from_pretrained(
38
  self.config.MODEL_NAME,
39
  torch_dtype=self.config.TORCH_DTYPE,
40
  device_map=self.config.DEVICE_MAP,
41
- ).eval()
42
 
43
  self.logger.info("Model loaded successfully")
44
 
@@ -138,40 +138,18 @@ class GarbageClassifier:
138
  tokenize=True,
139
  return_dict=True,
140
  return_tensors="pt",
141
- )
142
-
143
- # Move inputs to model device and set dtype
144
- inputs = inputs.to(self.model.device, dtype=self.model.dtype)
145
  input_len = inputs["input_ids"].shape[-1]
146
 
147
- # Generate response
148
- with torch.no_grad():
149
- generation_kwargs = {
150
- "max_new_tokens": self.config.MAX_NEW_TOKENS,
151
- "pad_token_id": self.processor.tokenizer.eos_token_id,
152
- "disable_compile": True, # Important for stability
153
- }
154
-
155
- if self.config.DO_SAMPLE:
156
- generation_kwargs.update(
157
- {
158
- "do_sample": True,
159
- "temperature": self.config.TEMPERATURE,
160
- "top_p": self.config.TOP_P,
161
- "top_k": self.config.TOP_K,
162
- }
163
- )
164
- else:
165
- generation_kwargs["do_sample"] = False
166
-
167
- outputs = self.model.generate(**inputs, **generation_kwargs)
168
-
169
- # Decode response
170
  response = self.processor.batch_decode(
171
  outputs[:, input_len:],
172
  skip_special_tokens=True,
173
- clean_up_tokenization_spaces=True,
174
- )[0]
175
 
176
  # Extract classification from response
177
  classification = self._extract_classification(response)
 
1
+ from transformers import AutoProcessor, AutoModelForImageTextToText
2
  from PIL import Image
3
  import torch
4
  import logging
 
34
  )
35
 
36
  # Load model
37
+ self.model = AutoModelForImageTextToText.from_pretrained(
38
  self.config.MODEL_NAME,
39
  torch_dtype=self.config.TORCH_DTYPE,
40
  device_map=self.config.DEVICE_MAP,
41
+ )
42
 
43
  self.logger.info("Model loaded successfully")
44
 
 
138
  tokenize=True,
139
  return_dict=True,
140
  return_tensors="pt",
141
+ ).to(self.model.device, dtype=self.model.dtype)
 
 
 
142
  input_len = inputs["input_ids"].shape[-1]
143
 
144
+ outputs = self.model.generate(
145
+ **inputs,
146
+ max_new_tokens=self.config.MAX_NEW_TOKENS,
147
+ disable_compile=True,
148
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  response = self.processor.batch_decode(
150
  outputs[:, input_len:],
151
  skip_special_tokens=True,
152
+ )
 
153
 
154
  # Extract classification from response
155
  classification = self._extract_classification(response)
config.py CHANGED
@@ -9,15 +9,14 @@ class Config:
9
  MODEL_NAME: str = "google/gemma-3n-E2B-it"
10
 
11
  # Generation parameters
12
- MAX_NEW_TOKENS: int = 256
13
- TEMPERATURE: float = 0.3
14
- DO_SAMPLE: bool = True
15
- TOP_P: float = 0.8
16
- TOP_K: int = 40
17
 
18
  # Device configuration
19
  TORCH_DTYPE: str = torch.bfloat16
20
- DEVICE_MAP: str = "auto"
 
 
 
21
 
22
  # Image preprocessing
23
  IMAGE_SIZE: int = 512
 
9
  MODEL_NAME: str = "google/gemma-3n-E2B-it"
10
 
11
  # Generation parameters
12
+ MAX_NEW_TOKENS: int = 512
 
 
 
 
13
 
14
  # Device configuration
15
  TORCH_DTYPE: str = torch.bfloat16
16
+ if torch.cuda.is_available():
17
+ DEVICE_MAP: str = "cuda:0" # Use first GPU if available
18
+ else:
19
+ DEVICE_MAP: str = "cpu"
20
 
21
  # Image preprocessing
22
  IMAGE_SIZE: int = 512