shunxing1234 commited on
Commit
9de2925
·
verified ·
1 Parent(s): 59f0181

Update modeling_telechat.py

Browse files
Files changed (1) hide show
  1. modeling_telechat.py +21 -6
modeling_telechat.py CHANGED
@@ -43,6 +43,8 @@ except ImportError:
43
  try:
44
  from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func # flashattn2
45
  print("# FLASH ATTENTION 2 DETECTED #")
 
 
46
  except ImportError:
47
  print("# NO FLASH ATTENTION DETECTED #")
48
  flash_attn_unpadded_func = None
@@ -110,10 +112,11 @@ logger = logging.get_logger(__name__)
110
  def exists(v):
111
  return v is not None
112
 
113
-
114
  class RotaryEmbedding(nn.Module):
115
  def __init__(self, dim, use_xpos=False, xpos_scale_base=512, theta=10000):
116
  super().__init__()
 
 
117
  inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
118
  self.register_buffer('inv_freq', inv_freq)
119
  self.cache = dict()
@@ -125,13 +128,25 @@ class RotaryEmbedding(nn.Module):
125
  scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
126
  self.register_buffer('scale', scale)
127
  self.scale_base = xpos_scale_base
 
 
 
 
 
 
 
128
 
129
  def forward(self, seq, cache_key=None):
130
 
131
  if cache_key is not None and cache_key in self.cache:
132
  return self.cache[cache_key]
133
-
134
- inv_freq = self.inv_freq.to(device=seq.device)
 
 
 
 
 
135
  freqs = einsum('i , j -> i j', seq, inv_freq)
136
  # first part even vector components, second part odd vector components,
137
  # 2 * dim in dimension size
@@ -257,7 +272,7 @@ class TELECHATAttention(nn.Module):
257
  self.pruned_heads = set()
258
 
259
  self.use_flash_attn = False
260
- self.is_cross_attention = False
261
 
262
 
263
  def set_max_positions(self, max_positions, device='cuda'):
@@ -1085,8 +1100,8 @@ class TELECHAT(TELECHATPretrainedModel):
1085
  input_ids = tokenizer.encode(inputs,
1086
  return_tensors="pt"
1087
  )
1088
- if len(input_ids[0]) >= 2000:
1089
- input_ids = input_ids[:, -2000:]
1090
  input_ids = input_ids.to(0)
1091
  output = self.generate(input_ids,generation_config)
1092
  response = tokenizer.decode(output[0].cpu().numpy().tolist()).split('<_bot>')[-1].split('</s>')[0]
 
43
  try:
44
  from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func # flashattn2
45
  print("# FLASH ATTENTION 2 DETECTED #")
46
+ r
47
+ r
48
  except ImportError:
49
  print("# NO FLASH ATTENTION DETECTED #")
50
  flash_attn_unpadded_func = None
 
112
  def exists(v):
113
  return v is not None
114
 
 
115
  class RotaryEmbedding(nn.Module):
116
  def __init__(self, dim, use_xpos=False, xpos_scale_base=512, theta=10000):
117
  super().__init__()
118
+ self.theta = theta
119
+ self.dim = dim
120
  inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
121
  self.register_buffer('inv_freq', inv_freq)
122
  self.cache = dict()
 
128
  scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
129
  self.register_buffer('scale', scale)
130
  self.scale_base = xpos_scale_base
131
+ def get_ntk_alpha(self, true_seq_len):
132
+ context_value = math.log(true_seq_len / 4096, 2) + 1
133
+ # ntk_alpha = 2 ** context_value - 1
134
+ ntk_alpha = 2 ** math.ceil(context_value) - 1
135
+ ntk_alpha = max(ntk_alpha, 1)
136
+ return ntk_alpha
137
+
138
 
139
  def forward(self, seq, cache_key=None):
140
 
141
  if cache_key is not None and cache_key in self.cache:
142
  return self.cache[cache_key]
143
+
144
+
145
+ ntk_alpha = self.get_ntk_alpha(int(cache_key.split(":")[1]))
146
+ theta = self.theta * ntk_alpha
147
+ #print("theta",theta)
148
+ inv_freq = 1.0 / (theta ** (torch.arange(0, self.dim, 2).float() / self.dim))
149
+ inv_freq = inv_freq.to(device=seq.device)
150
  freqs = einsum('i , j -> i j', seq, inv_freq)
151
  # first part even vector components, second part odd vector components,
152
  # 2 * dim in dimension size
 
272
  self.pruned_heads = set()
273
 
274
  self.use_flash_attn = False
275
+
276
 
277
 
278
  def set_max_positions(self, max_positions, device='cuda'):
 
1100
  input_ids = tokenizer.encode(inputs,
1101
  return_tensors="pt"
1102
  )
1103
+ #if len(input_ids[0]) >= 2000:
1104
+ # input_ids = input_ids[:, -2000:]
1105
  input_ids = input_ids.to(0)
1106
  output = self.generate(input_ids,generation_config)
1107
  response = tokenizer.decode(output[0].cpu().numpy().tolist()).split('<_bot>')[-1].split('</s>')[0]