1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
| @chat_bp.route('/stream', methods=['GET', 'POST']) def stream_chat(): """流式聊天接口 - 支持SSE事件流和MCP工具调用 两种使用模式: 1. 双阶段模式:先通过GET请求建立连接,获取sessionId,再通过POST请求发送消息 2. 单阶段模式:直接通过POST请求发送消息并返回流式响应 """ if request.method == 'POST': data = request.get_json() sessionId = data.get('sessionId') create_new_stream = data.get('createStream', False) logger.info(f"收到POST流式请求: sessionId={sessionId}, createStream={create_new_stream}") else: sessionId = request.args.get('sessionId') create_new_stream = False logger.info(f"收到GET流式请求: sessionId={sessionId}") if not sessionId: sessionId = str(uuid.uuid4()) logger.info(f"生成新的会话ID: {sessionId}") if request.method == 'GET' or create_new_stream: def generate_session(): logger.debug(f"开始生成SSE会话: sessionId={sessionId}") yield f"data: {json.dumps({'type': 'session_id', 'session_id': sessionId})}\n\n" logger.debug(f"已发送会话ID: {sessionId}") response_queue = queue.Queue() active_streams[sessionId] = response_queue logger.debug(f"已创建响应队列: sessionId={sessionId}") if request.method == 'POST' and create_new_stream: userInput = data.get('message') if userInput: logger.info(f"处理POST流式请求中的消息: sessionId={sessionId}, message={userInput[:50]}...") modelType = data.get('modelType', 'deepseek') systemPrompt = data.get('systemPrompt', '') mcp_url = data.get('mcpUrl') or get_env_variable("MCP_URL") mcpChat = mcpChatSessions.get(sessionId) if not mcpChat: logger.info(f"创建新的MCP聊天会话: sessionId={sessionId}, modelType={modelType}") mcpChat = MCPChat(modelType, systemPrompt) mcpChatSessions[sessionId] = mcpChat if mcp_url and mcp_url != mcpChat.mcp_url: logger.info(f"设置自定义MCP URL: {mcp_url}") mcpChat.mcp_url = mcp_url has_tool_call_flag = [False] logger.debug(f"启动异步处理线程: sessionId={sessionId}") thread = threading.Thread( target=process_async_response, args=(sessionId, mcpChat, userInput, response_queue, has_tool_call_flag) ) thread.daemon = True thread.start() try: while True: try: item = response_queue.get(block=False) if item is None: logger.debug(f"收到结束信号: sessionId={sessionId}") yield f"data: {json.dumps({'type': 'end'})}\n\n" break else: logger.debug(f"发送SSE数据: {item[:100]}...") yield item except queue.Empty: time.sleep(0.1) finally: if sessionId in active_streams: logger.debug(f"清理响应队列: sessionId={sessionId}") del active_streams[sessionId] return Response( stream_with_context(generate_session()), mimetype='text/event-stream', headers={ 'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no', 'Connection': 'keep-alive' } ) userInput = data.get('message') modelType = data.get('modelType', 'deepseek') systemPrompt = data.get('systemPrompt', '') mcp_url = data.get('mcpUrl') or get_env_variable("MCP_URL") if not userInput: return jsonify({"error": "消息内容不能为空"}), 400 mcpChat = mcpChatSessions.get(sessionId) if not mcpChat: logger.info(f"创建新的MCP聊天会话: sessionId={sessionId}, modelType={modelType}") mcpChat = MCPChat(modelType, systemPrompt) mcpChatSessions[sessionId] = mcpChat if mcp_url and mcp_url != mcpChat.mcp_url: logger.info(f"设置自定义MCP URL: {mcp_url}") mcpChat.mcp_url = mcp_url if sessionId not in active_streams: logger.warning(f"未找到活动的流连接: sessionId={sessionId}") return jsonify({ "error": "没有找到活动的流连接,请先建立GET连接或使用createStream=true", "sessionId": sessionId }), 400 response_queue = active_streams[sessionId] has_tool_call_flag = [False] logger.debug(f"启动异步处理线程: sessionId={sessionId}") thread = threading.Thread( target=process_async_response, args=(sessionId, mcpChat, userInput, response_queue, has_tool_call_flag) ) thread.daemon = True thread.start() return jsonify({"status": "processing", "sessionId": sessionId})
def process_async_response(sessionId, mcpChat, userInput, response_queue, has_tool_call_flag): """在单独的线程中处理异步模型调用""" complete_response = "" def simple_callback(chunk, is_complete): nonlocal complete_response logger.debug(f"AI输出: sessionId={sessionId}, chunk={chunk}, is_complete={is_complete}") complete_response += chunk if chunk and chunk.strip(): response_queue.put(f"data: {json.dumps({'type': 'chunk', 'content': chunk})}\n\n") if is_complete: logger.info(f"AI响应完成: sessionId={sessionId}") logger.info(f"完整响应: {complete_response}") process_complete_response(complete_response, response_queue) response_queue.put(None) def process_complete_response(response_text, response_queue): """处理完整响应中的工具调用""" pattern = r"```mcp\s*([\s\S]*?)\s*```" matches = re.finditer(pattern, response_text) tool_calls_found = False for match in matches: tool_call_text = match.group(1).strip() logger.info(f"检测到工具调用: {tool_call_text}") tool_calls_found = True has_tool_call_flag[0] = True process_tool_call(tool_call_text, response_queue) if not tool_calls_found: logger.info("未检测到工具调用") def process_tool_call(tool_call_text, response_queue): """处理工具调用""" try: tool_call = json.loads(tool_call_text) tool_name = tool_call.get("tool") params = tool_call.get("params", {}) logger.info(f"解析工具调用: tool_name={tool_name}, params={params}") response_queue.put(f"data: {json.dumps({'type': 'tool_call'})}\n\n") tool_loop = asyncio.new_event_loop() def run_tool_call(): asyncio.set_event_loop(tool_loop) try: tool_loop.run_until_complete(execute_tool_and_send_results(tool_name, params, response_queue)) finally: tool_loop.close() tool_thread = threading.Thread(target=run_tool_call) tool_thread.daemon = True tool_thread.start() logger.info(f"等待工具调用完成: tool_name={tool_name}") tool_thread.join() logger.info(f"工具调用线程已完成: tool_name={tool_name}") except json.JSONDecodeError as e: error_msg = f"❌ 工具调用格式错误: {tool_call_text}, 错误: {str(e)}" logger.error(f"工具调用JSON解析错误: {error_msg}") response_queue.put(f"data: {json.dumps({'type': 'chunk', 'content': error_msg})}\n\n") except Exception as e: error_msg = f"❌ 工具调用处理失败: {str(e)}" logger.error(f"工具调用处理异常: {error_msg}", exc_info=True) response_queue.put(f"data: {json.dumps({'type': 'chunk', 'content': error_msg})}\n\n") async def execute_tool_and_send_results(tool_name, params, response_queue): """异步执行工具调用并发送结果""" try: logger.debug(f"开始流式调用工具: tool_name={tool_name}") tool_responses = [] async for result_chunk in mcpChat._call_tool_streaming(tool_name, params): logger.debug(f"工具调用结果: {result_chunk[:100]}...") response_queue.put(f"data: {json.dumps({'type': 'chunk', 'content': result_chunk})}\n\n") tool_responses.append(result_chunk) logger.info(f"工具调用完成,共发送了 {len(tool_responses)} 个结果块") except Exception as e: error_msg = f"❌ 工具 '{tool_name}' 调用失败: {str(e)}" logger.error(f"工具调用失败: {error_msg}", exc_info=True) response_queue.put(f"data: {json.dumps({'type': 'chunk', 'content': error_msg})}\n\n") loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: logger.info(f"开始获取流式响应: sessionId={sessionId}, userInput={userInput[:50]}...") async def run_streaming(): await mcpChat.get_streaming_response(userInput, simple_callback) loop.run_until_complete(run_streaming()) except Exception as e: error_message = str(e) logger.error(f"流式处理错误: {error_message}", exc_info=True) response_queue.put(f"data: {json.dumps({'type': 'error', 'error': error_message})}\n\n") response_queue.put(None) finally: loop.close()
active_streams = {}
|