326 lines
13 KiB
C#
326 lines
13 KiB
C#
using System.Collections.Concurrent;
|
||
using System.Text;
|
||
using System.Text.Json;
|
||
using Microsoft.AspNetCore.Http;
|
||
using Microsoft.Extensions.DependencyInjection;
|
||
using Microsoft.Extensions.Logging;
|
||
using Newtonsoft.Json;
|
||
using Newtonsoft.Json.Serialization;
|
||
using Volo.Abp.Domain.Services;
|
||
using Yi.Framework.AiHub.Domain.AiGateWay;
|
||
using Yi.Framework.AiHub.Domain.Entities;
|
||
using Yi.Framework.AiHub.Domain.Entities.Model;
|
||
using Yi.Framework.AiHub.Domain.Shared.Dtos;
|
||
using Yi.Framework.AiHub.Domain.Shared.Dtos.OpenAi;
|
||
using Yi.Framework.AiHub.Domain.Shared.Enums;
|
||
using Yi.Framework.SqlSugarCore.Abstractions;
|
||
using ThorJsonSerializer = Yi.Framework.AiHub.Domain.AiGateWay.ThorJsonSerializer;
|
||
|
||
namespace Yi.Framework.AiHub.Domain.Managers;
|
||
|
||
public class ChatManager : DomainService
|
||
{
|
||
private readonly AiGateWayManager _aiGateWayManager;
|
||
private readonly ISqlSugarRepository<AiModelEntity> _aiModelRepository;
|
||
private readonly ILogger<ChatManager> _logger;
|
||
private readonly AiMessageManager _aiMessageManager;
|
||
private readonly UsageStatisticsManager _usageStatisticsManager;
|
||
|
||
public ChatManager(
|
||
AiGateWayManager aiGateWayManager,
|
||
ISqlSugarRepository<AiModelEntity> aiModelRepository,
|
||
ILogger<ChatManager> logger,
|
||
AiMessageManager aiMessageManager,
|
||
UsageStatisticsManager usageStatisticsManager)
|
||
{
|
||
_aiGateWayManager = aiGateWayManager;
|
||
_aiModelRepository = aiModelRepository;
|
||
_logger = logger;
|
||
_aiMessageManager = aiMessageManager;
|
||
_usageStatisticsManager = usageStatisticsManager;
|
||
}
|
||
|
||
/// <summary>
|
||
/// 支持工具调用的流式对话
|
||
/// </summary>
|
||
/// <param name="httpContext"></param>
|
||
/// <param name="request"></param>
|
||
/// <param name="userId"></param>
|
||
/// <param name="sessionId"></param>
|
||
/// <param name="tokenId"></param>
|
||
/// <param name="cancellationToken"></param>
|
||
public async Task CompleteChatWithToolsAsync(
|
||
HttpContext httpContext,
|
||
ThorChatCompletionsRequest request,
|
||
Guid? userId = null,
|
||
Guid? sessionId = null,
|
||
Guid? tokenId = null,
|
||
CancellationToken cancellationToken = default)
|
||
{
|
||
var response = httpContext.Response;
|
||
// 设置响应头,声明是 SSE 流
|
||
response.ContentType = "text/event-stream;charset=utf-8;";
|
||
response.Headers.TryAdd("Cache-Control", "no-cache");
|
||
response.Headers.TryAdd("Connection", "keep-alive");
|
||
|
||
var modelDescribe = await GetModelAsync(ModelApiTypeEnum.OpenAi, request.Model);
|
||
var chatService = LazyServiceProvider.GetRequiredKeyedService<IChatCompletionService>(modelDescribe.HandlerName);
|
||
|
||
var tokenUsage = new ThorUsageResponse();
|
||
var messageQueue = new ConcurrentQueue<string>();
|
||
var backupSystemContent = new StringBuilder();
|
||
var isComplete = false;
|
||
|
||
// 启动后台任务来消费队列
|
||
var outputTask = Task.Run(async () =>
|
||
{
|
||
while (!(isComplete && messageQueue.IsEmpty))
|
||
{
|
||
if (messageQueue.TryDequeue(out var message))
|
||
{
|
||
await response.WriteAsync(message, Encoding.UTF8, cancellationToken).ConfigureAwait(false);
|
||
await response.Body.FlushAsync(cancellationToken).ConfigureAwait(false);
|
||
}
|
||
|
||
if (!isComplete)
|
||
{
|
||
await Task.Delay(TimeSpan.FromMilliseconds(75), cancellationToken).ConfigureAwait(false);
|
||
}
|
||
else
|
||
{
|
||
await Task.Delay(10, cancellationToken).ConfigureAwait(false);
|
||
}
|
||
}
|
||
}, cancellationToken);
|
||
|
||
try
|
||
{
|
||
// 多轮工具调用循环
|
||
var conversationMessages = request.Messages ?? new List<ThorChatMessage>();
|
||
var maxIterations = 10; // 防止无限循环
|
||
var currentIteration = 0;
|
||
|
||
while (currentIteration < maxIterations)
|
||
{
|
||
currentIteration++;
|
||
|
||
// 更新请求消息
|
||
request.Messages = conversationMessages;
|
||
|
||
// 调用 AI API
|
||
var completeChatResponse = chatService.CompleteChatStreamAsync(modelDescribe, request, cancellationToken);
|
||
|
||
var currentToolCalls = new List<ThorToolCall>();
|
||
var hasToolCalls = false;
|
||
var currentContent = new StringBuilder();
|
||
|
||
await foreach (var data in completeChatResponse)
|
||
{
|
||
data.SupplementalMultiplier(modelDescribe.Multiplier);
|
||
|
||
if (data.Usage is not null && (data.Usage.CompletionTokens > 0 || data.Usage.OutputTokens > 0))
|
||
{
|
||
tokenUsage = data.Usage;
|
||
}
|
||
|
||
// 检查是否有工具调用
|
||
if (data.Choices?.FirstOrDefault()?.Delta?.ToolCalls != null)
|
||
{
|
||
hasToolCalls = true;
|
||
foreach (var toolCall in data.Choices.First().Delta.ToolCalls)
|
||
{
|
||
// 累积工具调用信息
|
||
var existingToolCall = currentToolCalls.FirstOrDefault(tc => tc.Id == toolCall.Id);
|
||
if (existingToolCall == null)
|
||
{
|
||
currentToolCalls.Add(toolCall);
|
||
}
|
||
else
|
||
{
|
||
// 合并工具调用的 arguments(流式返回可能是分片的)
|
||
if (existingToolCall.Function != null && toolCall.Function?.Arguments != null)
|
||
{
|
||
existingToolCall.Function.Arguments += toolCall.Function.Arguments;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 检查 finish_reason
|
||
var finishReason = data.Choices?.FirstOrDefault()?.FinishReason;
|
||
if (finishReason == "tool_calls")
|
||
{
|
||
hasToolCalls = true;
|
||
}
|
||
|
||
// 累积内容
|
||
var content = data.Choices?.FirstOrDefault()?.Delta.Content;
|
||
if (!string.IsNullOrEmpty(content))
|
||
{
|
||
currentContent.Append(content);
|
||
}
|
||
|
||
var message = System.Text.Json.JsonSerializer.Serialize(data, ThorJsonSerializer.DefaultOptions);
|
||
backupSystemContent.Append(content);
|
||
messageQueue.Enqueue($"data: {message}\n\n");
|
||
}
|
||
|
||
// 如果没有工具调用,结束循环
|
||
if (!hasToolCalls || currentToolCalls.Count == 0)
|
||
{
|
||
break;
|
||
}
|
||
|
||
// 发送工具调用状态消息
|
||
var toolCallStatusMessage = new
|
||
{
|
||
type = "tool_call_status",
|
||
status = "calling_tools",
|
||
tool_calls = currentToolCalls
|
||
};
|
||
var statusJson = System.Text.Json.JsonSerializer.Serialize(toolCallStatusMessage, ThorJsonSerializer.DefaultOptions);
|
||
messageQueue.Enqueue($"data: {statusJson}\n\n");
|
||
|
||
// 将 AI 的工具调用消息添加到历史
|
||
conversationMessages.Add(new ThorChatMessage
|
||
{
|
||
Role = ThorChatMessageRoleConst.Assistant,
|
||
Content = currentContent.ToString(),
|
||
ToolCalls = currentToolCalls
|
||
});
|
||
|
||
// 执行工具调用
|
||
foreach (var toolCall in currentToolCalls)
|
||
{
|
||
try
|
||
{
|
||
// TODO: 实现 MCP 工具调用逻辑
|
||
// 这里需要根据 toolCall.Function.Name 和 toolCall.Function.Arguments 调用相应的工具
|
||
// 示例:
|
||
// var toolResult = await InvokeMcpToolAsync(toolCall.Function.Name, toolCall.Function.Arguments, cancellationToken);
|
||
|
||
var toolResult = "TODO: 实现 MCP 工具调用";
|
||
|
||
// 将工具结果添加到消息历史
|
||
conversationMessages.Add(ThorChatMessage.CreateToolMessage(toolResult, toolCall.Id));
|
||
|
||
// 发送工具执行结果消息
|
||
var toolResultMessage = new
|
||
{
|
||
type = "tool_call_result",
|
||
tool_call_id = toolCall.Id,
|
||
tool_name = toolCall.Function?.Name,
|
||
result = toolResult
|
||
};
|
||
var resultJson = System.Text.Json.JsonSerializer.Serialize(toolResultMessage, ThorJsonSerializer.DefaultOptions);
|
||
messageQueue.Enqueue($"data: {resultJson}\n\n");
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
_logger.LogError(ex, $"工具调用失败: {toolCall.Function?.Name}");
|
||
|
||
// 将错误信息添加到消息历史
|
||
var errorMessage = $"工具调用失败: {ex.Message}";
|
||
conversationMessages.Add(ThorChatMessage.CreateToolMessage(errorMessage, toolCall.Id));
|
||
|
||
// 发送工具执行错误消息
|
||
var toolErrorMessage = new
|
||
{
|
||
type = "tool_call_error",
|
||
tool_call_id = toolCall.Id,
|
||
tool_name = toolCall.Function?.Name,
|
||
error = errorMessage
|
||
};
|
||
var errorJson = System.Text.Json.JsonSerializer.Serialize(toolErrorMessage, ThorJsonSerializer.DefaultOptions);
|
||
messageQueue.Enqueue($"data: {errorJson}\n\n");
|
||
}
|
||
}
|
||
|
||
// 继续下一轮对话,让 AI 根据工具结果生成回复
|
||
}
|
||
}
|
||
catch (Exception e)
|
||
{
|
||
_logger.LogError(e, $"AI对话异常");
|
||
var errorContent = $"对话AI异常,异常信息:\n当前AI模型:{request.Model}\n异常信息:{e.Message}\n异常堆栈:{e}";
|
||
var model = new ThorChatCompletionsResponse()
|
||
{
|
||
Choices = new List<ThorChatChoiceResponse>()
|
||
{
|
||
new ThorChatChoiceResponse()
|
||
{
|
||
Delta = new ThorChatMessage()
|
||
{
|
||
Content = errorContent
|
||
}
|
||
}
|
||
}
|
||
};
|
||
var message = JsonConvert.SerializeObject(model, new JsonSerializerSettings
|
||
{
|
||
ContractResolver = new CamelCasePropertyNamesContractResolver()
|
||
});
|
||
backupSystemContent.Append(errorContent);
|
||
messageQueue.Enqueue($"data: {message}\n\n");
|
||
}
|
||
|
||
// 断开连接
|
||
messageQueue.Enqueue("data: [DONE]\n\n");
|
||
isComplete = true;
|
||
|
||
await outputTask;
|
||
|
||
// 保存消息和统计信息
|
||
await _aiMessageManager.CreateUserMessageAsync(userId, sessionId,
|
||
new Shared.Dtos.MessageInputDto
|
||
{
|
||
Content = sessionId is null ? "不予存储" : request.Messages?.LastOrDefault()?.MessagesStore ?? string.Empty,
|
||
ModelId = request.Model,
|
||
TokenUsage = tokenUsage,
|
||
}, tokenId);
|
||
|
||
await _aiMessageManager.CreateSystemMessageAsync(userId, sessionId,
|
||
new Shared.Dtos.MessageInputDto
|
||
{
|
||
Content = sessionId is null ? "不予存储" : backupSystemContent.ToString(),
|
||
ModelId = request.Model,
|
||
TokenUsage = tokenUsage
|
||
}, tokenId);
|
||
|
||
await _usageStatisticsManager.SetUsageAsync(userId, request.Model, tokenUsage, tokenId);
|
||
}
|
||
|
||
/// <summary>
|
||
/// 获取模型
|
||
/// </summary>
|
||
private async Task<AiModelDescribe> GetModelAsync(ModelApiTypeEnum modelApiType, string modelId)
|
||
{
|
||
var aiModelDescribe = await _aiModelRepository._DbQueryable
|
||
.LeftJoin<AiAppAggregateRoot>((model, app) => model.AiAppId == app.Id)
|
||
.Where((model, app) => model.ModelId == modelId)
|
||
.Where((model, app) => model.ModelApiType == modelApiType)
|
||
.Select((model, app) =>
|
||
new AiModelDescribe
|
||
{
|
||
AppId = app.Id,
|
||
AppName = app.Name,
|
||
Endpoint = app.Endpoint,
|
||
ApiKey = app.ApiKey,
|
||
OrderNum = model.OrderNum,
|
||
HandlerName = model.HandlerName,
|
||
ModelId = model.ModelId,
|
||
ModelName = model.Name,
|
||
Description = model.Description,
|
||
AppExtraUrl = app.ExtraUrl,
|
||
ModelExtraInfo = model.ExtraInfo,
|
||
Multiplier = model.Multiplier
|
||
})
|
||
.FirstAsync();
|
||
if (aiModelDescribe is null)
|
||
{
|
||
throw new UserFriendlyException($"【{modelId}】模型当前版本【{modelApiType}】格式不支持");
|
||
}
|
||
|
||
return aiModelDescribe;
|
||
}
|
||
} |