Spaces:
Running
Running
| from flask import Flask, request, jsonify, render_template_string | |
| from sentence_transformers import SentenceTransformer, util | |
| import logging | |
| import sys | |
| import signal | |
| # 初始化 Flask 应用 | |
| app = Flask(__name__) | |
| # 配置日志,级别设为 INFO | |
| logging.basicConfig(level=logging.INFO) | |
| app.logger = logging.getLogger("CodeSearchAPI") | |
| # 预定义代码片段 | |
| CODE_SNIPPETS = [ | |
| "def sort_list(x): return sorted(x)", | |
| """def count_above_threshold(elements, threshold=0): | |
| return sum(1 for e in elements if e > threshold)""", | |
| """def find_min_max(elements): | |
| return min(elements), max(elements)""" | |
| """def count_evens(nums): | |
| return len([n for n in nums if n % 2 == 0])""", | |
| """def reverse_string(s): | |
| return s[::-1]""", | |
| """def is_prime(n): | |
| if n < 2: | |
| return False | |
| for i in range(2, int(n**0.5)+1): | |
| if n % i == 0: | |
| return False | |
| return True""", | |
| """def factorial(n): | |
| result = 1 | |
| for i in range(1, n+1): | |
| result *= i | |
| return result""", | |
| """def sum_of_squares(nums): | |
| return sum(map(lambda x: x**2, nums))""" | |
| ] | |
| # 全局服务状态 | |
| service_ready = False | |
| # 优雅关闭处理 | |
| def handle_shutdown(signum, frame): | |
| app.logger.info("收到终止信号,开始关闭...") | |
| sys.exit(0) | |
| signal.signal(signal.SIGTERM, handle_shutdown) | |
| signal.signal(signal.SIGINT, handle_shutdown) | |
| # 初始化模型和预计算编码 | |
| try: | |
| app.logger.info("开始加载模型...") | |
| model = SentenceTransformer( | |
| "flax-sentence-embeddings/st-codesearch-distilroberta-base", | |
| cache_folder="/model-cache" | |
| ) | |
| # 预计算代码片段的编码(强制使用 CPU) | |
| code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True, device="cpu") | |
| service_ready = True | |
| app.logger.info("服务初始化完成") | |
| except Exception as e: | |
| app.logger.error("初始化失败: %s", str(e)) | |
| raise | |
| # Hugging Face 健康检查端点,必须响应根路径 | |
| def hf_health_check(): | |
| # 如果请求接受 HTML,则返回一个简单的 HTML 页面(包含测试链接) | |
| if request.accept_mimetypes.accept_html: | |
| html = """ | |
| <h2>CodeSearch API</h2> | |
| <p>服务状态:{{ status }}</p> | |
| <p>你可以在地址栏输入 /search?query=你的查询 来测试接口</p> | |
| """ | |
| status = "ready" if service_ready else "initializing" | |
| return render_template_string(html, status=status) | |
| # 否则返回 JSON 格式的健康检查 | |
| if service_ready: | |
| return jsonify({"status": "ready"}), 200 | |
| else: | |
| return jsonify({"status": "initializing"}), 503 | |
| # 搜索 API 端点,同时支持 GET 和 POST 请求 | |
| def handle_search(): | |
| if not service_ready: | |
| app.logger.info("服务未就绪") | |
| return jsonify({"error": "服务正在初始化"}), 503 | |
| try: | |
| # 根据请求方法提取查询内容 | |
| if request.method == 'GET': | |
| query = request.args.get('query', '').strip() | |
| else: | |
| data = request.get_json() or {} | |
| query = data.get('query', '').strip() | |
| if not query: | |
| app.logger.info("收到空的查询请求") | |
| return jsonify({"error": "查询不能为空"}), 400 | |
| # 记录接收到的查询 | |
| app.logger.info("收到查询请求: %s", query) | |
| # 对查询进行编码,并进行语义搜索 | |
| query_emb = model.encode(query, convert_to_tensor=True, device="cpu") | |
| hits = util.semantic_search(query_emb, code_emb, top_k=1)[0] | |
| best = hits[0] | |
| result = { | |
| "code": CODE_SNIPPETS[best['corpus_id']], | |
| "score": round(float(best['score']), 4) | |
| } | |
| # 记录返回结果 | |
| app.logger.info("返回结果: %s", result) | |
| return jsonify(result) | |
| except Exception as e: | |
| app.logger.error("请求处理失败: %s", str(e)) | |
| return jsonify({"error": "服务器内部错误"}), 500 | |
| if __name__ == "__main__": | |
| # 本地测试用,Hugging Face Spaces 通常通过 gunicorn 启动 | |
| app.run(host='0.0.0.0', port=7860) | |