Update inference_fine_tune.py
Browse files- inference_fine_tune.py +3 -1
inference_fine_tune.py
CHANGED
|
@@ -47,6 +47,7 @@ def generate_response(prompt:str):
|
|
| 47 |
temperature = 0.7
|
| 48 |
top_k = 50
|
| 49 |
i = 0
|
|
|
|
| 50 |
while decoder_input.shape[1] < 2000:
|
| 51 |
# Apply causal mask based on current decoder_input length
|
| 52 |
# decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
|
|
@@ -59,11 +60,12 @@ def generate_response(prompt:str):
|
|
| 59 |
next_token = torch.multinomial(probs, num_samples=1)
|
| 60 |
next_token = top_k_indices.gather(-1, next_token)
|
| 61 |
word += tokenizer.decode([next_token.item()])
|
|
|
|
| 62 |
i+=1
|
| 63 |
decoder_input = torch.cat([decoder_input, next_token], dim=1)
|
| 64 |
if decoder_input.shape[1] > config['seq_len']:
|
| 65 |
decoder_input = decoder_input[:,-config['seq_len']:]
|
| 66 |
if next_token.item() == eos_token_id or i >= 1024:
|
| 67 |
break
|
| 68 |
-
print(
|
| 69 |
return word
|
|
|
|
| 47 |
temperature = 0.7
|
| 48 |
top_k = 50
|
| 49 |
i = 0
|
| 50 |
+
print("Output : ",end="")
|
| 51 |
while decoder_input.shape[1] < 2000:
|
| 52 |
# Apply causal mask based on current decoder_input length
|
| 53 |
# decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
|
|
|
|
| 60 |
next_token = torch.multinomial(probs, num_samples=1)
|
| 61 |
next_token = top_k_indices.gather(-1, next_token)
|
| 62 |
word += tokenizer.decode([next_token.item()])
|
| 63 |
+
print(word,end="")
|
| 64 |
i+=1
|
| 65 |
decoder_input = torch.cat([decoder_input, next_token], dim=1)
|
| 66 |
if decoder_input.shape[1] > config['seq_len']:
|
| 67 |
decoder_input = decoder_input[:,-config['seq_len']:]
|
| 68 |
if next_token.item() == eos_token_id or i >= 1024:
|
| 69 |
break
|
| 70 |
+
print()
|
| 71 |
return word
|