diff --git a/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Application/Services/Chat/AiChatService.cs b/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Application/Services/Chat/AiChatService.cs index bc470ac5..815d416c 100644 --- a/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Application/Services/Chat/AiChatService.cs +++ b/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Application/Services/Chat/AiChatService.cs @@ -43,11 +43,15 @@ public class AiChatService : ApplicationService private readonly ILogger _logger; private readonly AiGateWayManager _aiGateWayManager; private readonly PremiumPackageManager _premiumPackageManager; + private readonly ChatManager _chatManager; public AiChatService(IHttpContextAccessor httpContextAccessor, AiBlacklistManager aiBlacklistManager, ISqlSugarRepository aiModelRepository, - ILogger logger, AiGateWayManager aiGateWayManager, PremiumPackageManager premiumPackageManager) + ILogger logger, + AiGateWayManager aiGateWayManager, + PremiumPackageManager premiumPackageManager, + ChatManager chatManager) { _httpContextAccessor = httpContextAccessor; _aiBlacklistManager = aiBlacklistManager; @@ -55,6 +59,7 @@ public class AiChatService : ApplicationService _logger = logger; _aiGateWayManager = aiGateWayManager; _premiumPackageManager = premiumPackageManager; + _chatManager = chatManager; } @@ -140,9 +145,19 @@ public class AiChatService : ApplicationService } } - //ai网关代理httpcontext - await _aiGateWayManager.CompleteChatStreamForStatisticsAsync(_httpContextAccessor.HttpContext, input, - CurrentUser.Id, sessionId, null, cancellationToken); + // 判断是否有工具调用 + if (input.Tools != null && input.Tools.Count > 0) + { + // 使用 ChatManager 处理支持工具调用的对话 + await _chatManager.CompleteChatWithToolsAsync(_httpContextAccessor.HttpContext, input, + CurrentUser.Id, sessionId, null, cancellationToken); + } + else + { + // 使用原来的 AiGateWayManager 处理普通对话 + await _aiGateWayManager.CompleteChatStreamForStatisticsAsync(_httpContextAccessor.HttpContext, input, + CurrentUser.Id, sessionId, null, cancellationToken); + } } diff --git a/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Domain/Managers/ChatManager.cs b/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Domain/Managers/ChatManager.cs new file mode 100644 index 00000000..bce3dc56 --- /dev/null +++ b/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Domain/Managers/ChatManager.cs @@ -0,0 +1,326 @@ +using System.Collections.Concurrent; +using System.Text; +using System.Text.Json; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Newtonsoft.Json; +using Newtonsoft.Json.Serialization; +using Volo.Abp.Domain.Services; +using Yi.Framework.AiHub.Domain.AiGateWay; +using Yi.Framework.AiHub.Domain.Entities; +using Yi.Framework.AiHub.Domain.Entities.Model; +using Yi.Framework.AiHub.Domain.Shared.Dtos; +using Yi.Framework.AiHub.Domain.Shared.Dtos.OpenAi; +using Yi.Framework.AiHub.Domain.Shared.Enums; +using Yi.Framework.SqlSugarCore.Abstractions; +using ThorJsonSerializer = Yi.Framework.AiHub.Domain.AiGateWay.ThorJsonSerializer; + +namespace Yi.Framework.AiHub.Domain.Managers; + +public class ChatManager : DomainService +{ + private readonly AiGateWayManager _aiGateWayManager; + private readonly ISqlSugarRepository _aiModelRepository; + private readonly ILogger _logger; + private readonly AiMessageManager _aiMessageManager; + private readonly UsageStatisticsManager _usageStatisticsManager; + + public ChatManager( + AiGateWayManager aiGateWayManager, + ISqlSugarRepository aiModelRepository, + ILogger logger, + AiMessageManager aiMessageManager, + UsageStatisticsManager usageStatisticsManager) + { + _aiGateWayManager = aiGateWayManager; + _aiModelRepository = aiModelRepository; + _logger = logger; + _aiMessageManager = aiMessageManager; + _usageStatisticsManager = usageStatisticsManager; + } + + /// + /// 支持工具调用的流式对话 + /// + /// + /// + /// + /// + /// + /// + public async Task CompleteChatWithToolsAsync( + HttpContext httpContext, + ThorChatCompletionsRequest request, + Guid? userId = null, + Guid? sessionId = null, + Guid? tokenId = null, + CancellationToken cancellationToken = default) + { + var response = httpContext.Response; + // 设置响应头,声明是 SSE 流 + response.ContentType = "text/event-stream;charset=utf-8;"; + response.Headers.TryAdd("Cache-Control", "no-cache"); + response.Headers.TryAdd("Connection", "keep-alive"); + + var modelDescribe = await GetModelAsync(ModelApiTypeEnum.OpenAi, request.Model); + var chatService = LazyServiceProvider.GetRequiredKeyedService(modelDescribe.HandlerName); + + var tokenUsage = new ThorUsageResponse(); + var messageQueue = new ConcurrentQueue(); + var backupSystemContent = new StringBuilder(); + var isComplete = false; + + // 启动后台任务来消费队列 + var outputTask = Task.Run(async () => + { + while (!(isComplete && messageQueue.IsEmpty)) + { + if (messageQueue.TryDequeue(out var message)) + { + await response.WriteAsync(message, Encoding.UTF8, cancellationToken).ConfigureAwait(false); + await response.Body.FlushAsync(cancellationToken).ConfigureAwait(false); + } + + if (!isComplete) + { + await Task.Delay(TimeSpan.FromMilliseconds(75), cancellationToken).ConfigureAwait(false); + } + else + { + await Task.Delay(10, cancellationToken).ConfigureAwait(false); + } + } + }, cancellationToken); + + try + { + // 多轮工具调用循环 + var conversationMessages = request.Messages ?? new List(); + var maxIterations = 10; // 防止无限循环 + var currentIteration = 0; + + while (currentIteration < maxIterations) + { + currentIteration++; + + // 更新请求消息 + request.Messages = conversationMessages; + + // 调用 AI API + var completeChatResponse = chatService.CompleteChatStreamAsync(modelDescribe, request, cancellationToken); + + var currentToolCalls = new List(); + var hasToolCalls = false; + var currentContent = new StringBuilder(); + + await foreach (var data in completeChatResponse) + { + data.SupplementalMultiplier(modelDescribe.Multiplier); + + if (data.Usage is not null && (data.Usage.CompletionTokens > 0 || data.Usage.OutputTokens > 0)) + { + tokenUsage = data.Usage; + } + + // 检查是否有工具调用 + if (data.Choices?.FirstOrDefault()?.Delta?.ToolCalls != null) + { + hasToolCalls = true; + foreach (var toolCall in data.Choices.First().Delta.ToolCalls) + { + // 累积工具调用信息 + var existingToolCall = currentToolCalls.FirstOrDefault(tc => tc.Id == toolCall.Id); + if (existingToolCall == null) + { + currentToolCalls.Add(toolCall); + } + else + { + // 合并工具调用的 arguments(流式返回可能是分片的) + if (existingToolCall.Function != null && toolCall.Function?.Arguments != null) + { + existingToolCall.Function.Arguments += toolCall.Function.Arguments; + } + } + } + } + + // 检查 finish_reason + var finishReason = data.Choices?.FirstOrDefault()?.FinishReason; + if (finishReason == "tool_calls") + { + hasToolCalls = true; + } + + // 累积内容 + var content = data.Choices?.FirstOrDefault()?.Delta.Content; + if (!string.IsNullOrEmpty(content)) + { + currentContent.Append(content); + } + + var message = System.Text.Json.JsonSerializer.Serialize(data, ThorJsonSerializer.DefaultOptions); + backupSystemContent.Append(content); + messageQueue.Enqueue($"data: {message}\n\n"); + } + + // 如果没有工具调用,结束循环 + if (!hasToolCalls || currentToolCalls.Count == 0) + { + break; + } + + // 发送工具调用状态消息 + var toolCallStatusMessage = new + { + type = "tool_call_status", + status = "calling_tools", + tool_calls = currentToolCalls + }; + var statusJson = System.Text.Json.JsonSerializer.Serialize(toolCallStatusMessage, ThorJsonSerializer.DefaultOptions); + messageQueue.Enqueue($"data: {statusJson}\n\n"); + + // 将 AI 的工具调用消息添加到历史 + conversationMessages.Add(new ThorChatMessage + { + Role = ThorChatMessageRoleConst.Assistant, + Content = currentContent.ToString(), + ToolCalls = currentToolCalls + }); + + // 执行工具调用 + foreach (var toolCall in currentToolCalls) + { + try + { + // TODO: 实现 MCP 工具调用逻辑 + // 这里需要根据 toolCall.Function.Name 和 toolCall.Function.Arguments 调用相应的工具 + // 示例: + // var toolResult = await InvokeMcpToolAsync(toolCall.Function.Name, toolCall.Function.Arguments, cancellationToken); + + var toolResult = "TODO: 实现 MCP 工具调用"; + + // 将工具结果添加到消息历史 + conversationMessages.Add(ThorChatMessage.CreateToolMessage(toolResult, toolCall.Id)); + + // 发送工具执行结果消息 + var toolResultMessage = new + { + type = "tool_call_result", + tool_call_id = toolCall.Id, + tool_name = toolCall.Function?.Name, + result = toolResult + }; + var resultJson = System.Text.Json.JsonSerializer.Serialize(toolResultMessage, ThorJsonSerializer.DefaultOptions); + messageQueue.Enqueue($"data: {resultJson}\n\n"); + } + catch (Exception ex) + { + _logger.LogError(ex, $"工具调用失败: {toolCall.Function?.Name}"); + + // 将错误信息添加到消息历史 + var errorMessage = $"工具调用失败: {ex.Message}"; + conversationMessages.Add(ThorChatMessage.CreateToolMessage(errorMessage, toolCall.Id)); + + // 发送工具执行错误消息 + var toolErrorMessage = new + { + type = "tool_call_error", + tool_call_id = toolCall.Id, + tool_name = toolCall.Function?.Name, + error = errorMessage + }; + var errorJson = System.Text.Json.JsonSerializer.Serialize(toolErrorMessage, ThorJsonSerializer.DefaultOptions); + messageQueue.Enqueue($"data: {errorJson}\n\n"); + } + } + + // 继续下一轮对话,让 AI 根据工具结果生成回复 + } + } + catch (Exception e) + { + _logger.LogError(e, $"AI对话异常"); + var errorContent = $"对话AI异常,异常信息:\n当前AI模型:{request.Model}\n异常信息:{e.Message}\n异常堆栈:{e}"; + var model = new ThorChatCompletionsResponse() + { + Choices = new List() + { + new ThorChatChoiceResponse() + { + Delta = new ThorChatMessage() + { + Content = errorContent + } + } + } + }; + var message = JsonConvert.SerializeObject(model, new JsonSerializerSettings + { + ContractResolver = new CamelCasePropertyNamesContractResolver() + }); + backupSystemContent.Append(errorContent); + messageQueue.Enqueue($"data: {message}\n\n"); + } + + // 断开连接 + messageQueue.Enqueue("data: [DONE]\n\n"); + isComplete = true; + + await outputTask; + + // 保存消息和统计信息 + await _aiMessageManager.CreateUserMessageAsync(userId, sessionId, + new Shared.Dtos.MessageInputDto + { + Content = sessionId is null ? "不予存储" : request.Messages?.LastOrDefault()?.MessagesStore ?? string.Empty, + ModelId = request.Model, + TokenUsage = tokenUsage, + }, tokenId); + + await _aiMessageManager.CreateSystemMessageAsync(userId, sessionId, + new Shared.Dtos.MessageInputDto + { + Content = sessionId is null ? "不予存储" : backupSystemContent.ToString(), + ModelId = request.Model, + TokenUsage = tokenUsage + }, tokenId); + + await _usageStatisticsManager.SetUsageAsync(userId, request.Model, tokenUsage, tokenId); + } + + /// + /// 获取模型 + /// + private async Task GetModelAsync(ModelApiTypeEnum modelApiType, string modelId) + { + var aiModelDescribe = await _aiModelRepository._DbQueryable + .LeftJoin((model, app) => model.AiAppId == app.Id) + .Where((model, app) => model.ModelId == modelId) + .Where((model, app) => model.ModelApiType == modelApiType) + .Select((model, app) => + new AiModelDescribe + { + AppId = app.Id, + AppName = app.Name, + Endpoint = app.Endpoint, + ApiKey = app.ApiKey, + OrderNum = model.OrderNum, + HandlerName = model.HandlerName, + ModelId = model.ModelId, + ModelName = model.Name, + Description = model.Description, + AppExtraUrl = app.ExtraUrl, + ModelExtraInfo = model.ExtraInfo, + Multiplier = model.Multiplier + }) + .FirstAsync(); + if (aiModelDescribe is null) + { + throw new UserFriendlyException($"【{modelId}】模型当前版本【{modelApiType}】格式不支持"); + } + + return aiModelDescribe; + } +} \ No newline at end of file