feat: 新增工具调用

This commit is contained in:
ccnetcore
2025-12-23 00:49:17 +08:00
parent 8f515f76c0
commit 81089cc058
2 changed files with 345 additions and 4 deletions

View File

@@ -43,11 +43,15 @@ public class AiChatService : ApplicationService
private readonly ILogger<AiChatService> _logger;
private readonly AiGateWayManager _aiGateWayManager;
private readonly PremiumPackageManager _premiumPackageManager;
private readonly ChatManager _chatManager;
public AiChatService(IHttpContextAccessor httpContextAccessor,
AiBlacklistManager aiBlacklistManager,
ISqlSugarRepository<AiModelEntity> aiModelRepository,
ILogger<AiChatService> logger, AiGateWayManager aiGateWayManager, PremiumPackageManager premiumPackageManager)
ILogger<AiChatService> logger,
AiGateWayManager aiGateWayManager,
PremiumPackageManager premiumPackageManager,
ChatManager chatManager)
{
_httpContextAccessor = httpContextAccessor;
_aiBlacklistManager = aiBlacklistManager;
@@ -55,6 +59,7 @@ public class AiChatService : ApplicationService
_logger = logger;
_aiGateWayManager = aiGateWayManager;
_premiumPackageManager = premiumPackageManager;
_chatManager = chatManager;
}
@@ -140,9 +145,19 @@ public class AiChatService : ApplicationService
}
}
//ai网关代理httpcontext
await _aiGateWayManager.CompleteChatStreamForStatisticsAsync(_httpContextAccessor.HttpContext, input,
CurrentUser.Id, sessionId, null, cancellationToken);
// 判断是否有工具调用
if (input.Tools != null && input.Tools.Count > 0)
{
// 使用 ChatManager 处理支持工具调用的对话
await _chatManager.CompleteChatWithToolsAsync(_httpContextAccessor.HttpContext, input,
CurrentUser.Id, sessionId, null, cancellationToken);
}
else
{
// 使用原来的 AiGateWayManager 处理普通对话
await _aiGateWayManager.CompleteChatStreamForStatisticsAsync(_httpContextAccessor.HttpContext, input,
CurrentUser.Id, sessionId, null, cancellationToken);
}
}

View File

@@ -0,0 +1,326 @@
using System.Collections.Concurrent;
using System.Text;
using System.Text.Json;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json;
using Newtonsoft.Json.Serialization;
using Volo.Abp.Domain.Services;
using Yi.Framework.AiHub.Domain.AiGateWay;
using Yi.Framework.AiHub.Domain.Entities;
using Yi.Framework.AiHub.Domain.Entities.Model;
using Yi.Framework.AiHub.Domain.Shared.Dtos;
using Yi.Framework.AiHub.Domain.Shared.Dtos.OpenAi;
using Yi.Framework.AiHub.Domain.Shared.Enums;
using Yi.Framework.SqlSugarCore.Abstractions;
using ThorJsonSerializer = Yi.Framework.AiHub.Domain.AiGateWay.ThorJsonSerializer;
namespace Yi.Framework.AiHub.Domain.Managers;
public class ChatManager : DomainService
{
private readonly AiGateWayManager _aiGateWayManager;
private readonly ISqlSugarRepository<AiModelEntity> _aiModelRepository;
private readonly ILogger<ChatManager> _logger;
private readonly AiMessageManager _aiMessageManager;
private readonly UsageStatisticsManager _usageStatisticsManager;
public ChatManager(
AiGateWayManager aiGateWayManager,
ISqlSugarRepository<AiModelEntity> aiModelRepository,
ILogger<ChatManager> logger,
AiMessageManager aiMessageManager,
UsageStatisticsManager usageStatisticsManager)
{
_aiGateWayManager = aiGateWayManager;
_aiModelRepository = aiModelRepository;
_logger = logger;
_aiMessageManager = aiMessageManager;
_usageStatisticsManager = usageStatisticsManager;
}
/// <summary>
/// 支持工具调用的流式对话
/// </summary>
/// <param name="httpContext"></param>
/// <param name="request"></param>
/// <param name="userId"></param>
/// <param name="sessionId"></param>
/// <param name="tokenId"></param>
/// <param name="cancellationToken"></param>
public async Task CompleteChatWithToolsAsync(
HttpContext httpContext,
ThorChatCompletionsRequest request,
Guid? userId = null,
Guid? sessionId = null,
Guid? tokenId = null,
CancellationToken cancellationToken = default)
{
var response = httpContext.Response;
// 设置响应头,声明是 SSE 流
response.ContentType = "text/event-stream;charset=utf-8;";
response.Headers.TryAdd("Cache-Control", "no-cache");
response.Headers.TryAdd("Connection", "keep-alive");
var modelDescribe = await GetModelAsync(ModelApiTypeEnum.OpenAi, request.Model);
var chatService = LazyServiceProvider.GetRequiredKeyedService<IChatCompletionService>(modelDescribe.HandlerName);
var tokenUsage = new ThorUsageResponse();
var messageQueue = new ConcurrentQueue<string>();
var backupSystemContent = new StringBuilder();
var isComplete = false;
// 启动后台任务来消费队列
var outputTask = Task.Run(async () =>
{
while (!(isComplete && messageQueue.IsEmpty))
{
if (messageQueue.TryDequeue(out var message))
{
await response.WriteAsync(message, Encoding.UTF8, cancellationToken).ConfigureAwait(false);
await response.Body.FlushAsync(cancellationToken).ConfigureAwait(false);
}
if (!isComplete)
{
await Task.Delay(TimeSpan.FromMilliseconds(75), cancellationToken).ConfigureAwait(false);
}
else
{
await Task.Delay(10, cancellationToken).ConfigureAwait(false);
}
}
}, cancellationToken);
try
{
// 多轮工具调用循环
var conversationMessages = request.Messages ?? new List<ThorChatMessage>();
var maxIterations = 10; // 防止无限循环
var currentIteration = 0;
while (currentIteration < maxIterations)
{
currentIteration++;
// 更新请求消息
request.Messages = conversationMessages;
// 调用 AI API
var completeChatResponse = chatService.CompleteChatStreamAsync(modelDescribe, request, cancellationToken);
var currentToolCalls = new List<ThorToolCall>();
var hasToolCalls = false;
var currentContent = new StringBuilder();
await foreach (var data in completeChatResponse)
{
data.SupplementalMultiplier(modelDescribe.Multiplier);
if (data.Usage is not null && (data.Usage.CompletionTokens > 0 || data.Usage.OutputTokens > 0))
{
tokenUsage = data.Usage;
}
// 检查是否有工具调用
if (data.Choices?.FirstOrDefault()?.Delta?.ToolCalls != null)
{
hasToolCalls = true;
foreach (var toolCall in data.Choices.First().Delta.ToolCalls)
{
// 累积工具调用信息
var existingToolCall = currentToolCalls.FirstOrDefault(tc => tc.Id == toolCall.Id);
if (existingToolCall == null)
{
currentToolCalls.Add(toolCall);
}
else
{
// 合并工具调用的 arguments流式返回可能是分片的
if (existingToolCall.Function != null && toolCall.Function?.Arguments != null)
{
existingToolCall.Function.Arguments += toolCall.Function.Arguments;
}
}
}
}
// 检查 finish_reason
var finishReason = data.Choices?.FirstOrDefault()?.FinishReason;
if (finishReason == "tool_calls")
{
hasToolCalls = true;
}
// 累积内容
var content = data.Choices?.FirstOrDefault()?.Delta.Content;
if (!string.IsNullOrEmpty(content))
{
currentContent.Append(content);
}
var message = System.Text.Json.JsonSerializer.Serialize(data, ThorJsonSerializer.DefaultOptions);
backupSystemContent.Append(content);
messageQueue.Enqueue($"data: {message}\n\n");
}
// 如果没有工具调用,结束循环
if (!hasToolCalls || currentToolCalls.Count == 0)
{
break;
}
// 发送工具调用状态消息
var toolCallStatusMessage = new
{
type = "tool_call_status",
status = "calling_tools",
tool_calls = currentToolCalls
};
var statusJson = System.Text.Json.JsonSerializer.Serialize(toolCallStatusMessage, ThorJsonSerializer.DefaultOptions);
messageQueue.Enqueue($"data: {statusJson}\n\n");
// 将 AI 的工具调用消息添加到历史
conversationMessages.Add(new ThorChatMessage
{
Role = ThorChatMessageRoleConst.Assistant,
Content = currentContent.ToString(),
ToolCalls = currentToolCalls
});
// 执行工具调用
foreach (var toolCall in currentToolCalls)
{
try
{
// TODO: 实现 MCP 工具调用逻辑
// 这里需要根据 toolCall.Function.Name 和 toolCall.Function.Arguments 调用相应的工具
// 示例:
// var toolResult = await InvokeMcpToolAsync(toolCall.Function.Name, toolCall.Function.Arguments, cancellationToken);
var toolResult = "TODO: 实现 MCP 工具调用";
// 将工具结果添加到消息历史
conversationMessages.Add(ThorChatMessage.CreateToolMessage(toolResult, toolCall.Id));
// 发送工具执行结果消息
var toolResultMessage = new
{
type = "tool_call_result",
tool_call_id = toolCall.Id,
tool_name = toolCall.Function?.Name,
result = toolResult
};
var resultJson = System.Text.Json.JsonSerializer.Serialize(toolResultMessage, ThorJsonSerializer.DefaultOptions);
messageQueue.Enqueue($"data: {resultJson}\n\n");
}
catch (Exception ex)
{
_logger.LogError(ex, $"工具调用失败: {toolCall.Function?.Name}");
// 将错误信息添加到消息历史
var errorMessage = $"工具调用失败: {ex.Message}";
conversationMessages.Add(ThorChatMessage.CreateToolMessage(errorMessage, toolCall.Id));
// 发送工具执行错误消息
var toolErrorMessage = new
{
type = "tool_call_error",
tool_call_id = toolCall.Id,
tool_name = toolCall.Function?.Name,
error = errorMessage
};
var errorJson = System.Text.Json.JsonSerializer.Serialize(toolErrorMessage, ThorJsonSerializer.DefaultOptions);
messageQueue.Enqueue($"data: {errorJson}\n\n");
}
}
// 继续下一轮对话,让 AI 根据工具结果生成回复
}
}
catch (Exception e)
{
_logger.LogError(e, $"AI对话异常");
var errorContent = $"对话AI异常异常信息\n当前AI模型{request.Model}\n异常信息{e.Message}\n异常堆栈{e}";
var model = new ThorChatCompletionsResponse()
{
Choices = new List<ThorChatChoiceResponse>()
{
new ThorChatChoiceResponse()
{
Delta = new ThorChatMessage()
{
Content = errorContent
}
}
}
};
var message = JsonConvert.SerializeObject(model, new JsonSerializerSettings
{
ContractResolver = new CamelCasePropertyNamesContractResolver()
});
backupSystemContent.Append(errorContent);
messageQueue.Enqueue($"data: {message}\n\n");
}
// 断开连接
messageQueue.Enqueue("data: [DONE]\n\n");
isComplete = true;
await outputTask;
// 保存消息和统计信息
await _aiMessageManager.CreateUserMessageAsync(userId, sessionId,
new Shared.Dtos.MessageInputDto
{
Content = sessionId is null ? "不予存储" : request.Messages?.LastOrDefault()?.MessagesStore ?? string.Empty,
ModelId = request.Model,
TokenUsage = tokenUsage,
}, tokenId);
await _aiMessageManager.CreateSystemMessageAsync(userId, sessionId,
new Shared.Dtos.MessageInputDto
{
Content = sessionId is null ? "不予存储" : backupSystemContent.ToString(),
ModelId = request.Model,
TokenUsage = tokenUsage
}, tokenId);
await _usageStatisticsManager.SetUsageAsync(userId, request.Model, tokenUsage, tokenId);
}
/// <summary>
/// 获取模型
/// </summary>
private async Task<AiModelDescribe> GetModelAsync(ModelApiTypeEnum modelApiType, string modelId)
{
var aiModelDescribe = await _aiModelRepository._DbQueryable
.LeftJoin<AiAppAggregateRoot>((model, app) => model.AiAppId == app.Id)
.Where((model, app) => model.ModelId == modelId)
.Where((model, app) => model.ModelApiType == modelApiType)
.Select((model, app) =>
new AiModelDescribe
{
AppId = app.Id,
AppName = app.Name,
Endpoint = app.Endpoint,
ApiKey = app.ApiKey,
OrderNum = model.OrderNum,
HandlerName = model.HandlerName,
ModelId = model.ModelId,
ModelName = model.Name,
Description = model.Description,
AppExtraUrl = app.ExtraUrl,
ModelExtraInfo = model.ExtraInfo,
Multiplier = model.Multiplier
})
.FirstAsync();
if (aiModelDescribe is null)
{
throw new UserFriendlyException($"【{modelId}】模型当前版本【{modelApiType}】格式不支持");
}
return aiModelDescribe;
}
}