229 lines
8.9 KiB
C#
229 lines
8.9 KiB
C#
using System.ClientModel;
|
||
using System.Diagnostics.CodeAnalysis;
|
||
using System.Net;
|
||
using System.Reflection;
|
||
using System.Text;
|
||
using System.Text.Json;
|
||
using Dm.util;
|
||
using Microsoft.Agents.AI;
|
||
using Microsoft.AspNetCore.Http;
|
||
using Microsoft.Extensions.AI;
|
||
using Microsoft.Extensions.DependencyInjection;
|
||
using Microsoft.Extensions.Logging;
|
||
using ModelContextProtocol.Server;
|
||
using OpenAI;
|
||
using OpenAI.Chat;
|
||
using OpenAI.Responses;
|
||
using Volo.Abp.Domain.Repositories;
|
||
using Volo.Abp.Domain.Services;
|
||
using Yi.Framework.AiHub.Application.Contracts.Dtos.Chat;
|
||
using Yi.Framework.AiHub.Domain.AiGateWay;
|
||
using Yi.Framework.AiHub.Domain.Entities.Chat;
|
||
using Yi.Framework.AiHub.Domain.Entities.OpenApi;
|
||
using Yi.Framework.AiHub.Domain.Shared.Dtos.OpenAi;
|
||
using Yi.Framework.SqlSugarCore.Abstractions;
|
||
|
||
namespace Yi.Framework.AiHub.Domain.Managers;
|
||
|
||
public class ChatManager : DomainService
|
||
{
|
||
private readonly AiGateWayManager _aiGateWayManager;
|
||
private readonly ILoggerFactory _loggerFactory;
|
||
private readonly ISqlSugarRepository<MessageAggregateRoot> _messageRepository;
|
||
private readonly ISqlSugarRepository<AgentStoreAggregateRoot> _agentStoreRepository;
|
||
private readonly ISqlSugarRepository<TokenAggregateRoot> _tokenRepository;
|
||
|
||
public ChatManager(AiGateWayManager aiGateWayManager, ILoggerFactory loggerFactory,
|
||
ISqlSugarRepository<MessageAggregateRoot> messageRepository,
|
||
ISqlSugarRepository<AgentStoreAggregateRoot> agentStoreRepository,
|
||
ISqlSugarRepository<TokenAggregateRoot> tokenRepository)
|
||
{
|
||
_aiGateWayManager = aiGateWayManager;
|
||
_loggerFactory = loggerFactory;
|
||
_messageRepository = messageRepository;
|
||
_agentStoreRepository = agentStoreRepository;
|
||
_tokenRepository = tokenRepository;
|
||
}
|
||
|
||
|
||
public async Task AgentCompleteChatStreamAsync(HttpContext httpContext,
|
||
Guid sessionId,
|
||
string content,
|
||
Guid tokenId,
|
||
string modelId,
|
||
Guid userId,
|
||
List<string> tools
|
||
, CancellationToken cancellationToken)
|
||
{
|
||
|
||
// HttpClient.DefaultProxy = new WebProxy("127.0.0.1:8888");
|
||
var response = httpContext.Response;
|
||
// 设置响应头,声明是 SSE 流
|
||
response.ContentType = "text/event-stream;charset=utf-8;";
|
||
response.Headers.TryAdd("Cache-Control", "no-cache");
|
||
response.Headers.TryAdd("Connection", "keep-alive");
|
||
|
||
//token状态检查,在应用层统一处理
|
||
var token = await _tokenRepository.GetFirstAsync(x => x.Id == tokenId);
|
||
var client = new OpenAIClient(new ApiKeyCredential(token.Token),
|
||
new OpenAIClientOptions
|
||
{
|
||
Endpoint = new Uri("https://yxai.chat/v1"),
|
||
});
|
||
|
||
var agent = client.GetChatClient(modelId)
|
||
.CreateAIAgent("你是一个专业的网页ai助手,擅长解答用户问题");
|
||
|
||
//线程根据sessionId数据库中获取
|
||
var agentStore =
|
||
await _agentStoreRepository.GetFirstAsync(x => x.SessionId == sessionId);
|
||
if (agentStore is null)
|
||
{
|
||
agentStore = new AgentStoreAggregateRoot(sessionId);
|
||
}
|
||
|
||
//获取当前线程
|
||
AgentThread currentThread;
|
||
if (!string.IsNullOrWhiteSpace(agentStore.Store))
|
||
{
|
||
//获取当前存储
|
||
JsonElement reloaded = JsonSerializer.Deserialize<JsonElement>(agentStore.Store, JsonSerializerOptions.Web);
|
||
currentThread = agent.DeserializeThread(reloaded, JsonSerializerOptions.Web);
|
||
}
|
||
else
|
||
{
|
||
currentThread = agent.GetNewThread();
|
||
}
|
||
|
||
|
||
var toolContents = GetTools();
|
||
var chatOptions = new ChatOptions()
|
||
{
|
||
Tools = toolContents.Select(x => (AITool)x).ToList(),
|
||
ToolMode = ChatToolMode.Auto
|
||
};
|
||
|
||
await foreach (var update in agent.RunStreamingAsync(content, currentThread, new ChatClientAgentRunOptions(chatOptions), cancellationToken))
|
||
{
|
||
// 检查每个更新中的内容
|
||
foreach (var updateContent in update.Contents)
|
||
{
|
||
switch (updateContent)
|
||
{
|
||
//工具调用中
|
||
case FunctionCallContent functionCall:
|
||
await SendHttpStreamMessageAsync(httpContext,
|
||
new AgentResultOutput
|
||
{
|
||
TypeEnum = AgentResultTypeEnum.ToolCalling,
|
||
Content = functionCall.Name
|
||
},
|
||
isDone: false, cancellationToken);
|
||
break;
|
||
|
||
//工具调用完成
|
||
case FunctionResultContent functionResult:
|
||
await SendHttpStreamMessageAsync(httpContext,
|
||
new AgentResultOutput
|
||
{
|
||
TypeEnum = AgentResultTypeEnum.ToolCalled,
|
||
Content = functionResult.Result
|
||
},
|
||
isDone: false, cancellationToken);
|
||
break;
|
||
|
||
//内容输出
|
||
case TextContent textContent:
|
||
//发送消息给前端
|
||
await SendHttpStreamMessageAsync(httpContext,
|
||
new AgentResultOutput
|
||
{
|
||
TypeEnum = AgentResultTypeEnum.Text,
|
||
Content = textContent.Text
|
||
},
|
||
isDone: false, cancellationToken);
|
||
break;
|
||
|
||
//用量统计
|
||
case UsageContent usageContent:
|
||
//存储message 为了token算费
|
||
await SendHttpStreamMessageAsync(httpContext,
|
||
new AgentResultOutput
|
||
{
|
||
TypeEnum = AgentResultTypeEnum.Usage,
|
||
Content = new ThorUsageResponse
|
||
{
|
||
InputTokens = Convert.ToInt32(usageContent.Details.InputTokenCount ?? 0),
|
||
OutputTokens = Convert.ToInt32(usageContent.Details.OutputTokenCount ?? 0),
|
||
TotalTokens = usageContent.Details.TotalTokenCount ?? 0,
|
||
}
|
||
},
|
||
isDone: false, cancellationToken);
|
||
Console.WriteLine();
|
||
Console.WriteLine($"✅ 用量统计: {usageContent.Details.TotalTokenCount}");
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
|
||
//断开连接
|
||
await SendHttpStreamMessageAsync(httpContext, null, isDone: true, cancellationToken);
|
||
|
||
//将线程持久化到数据库
|
||
string serializedJson = currentThread.Serialize(JsonSerializerOptions.Web).GetRawText();
|
||
agentStore.Store = serializedJson;
|
||
|
||
//插入或者更新
|
||
await _agentStoreRepository.InsertOrUpdateAsync(agentStore);
|
||
}
|
||
|
||
|
||
private List<AIFunction> GetTools()
|
||
{
|
||
var toolClasses = typeof(YiFrameworkAiHubDomainModule).Assembly.GetTypes()
|
||
.Where(x => x.GetCustomAttribute<McpServerToolTypeAttribute>() is not null)
|
||
.ToList();
|
||
|
||
List<AIFunction> mcpTools = new List<AIFunction>();
|
||
foreach (var toolClass in toolClasses)
|
||
{
|
||
var instance = LazyServiceProvider.GetRequiredService(toolClass);
|
||
var toolMethods = toolClass.GetMethods()
|
||
.Where(y => y.GetCustomAttribute<McpServerToolAttribute>() is not null).ToList();
|
||
foreach (var toolMethod in toolMethods)
|
||
{
|
||
mcpTools.add(AIFunctionFactory.Create(toolMethod, instance));
|
||
}
|
||
}
|
||
|
||
return mcpTools;
|
||
}
|
||
|
||
/// <summary>
|
||
/// 发送消息
|
||
/// </summary>
|
||
/// <param name="httpContext"></param>
|
||
/// <param name="content"></param>
|
||
/// <param name="isDone"></param>
|
||
/// <param name="cancellationToken"></param>
|
||
/// <returns></returns>
|
||
private async Task SendHttpStreamMessageAsync(HttpContext httpContext,
|
||
AgentResultOutput? content,
|
||
bool isDone = false,
|
||
CancellationToken cancellationToken = default)
|
||
{
|
||
var response = httpContext.Response;
|
||
string output;
|
||
if (isDone)
|
||
{
|
||
output = "[DONE]";
|
||
}
|
||
else
|
||
{
|
||
output = JsonSerializer.Serialize(content,ThorJsonSerializer.DefaultOptions);
|
||
}
|
||
|
||
await response.WriteAsync($"data: {output}\n\n", Encoding.UTF8, cancellationToken).ConfigureAwait(false);
|
||
await response.Body.FlushAsync(cancellationToken).ConfigureAwait(false);
|
||
}
|
||
} |