diff --git a/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Application/Services/OpenApiService.cs b/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Application/Services/OpenApiService.cs index ff4b46d8..26dcda77 100644 --- a/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Application/Services/OpenApiService.cs +++ b/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Application/Services/OpenApiService.cs @@ -262,7 +262,7 @@ public class OpenApiService : ApplicationService /// /// /// - [HttpPost("openApi/v1beta/models/{modelId}:generateContent")] + [HttpPost("openApi/v1beta/models/{modelId}:{action:regex(^(generateContent|streamGenerateContent)$)}")] public async Task GenerateContentAsync([FromBody] JsonElement input, [FromRoute] string modelId, [FromQuery] string? alt, CancellationToken cancellationToken) @@ -297,13 +297,15 @@ public class OpenApiService : ApplicationService //ai网关代理httpcontext if (alt == "sse") { - // await _aiGateWayManager.OpenAiResponsesStreamForStatisticsAsync(_httpContextAccessor.HttpContext, - // input, - // userId, null, tokenId, cancellationToken); + await _aiGateWayManager.GeminiGenerateContentStreamForStatisticsAsync(_httpContextAccessor.HttpContext, + modelId, input, + userId, + null, tokenId, + cancellationToken); } else { - await _aiGateWayManager.GeminiGenerateContentAsyncForStatisticsAsync(_httpContextAccessor.HttpContext, + await _aiGateWayManager.GeminiGenerateContentForStatisticsAsync(_httpContextAccessor.HttpContext, modelId, input, userId, null, tokenId, @@ -321,6 +323,13 @@ public class OpenApiService : ApplicationService { return apiKeyHeader.Trim(); } + + // 再从 谷歌 获取 + string googApiKeyHeader = httpContext.Request.Headers["x-goog-api-key"]; + if (!string.IsNullOrWhiteSpace(googApiKeyHeader)) + { + return googApiKeyHeader.Trim(); + } // 再检查 Authorization 头 string authHeader = httpContext.Request.Headers["Authorization"]; diff --git a/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Domain.Shared/Dtos/Gemini/GeminiGenerateContentAcquirer.cs b/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Domain.Shared/Dtos/Gemini/GeminiGenerateContentAcquirer.cs index 8695d4b4..642fd850 100644 --- a/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Domain.Shared/Dtos/Gemini/GeminiGenerateContentAcquirer.cs +++ b/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Domain.Shared/Dtos/Gemini/GeminiGenerateContentAcquirer.cs @@ -6,12 +6,12 @@ namespace Yi.Framework.AiHub.Domain.Shared.Dtos.Gemini; public static class GeminiGenerateContentAcquirer { - public static ThorUsageResponse GetUsage(JsonElement response) + public static ThorUsageResponse? GetUsage(JsonElement response) { var usage = response.GetPath("usageMetadata"); if (!usage.HasValue) { - return new ThorUsageResponse(); + return null; } var inputTokens = usage.Value.GetPath("promptTokenCount").GetInt(); diff --git a/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Domain/AiGateWay/Impl/ThorGemini/Chats/GeminiGenerateContentService.cs b/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Domain/AiGateWay/Impl/ThorGemini/Chats/GeminiGenerateContentService.cs index 6f1a0831..460e20b8 100644 --- a/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Domain/AiGateWay/Impl/ThorGemini/Chats/GeminiGenerateContentService.cs +++ b/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Domain/AiGateWay/Impl/ThorGemini/Chats/GeminiGenerateContentService.cs @@ -5,25 +5,69 @@ using System.Text.Json; using Microsoft.Extensions.Logging; using Yi.Framework.AiHub.Domain.AiGateWay.Exceptions; using Yi.Framework.AiHub.Domain.Shared.Dtos; +using Yi.Framework.AiHub.Domain.Shared.Dtos.OpenAi; using Yi.Framework.AiHub.Domain.Shared.Dtos.OpenAi.Responses; namespace Yi.Framework.AiHub.Domain.AiGateWay.Impl.ThorGemini.Chats; -public class GeminiGenerateContentService(ILogger logger,IHttpClientFactory httpClientFactory):IGeminiGenerateContentService +public class GeminiGenerateContentService( + ILogger logger, + IHttpClientFactory httpClientFactory) : IGeminiGenerateContentService { - public IAsyncEnumerable GenerateContentStreamAsync(AiModelDescribe aiModelDescribe, JsonElement input, + public async IAsyncEnumerable GenerateContentStreamAsync(AiModelDescribe options, JsonElement input, CancellationToken cancellationToken) { - throw new NotImplementedException(); + var response = await httpClientFactory.CreateClient().PostJsonAsync( + options?.Endpoint.TrimEnd('/') + $"/v1beta/models/{options.ModelId}:streamGenerateContent?alt=sse", + input, null, new Dictionary() + { + { "x-goog-api-key", options.ApiKey } + }).ConfigureAwait(false); + + + // 大于等于400的状态码都认为是异常 + if (response.StatusCode >= HttpStatusCode.BadRequest) + { + var error = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + logger.LogError("Gemini生成异常 请求地址:{Address}, StatusCode: {StatusCode} Response: {Response}", + options.Endpoint, + response.StatusCode, error); + + throw new Exception("Gemini生成异常" + response.StatusCode); + } + + using var stream = new StreamReader(await response.Content.ReadAsStreamAsync(cancellationToken)); + + using StreamReader reader = new(await response.Content.ReadAsStreamAsync(cancellationToken)); + string? line = string.Empty; + + while ((line = await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false)) != null) + { + line += Environment.NewLine; + if (string.IsNullOrWhiteSpace(line)) + { + continue; + } + + if (!line.StartsWith(OpenAIConstant.Data)) continue; + + var data = line[OpenAIConstant.Data.Length..].Trim(); + + var result = JsonSerializer.Deserialize(data, + ThorJsonSerializer.DefaultOptions); + + yield return result; + } } - public async Task GenerateContentAsync(AiModelDescribe options,JsonElement input, CancellationToken cancellationToken) + public async Task GenerateContentAsync(AiModelDescribe options, JsonElement input, + CancellationToken cancellationToken) { var response = await httpClientFactory.CreateClient().PostJsonAsync( options?.Endpoint.TrimEnd('/') + $"/v1beta/models/{options.ModelId}:generateContent", - input,null, new Dictionary() + input, null, new Dictionary() { - {"x-goog-api-key",options.ApiKey} + { "x-goog-api-key", options.ApiKey } }).ConfigureAwait(false); if (response.StatusCode == HttpStatusCode.Unauthorized) @@ -41,7 +85,8 @@ public class GeminiGenerateContentService(ILogger if (response.StatusCode >= HttpStatusCode.BadRequest) { var error = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); - logger.LogError("Gemini 生成异常 请求地址:{Address}, StatusCode: {StatusCode} Response: {Response}", options.Endpoint, + logger.LogError("Gemini 生成异常 请求地址:{Address}, StatusCode: {StatusCode} Response: {Response}", + options.Endpoint, response.StatusCode, error); throw new BusinessException("Gemini 生成异常", response.StatusCode.ToString()); diff --git a/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Domain/Managers/AiGateWayManager.cs b/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Domain/Managers/AiGateWayManager.cs index a9b48196..2eda300b 100644 --- a/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Domain/Managers/AiGateWayManager.cs +++ b/Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Domain/Managers/AiGateWayManager.cs @@ -799,7 +799,7 @@ public class AiGateWayManager : DomainService /// /// /// - public async Task GeminiGenerateContentAsyncForStatisticsAsync(HttpContext httpContext, + public async Task GeminiGenerateContentForStatisticsAsync(HttpContext httpContext, string modelId, JsonElement request, Guid? userId = null, @@ -814,7 +814,7 @@ public class AiGateWayManager : DomainService LazyServiceProvider.GetRequiredKeyedService(modelDescribe.HandlerName); var data = await chatService.GenerateContentAsync(modelDescribe, request, cancellationToken); - var tokenUsage= GeminiGenerateContentAcquirer.GetUsage(data); + var tokenUsage = GeminiGenerateContentAcquirer.GetUsage(data); tokenUsage.SetSupplementalMultiplier(modelDescribe.Multiplier); if (userId is not null) @@ -847,9 +847,88 @@ public class AiGateWayManager : DomainService await response.WriteAsJsonAsync(data, cancellationToken); } - - - + + + /// + /// Gemini 生成-流式-缓存处理 + /// + /// + /// + /// + /// + /// + /// Token Id(Web端传null或Guid.Empty) + /// + /// + public async Task GeminiGenerateContentStreamForStatisticsAsync( + HttpContext httpContext, + string modelId, + JsonElement request, + Guid? userId = null, + Guid? sessionId = null, + Guid? tokenId = null, + CancellationToken cancellationToken = default) + { + var response = httpContext.Response; + // 设置响应头,声明是 SSE 流 + response.ContentType = "text/event-stream;charset=utf-8;"; + response.Headers.TryAdd("Cache-Control", "no-cache"); + response.Headers.TryAdd("Connection", "keep-alive"); + + var modelDescribe = await GetModelAsync(ModelApiTypeEnum.GenerateContent, modelId); + var chatService = + LazyServiceProvider.GetRequiredKeyedService(modelDescribe.HandlerName); + + var completeChatResponse = chatService.GenerateContentStreamAsync(modelDescribe,request, cancellationToken); + ThorUsageResponse? tokenUsage = null; + try + { + await foreach (var responseResult in completeChatResponse) + { + if ( responseResult!.Value.GetPath("candidates", 0, "finishReason").GetString() == "STOP") + { + tokenUsage = GeminiGenerateContentAcquirer.GetUsage(responseResult!.Value); + tokenUsage.SetSupplementalMultiplier(modelDescribe.Multiplier); + } + await response.WriteAsync($"data: {JsonSerializer.Serialize(responseResult)}\n\n", Encoding.UTF8, cancellationToken).ConfigureAwait(false); + await response.Body.FlushAsync(cancellationToken).ConfigureAwait(false); + } + } + catch (Exception e) + { + _logger.LogError(e, $"Ai生成异常"); + var errorContent = $"生成Ai异常,异常信息:\n当前Ai模型:{modelId}\n异常信息:{e.Message}\n异常堆栈:{e}"; + throw new UserFriendlyException(errorContent); + } + + await _aiMessageManager.CreateUserMessageAsync(userId, sessionId, + new MessageInputDto + { + Content = "不予存储" , + ModelId = modelId, + TokenUsage = tokenUsage, + }, tokenId); + + await _aiMessageManager.CreateSystemMessageAsync(userId, sessionId, + new MessageInputDto + { + Content = "不予存储" , + ModelId = modelId, + TokenUsage = tokenUsage + }, tokenId); + + await _usageStatisticsManager.SetUsageAsync(userId, modelId, tokenUsage, tokenId); + + // 扣减尊享token包用量 + if (userId.HasValue && tokenUsage is not null) + { + var totalTokens = tokenUsage.TotalTokens ?? 0; + if (tokenUsage.TotalTokens > 0) + { + await PremiumPackageManager.TryConsumeTokensAsync(userId.Value, totalTokens); + } + } + }