khalednabawi11 commited on
Commit
a1a3a62
·
verified ·
1 Parent(s): ee43c87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -8
app.py CHANGED
@@ -13,17 +13,46 @@ model.eval()
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  model.to(device)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def generate_caption(image):
17
- # Preprocess
18
- inputs = processor(image, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Generate caption
21
- with torch.no_grad():
22
- output = model.generate(**inputs, max_new_tokens=50)
23
 
24
- # Decode
25
- caption = processor.decode(output[0], skip_special_tokens=True)
26
- return caption
27
 
28
  # Gradio UI
29
  demo = gr.Interface(
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  model.to(device)
15
 
16
+ # def generate_caption(image):
17
+ # # Preprocess
18
+ # inputs = processor(image, return_tensors="pt").to(device)
19
+
20
+ # # Generate caption
21
+ # with torch.no_grad():
22
+ # output = model.generate(**inputs, max_new_tokens=50)
23
+
24
+ # # Decode
25
+ # caption = processor.decode(output[0], skip_special_tokens=True)
26
+ # return caption
27
+
28
  def generate_caption(image):
29
+ prompt = "Radiology report:"
30
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
31
+
32
+ output = model.generate(
33
+ **inputs,
34
+ max_length=256,
35
+ num_beams=5,
36
+ repetition_penalty=1.2,
37
+ length_penalty=1.0,
38
+ early_stopping=True
39
+ )
40
+
41
+ caption = processor.batch_decode(output, skip_special_tokens=True)[0]
42
+
43
+ # Optional: format as a full report
44
+ report = f"""
45
+ **Radiology Report**
46
+
47
+ **Findings**: {caption}
48
+
49
+ **Impression**: {caption}
50
+
51
+ **Recommendation**: Clinical correlation advised. Consider MRI if necessary.
52
+ """
53
+ return report.strip()
54
 
 
 
 
55
 
 
 
 
56
 
57
  # Gradio UI
58
  demo = gr.Interface(