Spaces:
Paused
Paused
aeb56
commited on
Commit
·
2f60fd7
1
Parent(s):
74fe23d
Fix flash attention error by patching model config to use eager attention
Browse files
app.py
CHANGED
|
@@ -40,6 +40,13 @@ class ChatBot:
|
|
| 40 |
)
|
| 41 |
|
| 42 |
self.model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
self.loaded = True
|
| 44 |
|
| 45 |
# Get GPU distribution info
|
|
@@ -85,7 +92,7 @@ class ChatBot:
|
|
| 85 |
inputs = self.tokenizer(prompt, return_tensors="pt")
|
| 86 |
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
| 87 |
|
| 88 |
-
# Generate
|
| 89 |
with torch.no_grad():
|
| 90 |
outputs = self.model.generate(
|
| 91 |
**inputs,
|
|
@@ -94,6 +101,7 @@ class ChatBot:
|
|
| 94 |
top_p=top_p,
|
| 95 |
do_sample=temperature > 0,
|
| 96 |
pad_token_id=self.tokenizer.eos_token_id,
|
|
|
|
| 97 |
)
|
| 98 |
|
| 99 |
# Decode
|
|
|
|
| 40 |
)
|
| 41 |
|
| 42 |
self.model.eval()
|
| 43 |
+
|
| 44 |
+
# Patch model config to avoid flash attention issues
|
| 45 |
+
if hasattr(self.model.config, '_attn_implementation'):
|
| 46 |
+
self.model.config._attn_implementation = "eager"
|
| 47 |
+
if hasattr(self.model.config, 'attn_implementation'):
|
| 48 |
+
self.model.config.attn_implementation = "eager"
|
| 49 |
+
|
| 50 |
self.loaded = True
|
| 51 |
|
| 52 |
# Get GPU distribution info
|
|
|
|
| 92 |
inputs = self.tokenizer(prompt, return_tensors="pt")
|
| 93 |
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
| 94 |
|
| 95 |
+
# Generate with explicit attention settings
|
| 96 |
with torch.no_grad():
|
| 97 |
outputs = self.model.generate(
|
| 98 |
**inputs,
|
|
|
|
| 101 |
top_p=top_p,
|
| 102 |
do_sample=temperature > 0,
|
| 103 |
pad_token_id=self.tokenizer.eos_token_id,
|
| 104 |
+
use_cache=True, # Enable KV caching
|
| 105 |
)
|
| 106 |
|
| 107 |
# Decode
|