File size: 7,311 Bytes
00f3ac8
 
 
 
 
 
 
 
 
 
 
 
a9c363a
 
00f3ac8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b135ec8
00f3ac8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
"""
Gradio web app for Shakespeare-style text generation using the trained GPT model.
This app provides an interactive interface for users to generate Shakespeare-style text
with customizable parameters.
"""

import os
import torch
import gradio as gr
from model import GPT, GPTConfig
import tiktoken

torch.set_default_device('cpu')

class ShakespeareTextGenerator:
    def __init__(self, model_path='compressed_model_cpu_compatible.pt'):
        """Initialize the text generator with the trained model"""
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        # Load checkpoint
        checkpoint = torch.load(model_path, map_location=self.device)
        
        # Initialize model with saved config
        self.config = GPTConfig(**checkpoint['config'])
        self.model = GPT(self.config)
        
        # Load state dict and convert to correct dtype if needed
        if checkpoint['dtype'] == 'float16' and self.device == 'cuda':
            self.model.half()
        elif checkpoint['dtype'] == 'float32':
            self.model.float()
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.to(self.device)
        self.model.eval()
        
        # Initialize tokenizer with special token handling
        self.tokenizer = tiktoken.get_encoding('gpt2')
        self.end_token = self.tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]

    def generate(self, 
                prompt, 
                max_length=100,
                temperature=0.7,
                top_k=50,
                top_p=0.9,
                num_return_sequences=1):
        """
        Generate Shakespeare-style text based on the prompt
        """
        # Encode the prompt with special token handling
        input_ids = torch.tensor(
            self.tokenizer.encode(prompt, allowed_special=set())
        ).unsqueeze(0).to(self.device)
        
        generated_sequences = []
        
        with torch.no_grad():
            for _ in range(num_return_sequences):
                # Initialize sequence with input_ids
                cur_ids = input_ids.clone()
                
                for _ in range(max_length):
                    # Get model's logits for next token
                    outputs, _ = self.model(cur_ids)
                    next_token_logits = outputs[:, -1, :] / temperature
                    
                    # Apply top-k filtering
                    if top_k > 0:
                        values, _ = torch.topk(next_token_logits, top_k)
                        min_value = values[:, -1].unsqueeze(-1).expand_as(next_token_logits)
                        next_token_logits = torch.where(
                            next_token_logits < min_value,
                            torch.ones_like(next_token_logits) * float('-inf'),
                            next_token_logits
                        )
                    
                    # Apply top-p (nucleus) filtering
                    if top_p < 1.0:
                        sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                        cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
                        
                        # Remove tokens with cumulative probability above the threshold
                        sorted_indices_to_remove = cumulative_probs > top_p
                        # Shift the indices to the right to keep also the first token above the threshold
                        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                        sorted_indices_to_remove[..., 0] = 0
                        
                        # Scatter sorted tensors to original indexing
                        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                        next_token_logits = next_token_logits.masked_fill(indices_to_remove, float('-inf'))
                    
                    # Sample next token
                    probs = torch.softmax(next_token_logits, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1)
                    
                    # Append to sequence
                    cur_ids = torch.cat([cur_ids, next_token], dim=1)
                    
                    # Stop if we predict the end of text token
                    if next_token.item() == self.end_token:
                        break
                
                # Decode the generated sequence
                generated_text = self.tokenizer.decode(cur_ids[0].tolist())
                generated_sequences.append(generated_text)
        
        return generated_sequences

# Initialize the generator
generator = ShakespeareTextGenerator()

def generate_text(prompt, max_length, temperature, top_k, top_p, num_sequences):
    """Gradio interface function"""
    try:
        sequences = generator.generate(
            prompt=prompt,
            max_length=max_length,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            num_return_sequences=num_sequences
        )
        return "\n\n---\n\n".join(sequences)
    except Exception as e:
        return f"Error: {str(e)}"

# Create Gradio interface
iface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(
            lines=3, 
            label="Prompt", 
            placeholder="Enter your prompt here...",
            value="To be, or not to be,"
        ),
        gr.Slider(
            minimum=10, 
            maximum=500, 
            value=100, 
            step=10, 
            label="Maximum Length"
        ),
        gr.Slider(
            minimum=0.1, 
            maximum=2.0, 
            value=0.7, 
            step=0.1, 
            label="Temperature (randomness)"
        ),
        gr.Slider(
            minimum=0, 
            maximum=100, 
            value=50, 
            step=5, 
            label="Top-k"
        ),
        gr.Slider(
            minimum=0.0, 
            maximum=1.0, 
            value=0.9, 
            step=0.05, 
            label="Top-p (nucleus sampling)"
        ),
        gr.Slider(
            minimum=1, 
            maximum=5, 
            value=1, 
            step=1, 
            label="Number of Sequences"
        )
    ],
    outputs=gr.Textbox(
        lines=10, 
        label="Generated Text"
    ),
    title="Shakespeare-Style Text Generator",
    description="""Generate Shakespeare-style text using a fine-tuned GPT model. Training repository: [https://github.com/dhairyag/ShakespeareGPT-Forge](https://github.com/dhairyag/ShakespeareGPT-Forge)
    Adjust the parameters to control the generation:
    - Temperature: Higher values make the output more random
    - Top-k: Limits the vocabulary to the k most likely tokens
    - Top-p: Limits the cumulative probability of tokens considered
    - Number of Sequences: Generate multiple variations""",
    examples=[
        ["To be, or not to be,", 100, 0.7, 50, 0.9, 1],
        ["O Romeo, Romeo,", 150, 0.8, 40, 0.85, 2],
        ["All the world's a stage,", 200, 0.6, 60, 0.95, 1]
    ]
)

# Launch the app
if __name__ == "__main__":
    iface.launch()