youssef
commited on
Commit
·
d200533
1
Parent(s):
bd727fa
remove flash attn
Browse files
src/video_processor/processor.py
CHANGED
|
@@ -34,7 +34,7 @@ class VideoAnalyzer:
|
|
| 34 |
self.model = AutoModelForImageTextToText.from_pretrained(
|
| 35 |
self.model_path,
|
| 36 |
torch_dtype=torch.bfloat16,
|
| 37 |
-
_attn_implementation="flash_attention_2"
|
| 38 |
).to(DEVICE)
|
| 39 |
logger.info(f"Model loaded on device: {self.model.device} using attention implementation: flash_attention_2")
|
| 40 |
|
|
@@ -70,6 +70,11 @@ class VideoAnalyzer:
|
|
| 70 |
return_tensors="pt"
|
| 71 |
).to(self.model.device)
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
# Generate description with increased token limit
|
| 74 |
generated_ids = self.model.generate(
|
| 75 |
**inputs,
|
|
|
|
| 34 |
self.model = AutoModelForImageTextToText.from_pretrained(
|
| 35 |
self.model_path,
|
| 36 |
torch_dtype=torch.bfloat16,
|
| 37 |
+
# _attn_implementation="flash_attention_2"
|
| 38 |
).to(DEVICE)
|
| 39 |
logger.info(f"Model loaded on device: {self.model.device} using attention implementation: flash_attention_2")
|
| 40 |
|
|
|
|
| 70 |
return_tensors="pt"
|
| 71 |
).to(self.model.device)
|
| 72 |
|
| 73 |
+
# Convert inputs to bfloat16 before moving to GPU
|
| 74 |
+
#for key in inputs:
|
| 75 |
+
# if torch.is_tensor(inputs[key]):
|
| 76 |
+
# inputs[key] = inputs[key].to(dtype=torch.bfloat16, device=self.model.device)
|
| 77 |
+
|
| 78 |
# Generate description with increased token limit
|
| 79 |
generated_ids = self.model.generate(
|
| 80 |
**inputs,
|