pyedward commited on
Commit
d7810bb
·
verified ·
1 Parent(s): 025dfa3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -21
app.py CHANGED
@@ -3,48 +3,33 @@ from PIL import Image
3
  import torch
4
  from transformers import AutoImageProcessor, AutoModelForImageClassification
5
 
6
- # ----------------------
7
- # Load model + processor
8
- # ----------------------
9
  processor = AutoImageProcessor.from_pretrained("prithivMLmods/Weather-Image-Classification")
10
  model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Weather-Image-Classification")
11
 
12
- # ----------------------
13
  # Inference function
14
- # ----------------------
15
- def classify_weather(image_file):
16
  try:
17
- # Open the uploaded file
18
- image = Image.open(image_file).convert("RGB")
19
 
20
- # Preprocess
21
- inputs = processor(images=[image], return_tensors="pt")
22
-
23
- # Inference
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
  logits = outputs.logits.squeeze()
27
  probs = torch.softmax(logits, dim=-1).tolist()
28
  labels = [model.config.id2label[i] for i in range(len(probs))]
29
-
30
- # Return label -> probability dictionary
31
  return dict(zip(labels, probs))
32
-
33
- except Exception as e:
34
- # Safe fallback if something unexpected happens
35
  return {"Error": 1.0}
36
 
37
- # ----------------------
38
  # Gradio interface
39
- # ----------------------
40
  iface = gr.Interface(
41
  fn=classify_weather,
42
- inputs=gr.File(file_types=[".jpg", ".png"]), # Accept uploaded files
43
  outputs=gr.Label(num_top_classes=5, label="Weather Condition"),
44
  title="Weather Image Classification",
45
  description="Upload an image to classify the weather condition (sun, rain, snow, fog, or clouds)."
46
  )
47
 
48
- # Launch the Space with error reporting
49
  if __name__ == "__main__":
50
  iface.launch(show_error=True)
 
3
  import torch
4
  from transformers import AutoImageProcessor, AutoModelForImageClassification
5
 
6
+ # Load model
 
 
7
  processor = AutoImageProcessor.from_pretrained("prithivMLmods/Weather-Image-Classification")
8
  model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Weather-Image-Classification")
9
 
 
10
  # Inference function
11
+ def classify_weather(image_input):
 
12
  try:
13
+ # PIL image guaranteed by Gradio
14
+ inputs = processor(images=[image_input], return_tensors="pt")
15
 
 
 
 
 
16
  with torch.no_grad():
17
  outputs = model(**inputs)
18
  logits = outputs.logits.squeeze()
19
  probs = torch.softmax(logits, dim=-1).tolist()
20
  labels = [model.config.id2label[i] for i in range(len(probs))]
 
 
21
  return dict(zip(labels, probs))
22
+ except Exception:
 
 
23
  return {"Error": 1.0}
24
 
 
25
  # Gradio interface
 
26
  iface = gr.Interface(
27
  fn=classify_weather,
28
+ inputs=gr.Image(type="pil"), # PIL input
29
  outputs=gr.Label(num_top_classes=5, label="Weather Condition"),
30
  title="Weather Image Classification",
31
  description="Upload an image to classify the weather condition (sun, rain, snow, fog, or clouds)."
32
  )
33
 
 
34
  if __name__ == "__main__":
35
  iface.launch(show_error=True)