Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -13,12 +13,12 @@ print(f"Loading model: {model_name}") # モデル読み込み開始ログ
|
|
| 13 |
|
| 14 |
# --- Tokenizerの読み込み ---
|
| 15 |
try:
|
| 16 |
-
|
| 17 |
-
|
|
|
|
| 18 |
print("Tokenizers loaded successfully.")
|
| 19 |
except Exception as e:
|
| 20 |
print(f"Error loading tokenizers: {e}")
|
| 21 |
-
# エラーが発生した場合、Gradioインターフェースでエラーを表示するなどの処理を追加できます
|
| 22 |
raise # ここではエラーを再発生させて、起動を停止させます
|
| 23 |
|
| 24 |
# decoder_tokenizerのpad_token設定
|
|
@@ -35,6 +35,8 @@ if decoder_tokenizer.pad_token is None:
|
|
| 35 |
|
| 36 |
# --- モデルの読み込み ---
|
| 37 |
try:
|
|
|
|
|
|
|
| 38 |
model = EncoderDecoderModel.from_pretrained(model_name).to(device)
|
| 39 |
model.eval() # 評価モードに設定
|
| 40 |
print("Model loaded successfully and moved to device.")
|
|
@@ -62,16 +64,18 @@ def generate_docstring(code: str) -> str:
|
|
| 62 |
|
| 63 |
# 生成実行
|
| 64 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
| 65 |
output_ids = model.generate(
|
| 66 |
input_ids=inputs.input_ids,
|
| 67 |
attention_mask=inputs.attention_mask,
|
| 68 |
max_length=256, # 生成するDocstringの最大長
|
| 69 |
num_beams=5, # ビームサーチのビーム数
|
| 70 |
early_stopping=True, # 早く停止させるか
|
| 71 |
-
# decoder_start_token_idは通常model.config
|
| 72 |
-
# decoder_start_token_id=model.config.decoder_start_token_id,
|
| 73 |
eos_token_id=decoder_tokenizer.eos_token_id, # EOSトークンID
|
| 74 |
-
pad_token_id=
|
| 75 |
no_repeat_ngram_size=2 # 繰り返さないN-gramサイズ
|
| 76 |
)
|
| 77 |
|
|
|
|
| 13 |
|
| 14 |
# --- Tokenizerの読み込み ---
|
| 15 |
try:
|
| 16 |
+
# subfolder引数を使用してサブディレクトリを指定
|
| 17 |
+
encoder_tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder="encoder_tokenizer")
|
| 18 |
+
decoder_tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder="decoder_tokenizer")
|
| 19 |
print("Tokenizers loaded successfully.")
|
| 20 |
except Exception as e:
|
| 21 |
print(f"Error loading tokenizers: {e}")
|
|
|
|
| 22 |
raise # ここではエラーを再発生させて、起動を停止させます
|
| 23 |
|
| 24 |
# decoder_tokenizerのpad_token設定
|
|
|
|
| 35 |
|
| 36 |
# --- モデルの読み込み ---
|
| 37 |
try:
|
| 38 |
+
# モデルの読み込みは通常通りリポジトリ名を指定すればOK
|
| 39 |
+
# config.jsonが適切に設定されていれば、エンコーダー/デコーダー部分は自動的に読み込まれる
|
| 40 |
model = EncoderDecoderModel.from_pretrained(model_name).to(device)
|
| 41 |
model.eval() # 評価モードに設定
|
| 42 |
print("Model loaded successfully and moved to device.")
|
|
|
|
| 64 |
|
| 65 |
# 生成実行
|
| 66 |
with torch.no_grad():
|
| 67 |
+
# pad_token_idを明示的に指定 (重要: Noneでないことを確認)
|
| 68 |
+
pad_token_id = decoder_tokenizer.pad_token_id if decoder_tokenizer.pad_token_id is not None else decoder_tokenizer.eos_token_id
|
| 69 |
+
|
| 70 |
output_ids = model.generate(
|
| 71 |
input_ids=inputs.input_ids,
|
| 72 |
attention_mask=inputs.attention_mask,
|
| 73 |
max_length=256, # 生成するDocstringの最大長
|
| 74 |
num_beams=5, # ビームサーチのビーム数
|
| 75 |
early_stopping=True, # 早く停止させるか
|
| 76 |
+
# decoder_start_token_idは通常model.configから自動設定される
|
|
|
|
| 77 |
eos_token_id=decoder_tokenizer.eos_token_id, # EOSトークンID
|
| 78 |
+
pad_token_id=pad_token_id, # PADトークンID (Noneでないことを保証)
|
| 79 |
no_repeat_ngram_size=2 # 繰り返さないN-gramサイズ
|
| 80 |
)
|
| 81 |
|