feat: 支持Claude模型API类型及尊享包校验与扣减逻辑

This commit is contained in:
ccnetcore
2025-10-14 22:17:21 +08:00
parent 31dc756868
commit 15713cf7fe
8 changed files with 963 additions and 30 deletions

View File

@@ -12,11 +12,13 @@ using Volo.Abp.Domain.Services;
using Yi.Framework.AiHub.Domain.AiGateWay;
using Yi.Framework.AiHub.Domain.AiGateWay.Exceptions;
using Yi.Framework.AiHub.Domain.Entities.Model;
using Yi.Framework.AiHub.Domain.Shared.Consts;
using Yi.Framework.AiHub.Domain.Shared.Dtos;
using Yi.Framework.AiHub.Domain.Shared.Dtos.Anthropic;
using Yi.Framework.AiHub.Domain.Shared.Dtos.OpenAi;
using Yi.Framework.AiHub.Domain.Shared.Dtos.OpenAi.Embeddings;
using Yi.Framework.AiHub.Domain.Shared.Dtos.OpenAi.Images;
using Yi.Framework.AiHub.Domain.Shared.Enums;
using Yi.Framework.Core.Extensions;
using Yi.Framework.SqlSugarCore.Abstractions;
using JsonSerializer = System.Text.Json.JsonSerializer;
@@ -27,21 +29,24 @@ namespace Yi.Framework.AiHub.Domain.Managers;
public class AiGateWayManager : DomainService
{
private readonly ISqlSugarRepository<AiAppAggregateRoot> _aiAppRepository;
private readonly ISqlSugarRepository<AiModelEntity> _aiModelRepository;
private readonly ILogger<AiGateWayManager> _logger;
private readonly AiMessageManager _aiMessageManager;
private readonly UsageStatisticsManager _usageStatisticsManager;
private readonly ISpecialCompatible _specialCompatible;
private PremiumPackageManager? _premiumPackageManager;
public AiGateWayManager(ISqlSugarRepository<AiAppAggregateRoot> aiAppRepository, ILogger<AiGateWayManager> logger,
AiMessageManager aiMessageManager, UsageStatisticsManager usageStatisticsManager,
ISpecialCompatible specialCompatible)
ISpecialCompatible specialCompatible, ISqlSugarRepository<AiModelEntity> aiModelRepository)
{
_aiAppRepository = aiAppRepository;
_logger = logger;
_aiMessageManager = aiMessageManager;
_usageStatisticsManager = usageStatisticsManager;
_specialCompatible = specialCompatible;
_aiModelRepository = aiModelRepository;
}
private PremiumPackageManager PremiumPackageManager =>
@@ -50,17 +55,17 @@ public class AiGateWayManager : DomainService
/// <summary>
/// 获取模型
/// </summary>
/// <param name="modelApiType"></param>
/// <param name="modelId"></param>
/// <returns></returns>
private async Task<AiModelDescribe> GetModelAsync(string modelId)
private async Task<AiModelDescribe> GetModelAsync(ModelApiTypeEnum modelApiType, string modelId)
{
var allApp = await _aiAppRepository._DbQueryable.Includes(x => x.AiModels).ToListAsync();
foreach (var app in allApp)
{
var model = app.AiModels.FirstOrDefault(x => x.ModelId == modelId);
if (model is not null)
{
return new AiModelDescribe
var aiModelDescribe = await _aiModelRepository._DbQueryable
.LeftJoin<AiAppAggregateRoot>((model, app) => model.AiAppId == app.Id)
.Where((model, app) => model.ModelId == modelId)
.Where((model, app) => model.ModelApiType == modelApiType)
.Select((model, app) =>
new AiModelDescribe
{
AppId = app.Id,
AppName = app.Name,
@@ -73,11 +78,14 @@ public class AiGateWayManager : DomainService
Description = model.Description,
AppExtraUrl = app.ExtraUrl,
ModelExtraInfo = model.ExtraInfo
};
}
})
.FirstAsync();
if (aiModelDescribe is null)
{
throw new UserFriendlyException($"【{modelId}】模型当前版本【{modelApiType}】格式不支持");
}
throw new UserFriendlyException($"{modelId}模型当前版本不支持");
return aiModelDescribe;
}
@@ -92,7 +100,7 @@ public class AiGateWayManager : DomainService
[EnumeratorCancellation] CancellationToken cancellationToken)
{
_specialCompatible.Compatible(request);
var modelDescribe = await GetModelAsync(request.Model);
var modelDescribe = await GetModelAsync(ModelApiTypeEnum.OpenAi, request.Model);
var chatService =
LazyServiceProvider.GetRequiredKeyedService<IChatCompletionService>(modelDescribe.HandlerName);
@@ -122,7 +130,7 @@ public class AiGateWayManager : DomainService
var response = httpContext.Response;
// 设置响应头,声明是 json
//response.ContentType = "application/json; charset=UTF-8";
var modelDescribe = await GetModelAsync(request.Model);
var modelDescribe = await GetModelAsync(ModelApiTypeEnum.OpenAi, request.Model);
var chatService =
LazyServiceProvider.GetRequiredKeyedService<IChatCompletionService>(modelDescribe.HandlerName);
var data = await chatService.CompleteChatAsync(modelDescribe, request, cancellationToken);
@@ -277,6 +285,16 @@ public class AiGateWayManager : DomainService
});
await _usageStatisticsManager.SetUsageAsync(userId, request.Model, tokenUsage);
// 扣减尊享token包用量
if (userId is not null && PremiumPackageConst.ModeIds.Contains(request.Model))
{
var totalTokens = tokenUsage.TotalTokens ?? 0;
if (totalTokens > 0)
{
await PremiumPackageManager.ConsumeTokensAsync(userId.Value, totalTokens);
}
}
}
@@ -297,7 +315,7 @@ public class AiGateWayManager : DomainService
var model = request.Model;
if (string.IsNullOrEmpty(model)) model = "dall-e-2";
var modelDescribe = await GetModelAsync(model);
var modelDescribe = await GetModelAsync(ModelApiTypeEnum.OpenAi, model);
// 获取渠道指定的实现类型的服务
var imageService =
@@ -329,6 +347,16 @@ public class AiGateWayManager : DomainService
});
await _usageStatisticsManager.SetUsageAsync(userId, model, response.Usage);
// 扣减尊享token包用量
if (userId is not null && PremiumPackageConst.ModeIds.Contains(request.Model))
{
var totalTokens = response.Usage.TotalTokens ?? 0;
if (totalTokens > 0)
{
await PremiumPackageManager.ConsumeTokensAsync(userId.Value, totalTokens);
}
}
}
catch (Exception e)
{
@@ -357,7 +385,7 @@ public class AiGateWayManager : DomainService
using var embedding =
Activity.Current?.Source.StartActivity("向量模型调用");
var modelDescribe = await GetModelAsync(input.Model);
var modelDescribe = await GetModelAsync(ModelApiTypeEnum.OpenAi, input.Model);
// 获取渠道指定的实现类型的服务
var embeddingService =
@@ -461,7 +489,7 @@ public class AiGateWayManager : DomainService
[EnumeratorCancellation] CancellationToken cancellationToken)
{
_specialCompatible.AnthropicCompatible(request);
var modelDescribe = await GetModelAsync(request.Model);
var modelDescribe = await GetModelAsync(ModelApiTypeEnum.Claude, request.Model);
var chatService =
LazyServiceProvider.GetRequiredKeyedService<IAnthropicChatCompletionService>(modelDescribe.HandlerName);
@@ -491,7 +519,7 @@ public class AiGateWayManager : DomainService
var response = httpContext.Response;
// 设置响应头,声明是 json
//response.ContentType = "application/json; charset=UTF-8";
var modelDescribe = await GetModelAsync(request.Model);
var modelDescribe = await GetModelAsync(ModelApiTypeEnum.Claude, request.Model);
var chatService =
LazyServiceProvider.GetRequiredKeyedService<IAnthropicChatCompletionService>(modelDescribe.HandlerName);
var data = await chatService.ChatCompletionsAsync(modelDescribe, request, cancellationToken);
@@ -516,14 +544,10 @@ public class AiGateWayManager : DomainService
await _usageStatisticsManager.SetUsageAsync(userId.Value, request.Model, data.TokenUsage);
// 扣减尊享token包用量
var totalTokens = data.TokenUsage.TotalTokens??0;
var totalTokens = data.TokenUsage.TotalTokens ?? 0;
if (totalTokens > 0)
{
var consumeSuccess = await PremiumPackageManager.ConsumeTokensAsync(userId.Value, totalTokens);
if (!consumeSuccess)
{
_logger.LogWarning($"用户 {userId.Value} 尊享token包扣减失败消耗token数: {totalTokens}");
}
await PremiumPackageManager.ConsumeTokensAsync(userId.Value, totalTokens);
}
}
@@ -562,10 +586,11 @@ public class AiGateWayManager : DomainService
await foreach (var responseResult in completeChatResponse)
{
//message_start是为了保底机制
if (responseResult.Item1.Contains("message_delta")||responseResult.Item1.Contains("message_start"))
if (responseResult.Item1.Contains("message_delta") || responseResult.Item1.Contains("message_start"))
{
tokenUsage = responseResult.Item2?.TokenUsage;
}
backupSystemContent.Append(responseResult.Item2?.Delta?.Text);
await WriteAsEventStreamDataAsync(httpContext, responseResult.Item1, responseResult.Item2,
cancellationToken);
@@ -622,7 +647,7 @@ public class AiGateWayManager : DomainService
// 扣减尊享token包用量
if (userId.HasValue && tokenUsage is not null)
{
var totalTokens = tokenUsage.TotalTokens??0;
var totalTokens = tokenUsage.TotalTokens ?? 0;
if (totalTokens > 0)
{
var consumeSuccess = await PremiumPackageManager.ConsumeTokensAsync(userId.Value, totalTokens);