File size: 3,554 Bytes
a5eaebe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from enum import Enum
import time
import traceback
import uuid
from werkzeug.exceptions import InternalServerError
from config.logger import CustomLogger,request_id_var
import torch
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
import os
import sys
import nltk
from nltk.tokenize.punkt import PunktSentenceTokenizer

try:
    if getattr(sys, 'frozen', False):
        application_path = sys._MEIPASS
    else:
        application_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..')
    log=CustomLogger()
    log.info(f"application_path : {application_path}")
    log.info("before loading model")
    
    device = "cuda"
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    log.info(f"device : {device}")
    gpu=0 if torch.cuda.is_available() else -1

    # Loading model and tokenizer
    pipeline_kwargs={"return_token_type_ids": False,"max_length": 512,"truncation": True,"batch_size": 1}
    gibberishModel = AutoModelForSequenceClassification.from_pretrained(os.path.join(application_path, "models/gibberish")).to(device)
    gibberishTokenizer = AutoTokenizer.from_pretrained(os.path.join(application_path, "models/gibberish"))
    
    request_id_var.set("Startup")
    log_dict={}
    log.info("model loaded")

except Exception as e:
    log.error(f"Exception: {e}")
    log.error(f"Exception: {str(traceback.extract_tb(e.__traceback__)[0].lineno),e}")



class MatchType(Enum):
    SENTENCE = "sentence"
    FULL = "full"

    def get_inputs(self, prompt: str) -> list[str]:
        if self == MatchType.SENTENCE:
            pk = PunktSentenceTokenizer()
            return pk.sentences_from_text(text=prompt)
        return [prompt]
    

class Gibberish:

    def scan(self,payload):
        log.info("inside gibberish_check")
        id=uuid.uuid4().hex
        request_id_var.set(id)
        log_dict[request_id_var.get()]=[]
        try:
            st = time.time()
            text=payload['text']
            gibberish_labels = payload['labels']
            nlp = pipeline(task="text-classification", model=gibberishModel, tokenizer=gibberishTokenizer, device=device,model_kwargs=pipeline_kwargs)
            match_type = MatchType(MatchType.FULL)
            results_all = nlp(match_type.get_inputs(text))
            log.debug(f"Gibberish detection finished :{results_all}")
            output={}
            res=[]
            for result in results_all:
                score = round(
                    result["score"] if result["label"] in gibberish_labels else 1 - result["score"],
                    2,
                )
                output['gibberish_label'] =  result["label"]
                output['gibberish_score'] = score

                res.append(output)

            del nlp
            er=log_dict[request_id_var.get()]
            logobj = {"_id":id,"error":er}
            if len(er)!=0:
                log.debug(str(logobj))
            del log_dict[id]
            return {"result":res,"time_taken":str(round(time.time()-st,3))+"s"}
        
        except Exception as e:   
            log.error("Error occured in gibberish_check")
            log.error(f"Exception: {str(traceback.extract_tb(e.__traceback__)[0].lineno),e}")
            log_dict[request_id_var.get()].append({"Line number":str(traceback.extract_tb(e.__traceback__)[0].lineno),"Error":str(e),
                                                    "Error Module":"Failed at gibberish_check call"})
            raise InternalServerError()