pyedward commited on
Commit
025dfa3
·
verified ·
1 Parent(s): 268f498

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -19
app.py CHANGED
@@ -1,49 +1,50 @@
1
  import gradio as gr
2
- import numpy as np
3
  from PIL import Image
4
  import torch
5
  from transformers import AutoImageProcessor, AutoModelForImageClassification
6
 
 
7
  # Load model + processor
 
8
  processor = AutoImageProcessor.from_pretrained("prithivMLmods/Weather-Image-Classification")
9
  model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Weather-Image-Classification")
10
 
11
- def classify_weather(image_input):
 
 
 
12
  try:
13
- if isinstance(image_input, np.ndarray):
14
- image = Image.fromarray(image_input.astype("uint8")).convert("RGB")
15
- else:
16
- raise TypeError("Only NumPy array input is supported.")
17
 
18
- # preprocess as batch
19
  inputs = processor(images=[image], return_tensors="pt")
20
 
21
- # inference
22
  with torch.no_grad():
23
  outputs = model(**inputs)
24
- logits = outputs.logits
25
- predicted_class_id = logits.argmax(-1).item()
26
- predicted_label = model.config.id2label[predicted_class_id]
27
-
28
- # optional: return probabilities for Label(num_top_classes=5)
29
- probs = torch.softmax(logits, dim=-1).squeeze().tolist()
30
  labels = [model.config.id2label[i] for i in range(len(probs))]
31
- output_dict = dict(zip(labels, probs))
32
 
33
- return output_dict
 
34
 
35
  except Exception as e:
36
- return {"Error": str(e)}
37
-
38
 
 
39
  # Gradio interface
 
40
  iface = gr.Interface(
41
  fn=classify_weather,
42
- inputs=gr.Image(type="numpy"),
43
  outputs=gr.Label(num_top_classes=5, label="Weather Condition"),
44
  title="Weather Image Classification",
45
  description="Upload an image to classify the weather condition (sun, rain, snow, fog, or clouds)."
46
  )
47
 
 
48
  if __name__ == "__main__":
49
  iface.launch(show_error=True)
 
1
  import gradio as gr
 
2
  from PIL import Image
3
  import torch
4
  from transformers import AutoImageProcessor, AutoModelForImageClassification
5
 
6
+ # ----------------------
7
  # Load model + processor
8
+ # ----------------------
9
  processor = AutoImageProcessor.from_pretrained("prithivMLmods/Weather-Image-Classification")
10
  model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Weather-Image-Classification")
11
 
12
+ # ----------------------
13
+ # Inference function
14
+ # ----------------------
15
+ def classify_weather(image_file):
16
  try:
17
+ # Open the uploaded file
18
+ image = Image.open(image_file).convert("RGB")
 
 
19
 
20
+ # Preprocess
21
  inputs = processor(images=[image], return_tensors="pt")
22
 
23
+ # Inference
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
+ logits = outputs.logits.squeeze()
27
+ probs = torch.softmax(logits, dim=-1).tolist()
 
 
 
 
28
  labels = [model.config.id2label[i] for i in range(len(probs))]
 
29
 
30
+ # Return label -> probability dictionary
31
+ return dict(zip(labels, probs))
32
 
33
  except Exception as e:
34
+ # Safe fallback if something unexpected happens
35
+ return {"Error": 1.0}
36
 
37
+ # ----------------------
38
  # Gradio interface
39
+ # ----------------------
40
  iface = gr.Interface(
41
  fn=classify_weather,
42
+ inputs=gr.File(file_types=[".jpg", ".png"]), # Accept uploaded files
43
  outputs=gr.Label(num_top_classes=5, label="Weather Condition"),
44
  title="Weather Image Classification",
45
  description="Upload an image to classify the weather condition (sun, rain, snow, fog, or clouds)."
46
  )
47
 
48
+ # Launch the Space with error reporting
49
  if __name__ == "__main__":
50
  iface.launch(show_error=True)