Spaces:
Running
Running
| from flask import Flask, request, jsonify | |
| from sentence_transformers import SentenceTransformer, util | |
| import logging | |
| import os | |
| app = Flask(__name__) | |
| # 配置日志 | |
| logging.basicConfig(level=logging.INFO) | |
| app.logger = logging.getLogger("CodeSearchAPI") | |
| # 预定义代码片段 | |
| CODE_SNIPPETS = [ | |
| """def sort_list(x): return sorted(x)""", | |
| """def count_above_threshold(elements, threshold=0): | |
| return sum(1 for e in elements if e > threshold)""", | |
| """def find_min_max(elements): | |
| return min(elements), max(elements)""" | |
| ] | |
| # 初始化标记 | |
| model_ready = False | |
| try: | |
| # 初始化模型(使用预下载的缓存) | |
| model = SentenceTransformer( | |
| "flax-sentence-embeddings/st-codesearch-distilroberta-base", | |
| cache_folder=os.getenv("HF_HOME") | |
| ) | |
| # 预计算编码 | |
| code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True) | |
| model_ready = True | |
| app.logger.info("模型加载完成,服务就绪") | |
| except Exception as e: | |
| app.logger.error(f"模型初始化失败: {str(e)}") | |
| raise | |
| def health_check(): | |
| """健康检查端点""" | |
| if model_ready: | |
| return jsonify({"status": "ready"}), 200 | |
| else: | |
| return jsonify({"status": "initializing"}), 503 | |
| def handle_search(): | |
| """搜索请求处理""" | |
| if not model_ready: | |
| return jsonify({"error": "服务正在初始化"}), 503 | |
| try: | |
| # 请求验证 | |
| if not request.is_json: | |
| return jsonify({"error": "需要 application/json"}), 415 | |
| data = request.get_json() | |
| query = data.get('query', '').strip() | |
| if not query: | |
| return jsonify({"error": "查询不能为空"}), 400 | |
| # 处理查询 | |
| query_emb = model.encode(query, convert_to_tensor=True) | |
| hits = util.semantic_search(query_emb, code_emb, top_k=1)[0] | |
| best = hits[0] | |
| return jsonify({ | |
| "code": CODE_SNIPPETS[best['corpus_id']], | |
| "score": round(float(best['score']), 4) | |
| }) | |
| except Exception as e: | |
| app.logger.error(f"请求处理失败: {str(e)}") | |
| return jsonify({"error": "服务器内部错误"}), 500 | |
| if __name__ == "__main__": | |
| app.run(host='0.0.0.0', port=8080) |