File size: 46,346 Bytes
b347ca3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a89ac91
b347ca3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
import os
import subprocess
import math
import difflib
import tempfile
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from sacremoses import MosesDetokenizer
from flask import Flask, request, jsonify, render_template
import traceback # Import traceback for better error logging
import re # Import the regular expression module

# --- Constants and Paths ---
# Ensure these files are in the same directory as app.py or provide correct paths
FINETUNED_MODEL_PATH = "hoc_best.pt"
BPE_CODES_PATH = "bpecodes"
DICT_TXT_PATH = "dict.txt"
FASTBPE_BIN_PATH = "./fastbpe_exec" # Assumes fast executable is alongside app.py

HALLMARKS = [ # Keep this consistent with training/evaluation
    "activating invasion and metastasis", "avoiding immune destruction",
    "cellular energetics", "enabling replicative immortality",
    "evading growth suppressors", "genomic instability and mutation",
    "inducing angiogenesis", "resisting cell death",
    "sustaining proliferative signaling", "tumor promoting inflammation",
]

# --- Model Architecture Definitions (Copy from your notebook) ---
# NOTE: Make sure these classes are IDENTICAL to the ones used for training
#       including GPTConfig, LayerNorm, CausalSelfAttention, MLP, Block, GPT, GPTWithSoftPrompt

@dataclass
class GPTConfig:
    block_size: int
    vocab_size: int
    n_layer: int
    n_head: int
    n_embd: int
    dropout: float = 0.0
    bias: bool = True

class LayerNorm(nn.Module):
    # (Copied from notebook)
    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
    def forward(self, x):
        return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)

class CausalSelfAttention(nn.Module):
    # (Copied from notebook)
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.flash = hasattr(F, 'scaled_dot_product_attention') # Check for flash attention
        if not self.flash:
            # print("Warning: Flash Attention not available.") # Optional warning
            # Make the buffer persistent otherwise device mismatches during forward pass
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size), persistent=True)
        #else:
            # print("Using Flash Attention.") # Optional info

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            # Ensure bias buffer is used correctly
            # Check if bias buffer exists before using it
            if hasattr(self, 'bias'):
                att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            else:
                 # Fallback if somehow bias wasn't registered (shouldn't happen with persistent=True)
                 mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
                 att = att.masked_fill(mask == 0, float('-inf'))

            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)

        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y


class MLP(nn.Module):
     # (Copied from notebook)
    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
     # (Copied from notebook)
    def __init__(self, config):
        super().__init__()
        self.ln1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class GPT(nn.Module):
     # (Copied from notebook - simplified _init_weights and removed generate)
    def __init__(self, config):
        super().__init__()
        #assert config.vocab_size is not None
        #assert config.block_size is not None
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        if t > self.config.block_size:
             # Crop sequence if longer than block size
             print(f"Warning: Input sequence length ({t}) > block size ({self.config.block_size}). Cropping.")
             idx = idx[:, -self.config.block_size:]
             t = self.config.block_size
        #assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
            return logits, loss
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            # Check for NaN/Inf in logits before returning
            if torch.isnan(logits).any() or torch.isinf(logits).any():
                print("WARNING: NaN or Inf detected in logits during inference.")
                # Handle appropriately - maybe return an error indicator or zero logits?
                # For now, just print warning.
            return logits, None

class GPTWithSoftPrompt(nn.Module):
     # (Copied from notebook - simplified)
    def __init__(self, base_gpt: GPT, prompt_len=1):
        super().__init__()
        self.config = base_gpt.config
        self.transformer = base_gpt.transformer
        self.lm_head = base_gpt.lm_head
        C = self.config.n_embd
        self.soft_prompt = nn.Parameter(torch.zeros(1, prompt_len, C)) # Keep on CPU first
        nn.init.normal_(self.soft_prompt, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        device = idx.device # Get device from input tensor

        # Make sure soft_prompt is on the same device as input
        soft_prompt_on_device = self.soft_prompt.to(device)

        # token + pos
        tok_emb = self.transformer.wte(idx) # (B,T,C)
        pos = torch.arange(0, T, dtype=torch.long, device=device)
        pos_emb = self.transformer.wpe(pos) # (T,C)
        x_tokens = tok_emb + pos_emb

        # prepend soft prompt
        soft = soft_prompt_on_device.expand(B, -1, -1) # (B,P,C)

        # --- FIX: Define P before the if/else block ---
        P = soft.size(1) # Get soft prompt length

        x = torch.cat([soft, x_tokens], dim=1) # (B,P+T,C)

        # --- Standard Transformer forward pass ---
        x = self.transformer.drop(x)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x) # (B,P+T,V)
        # --- End Standard ---

        if targets is None:
            # Inference: return logits for the last token of the *original* sequence
            # We need the prediction *after* the last input token, which is at index T (P+T-1 overall)
            # Use P which is now defined
            # Ensure index is within bounds
            target_logit_index = P + T - 1
            if target_logit_index >= logits.size(1):
                 print(f"Warning: Calculated logit index {target_logit_index} out of bounds for logits shape {logits.shape}. Returning last logit.")
                 target_logit_index = -1 # Fallback to last logit

            final_logits = logits[:, target_logit_index, :]
            # Check for NaN/Inf
            if torch.isnan(final_logits).any() or torch.isinf(final_logits).any():
                 print(f"WARNING: NaN or Inf detected in final_logits at index {target_logit_index}.")
                 # Handle appropriately - maybe return zeros or raise an error?
                 # For now, just print warning. Let the calling function handle it.

            return final_logits, None # Return (B, V)
        else:
            # Training loss calculation (copied from notebook)
            # P is already defined above
            pad_ignore = torch.full((B, P), -1, dtype=targets.dtype, device=device)
            full_targets = torch.cat([pad_ignore, targets], dim=1)
            logits_lm  = logits[:, :-1, :].contiguous()
            targets_lm = full_targets[:, 1:].contiguous()
            loss = F.cross_entropy(
                logits_lm.view(-1, logits_lm.size(-1)),
                targets_lm.view(-1),
                ignore_index=-1
            )
            # Check for NaN/Inf in loss
            if torch.isnan(loss) or torch.isinf(loss):
                print("WARNING: NaN or Inf detected in loss calculation.")
                # Potentially add debugging info here (e.g., print shapes, inputs)
            return logits, loss

    # --- Constrained generation method (from Section 2.9) ---
    @torch.no_grad()
    def generate_labels(self, idx, allowed_mask, max_new_tokens=24, temperature=0.0):
        self.eval() # Ensure model is in eval mode
        B = idx.size(0)
        # Add soft prompt length to effective block size consideration
        P = self.soft_prompt.size(1)
        # Correct effective block size based on GPT class logic
        effective_block_size = self.config.block_size # GPT forward handles cropping

        # Start with input index
        out = idx.clone() # Clone to avoid modifying original input

        # Ensure allowed_mask is on the correct device
        allowed_mask = allowed_mask.to(idx.device)

        finished = torch.zeros(B, dtype=torch.bool, device=idx.device)

        # Get global eos_id safely
        global eos_id
        current_eos_id = eos_id # Use the globally loaded eos_id

        for step in range(max_new_tokens):
            # Crop context if it exceeds model's block size (GPT forward handles this internally now)
            # ctx = out if out.size(1) <= effective_block_size else out[:, -effective_block_size:]
            ctx = out # Pass the current sequence

            # Forward pass - expects shape (B, T), model handles soft prompt internally
            # It returns logits for the *next* token prediction after the last token in ctx
            logits, _ = self(ctx) # Gets logits for last token prediction, shape (B, V)

            # Check for NaN/Inf in logits
            if torch.isnan(logits).any() or torch.isinf(logits).any():
                print(f"WARNING: NaN or Inf detected in logits during generation step {step}. Stopping generation.")
                # Return what we have so far, excluding potentially bad last token
                return out[:, idx.size(1):] # Or handle error differently

            # Apply constraint mask
            # Ensure mask shape matches logits shape
            if logits.shape != allowed_mask.shape:
                 print(f"Warning: Logits shape {logits.shape} doesn't match mask shape {allowed_mask.shape}. Reshaping mask.")
                 # This assumes mask needs batch dim added
                 current_mask = allowed_mask.unsqueeze(0).expand_as(logits)
            else:
                 current_mask = allowed_mask

            logits = logits + current_mask

            # Sample next token
            if temperature <= 0:
                # Greedy decoding
                next_id = torch.argmax(logits, dim=-1)  # (B,)
            else:
                # Temperature sampling
                probs = F.softmax(logits / temperature, dim=-1)
                # Check for NaN/Inf in probs
                if torch.isnan(probs).any() or torch.isinf(probs).any():
                     print(f"WARNING: NaN or Inf detected in probabilities during generation step {step}. Using argmax fallback.")
                     next_id = torch.argmax(logits, dim=-1) # Fallback to greedy
                else:
                    try:
                         next_id = torch.multinomial(probs, num_samples=1).squeeze(1) # (B,)
                    except RuntimeError as e:
                         print(f"WARNING: torch.multinomial failed: {e}. Using argmax fallback.")
                         next_id = torch.argmax(logits, dim=-1) # Fallback to greedy


            # Handle finished sequences (force EOS) and update output
            # Check if current_eos_id is valid
            if not isinstance(current_eos_id, int):
                print(f"Warning: Global eos_id is not an integer ({current_eos_id}). Defaulting to 0.")
                current_eos_id = 0
            next_id = next_id.masked_fill(finished, current_eos_id) # Use the validated eos_id

            # Check if next_id contains invalid values (e.g., negative)
            if (next_id < 0).any():
                 print(f"WARNING: Negative token ID generated: {next_id}. Clipping to 0.")
                 next_id = torch.clamp(next_id, min=0)


            # Append the next token ID
            out = torch.cat([out, next_id.unsqueeze(1)], dim=1)

            # Update finished status
            finished |= (next_id == current_eos_id)

            # Stop if all sequences in the batch are finished
            if bool(finished.all()):
                # print(f"Generation finished early at step {step+1}") # Optional debug info
                break
        # else:
            # print(f"Generation reached max_new_tokens ({max_new_tokens})") # Optional debug info

        # Return only the generated part (excluding the initial idx length)
        return out[:, idx.size(1):]


# --- Tokenizer Helper Functions ---
# Added robustness and error checks
# Global tokenizer maps and special IDs, loaded once at startup
token2id, id2token = {}, {}
eos_id = 0 # Default, will be overwritten
pad_id = 0 # Default, will be overwritten
detokenizer = None

def load_tokenizer_data(dict_path):
    global token2id, id2token, eos_id, pad_id, detokenizer
    print(f"Loading vocabulary from {dict_path}...")
    local_token2id, local_id2token = {}, {}
    try:
        with open(dict_path, encoding="utf-8") as f:
            for i, line in enumerate(f):
                parts = line.split() # Split by whitespace
                if not parts: continue # Skip empty lines
                tok = parts[0]
                if tok in local_token2id:
                     print(f"Warning: Duplicate token '{tok}' found at line {i+1}. Keeping first occurrence.")
                     continue
                local_token2id[tok] = i
                local_id2token[i] = tok

        # Assign to global variables only after successful loading
        token2id = local_token2id
        id2token = local_id2token

        # Use a known special token ID if </s> is missing, otherwise default might be wrong
        # Try multiple common EOS tokens
        possible_eos = ["</s>", "<|endoftext|>", "[EOS]"]
        found_eos = False
        for eos_tok in possible_eos:
            if eos_tok in token2id:
                eos_id = token2id[eos_tok]
                found_eos = True
                print(f"Found EOS token '{eos_tok}' with ID: {eos_id}")
                break
        if not found_eos:
             # If no common EOS found, fall back to the highest index or 0
             eos_id = max(token2id.values()) if token2id else 0
             print(f"Warning: Standard EOS tokens not found. Using highest index ({eos_id}) as EOS ID.")

        # Assign pad_id, often same as eos_id or a specific <pad> token
        pad_id = token2id.get("<pad>", eos_id) # Prefer <pad> if exists, else use eos_id
        print(f"Using PAD ID: {pad_id}")

        detokenizer = MosesDetokenizer(lang='en') # Initialize once
        print(f"Vocabulary loaded. Size: {len(token2id)}")
        if not detokenizer:
             raise ValueError("MosesDetokenizer failed to initialize.")

    except FileNotFoundError:
        print(f"ERROR: Vocabulary file not found at {dict_path}")
        raise
    except Exception as e:
        print(f"ERROR: Failed to load tokenizer data from {dict_path}: {e}")
        raise

def bpe_encode_lines(lines, shard_size=500, desc="BPE Encode"):
    """ Encodes lines using external fastBPE binary. Added error checking. """
    global BPE_CODES_PATH, FASTBPE_BIN_PATH
    # --- Input Validation ---
    if not isinstance(lines, list):
         print(f"Warning: bpe_encode_lines expected a list, got {type(lines)}. Attempting conversion.")
         try:
              lines = list(lines)
         except TypeError:
              raise ValueError("Input 'lines' must be a list or convertible to a list.")

    if not lines: return []

    # --- Path and Executable Checks ---
    abs_fastbpe_path = os.path.abspath(FASTBPE_BIN_PATH)
    abs_bpe_codes_path = os.path.abspath(BPE_CODES_PATH)

    if not os.path.exists(abs_fastbpe_path):
        raise FileNotFoundError(f"fastBPE executable not found at {abs_fastbpe_path}")
    if not os.path.exists(abs_bpe_codes_path):
        raise FileNotFoundError(f"BPE codes file not found at {abs_bpe_codes_path}")
    if not os.access(abs_fastbpe_path, os.X_OK):
        print(f"Warning: fastBPE binary at {abs_fastbpe_path} is not executable. Attempting chmod...")
        try:
            os.chmod(abs_fastbpe_path, 0o755)
        except OSError as e:
            raise PermissionError(f"Failed to make fastBPE executable: {e}. Please check permissions.")

    out_tokens = []
    # Process in chunks
    with tempfile.TemporaryDirectory() as td:
        for start in range(0, len(lines), shard_size):
            chunk = lines[start:start+shard_size]
            src_path = os.path.join(td, f"src_{start}.txt")
            dst_path = os.path.join(td, f"dst_{start}.bpe")

            try:
                # Write chunk to temp file, ensuring strings
                with open(src_path, "w", encoding="utf-8") as f:
                    for s in chunk:
                        f.write(str(s or "").strip() + "\n") # Ensure string conversion

                # Run fastBPE
                cmd = [abs_fastbpe_path, "applybpe", dst_path, src_path, abs_bpe_codes_path]
                # print(f"Running command: {' '.join(cmd)}") # Debug command
                process = subprocess.run(
                    cmd,
                    capture_output=True, text=True, check=False # Don't check=True here, handle error below
                )

                # Check for errors specifically
                if process.returncode != 0:
                     # Log more details on failure
                     print(f"ERROR: fastBPE failed (exit code {process.returncode}) on chunk starting at index {start}.")
                     print(f"Command: {' '.join(cmd)}")
                     print(f"Stderr:\n{process.stderr}")
                     # Optionally print some input data
                     print(f"First line of input chunk: {chunk[0] if chunk else 'N/A'}")
                     raise subprocess.CalledProcessError(process.returncode, cmd, output=process.stdout, stderr=process.stderr)

                # Read results if successful
                with open(dst_path, "r", encoding="utf-8") as f:
                    for line in f:
                        out_tokens.append(line.strip().split())

            except subprocess.CalledProcessError as e:
                # Handle specific subprocess errors (already printed details)
                raise # Re-raise to stop execution
            except Exception as e:
                print(f"ERROR: Unexpected error during BPE encoding chunk starting at index {start}: {e}")
                traceback.print_exc() # Print full traceback for unexpected errors
                raise # Re-raise
    return out_tokens


def tokens_to_ids(bpe_tokens):
    """ Converts BPE token strings to IDs using the global map. Added checks. """
    global token2id, pad_id
    if not isinstance(bpe_tokens, list):
         raise ValueError(f"Input 'bpe_tokens' must be a list, got {type(bpe_tokens)}.")

    ids = []
    oov_count = 0
    for t in bpe_tokens:
        if not isinstance(t, str):
             print(f"Warning: Non-string token found in bpe_tokens: {t}. Using pad_id.")
             ids.append(pad_id)
             oov_count += 1
             continue

        id_val = token2id.get(t, pad_id)
        ids.append(id_val)
        if id_val == pad_id and t not in token2id:
            oov_count += 1
            # print(f"Warning: OOV token '{t}' mapped to pad_id {pad_id}") # Reduce noise
    if oov_count > 0:
         print(f"Info: Found {oov_count} OOV tokens in sequence of length {len(bpe_tokens)}.")
    return ids, oov_count

def ids_to_tokens(ids):
    """ Converts IDs back to token strings. Added checks. """
    global id2token
    if not isinstance(ids, list):
        raise ValueError(f"Input 'ids' must be a list, got {type(ids)}.")

    tokens = []
    for i in ids:
        # Ensure ID is a valid integer before lookup
        try:
             # Handle potential floats or NaNs from generation
             if isinstance(i, float) and math.isnan(i):
                 token = "<nan>"
             else:
                 int_i = int(i)
                 token = id2token.get(int_i, "<unk>")
        except (ValueError, TypeError):
             print(f"Warning: Could not convert ID '{i}' to int. Using '<unk>'.")
             token = "<unk>"
        tokens.append(token)
    return tokens


def bpe_decode_tokens(bpe_tokens):
    """ Converts BPE token strings back to readable text. Added checks. """
    global detokenizer
    if detokenizer is None:
        raise RuntimeError("Detokenizer not initialized. Call load_tokenizer_data first.")
    if not isinstance(bpe_tokens, list):
         raise ValueError(f"Input 'bpe_tokens' must be a list, got {type(bpe_tokens)}.")

    # Ensure all items are strings before joining
    try:
        str_tokens = [str(t) for t in bpe_tokens]
    except Exception as e:
        print(f"Error converting tokens to strings: {e}. Tokens: {bpe_tokens}")
        return "<decoding error>"

    s = ' '.join(str_tokens).replace('@@ ', '')
    try:
        # Detokenizer might fail on empty or unusual input
        return detokenizer.detokenize(s.split()) if s.strip() else ""
    except Exception as e:
        print(f"Error during detokenization: {e}. Input string: '{s}'")
        return "<detokenization error>"


# --- Prediction Helper Functions ---

def to_canonical(pred_chunk: str):
    """ Maps a predicted text chunk to a canonical hallmark name. Added checks. """
    global HALLMARKS
    # Ensure input is a string
    if not isinstance(pred_chunk, str):
         # print(f"Warning: to_canonical received non-string input: {pred_chunk}. Returning None.")
         return None

    s = pred_chunk.strip().lower()
    low = [L.lower() for L in HALLMARKS]
    if not s: return None

    if s in low:
        return HALLMARKS[low.index(s)]

    # Use difflib for fuzzy matching
    try:
        best = difflib.get_close_matches(s, low, n=1, cutoff=0.7)
        return HALLMARKS[low.index(best[0])] if best else None
    except Exception as e:
        print(f"Error during difflib matching for '{s}': {e}")
        return None # Return None on error

def build_allowed_token_mask(vocab_size, device):
    """ Builds the mask for constrained decoding. Added error checks. """
    global HALLMARKS, token2id, eos_id, pad_id
    allowed = set()

    # --- Input Validation ---
    if vocab_size <= 0:
        raise ValueError("Vocabulary size must be positive.")
    if not token2id:
        raise RuntimeError("Tokenizer vocabulary (token2id) not loaded.")

    print("Encoding hallmarks for mask...")
    try:
        # Ensure HALLMARKS is a list of strings
        if not isinstance(HALLMARKS, list) or not all(isinstance(h, str) for h in HALLMARKS):
             raise ValueError("HALLMARKS must be a list of strings.")
        hallmark_bpes = bpe_encode_lines(HALLMARKS, desc="BPE Hallmarks (for mask)")
        for bpe_list in hallmark_bpes:
            ids, _ = tokens_to_ids(bpe_list)
            allowed.update(ids)
        print(f"Encoded {len(HALLMARKS)} hallmarks.")
    except Exception as e:
        print(f"ERROR: Failed to BPE encode or convert hallmarks for mask: {e}")
        raise

    print("Encoding separators for mask...")
    SEPS = [", ", ",", "; ", ";", "|"]
    try:
        sep_bpes = bpe_encode_lines(SEPS, desc="BPE Separators (for mask)")
        for bpe_list in sep_bpes:
             ids, _ = tokens_to_ids(bpe_list)
             allowed.update(ids)
        print(f"Encoded {len(SEPS)} separators.")
    except Exception as e:
         print(f"ERROR: Failed to BPE encode or convert separators for mask: {e}")
         raise

    # Add EOS token - Check if eos_id is valid
    if not isinstance(eos_id, int) or eos_id < 0 or eos_id >= vocab_size:
         print(f"Warning: Invalid EOS ID ({eos_id}). Defaulting mask EOS to 0.")
         effective_eos_id = 0
    else:
         effective_eos_id = eos_id
    allowed.add(effective_eos_id)
    print(f"Total allowed token IDs (including EOS {effective_eos_id}): {len(allowed)}")

    # Create the mask tensor on CPU first
    mask = torch.full((vocab_size,), float('-inf'), device=torch.device('cpu'))
    try:
        # Filter out potential invalid IDs before creating list for indexing
        # Ensure pad_id is valid if used for filtering
        effective_pad_id = pad_id if isinstance(pad_id, int) and 0 <= pad_id < vocab_size else -1 # Use -1 if pad_id is invalid

        valid_allowed_ids = []
        for id_ in allowed:
            if isinstance(id_, int) and 0 <= id_ < vocab_size: # Check type and range
                 # Filter out pad_id unless it's the same as the effective_eos_id
                 if id_ != effective_pad_id or id_ == effective_eos_id:
                      valid_allowed_ids.append(id_)
            # else: print(f"Warning: Invalid ID {id_} in allowed set skipped.") # Reduce noise

        if not valid_allowed_ids:
            raise ValueError("No valid token IDs found to allow in the mask.")

        # Check ranges again after filtering (belt and braces)
        max_valid_id = max(valid_allowed_ids)
        min_valid_id = min(valid_allowed_ids)
        if max_valid_id >= vocab_size or min_valid_id < 0:
             # This should ideally not happen if filtering worked
             raise IndexError(f"Filtered allowed IDs still out of range [{min_valid_id}, {max_valid_id}] for vocab size {vocab_size}.")

        # Apply mask
        mask[valid_allowed_ids] = 0.0 # Use list directly
        print(f"Mask created with {len(valid_allowed_ids)} allowed indices.")

    except IndexError as e:
        print(f"ERROR: Index error while creating mask. Vocab size: {vocab_size}. Error: {e}")
        # Find problematic IDs more carefully
        problem_ids = [i for i in allowed if not isinstance(i, int) or i < 0 or i >= vocab_size]
        print(f"Problematic IDs in allowed set: {problem_ids}")
        raise
    except Exception as e:
        print(f"ERROR: Unexpected error creating mask: {e}")
        traceback.print_exc()
        raise

    # Move final mask to target device
    try:
        target_device = torch.device(device) # Ensure device is a torch.device object
        return mask.to(target_device)
    except Exception as e:
        print(f"Error moving mask to device '{device}': {e}")
        raise


# --- Global Variables for Loaded Model and Assets ---
inference_model = None
ALLOWED_MASK = None
model_device = "cpu"
config = None # Added global config

# --- Initialization Function ---
def initialize_model_and_tokenizer():
    global inference_model, ALLOWED_MASK, model_device, token2id, config # Add config

    print("Initializing model...")
    # Determine device
    model_device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {model_device}")

    # Load tokenizer data first (essential for vocab size)
    try:
        load_tokenizer_data(DICT_TXT_PATH)
        if not token2id: # Check if loading actually populated the dict
             raise ValueError("Tokenizer loading failed to populate token2id dictionary.")
    except Exception as e:
        print(f"FATAL: Could not load tokenizer data. Cannot proceed. Error: {e}")
        return False # Indicate failure

    # Define model config (MUST match finetuning config)
    try:
        # Ensure config is globally accessible after definition
        config = GPTConfig(
            vocab_size=len(token2id), # Get vocab size from loaded data
            block_size=128,           # Match training
            n_layer=6,                # Match training
            n_head=6,                 # Match training
            n_embd=384,               # Match training
            dropout=0.1,              # Match training (dropout is off in eval mode)
            bias=True                 # Match training
        )
        print(f"Model Config: {config}")
    except Exception as e:
        print(f"FATAL: Error creating GPTConfig: {e}")
        return False

    # Instantiate base and wrapped model (on CPU initially)
    try:
        base_gpt = GPT(config)
        inference_model = GPTWithSoftPrompt(base_gpt, prompt_len=1)
    except Exception as e:
        print(f"FATAL: Error instantiating model: {e}")
        traceback.print_exc()
        return False

    # Load finetuned weights
    print(f"Loading finetuned weights from: {FINETUNED_MODEL_PATH}")
    if not os.path.exists(FINETUNED_MODEL_PATH):
        print(f"ERROR: Model weights file not found at {FINETUNED_MODEL_PATH}")
        return False

    try:
        # Load state dict onto CPU first
        state_dict = torch.load(FINETUNED_MODEL_PATH, map_location='cpu')

        # Clean state dict keys (handle DDP 'module.' prefix)
        cleaned_state_dict = {}
        for k, v in state_dict.items():
            name = k[7:] if k.startswith('module.') else k
            cleaned_state_dict[name] = v

        # Load into model
        missing_keys, unexpected_keys = inference_model.load_state_dict(cleaned_state_dict, strict=False)
        if missing_keys:
            # Filter out non-persistent buffer keys if necessary (though strict=False should handle this)
            missing_persistent = [k for k in missing_keys if inference_model.get_parameter(k) is not None or inference_model.get_buffer(k) is not None]
            if missing_persistent:
                print("Warning: Missing persistent keys during state dict load:", missing_persistent)
        if unexpected_keys:
            print("Warning: Unexpected keys during state dict load:", unexpected_keys)
        print("Weights loaded successfully.")

    except Exception as e:
        print(f"Error loading state dict from {FINETUNED_MODEL_PATH}: {e}")
        print("Ensure the model architecture matches the saved checkpoint and the file is not corrupted.")
        traceback.print_exc()
        return False

    # Move model to target device and set to eval mode
    try:
        inference_model.to(model_device)
        inference_model.eval()
        print(f"Model moved to device: {model_device} and set to eval mode.")
    except Exception as e:
        print(f"Error moving model to device '{model_device}': {e}")
        traceback.print_exc()
        return False


    # Build the allowed token mask (after model is on device)
    print("Building allowed token mask...")
    try:
        if config.vocab_size <= 0:
             raise ValueError("Vocabulary size must be positive to build mask.")
        # Ensure model_device is valid before passing
        device_obj = torch.device(model_device)
        ALLOWED_MASK = build_allowed_token_mask(config.vocab_size, device_obj)
        print("Allowed token mask created.")
    except Exception as e:
        print(f"ERROR: Failed to build allowed token mask: {e}")
        traceback.print_exc()
        return False

    return True # Indicate success


# --- Inference Function ---
def predict_hallmarks(abstract_text):
    global inference_model, ALLOWED_MASK, model_device, token2id, eos_id

    # --- Pre-computation Checks ---
    if inference_model is None:
         print("Error: Inference model is not loaded.")
         return ["Error: Model not loaded"]
    if ALLOWED_MASK is None:
         print("Error: Allowed mask is not built.")
         return ["Error: Mask not built"]
    if not token2id:
         print("Error: Tokenizer vocabulary not loaded.")
         return ["Error: Tokenizer not loaded"]

    # --- Input Validation ---
    if not isinstance(abstract_text, str):
        print(f"Warning: Received non-string abstract text type: {type(abstract_text)}. Attempting conversion.")
        try:
            abstract_text = str(abstract_text)
        except Exception:
             return ["Error: Invalid input type"]
    if not abstract_text.strip():
        print("Warning: Received empty or whitespace-only abstract text.")
        return [] # Return empty list for empty input


    try:
        # --- 1. Preprocess and Tokenize Input ---
        print("Tokenizing input abstract...")
        cleaned_abstract = " ".join(abstract_text.split())
        if not cleaned_abstract:
             print("Warning: Input abstract contains only whitespace after cleaning.")
             return []

        bpe_tokens_list = bpe_encode_lines([cleaned_abstract])
        if not bpe_tokens_list or not bpe_tokens_list[0]: # Check if list or first element is empty
            print("Warning: BPE encoding resulted in empty tokens.")
            return []
        bpe_tokens = bpe_tokens_list[0]

        input_ids_list, oov = tokens_to_ids(bpe_tokens)
        if oov > 0:
             print(f"Info: Input contained {oov} OOV tokens.")

        # Add EOS token
        input_ids = input_ids_list + [eos_id]

        # Convert to tensor and move to device
        input_tensor = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(model_device)

        # --- 2. Generate Predictions ---
        print("Generating predictions...")
        with torch.no_grad():
            generated_ids_tensor = inference_model.generate_labels(
                input_tensor,
                allowed_mask=ALLOWED_MASK,
                max_new_tokens=30,
                temperature=0.0
            )

        # --- 3. Decode and Post-process ---
        print("Decoding and cleaning predictions...")
        if generated_ids_tensor is None or generated_ids_tensor.numel() == 0:
             print("Warning: Generation resulted in empty tensor.")
             generated_ids = []
        else:
             # Ensure tensor is on CPU before converting to list
             generated_ids = generated_ids_tensor[0].cpu().tolist()

        if not generated_ids:
             print("No tokens generated.")
             return []

        generated_tokens = ids_to_tokens(generated_ids)

        # Remove tokens after EOS if present
        try:
            eos_token_str = id2token.get(eos_id, "</s>") # Get string representation
            if eos_token_str in generated_tokens:
                 eos_index = generated_tokens.index(eos_token_str)
                 generated_tokens = generated_tokens[:eos_index]
        except ValueError:
             pass # EOS not found is okay

        # Decode BPE tokens to string
        generated_text = bpe_decode_tokens(generated_tokens).strip().lower()
        print(f"Raw generated text: '{generated_text}'")

        # Split potential multi-labels and map to canonical
        parts = []
        if generated_text:
             potential_parts = re.split(r'[;,|]\s*', generated_text)
             parts = [p.strip() for p in potential_parts if p.strip()]
             if not parts: # Handle case with no delimiters
                  parts = [generated_text]

        predicted_labels = []
        seen_labels = set()
        for p in parts:
            canonical_label = to_canonical(p)
            if canonical_label and canonical_label not in seen_labels:
                predicted_labels.append(canonical_label)
                seen_labels.add(canonical_label)

        print(f"Final predicted labels: {predicted_labels}")
        return predicted_labels

    # --- Error Handling ---
    except FileNotFoundError as fnf_err:
         print(f"ERROR during prediction (File Not Found - likely BPE related): {fnf_err}")
         traceback.print_exc()
         return ["Error: BPE file processing error"]
    except PermissionError as perm_err:
        print(f"ERROR during prediction (Permission Error - likely fastBPE): {perm_err}")
        traceback.print_exc()
        return ["Error: BPE execution permission"]
    except RuntimeError as run_err:
         if "CUDA out of memory" in str(run_err):
              print(f"ERROR: CUDA Out of Memory during prediction. Input length: {len(input_ids) if 'input_ids' in locals() else 'N/A'}")
              traceback.print_exc()
              return ["Error: Input too long (OOM)"]
         else:
              print(f"ERROR during prediction (PyTorch RuntimeError): {run_err}")
              traceback.print_exc()
              return ["Error: Model runtime error"]
    except Exception as e:
        print(f"ERROR during prediction (General Exception): {e}")
        traceback.print_exc()
        return [f"Error: An unexpected error occurred"]


# --- Flask App ---
app = Flask(__name__)

# --- Load Model on Startup ---
model_initialized = False

@app.before_request
def ensure_model_loaded():
    """ Ensures model is loaded before handling the first request. """
    global model_initialized
    if not model_initialized:
        print("First request received, attempting to initialize model...")
        # Add basic locking if deploying with multiple workers (though not fully thread-safe here)
        # For true multi-worker safety, model loading should happen before workers fork.
        try:
            if initialize_model_and_tokenizer():
                model_initialized = True
                print("Model initialization successful.")
            else:
                print("FATAL: Model initialization failed during first request.")
                # We won't raise an error here, but subsequent requests will fail until fixed.
        except Exception as init_err:
             print(f"FATAL: Exception during model initialization: {init_err}")
             traceback.print_exc()


# --- Routes ---
@app.route('/')
def home():
    """ Renders the HTML frontend page. """
    # Check if initialization failed and show an error page if so?
    # For simplicity, we assume initialization works or subsequent predict calls fail.
    return render_template('index.html')

@app.route('/predict', methods=['POST'])
def predict():
    """ Handles prediction requests from the frontend. """
    global model_initialized
    # Check if model is ready
    if not model_initialized:
         print("Error: Model not initialized when /predict called.")
         # Return a specific status code like Service Unavailable
         return jsonify({'error': 'Model is not ready. Please try again later.'}), 503

    # Validate request format
    if not request.is_json:
        return jsonify({'error': 'Request must be JSON'}), 400

    data = request.get_json()
    abstract = data.get('abstract')

    # Validate input abstract
    if not abstract:
        return jsonify({'error': 'Missing "abstract" field in JSON request'}), 400
    if not isinstance(abstract, str):
         return jsonify({'error': '"abstract" field must be a string'}), 400
    if len(abstract.strip()) == 0:
        print("Received empty abstract, returning empty prediction.")
        return jsonify({'predictions': []})
    MAX_ABSTRACT_LEN = 10000 # Define max length
    if len(abstract) > MAX_ABSTRACT_LEN:
         print(f"Received overly long abstract ({len(abstract)} chars), rejecting.")
         return jsonify({'error': f'Input abstract is too long (max {MAX_ABSTRACT_LEN} chars)'}), 413 # Payload Too Large

    print(f"\n--- Received Prediction Request ---")
    print(f"Input Abstract (first 100 chars): {abstract[:100]}...")

    try:
        # Perform prediction
        predictions = predict_hallmarks(abstract)
        print(f"--- Prediction Complete ---")

        # Check if the result indicates an internal error occurred
        if isinstance(predictions, list) and len(predictions) > 0 and predictions[0].startswith("Error:"):
             print(f"Internal error during prediction: {predictions[0]}")
             # Return a generic server error to the client
             return jsonify({'error': 'An internal error occurred during prediction.'}), 500
        else:
             # Return successful predictions
             return jsonify({'predictions': predictions})

    except Exception as e:
        # Catch unexpected errors in the route handler itself
        print(f"--- Prediction Failed Unexpectedly in Route ---")
        print(f"Error: {e}")
        traceback.print_exc()
        return jsonify({'error': 'An internal server error occurred.'}), 500

# --- Run the App ---
if __name__ == '__main__':
    # Initialize model eagerly when running script directly
    if not model_initialized:
        print("Running script directly, initializing model eagerly...")
        if initialize_model_and_tokenizer():
            model_initialized = True
            print("Model initialization successful.")
        else:
            print("FATAL: Model initialization failed. Cannot start Flask server.")
            exit(1) # Exit if model fails to load on startup

    # Check fastBPE path validity before starting server
    abs_fastbpe_path = os.path.abspath(FASTBPE_BIN_PATH)
    if not os.path.exists(abs_fastbpe_path):
         print(f"ERROR: fastBPE binary not found at '{abs_fastbpe_path}'.")
         print("Please ensure fastBPE is compiled and the path is correct relative to app.py.")
         exit(1)
    if not os.access(abs_fastbpe_path, os.X_OK):
        print(f"ERROR: fastBPE binary at '{abs_fastbpe_path}' is not executable.")
        print("Attempting to make it executable with 'chmod +x'...")
        try:
            os.chmod(abs_fastbpe_path, 0o755)
            print(f"Successfully made '{abs_fastbpe_path}' executable.")
        except OSError as e:
            print(f"ERROR: Failed to make fastBPE executable: {e}")
            print("Please set execute permissions manually (e.g., 'chmod +x ./fast').")
            exit(1)

    print("Starting Flask server...")
    # Use host='0.0.0.0' to make it accessible on your network
    # Set debug=False for production environments
    app.run(host='0.0.0.0', port=5000, debug=False) # Changed debug to False