Merge branch 'ai-agent' into ai-hub
# Conflicts: # Yi.Abp.Net8/module/ai-hub/Yi.Framework.AiHub.Domain/Managers/AiGateWayManager.cs
This commit is contained in:
@@ -107,35 +107,6 @@ public class AzureDatabricksChatCompletionsService(ILogger<AzureDatabricksChatCo
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// var content = result?.Choices?.FirstOrDefault()?.Delta;
|
||||
//
|
||||
// if (first && content?.Content == OpenAIConstant.ThinkStart)
|
||||
// {
|
||||
// isThink = true;
|
||||
// continue;
|
||||
// // 需要将content的内容转换到其他字段
|
||||
// }
|
||||
//
|
||||
// if (isThink && content?.Content?.Contains(OpenAIConstant.ThinkEnd) == true)
|
||||
// {
|
||||
// isThink = false;
|
||||
// // 需要将content的内容转换到其他字段
|
||||
// continue;
|
||||
// }
|
||||
//
|
||||
// if (isThink && result?.Choices != null)
|
||||
// {
|
||||
// // 需要将content的内容转换到其他字段
|
||||
// foreach (var choice in result.Choices)
|
||||
// {
|
||||
// choice.Delta.ReasoningContent = choice.Delta.Content;
|
||||
// choice.Delta.Content = string.Empty;
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// first = false;
|
||||
|
||||
yield return result;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,9 @@ using Yi.Framework.AiHub.Domain.Shared.Dtos.OpenAi;
|
||||
|
||||
namespace Yi.Framework.AiHub.Domain.AiGateWay.Impl.ThorCustomOpenAI.Chats;
|
||||
|
||||
public sealed class OpenAiChatCompletionsService(ILogger<OpenAiChatCompletionsService> logger,IHttpClientFactory httpClientFactory)
|
||||
public sealed class OpenAiChatCompletionsService(
|
||||
ILogger<OpenAiChatCompletionsService> logger,
|
||||
IHttpClientFactory httpClientFactory)
|
||||
: IChatCompletionService
|
||||
{
|
||||
public async IAsyncEnumerable<ThorChatCompletionsResponse> CompleteChatStreamAsync(AiModelDescribe options,
|
||||
@@ -19,8 +21,18 @@ public sealed class OpenAiChatCompletionsService(ILogger<OpenAiChatCompletionsSe
|
||||
using var openai =
|
||||
Activity.Current?.Source.StartActivity("OpenAI 对话流式补全");
|
||||
|
||||
var endpoint = options?.Endpoint.TrimEnd('/');
|
||||
|
||||
//兼容 v1结尾
|
||||
if (endpoint != null && endpoint.EndsWith("/v1", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
endpoint = endpoint.Substring(0, endpoint.Length - "/v1".Length);
|
||||
}
|
||||
|
||||
var requestUri = endpoint + "/v1/chat/completions";
|
||||
|
||||
var response = await httpClientFactory.CreateClient().HttpRequestRaw(
|
||||
options?.Endpoint.TrimEnd('/') + "/chat/completions",
|
||||
requestUri,
|
||||
chatCompletionCreate, options.ApiKey);
|
||||
|
||||
openai?.SetTag("Model", chatCompletionCreate.Model);
|
||||
@@ -130,8 +142,16 @@ public sealed class OpenAiChatCompletionsService(ILogger<OpenAiChatCompletionsSe
|
||||
using var openai =
|
||||
Activity.Current?.Source.StartActivity("OpenAI 对话补全");
|
||||
|
||||
var endpoint = options?.Endpoint.TrimEnd('/');
|
||||
|
||||
//兼容 v1结尾
|
||||
if (endpoint != null && endpoint.EndsWith("/v1", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
endpoint = endpoint.Substring(0, endpoint.Length - "/v1".Length);
|
||||
}
|
||||
var requestUri = endpoint + "/v1/chat/completions";
|
||||
var response = await httpClientFactory.CreateClient().PostJsonAsync(
|
||||
options?.Endpoint.TrimEnd('/') + "/chat/completions",
|
||||
requestUri,
|
||||
chatCompletionCreate, options.ApiKey).ConfigureAwait(false);
|
||||
|
||||
openai?.SetTag("Model", chatCompletionCreate.Model);
|
||||
@@ -152,7 +172,8 @@ public sealed class OpenAiChatCompletionsService(ILogger<OpenAiChatCompletionsSe
|
||||
if (response.StatusCode >= HttpStatusCode.BadRequest)
|
||||
{
|
||||
var error = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
|
||||
logger.LogError("OpenAI对话异常 请求地址:{Address}, StatusCode: {StatusCode} Response: {Response}", options.Endpoint,
|
||||
logger.LogError("OpenAI对话异常 请求地址:{Address}, StatusCode: {StatusCode} Response: {Response}",
|
||||
options.Endpoint,
|
||||
response.StatusCode, error);
|
||||
|
||||
throw new BusinessException("OpenAI对话异常", response.StatusCode.ToString());
|
||||
|
||||
@@ -22,7 +22,16 @@ public class OpenAiResponseService(ILogger<OpenAiResponseService> logger,IHttpCl
|
||||
|
||||
var client = httpClientFactory.CreateClient();
|
||||
|
||||
var response = await client.HttpRequestRaw(options.Endpoint.TrimEnd('/') + "/responses", input, options.ApiKey);
|
||||
var endpoint = options?.Endpoint.TrimEnd('/');
|
||||
|
||||
//兼容 v1结尾
|
||||
if (endpoint != null && endpoint.EndsWith("/v1", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
endpoint = endpoint.Substring(0, endpoint.Length - "/v1".Length);
|
||||
}
|
||||
var requestUri = endpoint + "/v1/responses";
|
||||
|
||||
var response = await client.HttpRequestRaw(requestUri, input, options.ApiKey);
|
||||
|
||||
openai?.SetTag("Model", input.Model);
|
||||
openai?.SetTag("Response", response.StatusCode.ToString());
|
||||
@@ -86,8 +95,17 @@ public class OpenAiResponseService(ILogger<OpenAiResponseService> logger,IHttpCl
|
||||
using var openai =
|
||||
Activity.Current?.Source.StartActivity("OpenAI 响应");
|
||||
|
||||
var endpoint = options?.Endpoint.TrimEnd('/');
|
||||
|
||||
//兼容 v1结尾
|
||||
if (endpoint != null && endpoint.EndsWith("/v1", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
endpoint = endpoint.Substring(0, endpoint.Length - "/v1".Length);
|
||||
}
|
||||
var requestUri = endpoint + "/v1/responses";
|
||||
|
||||
var response = await httpClientFactory.CreateClient().PostJsonAsync(
|
||||
options?.Endpoint.TrimEnd('/') + "/responses",
|
||||
requestUri,
|
||||
chatCompletionCreate, options.ApiKey).ConfigureAwait(false);
|
||||
|
||||
openai?.SetTag("Model", chatCompletionCreate.Model);
|
||||
|
||||
@@ -23,9 +23,17 @@ public sealed class DeepSeekChatCompletionsService(ILogger<DeepSeekChatCompletio
|
||||
|
||||
using var openai =
|
||||
Activity.Current?.Source.StartActivity("OpenAI 对话流式补全");
|
||||
|
||||
|
||||
var endpoint = options?.Endpoint.TrimEnd('/');
|
||||
//兼容 v1结尾
|
||||
if (endpoint != null && endpoint.EndsWith("/v1", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
endpoint = endpoint.Substring(0, endpoint.Length - "/v1".Length);
|
||||
}
|
||||
var requestUri = endpoint + "/v1/chat/completions";
|
||||
|
||||
var response = await httpClientFactory.CreateClient().HttpRequestRaw(
|
||||
options?.Endpoint.TrimEnd('/') + "/chat/completions",
|
||||
requestUri,
|
||||
chatCompletionCreate, options.ApiKey);
|
||||
|
||||
openai?.SetTag("Model", chatCompletionCreate.Model);
|
||||
@@ -92,40 +100,6 @@ public sealed class DeepSeekChatCompletionsService(ILogger<DeepSeekChatCompletio
|
||||
|
||||
var result = JsonSerializer.Deserialize<ThorChatCompletionsResponse>(line,
|
||||
ThorJsonSerializer.DefaultOptions);
|
||||
|
||||
// var content = result?.Choices?.FirstOrDefault()?.Delta;
|
||||
//
|
||||
// // if (first && string.IsNullOrWhiteSpace(content?.Content) && string.IsNullOrEmpty(content?.ReasoningContent))
|
||||
// // {
|
||||
// // continue;
|
||||
// // }
|
||||
//
|
||||
// if (first && content.Content == OpenAIConstant.ThinkStart)
|
||||
// {
|
||||
// isThink = true;
|
||||
// //continue;
|
||||
// // 需要将content的内容转换到其他字段
|
||||
// }
|
||||
//
|
||||
// if (isThink && content.Content.Contains(OpenAIConstant.ThinkEnd))
|
||||
// {
|
||||
// isThink = false;
|
||||
// // 需要将content的内容转换到其他字段
|
||||
// //continue;
|
||||
// }
|
||||
//
|
||||
// if (isThink)
|
||||
// {
|
||||
// // 需要将content的内容转换到其他字段
|
||||
// foreach (var choice in result.Choices)
|
||||
// {
|
||||
// //choice.Delta.ReasoningContent = choice.Delta.Content;
|
||||
// //choice.Delta.Content = string.Empty;
|
||||
// }
|
||||
// }
|
||||
|
||||
// first = false;
|
||||
|
||||
yield return result;
|
||||
}
|
||||
}
|
||||
@@ -142,8 +116,16 @@ public sealed class DeepSeekChatCompletionsService(ILogger<DeepSeekChatCompletio
|
||||
options.Endpoint = "https://api.deepseek.com/v1";
|
||||
}
|
||||
|
||||
var endpoint = options?.Endpoint.TrimEnd('/');
|
||||
//兼容 v1结尾
|
||||
if (endpoint != null && endpoint.EndsWith("/v1", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
endpoint = endpoint.Substring(0, endpoint.Length - "/v1".Length);
|
||||
}
|
||||
var requestUri = endpoint + "/v1/chat/completions";
|
||||
|
||||
var response = await httpClientFactory.CreateClient().PostJsonAsync(
|
||||
options?.Endpoint.TrimEnd('/') + "/chat/completions",
|
||||
requestUri,
|
||||
chatCompletionCreate, options.ApiKey).ConfigureAwait(false);
|
||||
|
||||
openai?.SetTag("Model", chatCompletionCreate.Model);
|
||||
|
||||
@@ -14,23 +14,18 @@ public class ImageStoreTaskAggregateRoot : FullAuditedAggregateRoot<Guid>
|
||||
public string Prompt { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 参考图Base64
|
||||
/// 参考图PrefixBase64(带前缀,如 data:image/png;base64,xxx)
|
||||
/// </summary>
|
||||
[SugarColumn(IsJson = true)]
|
||||
public List<string> ReferenceImagesBase64 { get; set; }
|
||||
[SugarColumn(IsJson = true, ColumnDataType = StaticConfig.CodeFirst_BigString)]
|
||||
public List<string> ReferenceImagesPrefixBase64 { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 参考图url
|
||||
/// </summary>
|
||||
[SugarColumn(IsJson = true)]
|
||||
public List<string> ReferenceImagesUrl { get; set; }
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// 图片base64
|
||||
/// </summary>
|
||||
public string? StoreBase64 { get; set; }
|
||||
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// 图片绝对路径
|
||||
/// </summary>
|
||||
@@ -46,6 +41,43 @@ public class ImageStoreTaskAggregateRoot : FullAuditedAggregateRoot<Guid>
|
||||
/// </summary>
|
||||
public Guid UserId { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 用户名称
|
||||
/// </summary>
|
||||
public string? UserName { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 模型id
|
||||
/// </summary>
|
||||
public string ModelId { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 错误信息
|
||||
/// </summary>
|
||||
[SugarColumn(ColumnDataType = StaticConfig.CodeFirst_BigString)]
|
||||
public string? ErrorInfo { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 发布状态
|
||||
/// </summary>
|
||||
public PublishStatusEnum PublishStatus { get; set; } = PublishStatusEnum.Unpublished;
|
||||
|
||||
/// <summary>
|
||||
/// 分类标签
|
||||
/// </summary>
|
||||
[SugarColumn(IsJson = true)]
|
||||
public List<string> Categories { get; set; } = new();
|
||||
|
||||
/// <summary>
|
||||
/// 是否匿名
|
||||
/// </summary>
|
||||
public bool IsAnonymous { get; set; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// 密钥id
|
||||
/// </summary>
|
||||
public Guid? TokenId { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 设置成功
|
||||
/// </summary>
|
||||
@@ -55,4 +87,18 @@ public class ImageStoreTaskAggregateRoot : FullAuditedAggregateRoot<Guid>
|
||||
TaskStatus = TaskStatusEnum.Success;
|
||||
StoreUrl = storeUrl;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 设置发布
|
||||
/// </summary>
|
||||
/// <param name="isAnonymous"></param>
|
||||
/// <param name="categories"></param>
|
||||
public void SetPublish(bool isAnonymous,List<string> categories)
|
||||
{
|
||||
this.PublishStatus = PublishStatusEnum.Published;
|
||||
this.IsAnonymous = isAnonymous;
|
||||
this.Categories = categories;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -80,4 +80,9 @@ public class AiModelEntity : Entity<Guid>, IOrderNum, ISoftDelete
|
||||
/// 模型图标URL
|
||||
/// </summary>
|
||||
public string? IconUrl { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 是否为尊享模型
|
||||
/// </summary>
|
||||
public bool IsPremium { get; set; }
|
||||
}
|
||||
@@ -92,7 +92,7 @@ public class AiGateWayManager : DomainService
|
||||
{
|
||||
throw new UserFriendlyException($"【{modelId}】模型当前版本【{modelApiType}】格式不支持");
|
||||
}
|
||||
// ✅ 统一处理 -nx 后缀(网关层模型规范化)
|
||||
// ✅ 统一处理 yi- 后缀(网关层模型规范化)
|
||||
if (!string.IsNullOrEmpty(aiModelDescribe.ModelId) &&
|
||||
aiModelDescribe.ModelId.StartsWith("yi-", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
@@ -158,7 +158,12 @@ public class AiGateWayManager : DomainService
|
||||
await _usageStatisticsManager.SetUsageAsync(userId.Value, sourceModelId, data.Usage, tokenId);
|
||||
|
||||
// 扣减尊享token包用量
|
||||
if (PremiumPackageConst.ModeIds.Contains(sourceModelId))
|
||||
var isPremium = await _aiModelRepository._DbQueryable
|
||||
.Where(x => x.ModelId == request.Model)
|
||||
.Select(x => x.IsPremium)
|
||||
.FirstAsync();
|
||||
|
||||
if (isPremium)
|
||||
{
|
||||
var totalTokens = data.Usage?.TotalTokens ?? 0;
|
||||
if (totalTokens > 0)
|
||||
@@ -315,12 +320,20 @@ public class AiGateWayManager : DomainService
|
||||
await _usageStatisticsManager.SetUsageAsync(userId, sourceModelId, tokenUsage, tokenId);
|
||||
|
||||
// 扣减尊享token包用量
|
||||
if (userId is not null && PremiumPackageConst.ModeIds.Contains(sourceModelId))
|
||||
if (userId is not null)
|
||||
{
|
||||
var totalTokens = tokenUsage.TotalTokens ?? 0;
|
||||
if (totalTokens > 0)
|
||||
var isPremium = await _aiModelRepository._DbQueryable
|
||||
.Where(x => x.ModelId == request.Model)
|
||||
.Select(x => x.IsPremium)
|
||||
.FirstAsync();
|
||||
|
||||
if (isPremium)
|
||||
{
|
||||
await PremiumPackageManager.TryConsumeTokensAsync(userId.Value, totalTokens);
|
||||
var totalTokens = tokenUsage.TotalTokens ?? 0;
|
||||
if (totalTokens > 0)
|
||||
{
|
||||
await PremiumPackageManager.TryConsumeTokensAsync(userId.Value, totalTokens);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -378,12 +391,20 @@ public class AiGateWayManager : DomainService
|
||||
await _usageStatisticsManager.SetUsageAsync(userId, model, response.Usage, tokenId);
|
||||
|
||||
// 扣减尊享token包用量
|
||||
if (userId is not null && PremiumPackageConst.ModeIds.Contains(request.Model))
|
||||
if (userId is not null)
|
||||
{
|
||||
var totalTokens = response.Usage.TotalTokens ?? 0;
|
||||
if (totalTokens > 0)
|
||||
var isPremium = await _aiModelRepository._DbQueryable
|
||||
.Where(x => x.ModelId == request.Model)
|
||||
.Select(x => x.IsPremium)
|
||||
.FirstAsync();
|
||||
|
||||
if (isPremium)
|
||||
{
|
||||
await PremiumPackageManager.TryConsumeTokensAsync(userId.Value, totalTokens);
|
||||
var totalTokens = response.Usage.TotalTokens ?? 0;
|
||||
if (totalTokens > 0)
|
||||
{
|
||||
await PremiumPackageManager.TryConsumeTokensAsync(userId.Value, totalTokens);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -982,7 +1003,7 @@ public class AiGateWayManager : DomainService
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private const string ImageStoreHost = "http://localhost:19001/api/app";
|
||||
/// <summary>
|
||||
/// Gemini 生成(Image)-非流式-缓存处理
|
||||
/// 返回图片绝对路径
|
||||
@@ -1011,16 +1032,16 @@ public class AiGateWayManager : DomainService
|
||||
var data = await chatService.GenerateContentAsync(modelDescribe, request, cancellationToken);
|
||||
|
||||
//解析json,获取base64字符串
|
||||
var imageBase64 = GeminiGenerateContentAcquirer.GetImageBase64(data);
|
||||
var imagePrefixBase64 = GeminiGenerateContentAcquirer.GetImagePrefixBase64(data);
|
||||
|
||||
//远程调用上传接口,将base64转换为URL
|
||||
var httpClient = LazyServiceProvider.LazyGetRequiredService<IHttpClientFactory>().CreateClient();
|
||||
var uploadUrl = $"https://ccnetcore.com/prod-api/ai-hub/ai-image/upload-base64";
|
||||
var content = new StringContent(JsonSerializer.Serialize(imageBase64), Encoding.UTF8, "application/json");
|
||||
// var uploadUrl = $"https://ccnetcore.com/prod-api/ai-hub/ai-image/upload-base64";
|
||||
var uploadUrl = $"{ImageStoreHost}/ai-image/upload-base64";
|
||||
var content = new StringContent(JsonSerializer.Serialize(imagePrefixBase64), Encoding.UTF8, "application/json");
|
||||
var uploadResponse = await httpClient.PostAsync(uploadUrl, content, cancellationToken);
|
||||
uploadResponse.EnsureSuccessStatusCode();
|
||||
var storeUrl = await uploadResponse.Content.ReadAsStringAsync(cancellationToken);
|
||||
storeUrl = storeUrl.Trim('"'); // 移除JSON字符串的引号
|
||||
|
||||
var tokenUsage = new ThorUsageResponse
|
||||
{
|
||||
@@ -1047,8 +1068,7 @@ public class AiGateWayManager : DomainService
|
||||
}
|
||||
|
||||
//设置存储base64和url
|
||||
imageStoreTask.StoreBase64 = imageBase64;
|
||||
imageStoreTask.SetSuccess(storeUrl);
|
||||
imageStoreTask.SetSuccess($"{ImageStoreHost}{storeUrl}");
|
||||
await _imageStoreTaskRepository.UpdateAsync(imageStoreTask);
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ using Volo.Abp.Domain.Services;
|
||||
using Yi.Framework.AiHub.Application.Contracts.Dtos.Chat;
|
||||
using Yi.Framework.AiHub.Domain.AiGateWay;
|
||||
using Yi.Framework.AiHub.Domain.Entities.Chat;
|
||||
using Yi.Framework.AiHub.Domain.Entities.Model;
|
||||
using Yi.Framework.AiHub.Domain.Entities.OpenApi;
|
||||
using Yi.Framework.AiHub.Domain.Shared.Attributes;
|
||||
using Yi.Framework.AiHub.Domain.Shared.Consts;
|
||||
@@ -34,12 +35,13 @@ public class ChatManager : DomainService
|
||||
private readonly UsageStatisticsManager _usageStatisticsManager;
|
||||
private readonly PremiumPackageManager _premiumPackageManager;
|
||||
private readonly AiGateWayManager _aiGateWayManager;
|
||||
private readonly ISqlSugarRepository<AiModelEntity, Guid> _aiModelRepository;
|
||||
|
||||
public ChatManager(ILoggerFactory loggerFactory,
|
||||
ISqlSugarRepository<MessageAggregateRoot> messageRepository,
|
||||
ISqlSugarRepository<AgentStoreAggregateRoot> agentStoreRepository, AiMessageManager aiMessageManager,
|
||||
UsageStatisticsManager usageStatisticsManager, PremiumPackageManager premiumPackageManager,
|
||||
AiGateWayManager aiGateWayManager)
|
||||
AiGateWayManager aiGateWayManager, ISqlSugarRepository<AiModelEntity, Guid> aiModelRepository)
|
||||
{
|
||||
_loggerFactory = loggerFactory;
|
||||
_messageRepository = messageRepository;
|
||||
@@ -48,6 +50,7 @@ public class ChatManager : DomainService
|
||||
_usageStatisticsManager = usageStatisticsManager;
|
||||
_premiumPackageManager = premiumPackageManager;
|
||||
_aiGateWayManager = aiGateWayManager;
|
||||
_aiModelRepository = aiModelRepository;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
@@ -207,7 +210,12 @@ public class ChatManager : DomainService
|
||||
await _usageStatisticsManager.SetUsageAsync(userId, modelId, usage, tokenId);
|
||||
|
||||
//扣减尊享token包用量
|
||||
if (PremiumPackageConst.ModeIds.Contains(modelId))
|
||||
var isPremium = await _aiModelRepository._DbQueryable
|
||||
.Where(x => x.ModelId == modelId)
|
||||
.Select(x => x.IsPremium)
|
||||
.FirstAsync();
|
||||
|
||||
if (isPremium)
|
||||
{
|
||||
var totalTokens = usage?.TotalTokens ?? 0;
|
||||
if (totalTokens > 0)
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Volo.Abp.Caching;
|
||||
using Volo.Abp.Domain.Services;
|
||||
using Yi.Framework.AiHub.Domain.Entities.Model;
|
||||
using Yi.Framework.SqlSugarCore.Abstractions;
|
||||
|
||||
namespace Yi.Framework.AiHub.Domain.Managers;
|
||||
|
||||
/// <summary>
|
||||
/// 模型管理器
|
||||
/// </summary>
|
||||
public class ModelManager : DomainService
|
||||
{
|
||||
private readonly ISqlSugarRepository<AiModelEntity> _aiModelRepository;
|
||||
private readonly IDistributedCache<List<string>, string> _distributedCache;
|
||||
private readonly ILogger<ModelManager> _logger;
|
||||
private const string PREMIUM_MODEL_IDS_CACHE_KEY = "PremiumModelIds";
|
||||
|
||||
public ModelManager(
|
||||
ISqlSugarRepository<AiModelEntity> aiModelRepository,
|
||||
IDistributedCache<List<string>, string> distributedCache,
|
||||
ILogger<ModelManager> logger)
|
||||
{
|
||||
_aiModelRepository = aiModelRepository;
|
||||
_distributedCache = distributedCache;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 获取所有尊享模型ID列表(使用分布式缓存,10分钟过期)
|
||||
/// </summary>
|
||||
/// <returns>尊享模型ID列表</returns>
|
||||
public async Task<List<string>> GetPremiumModelIdsAsync()
|
||||
{
|
||||
var output = await _distributedCache.GetOrAddAsync(
|
||||
PREMIUM_MODEL_IDS_CACHE_KEY,
|
||||
async () =>
|
||||
{
|
||||
// 从数据库查询
|
||||
var premiumModelIds = await _aiModelRepository._DbQueryable
|
||||
.Where(x => x.IsPremium)
|
||||
.Select(x => x.ModelId)
|
||||
.ToListAsync();
|
||||
return premiumModelIds;
|
||||
},
|
||||
() => new Microsoft.Extensions.Caching.Distributed.DistributedCacheEntryOptions
|
||||
{
|
||||
AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(1)
|
||||
}
|
||||
);
|
||||
return output ?? new List<string>();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 判断指定模型是否为尊享模型
|
||||
/// </summary>
|
||||
/// <param name="modelId">模型ID</param>
|
||||
/// <returns>是否为尊享模型</returns>
|
||||
public async Task<bool> IsPremiumModelAsync(string modelId)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(modelId))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var premiumModelIds = await GetPremiumModelIdsAsync();
|
||||
return premiumModelIds.Contains(modelId);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 清除尊享模型ID缓存
|
||||
/// </summary>
|
||||
public async Task ClearPremiumModelIdsCacheAsync()
|
||||
{
|
||||
await _distributedCache.RemoveAsync(PREMIUM_MODEL_IDS_CACHE_KEY);
|
||||
_logger.LogInformation("已清除尊享模型ID分布式缓存");
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
using SqlSugar;
|
||||
using Volo.Abp.Domain.Services;
|
||||
using Yi.Framework.AiHub.Domain.Entities;
|
||||
using Yi.Framework.AiHub.Domain.Entities.Model;
|
||||
using Yi.Framework.AiHub.Domain.Entities.OpenApi;
|
||||
using Yi.Framework.AiHub.Domain.Shared.Consts;
|
||||
using Yi.Framework.SqlSugarCore.Abstractions;
|
||||
@@ -21,43 +22,62 @@ public class TokenValidationResult
|
||||
/// Token Id
|
||||
/// </summary>
|
||||
public Guid TokenId { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// token
|
||||
/// </summary>
|
||||
public string Token { get; set; }
|
||||
}
|
||||
|
||||
public class TokenManager : DomainService
|
||||
{
|
||||
private readonly ISqlSugarRepository<TokenAggregateRoot> _tokenRepository;
|
||||
private readonly ISqlSugarRepository<UsageStatisticsAggregateRoot> _usageStatisticsRepository;
|
||||
private readonly ISqlSugarRepository<AiModelEntity, Guid> _aiModelRepository;
|
||||
|
||||
public TokenManager(
|
||||
ISqlSugarRepository<TokenAggregateRoot> tokenRepository,
|
||||
ISqlSugarRepository<UsageStatisticsAggregateRoot> usageStatisticsRepository)
|
||||
ISqlSugarRepository<UsageStatisticsAggregateRoot> usageStatisticsRepository,
|
||||
ISqlSugarRepository<AiModelEntity, Guid> aiModelRepository)
|
||||
{
|
||||
_tokenRepository = tokenRepository;
|
||||
_usageStatisticsRepository = usageStatisticsRepository;
|
||||
_aiModelRepository = aiModelRepository;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 验证Token并返回用户Id和TokenId
|
||||
/// </summary>
|
||||
/// <param name="token">Token密钥</param>
|
||||
/// <param name="tokenOrId">Token密钥或者TokenId</param>
|
||||
/// <param name="modelId">模型Id(用于判断是否是尊享模型需要检查额度)</param>
|
||||
/// <returns>Token验证结果</returns>
|
||||
public async Task<TokenValidationResult> ValidateTokenAsync(string? token, string? modelId = null)
|
||||
public async Task<TokenValidationResult> ValidateTokenAsync(object tokenOrId, string? modelId = null)
|
||||
{
|
||||
if (token is null)
|
||||
|
||||
if (tokenOrId is null)
|
||||
{
|
||||
throw new UserFriendlyException("当前请求未包含token", "401");
|
||||
}
|
||||
|
||||
if (!token.StartsWith("yi-"))
|
||||
|
||||
TokenAggregateRoot entity;
|
||||
if (tokenOrId is Guid tokenId)
|
||||
{
|
||||
throw new UserFriendlyException("当前请求token非法", "401");
|
||||
entity = await _tokenRepository._DbQueryable
|
||||
.Where(x => x.Id == tokenId)
|
||||
.FirstAsync();
|
||||
}
|
||||
|
||||
var entity = await _tokenRepository._DbQueryable
|
||||
.Where(x => x.Token == token)
|
||||
.FirstAsync();
|
||||
|
||||
else
|
||||
{
|
||||
var tokenStr = tokenOrId.ToString();
|
||||
if (!tokenStr.StartsWith("yi-"))
|
||||
{
|
||||
throw new UserFriendlyException("当前请求token非法", "401");
|
||||
}
|
||||
entity = await _tokenRepository._DbQueryable
|
||||
.Where(x => x.Token == tokenStr)
|
||||
.FirstAsync();
|
||||
}
|
||||
|
||||
if (entity is null)
|
||||
{
|
||||
throw new UserFriendlyException("当前请求token无效", "401");
|
||||
@@ -76,21 +96,28 @@ public class TokenManager : DomainService
|
||||
}
|
||||
|
||||
// 如果是尊享模型且Token设置了额度限制,检查是否超限
|
||||
if (!string.IsNullOrEmpty(modelId) &&
|
||||
PremiumPackageConst.ModeIds.Contains(modelId) &&
|
||||
entity.PremiumQuotaLimit.HasValue)
|
||||
if (!string.IsNullOrEmpty(modelId) && entity.PremiumQuotaLimit.HasValue)
|
||||
{
|
||||
var usedQuota = await GetTokenPremiumUsedQuotaAsync(entity.UserId, entity.Id);
|
||||
if (usedQuota >= entity.PremiumQuotaLimit.Value)
|
||||
var isPremium = await _aiModelRepository._DbQueryable
|
||||
.Where(x => x.ModelId == modelId)
|
||||
.Select(x => x.IsPremium)
|
||||
.FirstAsync();
|
||||
|
||||
if (isPremium)
|
||||
{
|
||||
throw new UserFriendlyException($"当前Token的尊享包额度已用完(已使用:{usedQuota},限制:{entity.PremiumQuotaLimit.Value}),请调整额度限制或使用其他Token", "403");
|
||||
var usedQuota = await GetTokenPremiumUsedQuotaAsync(entity.UserId, entity.Id);
|
||||
if (usedQuota >= entity.PremiumQuotaLimit.Value)
|
||||
{
|
||||
throw new UserFriendlyException($"当前Token的尊享包额度已用完(已使用:{usedQuota},限制:{entity.PremiumQuotaLimit.Value}),请调整额度限制或使用其他Token", "403");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new TokenValidationResult
|
||||
{
|
||||
UserId = entity.UserId,
|
||||
TokenId = entity.Id
|
||||
TokenId = entity.Id,
|
||||
Token = entity.Token
|
||||
};
|
||||
}
|
||||
|
||||
@@ -99,7 +126,11 @@ public class TokenManager : DomainService
|
||||
/// </summary>
|
||||
private async Task<long> GetTokenPremiumUsedQuotaAsync(Guid userId, Guid tokenId)
|
||||
{
|
||||
var premiumModelIds = PremiumPackageConst.ModeIds;
|
||||
// 先获取所有尊享模型的ModelId列表
|
||||
var premiumModelIds = await _aiModelRepository._DbQueryable
|
||||
.Where(x => x.IsPremium)
|
||||
.Select(x => x.ModelId)
|
||||
.ToListAsync();
|
||||
|
||||
var usedQuota = await _usageStatisticsRepository._DbQueryable
|
||||
.Where(x => x.UserId == userId && x.TokenId == tokenId && premiumModelIds.Contains(x.ModelId))
|
||||
|
||||
Reference in New Issue
Block a user