Spaces:
Running
Running
| from flask import Flask, jsonify, request, render_template | |
| from flask_cors import CORS | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
| import torch | |
| app = Flask(__name__) | |
| CORS(app) | |
| # Global variables for model and tokenizer | |
| MODEL_PATH = "./models/fine-tuned-gpt2" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| tokenizer = None | |
| model = None | |
| def load_chatbot_model(): | |
| """Load the chatbot model and tokenizer""" | |
| global tokenizer, model | |
| if model is None: | |
| print(f"Loading chatbot model from {MODEL_PATH}...") | |
| print(f"Using device: {device}") | |
| tokenizer = GPT2Tokenizer.from_pretrained(MODEL_PATH) | |
| model = GPT2LMHeadModel.from_pretrained(MODEL_PATH) | |
| model.to(device) | |
| print("Model loaded successfully!") | |
| # Load model on startup | |
| load_chatbot_model() | |
| def index(): | |
| """Serve the chat interface""" | |
| return render_template('index.html') | |
| def root(): | |
| return jsonify({ | |
| "message": "Chatbot API", | |
| "status": "running", | |
| "model": "fine-tuned-gpt2", | |
| "device": str(device) | |
| }) | |
| def health(): | |
| return jsonify({ | |
| "status": "healthy", | |
| "model_loaded": model is not None, | |
| "device": str(device) | |
| }) | |
| def chat(): | |
| """ | |
| Generate a chatbot response based on conversation history | |
| """ | |
| if model is None or tokenizer is None: | |
| return jsonify({"error": "Model not loaded"}), 500 | |
| try: | |
| data = request.get_json() | |
| user_messages = data.get("user", []) | |
| ai_messages = data.get("ai", []) | |
| # Build conversation history | |
| combined_prompt = "" | |
| # Limit history to last 7 exchanges | |
| user_msgs = user_messages[-7:] if len(user_messages) > 7 else user_messages | |
| ai_msgs = ai_messages[-6:] if len(ai_messages) > 6 else ai_messages | |
| # Add conversation history | |
| for user_message, ai_message in zip(user_msgs[:-1], ai_msgs): | |
| combined_prompt += f"<user> {user_message}{tokenizer.eos_token}<AI> {ai_message}{tokenizer.eos_token}" | |
| # Add current message | |
| if user_msgs: | |
| combined_prompt += f"<user> {user_msgs[-1]}{tokenizer.eos_token}<AI>" | |
| # Tokenize and generate | |
| inputs = tokenizer.encode(combined_prompt, return_tensors="pt").to(device) | |
| attention_mask = torch.ones(inputs.shape, device=device) | |
| outputs = model.generate( | |
| inputs, | |
| max_new_tokens=50, | |
| num_beams=5, | |
| early_stopping=True, | |
| no_repeat_ngram_size=2, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.95, | |
| pad_token_id=tokenizer.eos_token_id, | |
| attention_mask=attention_mask, | |
| repetition_penalty=1.2 | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the new AI response | |
| # Split by <AI> and get the last response | |
| if "<AI>" in response: | |
| response = response.split("<AI>")[-1].strip() | |
| # Remove any <user> tags if they appear (model might generate them) | |
| if "<user>" in response: | |
| response = response.split("<user>")[0].strip() | |
| # Clean up any remaining special tokens | |
| response = response.replace("<AI>", "").replace("<user>", "").strip() | |
| # If empty response, provide a default | |
| if not response: | |
| response = "I'm not sure how to respond to that." | |
| return jsonify({ | |
| "response": response, | |
| "device": str(device) | |
| }) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860, debug=False) | |