BigSalmon commited on
Commit
557c130
·
1 Parent(s): 2146630

Create new file

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers.activations import get_activation
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from transformers import AutoTokenizer, AutoModel
10
+ from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
11
+ import math
12
+
13
+
14
+ st.title('GPT2:')
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ @st.cache(allow_output_mutation=True)
19
+ def get_model():
20
+ tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln85Paraphrase")
21
+ model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln85Paraphrase")
22
+ return model, tokenizer
23
+
24
+ model, tokenizer = get_model()
25
+
26
+ g = """***
27
+
28
+ original: sports teams are profitable for owners. [MASK], their valuations experience a dramatic uptick.
29
+ infill: sports teams are profitable for owners. ( accumulating vast sums / stockpiling treasure / realizing benefits / cashing in / registering robust financials / scoring on balance sheets ), their valuations experience a dramatic uptick.
30
+
31
+ ***
32
+
33
+ original:"""
34
+
35
+ def prefix_format(sentence):
36
+ words = sentence.split()
37
+ if "[MASK]" in sentence:
38
+ words2 = words.index("[MASK]")
39
+ #print(words2)
40
+ output = ("<|SUF|> " + ' '.join(words[words2+1:]) + " <|PRE|> " + ' '.join(words[:words2]) + " <|MID|>")
41
+ st.write(output)
42
+ else:
43
+ st.write("Add [MASK] to sentence")
44
+
45
+ with st.form(key='my_form'):
46
+ prompt = st.text_area(label='Enter sentence', value=g)
47
+ submit_button = st.form_submit_button(label='Submit')
48
+ if submit_button:
49
+ with torch.no_grad():
50
+ outputs = model(sequence, labels=input_ids)
51
+ loss, logits = outputs[:2]
52
+ perplex = math.exp(loss)
53
+ st.write(perplex)