Update modeling_telechat.py
Browse files- modeling_telechat.py +21 -6
modeling_telechat.py
CHANGED
|
@@ -43,6 +43,8 @@ except ImportError:
|
|
| 43 |
try:
|
| 44 |
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func # flashattn2
|
| 45 |
print("# FLASH ATTENTION 2 DETECTED #")
|
|
|
|
|
|
|
| 46 |
except ImportError:
|
| 47 |
print("# NO FLASH ATTENTION DETECTED #")
|
| 48 |
flash_attn_unpadded_func = None
|
|
@@ -110,10 +112,11 @@ logger = logging.get_logger(__name__)
|
|
| 110 |
def exists(v):
|
| 111 |
return v is not None
|
| 112 |
|
| 113 |
-
|
| 114 |
class RotaryEmbedding(nn.Module):
|
| 115 |
def __init__(self, dim, use_xpos=False, xpos_scale_base=512, theta=10000):
|
| 116 |
super().__init__()
|
|
|
|
|
|
|
| 117 |
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| 118 |
self.register_buffer('inv_freq', inv_freq)
|
| 119 |
self.cache = dict()
|
|
@@ -125,13 +128,25 @@ class RotaryEmbedding(nn.Module):
|
|
| 125 |
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
| 126 |
self.register_buffer('scale', scale)
|
| 127 |
self.scale_base = xpos_scale_base
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
def forward(self, seq, cache_key=None):
|
| 130 |
|
| 131 |
if cache_key is not None and cache_key in self.cache:
|
| 132 |
return self.cache[cache_key]
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
freqs = einsum('i , j -> i j', seq, inv_freq)
|
| 136 |
# first part even vector components, second part odd vector components,
|
| 137 |
# 2 * dim in dimension size
|
|
@@ -257,7 +272,7 @@ class TELECHATAttention(nn.Module):
|
|
| 257 |
self.pruned_heads = set()
|
| 258 |
|
| 259 |
self.use_flash_attn = False
|
| 260 |
-
|
| 261 |
|
| 262 |
|
| 263 |
def set_max_positions(self, max_positions, device='cuda'):
|
|
@@ -1085,8 +1100,8 @@ class TELECHAT(TELECHATPretrainedModel):
|
|
| 1085 |
input_ids = tokenizer.encode(inputs,
|
| 1086 |
return_tensors="pt"
|
| 1087 |
)
|
| 1088 |
-
if len(input_ids[0]) >= 2000:
|
| 1089 |
-
|
| 1090 |
input_ids = input_ids.to(0)
|
| 1091 |
output = self.generate(input_ids,generation_config)
|
| 1092 |
response = tokenizer.decode(output[0].cpu().numpy().tolist()).split('<_bot>')[-1].split('</s>')[0]
|
|
|
|
| 43 |
try:
|
| 44 |
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func # flashattn2
|
| 45 |
print("# FLASH ATTENTION 2 DETECTED #")
|
| 46 |
+
r
|
| 47 |
+
r
|
| 48 |
except ImportError:
|
| 49 |
print("# NO FLASH ATTENTION DETECTED #")
|
| 50 |
flash_attn_unpadded_func = None
|
|
|
|
| 112 |
def exists(v):
|
| 113 |
return v is not None
|
| 114 |
|
|
|
|
| 115 |
class RotaryEmbedding(nn.Module):
|
| 116 |
def __init__(self, dim, use_xpos=False, xpos_scale_base=512, theta=10000):
|
| 117 |
super().__init__()
|
| 118 |
+
self.theta = theta
|
| 119 |
+
self.dim = dim
|
| 120 |
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| 121 |
self.register_buffer('inv_freq', inv_freq)
|
| 122 |
self.cache = dict()
|
|
|
|
| 128 |
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
| 129 |
self.register_buffer('scale', scale)
|
| 130 |
self.scale_base = xpos_scale_base
|
| 131 |
+
def get_ntk_alpha(self, true_seq_len):
|
| 132 |
+
context_value = math.log(true_seq_len / 4096, 2) + 1
|
| 133 |
+
# ntk_alpha = 2 ** context_value - 1
|
| 134 |
+
ntk_alpha = 2 ** math.ceil(context_value) - 1
|
| 135 |
+
ntk_alpha = max(ntk_alpha, 1)
|
| 136 |
+
return ntk_alpha
|
| 137 |
+
|
| 138 |
|
| 139 |
def forward(self, seq, cache_key=None):
|
| 140 |
|
| 141 |
if cache_key is not None and cache_key in self.cache:
|
| 142 |
return self.cache[cache_key]
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
ntk_alpha = self.get_ntk_alpha(int(cache_key.split(":")[1]))
|
| 146 |
+
theta = self.theta * ntk_alpha
|
| 147 |
+
#print("theta",theta)
|
| 148 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
| 149 |
+
inv_freq = inv_freq.to(device=seq.device)
|
| 150 |
freqs = einsum('i , j -> i j', seq, inv_freq)
|
| 151 |
# first part even vector components, second part odd vector components,
|
| 152 |
# 2 * dim in dimension size
|
|
|
|
| 272 |
self.pruned_heads = set()
|
| 273 |
|
| 274 |
self.use_flash_attn = False
|
| 275 |
+
|
| 276 |
|
| 277 |
|
| 278 |
def set_max_positions(self, max_positions, device='cuda'):
|
|
|
|
| 1100 |
input_ids = tokenizer.encode(inputs,
|
| 1101 |
return_tensors="pt"
|
| 1102 |
)
|
| 1103 |
+
#if len(input_ids[0]) >= 2000:
|
| 1104 |
+
# input_ids = input_ids[:, -2000:]
|
| 1105 |
input_ids = input_ids.to(0)
|
| 1106 |
output = self.generate(input_ids,generation_config)
|
| 1107 |
response = tokenizer.decode(output[0].cpu().numpy().tolist()).split('<_bot>')[-1].split('</s>')[0]
|