Spaces:
Running
Running
Update models/transformer/text_generator.py
Browse files
models/transformer/text_generator.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from transformers import GPT2LMHeadModel
|
| 2 |
from pathlib import Path
|
| 3 |
from .utils import modified_tokenizer
|
| 4 |
-
from .constants import CHECKPOINT_PATH
|
| 5 |
|
| 6 |
|
| 7 |
class TextGenerator:
|
|
@@ -12,7 +12,7 @@ class TextGenerator:
|
|
| 12 |
"""
|
| 13 |
model_path = Path(data_path) / model_name
|
| 14 |
self.tokenizer = modified_tokenizer(model_path, None, data_path)
|
| 15 |
-
self.model = GPT2LMHeadModel.from_pretrained(str(model_path), device_map="auto")
|
| 16 |
self.model.eval()
|
| 17 |
|
| 18 |
def generate_text(self,
|
|
|
|
| 1 |
from transformers import GPT2LMHeadModel
|
| 2 |
from pathlib import Path
|
| 3 |
from .utils import modified_tokenizer
|
| 4 |
+
from .constants import CHECKPOINT_PATH, HF_TOKEN
|
| 5 |
|
| 6 |
|
| 7 |
class TextGenerator:
|
|
|
|
| 12 |
"""
|
| 13 |
model_path = Path(data_path) / model_name
|
| 14 |
self.tokenizer = modified_tokenizer(model_path, None, data_path)
|
| 15 |
+
self.model = GPT2LMHeadModel.from_pretrained(str(model_path), device_map="auto", token=HF_TOKEN)
|
| 16 |
self.model.eval()
|
| 17 |
|
| 18 |
def generate_text(self,
|