using System.Collections.Concurrent; using System.Diagnostics; using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Newtonsoft.Json; using Newtonsoft.Json.Serialization; using Volo.Abp.Domain.Services; using Yi.Framework.AiHub.Domain.AiGateWay; using Yi.Framework.AiHub.Domain.AiGateWay.Exceptions; using Yi.Framework.AiHub.Domain.Entities.Model; using Yi.Framework.AiHub.Domain.Shared.Dtos; using Yi.Framework.AiHub.Domain.Shared.Dtos.OpenAi; using Yi.Framework.AiHub.Domain.Shared.Dtos.OpenAi.Embeddings; using Yi.Framework.AiHub.Domain.Shared.Dtos.OpenAi.Images; using Yi.Framework.Core.Extensions; using Yi.Framework.SqlSugarCore.Abstractions; namespace Yi.Framework.AiHub.Domain.Managers; public class AiGateWayManager : DomainService { private readonly ISqlSugarRepository _aiAppRepository; private readonly ILogger _logger; private readonly AiMessageManager _aiMessageManager; private readonly UsageStatisticsManager _usageStatisticsManager; private readonly ISpecialCompatible _specialCompatible; public AiGateWayManager(ISqlSugarRepository aiAppRepository, ILogger logger, AiMessageManager aiMessageManager, UsageStatisticsManager usageStatisticsManager, ISpecialCompatible specialCompatible) { _aiAppRepository = aiAppRepository; _logger = logger; _aiMessageManager = aiMessageManager; _usageStatisticsManager = usageStatisticsManager; _specialCompatible = specialCompatible; } /// /// 获取模型 /// /// /// private async Task GetModelAsync(string modelId) { var allApp = await _aiAppRepository._DbQueryable.Includes(x => x.AiModels).ToListAsync(); foreach (var app in allApp) { var model = app.AiModels.FirstOrDefault(x => x.ModelId == modelId); if (model is not null) { return new AiModelDescribe { AppId = app.Id, AppName = app.Name, Endpoint = app.Endpoint, ApiKey = app.ApiKey, OrderNum = model.OrderNum, HandlerName = model.HandlerName, ModelId = model.ModelId, ModelName = model.Name, Description = model.Description, AppExtraUrl = app.ExtraUrl, ModelExtraInfo = model.ExtraInfo }; } } throw new UserFriendlyException($"{modelId}模型当前版本不支持"); } /// /// 聊天完成-流式 /// /// /// /// public async IAsyncEnumerable CompleteChatStreamAsync( ThorChatCompletionsRequest request, [EnumeratorCancellation] CancellationToken cancellationToken) { _specialCompatible.Compatible(request); var modelDescribe = await GetModelAsync(request.Model); var chatService = LazyServiceProvider.GetRequiredKeyedService(modelDescribe.HandlerName); await foreach (var result in chatService.CompleteChatStreamAsync(modelDescribe, request, cancellationToken)) { yield return result; } } /// /// 聊天完成-非流式 /// /// /// /// /// /// /// public async Task CompleteChatForStatisticsAsync(HttpContext httpContext, ThorChatCompletionsRequest request, Guid? userId = null, Guid? sessionId = null, CancellationToken cancellationToken = default) { _specialCompatible.Compatible(request); var response = httpContext.Response; // 设置响应头,声明是 json //response.ContentType = "application/json; charset=UTF-8"; var modelDescribe = await GetModelAsync(request.Model); var chatService = LazyServiceProvider.GetRequiredKeyedService(modelDescribe.HandlerName); var data = await chatService.CompleteChatAsync(modelDescribe, request, cancellationToken); if (userId is not null) { await _aiMessageManager.CreateUserMessageAsync(userId.Value, sessionId, new MessageInputDto { Content = request.Messages?.LastOrDefault().Content ?? string.Empty, ModelId = request.Model, TokenUsage = data.Usage, }); await _aiMessageManager.CreateSystemMessageAsync(userId.Value, sessionId, new MessageInputDto { Content = data.Choices.FirstOrDefault()?.Delta.Content, ModelId = request.Model, TokenUsage = data.Usage }); await _usageStatisticsManager.SetUsageAsync(userId.Value, request.Model, data.Usage); } await response.WriteAsJsonAsync(data, cancellationToken); } /// /// 聊天完成-缓存处理 /// /// /// /// /// /// /// public async Task CompleteChatStreamForStatisticsAsync( HttpContext httpContext, ThorChatCompletionsRequest request, Guid? userId = null, Guid? sessionId = null, CancellationToken cancellationToken = default) { var response = httpContext.Response; // 设置响应头,声明是 SSE 流 response.ContentType = "text/event-stream;charset=utf-8;"; response.Headers.TryAdd("Cache-Control", "no-cache"); response.Headers.TryAdd("Connection", "keep-alive"); var gateWay = LazyServiceProvider.GetRequiredService(); var completeChatResponse = gateWay.CompleteChatStreamAsync(request, cancellationToken); var tokenUsage = new ThorUsageResponse(); //缓存队列算法 // 创建一个队列来缓存消息 var messageQueue = new ConcurrentQueue(); StringBuilder backupSystemContent = new StringBuilder(); // 设置输出速率(例如每50毫秒输出一次) var outputInterval = TimeSpan.FromMilliseconds(75); // 标记是否完成接收 var isComplete = false; // 启动一个后台任务来消费队列 var outputTask = Task.Run(async () => { while (!(isComplete && messageQueue.IsEmpty)) { if (messageQueue.TryDequeue(out var message)) { await response.WriteAsync(message, Encoding.UTF8, cancellationToken).ConfigureAwait(false); await response.Body.FlushAsync(cancellationToken).ConfigureAwait(false); } if (!isComplete) { // 如果没有完成,才等待,已完成,全部输出 await Task.Delay(outputInterval, cancellationToken).ConfigureAwait(false); } } }, cancellationToken); //IAsyncEnumerable 只能在最外层捕获异常(如果你有其他办法的话...) try { await foreach (var data in completeChatResponse) { if (data.Usage is not null) { tokenUsage = data.Usage; } var message = System.Text.Json.JsonSerializer.Serialize(data, ThorJsonSerializer.DefaultOptions); backupSystemContent.Append(data.Choices.FirstOrDefault()?.Delta.Content); // 将消息加入队列而不是直接写入 messageQueue.Enqueue($"data: {message}\n\n"); } } catch (Exception e) { _logger.LogError(e, $"Ai对话异常"); var errorContent = $"对话Ai异常,异常信息:\n当前Ai模型:{request.Model}\n异常信息:{e.Message}\n异常堆栈:{e}"; var model = new ThorChatCompletionsResponse() { Choices = new List() { new ThorChatChoiceResponse() { Delta = new ThorChatMessage() { Content = errorContent } } } }; var message = JsonConvert.SerializeObject(model, new JsonSerializerSettings { ContractResolver = new CamelCasePropertyNamesContractResolver() }); backupSystemContent.Append(errorContent); messageQueue.Enqueue($"data: {message}\n\n"); } //断开连接 messageQueue.Enqueue("data: [DONE]\n\n"); // 标记完成并发送结束标记 isComplete = true; await outputTask; await _aiMessageManager.CreateUserMessageAsync(userId, sessionId, new MessageInputDto { Content = request.Messages?.LastOrDefault()?.Content ?? string.Empty, ModelId = request.Model, TokenUsage = tokenUsage, }); await _aiMessageManager.CreateSystemMessageAsync(userId, sessionId, new MessageInputDto { Content = backupSystemContent.ToString(), ModelId = request.Model, TokenUsage = tokenUsage }); await _usageStatisticsManager.SetUsageAsync(userId, request.Model, tokenUsage); } /// /// 图片生成 /// /// /// /// /// /// /// public async Task CreateImageForStatisticsAsync(HttpContext context, Guid? userId, Guid? sessionId, ImageCreateRequest request) { try { var model = request.Model; if (string.IsNullOrEmpty(model)) model = "dall-e-2"; var modelDescribe = await GetModelAsync(model); // 获取渠道指定的实现类型的服务 var imageService = LazyServiceProvider.GetRequiredKeyedService(modelDescribe.HandlerName); var response = await imageService.CreateImage(request, modelDescribe); if (response.Error != null || response.Results.Count == 0) { throw new BusinessException(response.Error?.Message ?? "图片生成失败", response.Error?.Code?.ToString()); } await context.Response.WriteAsJsonAsync(response); await _aiMessageManager.CreateUserMessageAsync(userId, sessionId, new MessageInputDto { Content = request.Prompt, ModelId = model, TokenUsage = response.Usage, }); await _aiMessageManager.CreateSystemMessageAsync(userId, sessionId, new MessageInputDto { Content = response.Results?.FirstOrDefault()?.Url, ModelId = model, TokenUsage = response.Usage }); await _usageStatisticsManager.SetUsageAsync(userId, model, response.Usage); } catch (Exception e) { var errorContent = $"图片生成Ai异常,异常信息:\n当前Ai模型:{request.Model}\n异常信息:{e.Message}\n异常堆栈:{e}"; throw new UserFriendlyException(errorContent); } } /// /// 向量生成 /// /// /// /// /// /// /// public async Task EmbeddingForStatisticsAsync(HttpContext context, Guid? userId, Guid? sessionId, ThorEmbeddingInput input) { try { if (input == null) throw new Exception("模型校验异常"); using var embedding = Activity.Current?.Source.StartActivity("向量模型调用"); var modelDescribe = await GetModelAsync(input.Model); // 获取渠道指定的实现类型的服务 var embeddingService = LazyServiceProvider.GetRequiredKeyedService(modelDescribe.HandlerName); var embeddingCreateRequest = new EmbeddingCreateRequest { Model = input.Model, EncodingFormat = input.EncodingFormat }; //dto进行转换,支持多种格式 if (input.Input is JsonElement str) { if (str.ValueKind == JsonValueKind.String) { embeddingCreateRequest.Input = str.ToString(); } else if (str.ValueKind == JsonValueKind.Array) { var inputString = str.EnumerateArray().Select(x => x.ToString()).ToArray(); embeddingCreateRequest.InputAsList = inputString.ToList(); } else { throw new Exception("Input,输入格式错误,非string或Array类型"); } } else if (input.Input is string strInput) { embeddingCreateRequest.Input = strInput; } else { throw new Exception("Input,输入格式错误,未找到类型"); } var stream = await embeddingService.EmbeddingAsync(embeddingCreateRequest, modelDescribe, context.RequestAborted); var usage = new ThorUsageResponse() { InputTokens = stream.Usage?.InputTokens ?? 0, CompletionTokens = 0, TotalTokens = stream.Usage?.InputTokens ?? 0 }; await context.Response.WriteAsJsonAsync(new { input.Model, stream.Data, stream.Error, stream.ObjectTypeName, Usage = usage }); //知识库暂不使用message统计 // await _aiMessageManager.CreateUserMessageAsync(userId, sessionId, // new MessageInputDto // { // Content = string.Empty, // ModelId = input.Model, // TokenUsage = usage, // }); // // await _aiMessageManager.CreateSystemMessageAsync(userId, sessionId, // new MessageInputDto // { // Content = string.Empty, // ModelId = input.Model, // TokenUsage = usage // }); await _usageStatisticsManager.SetUsageAsync(userId, input.Model, usage); } catch (ThorRateLimitException) { context.Response.StatusCode = 429; } catch (UnauthorizedAccessException e) { context.Response.StatusCode = 401; } catch (Exception e) { var errorContent = $"嵌入Ai异常,异常信息:\n当前Ai模型:{input.Model}\n异常信息:{e.Message}\n异常堆栈:{e}"; throw new UserFriendlyException(errorContent); } } }