pyedward commited on
Commit
268f498
·
verified ·
1 Parent(s): b1e7cc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -49
app.py CHANGED
@@ -1,49 +1,49 @@
1
-
2
- import gradio as gr
3
- import numpy as np
4
- from PIL import Image
5
- import torch
6
- from transformers import AutoImageProcessor, AutoModelForImageClassification
7
-
8
- # ----------------------
9
- # Load model + processor
10
- # ----------------------
11
- processor = AutoImageProcessor.from_pretrained("prithivMLmods/Weather-Image-Classification")
12
- model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Weather-Image-Classification")
13
-
14
- # ----------------------
15
- # Inference function
16
- # ----------------------
17
- def classify_weather(image_input):
18
- # Only NumPy array supported for Gradio input
19
- if isinstance(image_input, np.ndarray):
20
- image = Image.fromarray(image_input.astype('uint8')).convert("RGB")
21
- else:
22
- raise TypeError("Only NumPy array input is supported for this Gradio interface.")
23
-
24
- # Preprocess
25
- inputs = processor(images=image, return_tensors="pt")
26
-
27
- # Inference
28
- with torch.no_grad():
29
- outputs = model(**inputs)
30
- logits = outputs.logits
31
- predicted_class_id = logits.argmax(-1).item()
32
- predicted_label = model.config.id2label[predicted_class_id]
33
-
34
- return predicted_label
35
-
36
- # ----------------------
37
- # Gradio interface
38
- # ----------------------
39
- iface = gr.Interface(
40
- fn=classify_weather,
41
- inputs=gr.Image(type="numpy"), # NumPy array input
42
- outputs=gr.Label(num_top_classes=5, label="Weather Condition"),
43
- title="Weather Image Classification",
44
- description="Upload an image to classify the weather condition (sun, rain, snow, fog, or clouds)."
45
- )
46
-
47
- # Launch the app
48
- if __name__ == "__main__":
49
- iface.launch()
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
6
+
7
+ # Load model + processor
8
+ processor = AutoImageProcessor.from_pretrained("prithivMLmods/Weather-Image-Classification")
9
+ model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Weather-Image-Classification")
10
+
11
+ def classify_weather(image_input):
12
+ try:
13
+ if isinstance(image_input, np.ndarray):
14
+ image = Image.fromarray(image_input.astype("uint8")).convert("RGB")
15
+ else:
16
+ raise TypeError("Only NumPy array input is supported.")
17
+
18
+ # preprocess as batch
19
+ inputs = processor(images=[image], return_tensors="pt")
20
+
21
+ # inference
22
+ with torch.no_grad():
23
+ outputs = model(**inputs)
24
+ logits = outputs.logits
25
+ predicted_class_id = logits.argmax(-1).item()
26
+ predicted_label = model.config.id2label[predicted_class_id]
27
+
28
+ # optional: return probabilities for Label(num_top_classes=5)
29
+ probs = torch.softmax(logits, dim=-1).squeeze().tolist()
30
+ labels = [model.config.id2label[i] for i in range(len(probs))]
31
+ output_dict = dict(zip(labels, probs))
32
+
33
+ return output_dict
34
+
35
+ except Exception as e:
36
+ return {"Error": str(e)}
37
+
38
+
39
+ # Gradio interface
40
+ iface = gr.Interface(
41
+ fn=classify_weather,
42
+ inputs=gr.Image(type="numpy"),
43
+ outputs=gr.Label(num_top_classes=5, label="Weather Condition"),
44
+ title="Weather Image Classification",
45
+ description="Upload an image to classify the weather condition (sun, rain, snow, fog, or clouds)."
46
+ )
47
+
48
+ if __name__ == "__main__":
49
+ iface.launch(show_error=True)