feat: 完成token功能
This commit is contained in:
@@ -120,12 +120,14 @@ public class AiGateWayManager : DomainService
|
||||
/// <param name="request"></param>
|
||||
/// <param name="userId"></param>
|
||||
/// <param name="sessionId"></param>
|
||||
/// <param name="tokenId">Token Id(Web端传null或Guid.Empty)</param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns></returns>
|
||||
public async Task CompleteChatForStatisticsAsync(HttpContext httpContext,
|
||||
ThorChatCompletionsRequest request,
|
||||
Guid? userId = null,
|
||||
Guid? sessionId = null,
|
||||
Guid? tokenId = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
_specialCompatible.Compatible(request);
|
||||
@@ -145,7 +147,7 @@ public class AiGateWayManager : DomainService
|
||||
Content = sessionId is null ? "不予存储" : request.Messages?.LastOrDefault().Content ?? string.Empty,
|
||||
ModelId = request.Model,
|
||||
TokenUsage = data.Usage,
|
||||
});
|
||||
}, tokenId);
|
||||
|
||||
await _aiMessageManager.CreateSystemMessageAsync(userId.Value, sessionId,
|
||||
new MessageInputDto
|
||||
@@ -154,9 +156,9 @@ public class AiGateWayManager : DomainService
|
||||
sessionId is null ? "不予存储" : data.Choices?.FirstOrDefault()?.Delta.Content ?? string.Empty,
|
||||
ModelId = request.Model,
|
||||
TokenUsage = data.Usage
|
||||
});
|
||||
}, tokenId);
|
||||
|
||||
await _usageStatisticsManager.SetUsageAsync(userId.Value, request.Model, data.Usage);
|
||||
await _usageStatisticsManager.SetUsageAsync(userId.Value, request.Model, data.Usage, tokenId);
|
||||
|
||||
// 扣减尊享token包用量
|
||||
if (PremiumPackageConst.ModeIds.Contains(request.Model))
|
||||
@@ -179,6 +181,7 @@ public class AiGateWayManager : DomainService
|
||||
/// <param name="request"></param>
|
||||
/// <param name="userId"></param>
|
||||
/// <param name="sessionId"></param>
|
||||
/// <param name="tokenId">Token Id(Web端传null或Guid.Empty)</param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns></returns>
|
||||
public async Task CompleteChatStreamForStatisticsAsync(
|
||||
@@ -186,6 +189,7 @@ public class AiGateWayManager : DomainService
|
||||
ThorChatCompletionsRequest request,
|
||||
Guid? userId = null,
|
||||
Guid? sessionId = null,
|
||||
Guid? tokenId = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var response = httpContext.Response;
|
||||
@@ -288,7 +292,7 @@ public class AiGateWayManager : DomainService
|
||||
Content = sessionId is null ? "不予存储" : request.Messages?.LastOrDefault()?.Content ?? string.Empty,
|
||||
ModelId = request.Model,
|
||||
TokenUsage = tokenUsage,
|
||||
});
|
||||
}, tokenId);
|
||||
|
||||
await _aiMessageManager.CreateSystemMessageAsync(userId, sessionId,
|
||||
new MessageInputDto
|
||||
@@ -296,9 +300,9 @@ public class AiGateWayManager : DomainService
|
||||
Content = sessionId is null ? "不予存储" : backupSystemContent.ToString(),
|
||||
ModelId = request.Model,
|
||||
TokenUsage = tokenUsage
|
||||
});
|
||||
}, tokenId);
|
||||
|
||||
await _usageStatisticsManager.SetUsageAsync(userId, request.Model, tokenUsage);
|
||||
await _usageStatisticsManager.SetUsageAsync(userId, request.Model, tokenUsage, tokenId);
|
||||
|
||||
// 扣减尊享token包用量
|
||||
if (userId is not null && PremiumPackageConst.ModeIds.Contains(request.Model))
|
||||
@@ -319,10 +323,11 @@ public class AiGateWayManager : DomainService
|
||||
/// <param name="userId"></param>
|
||||
/// <param name="sessionId"></param>
|
||||
/// <param name="request"></param>
|
||||
/// <param name="tokenId">Token Id(Web端传null或Guid.Empty)</param>
|
||||
/// <exception cref="BusinessException"></exception>
|
||||
/// <exception cref="Exception"></exception>
|
||||
public async Task CreateImageForStatisticsAsync(HttpContext context, Guid? userId, Guid? sessionId,
|
||||
ImageCreateRequest request)
|
||||
ImageCreateRequest request, Guid? tokenId = null)
|
||||
{
|
||||
try
|
||||
{
|
||||
@@ -350,7 +355,7 @@ public class AiGateWayManager : DomainService
|
||||
Content = sessionId is null ? "不予存储" : request.Prompt,
|
||||
ModelId = model,
|
||||
TokenUsage = response.Usage,
|
||||
});
|
||||
}, tokenId);
|
||||
|
||||
await _aiMessageManager.CreateSystemMessageAsync(userId, sessionId,
|
||||
new MessageInputDto
|
||||
@@ -358,9 +363,9 @@ public class AiGateWayManager : DomainService
|
||||
Content = sessionId is null ? "不予存储" : response.Results?.FirstOrDefault()?.Url,
|
||||
ModelId = model,
|
||||
TokenUsage = response.Usage
|
||||
});
|
||||
}, tokenId);
|
||||
|
||||
await _usageStatisticsManager.SetUsageAsync(userId, model, response.Usage);
|
||||
await _usageStatisticsManager.SetUsageAsync(userId, model, response.Usage, tokenId);
|
||||
|
||||
// 扣减尊享token包用量
|
||||
if (userId is not null && PremiumPackageConst.ModeIds.Contains(request.Model))
|
||||
@@ -384,13 +389,14 @@ public class AiGateWayManager : DomainService
|
||||
/// 向量生成
|
||||
/// </summary>
|
||||
/// <param name="context"></param>
|
||||
/// <param name="userId"></param>
|
||||
/// <param name="sessionId"></param>
|
||||
/// <param name="input"></param>
|
||||
/// <param name="userId"></param>
|
||||
/// <param name="tokenId">Token Id(Web端传null或Guid.Empty)</param>
|
||||
/// <exception cref="Exception"></exception>
|
||||
/// <exception cref="BusinessException"></exception>
|
||||
public async Task EmbeddingForStatisticsAsync(HttpContext context, Guid? userId, Guid? sessionId,
|
||||
ThorEmbeddingInput input)
|
||||
ThorEmbeddingInput input, Guid? tokenId = null)
|
||||
{
|
||||
try
|
||||
{
|
||||
@@ -474,7 +480,7 @@ public class AiGateWayManager : DomainService
|
||||
// TokenUsage = usage
|
||||
// });
|
||||
|
||||
await _usageStatisticsManager.SetUsageAsync(userId, input.Model, usage);
|
||||
await _usageStatisticsManager.SetUsageAsync(userId, input.Model, usage, tokenId);
|
||||
}
|
||||
catch (ThorRateLimitException)
|
||||
{
|
||||
@@ -522,12 +528,14 @@ public class AiGateWayManager : DomainService
|
||||
/// <param name="request"></param>
|
||||
/// <param name="userId"></param>
|
||||
/// <param name="sessionId"></param>
|
||||
/// <param name="tokenId">Token Id(Web端传null或Guid.Empty)</param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns></returns>
|
||||
public async Task AnthropicCompleteChatForStatisticsAsync(HttpContext httpContext,
|
||||
AnthropicInput request,
|
||||
Guid? userId = null,
|
||||
Guid? sessionId = null,
|
||||
Guid? tokenId = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
_specialCompatible.AnthropicCompatible(request);
|
||||
@@ -549,7 +557,7 @@ public class AiGateWayManager : DomainService
|
||||
Content = sessionId is null ? "不予存储" : request.Messages?.FirstOrDefault()?.Content ?? string.Empty,
|
||||
ModelId = request.Model,
|
||||
TokenUsage = data.TokenUsage,
|
||||
});
|
||||
}, tokenId);
|
||||
|
||||
await _aiMessageManager.CreateSystemMessageAsync(userId.Value, sessionId,
|
||||
new MessageInputDto
|
||||
@@ -557,9 +565,9 @@ public class AiGateWayManager : DomainService
|
||||
Content = sessionId is null ? "不予存储" : data.content?.FirstOrDefault()?.text,
|
||||
ModelId = request.Model,
|
||||
TokenUsage = data.TokenUsage
|
||||
});
|
||||
}, tokenId);
|
||||
|
||||
await _usageStatisticsManager.SetUsageAsync(userId.Value, request.Model, data.TokenUsage);
|
||||
await _usageStatisticsManager.SetUsageAsync(userId.Value, request.Model, data.TokenUsage, tokenId);
|
||||
|
||||
// 扣减尊享token包用量
|
||||
var totalTokens = data.TokenUsage.TotalTokens ?? 0;
|
||||
@@ -579,6 +587,7 @@ public class AiGateWayManager : DomainService
|
||||
/// <param name="request"></param>
|
||||
/// <param name="userId"></param>
|
||||
/// <param name="sessionId"></param>
|
||||
/// <param name="tokenId">Token Id(Web端传null或Guid.Empty)</param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns></returns>
|
||||
public async Task AnthropicCompleteChatStreamForStatisticsAsync(
|
||||
@@ -586,6 +595,7 @@ public class AiGateWayManager : DomainService
|
||||
AnthropicInput request,
|
||||
Guid? userId = null,
|
||||
Guid? sessionId = null,
|
||||
Guid? tokenId = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var response = httpContext.Response;
|
||||
@@ -627,7 +637,7 @@ public class AiGateWayManager : DomainService
|
||||
Content = sessionId is null ? "不予存储" : request.Messages?.LastOrDefault()?.Content ?? string.Empty,
|
||||
ModelId = request.Model,
|
||||
TokenUsage = tokenUsage,
|
||||
});
|
||||
}, tokenId);
|
||||
|
||||
await _aiMessageManager.CreateSystemMessageAsync(userId, sessionId,
|
||||
new MessageInputDto
|
||||
@@ -635,9 +645,9 @@ public class AiGateWayManager : DomainService
|
||||
Content = sessionId is null ? "不予存储" : backupSystemContent.ToString(),
|
||||
ModelId = request.Model,
|
||||
TokenUsage = tokenUsage
|
||||
});
|
||||
}, tokenId);
|
||||
|
||||
await _usageStatisticsManager.SetUsageAsync(userId, request.Model, tokenUsage);
|
||||
await _usageStatisticsManager.SetUsageAsync(userId, request.Model, tokenUsage, tokenId);
|
||||
|
||||
// 扣减尊享token包用量
|
||||
if (userId.HasValue && tokenUsage is not null)
|
||||
|
||||
@@ -19,28 +19,30 @@ public class AiMessageManager : DomainService
|
||||
/// <summary>
|
||||
/// 创建系统消息
|
||||
/// </summary>
|
||||
/// <param name="sessionId"></param>
|
||||
/// <param name="userId"></param>
|
||||
/// <param name="input"></param>
|
||||
/// <param name="userId">用户Id</param>
|
||||
/// <param name="sessionId">会话Id</param>
|
||||
/// <param name="input">消息输入</param>
|
||||
/// <param name="tokenId">Token Id(Web端传Guid.Empty)</param>
|
||||
/// <returns></returns>
|
||||
public async Task CreateSystemMessageAsync(Guid? userId, Guid? sessionId, MessageInputDto input)
|
||||
public async Task CreateSystemMessageAsync(Guid? userId, Guid? sessionId, MessageInputDto input, Guid? tokenId = null)
|
||||
{
|
||||
input.Role = "system";
|
||||
var message = new MessageAggregateRoot(userId, sessionId, input.Content, input.Role, input.ModelId,input.TokenUsage);
|
||||
var message = new MessageAggregateRoot(userId, sessionId, input.Content, input.Role, input.ModelId, input.TokenUsage, tokenId);
|
||||
await _repository.InsertAsync(message);
|
||||
}
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// 创建系统消息
|
||||
/// 创建用户消息
|
||||
/// </summary>
|
||||
/// <param name="sessionId"></param>
|
||||
/// <param name="userId"></param>
|
||||
/// <param name="input"></param>
|
||||
/// <param name="userId">用户Id</param>
|
||||
/// <param name="sessionId">会话Id</param>
|
||||
/// <param name="input">消息输入</param>
|
||||
/// <param name="tokenId">Token Id(Web端传Guid.Empty)</param>
|
||||
/// <returns></returns>
|
||||
public async Task CreateUserMessageAsync(Guid? userId, Guid? sessionId, MessageInputDto input)
|
||||
public async Task CreateUserMessageAsync(Guid? userId, Guid? sessionId, MessageInputDto input, Guid? tokenId = null)
|
||||
{
|
||||
input.Role = "user";
|
||||
var message = new MessageAggregateRoot(userId, sessionId, input.Content, input.Role, input.ModelId,input.TokenUsage);
|
||||
var message = new MessageAggregateRoot(userId, sessionId, input.Content, input.Role, input.ModelId, input.TokenUsage, tokenId);
|
||||
await _repository.InsertAsync(message);
|
||||
}
|
||||
}
|
||||
@@ -1,64 +1,134 @@
|
||||
using Volo.Abp.Domain.Services;
|
||||
using Volo.Abp.Users;
|
||||
using SqlSugar;
|
||||
using Volo.Abp.Domain.Services;
|
||||
using Yi.Framework.AiHub.Domain.Entities;
|
||||
using Yi.Framework.AiHub.Domain.Entities.OpenApi;
|
||||
using Yi.Framework.AiHub.Domain.Shared.Consts;
|
||||
using Yi.Framework.SqlSugarCore.Abstractions;
|
||||
|
||||
namespace Yi.Framework.AiHub.Domain.Managers;
|
||||
|
||||
/// <summary>
|
||||
/// Token验证结果
|
||||
/// </summary>
|
||||
public class TokenValidationResult
|
||||
{
|
||||
/// <summary>
|
||||
/// 用户Id
|
||||
/// </summary>
|
||||
public Guid UserId { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Token Id
|
||||
/// </summary>
|
||||
public Guid TokenId { get; set; }
|
||||
}
|
||||
|
||||
public class TokenManager : DomainService
|
||||
{
|
||||
private readonly ISqlSugarRepository<TokenAggregateRoot> _tokenRepository;
|
||||
private readonly ISqlSugarRepository<UsageStatisticsAggregateRoot> _usageStatisticsRepository;
|
||||
|
||||
public TokenManager(ISqlSugarRepository<TokenAggregateRoot> tokenRepository)
|
||||
public TokenManager(
|
||||
ISqlSugarRepository<TokenAggregateRoot> tokenRepository,
|
||||
ISqlSugarRepository<UsageStatisticsAggregateRoot> usageStatisticsRepository)
|
||||
{
|
||||
_tokenRepository = tokenRepository;
|
||||
_usageStatisticsRepository = usageStatisticsRepository;
|
||||
}
|
||||
|
||||
public async Task<string?> GetAsync(Guid userId)
|
||||
{
|
||||
var entity = await _tokenRepository._DbQueryable.FirstAsync(x => x.UserId == userId);
|
||||
if (entity is not null)
|
||||
{
|
||||
return entity.Token;
|
||||
}
|
||||
else
|
||||
{
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public async Task CreateAsync(Guid userId)
|
||||
{
|
||||
var entity = await _tokenRepository._DbQueryable.FirstAsync(x => x.UserId == userId);
|
||||
if (entity is not null)
|
||||
{
|
||||
entity.ResetToken();
|
||||
await _tokenRepository.UpdateAsync(entity);
|
||||
}
|
||||
else
|
||||
{
|
||||
var token = new TokenAggregateRoot(userId);
|
||||
await _tokenRepository.InsertAsync(token);
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<Guid> GetUserIdAsync(string? token)
|
||||
/// <summary>
|
||||
/// 验证Token并返回用户Id和TokenId
|
||||
/// </summary>
|
||||
/// <param name="token">Token密钥</param>
|
||||
/// <param name="modelId">模型Id(用于判断是否是尊享模型需要检查额度)</param>
|
||||
/// <returns>Token验证结果</returns>
|
||||
public async Task<TokenValidationResult> ValidateTokenAsync(string? token, string? modelId = null)
|
||||
{
|
||||
if (token is null)
|
||||
{
|
||||
throw new UserFriendlyException("当前请求未包含token", "401");
|
||||
}
|
||||
|
||||
if (token.StartsWith("yi-"))
|
||||
if (!token.StartsWith("yi-"))
|
||||
{
|
||||
var entity = await _tokenRepository._DbQueryable.Where(x => x.Token == token).FirstAsync();
|
||||
if (entity is null)
|
||||
{
|
||||
throw new UserFriendlyException("当前请求token无效", "401");
|
||||
}
|
||||
|
||||
return entity.UserId;
|
||||
throw new UserFriendlyException("当前请求token非法", "401");
|
||||
}
|
||||
throw new UserFriendlyException("当前请求token非法", "401");
|
||||
|
||||
var entity = await _tokenRepository._DbQueryable
|
||||
.Where(x => x.Token == token)
|
||||
.FirstAsync();
|
||||
|
||||
if (entity is null)
|
||||
{
|
||||
throw new UserFriendlyException("当前请求token无效", "401");
|
||||
}
|
||||
|
||||
// 检查Token是否被禁用
|
||||
if (entity.IsDisabled)
|
||||
{
|
||||
throw new UserFriendlyException("当前Token已被禁用,请启用后再使用", "403");
|
||||
}
|
||||
|
||||
// 检查Token是否过期
|
||||
if (entity.ExpireTime.HasValue && entity.ExpireTime.Value < DateTime.Now)
|
||||
{
|
||||
throw new UserFriendlyException("当前Token已过期,请更新过期时间或创建新的Token", "403");
|
||||
}
|
||||
|
||||
// 如果是尊享模型且Token设置了额度限制,检查是否超限
|
||||
if (!string.IsNullOrEmpty(modelId) &&
|
||||
PremiumPackageConst.ModeIds.Contains(modelId) &&
|
||||
entity.PremiumQuotaLimit.HasValue)
|
||||
{
|
||||
var usedQuota = await GetTokenPremiumUsedQuotaAsync(entity.UserId, entity.Id);
|
||||
if (usedQuota >= entity.PremiumQuotaLimit.Value)
|
||||
{
|
||||
throw new UserFriendlyException($"当前Token的尊享包额度已用完(已使用:{usedQuota},限制:{entity.PremiumQuotaLimit.Value}),请调整额度限制或使用其他Token", "403");
|
||||
}
|
||||
}
|
||||
|
||||
return new TokenValidationResult
|
||||
{
|
||||
UserId = entity.UserId,
|
||||
TokenId = entity.Id
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 获取Token的尊享包已使用额度
|
||||
/// </summary>
|
||||
private async Task<long> GetTokenPremiumUsedQuotaAsync(Guid userId, Guid tokenId)
|
||||
{
|
||||
var premiumModelIds = PremiumPackageConst.ModeIds;
|
||||
|
||||
var usedQuota = await _usageStatisticsRepository._DbQueryable
|
||||
.Where(x => x.UserId == userId && x.TokenId == tokenId && premiumModelIds.Contains(x.ModelId))
|
||||
.SumAsync(x => x.TotalTokenCount);
|
||||
|
||||
return usedQuota;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 获取用户的Token(兼容旧接口,返回第一个可用的Token)
|
||||
/// </summary>
|
||||
[Obsolete("请使用 ValidateTokenAsync 方法")]
|
||||
public async Task<string?> GetAsync(Guid userId)
|
||||
{
|
||||
var entity = await _tokenRepository._DbQueryable
|
||||
.Where(x => x.UserId == userId && !x.IsDisabled)
|
||||
.OrderBy(x => x.CreationTime)
|
||||
.FirstAsync();
|
||||
|
||||
return entity?.Token;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 获取用户Id(兼容旧接口)
|
||||
/// </summary>
|
||||
[Obsolete("请使用 ValidateTokenAsync 方法")]
|
||||
public async Task<Guid> GetUserIdAsync(string? token)
|
||||
{
|
||||
var result = await ValidateTokenAsync(token);
|
||||
return result.UserId;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,8 +18,10 @@ public class UsageStatisticsManager : DomainService
|
||||
private IDistributedLockProvider DistributedLock =>
|
||||
LazyServiceProvider.LazyGetRequiredService<IDistributedLockProvider>();
|
||||
|
||||
public async Task SetUsageAsync(Guid? userId, string modelId, ThorUsageResponse? tokenUsage)
|
||||
public async Task SetUsageAsync(Guid? userId, string modelId, ThorUsageResponse? tokenUsage, Guid? tokenId = null)
|
||||
{
|
||||
var actualTokenId = tokenId ?? Guid.Empty;
|
||||
|
||||
long inputTokenCount = tokenUsage?.PromptTokens
|
||||
?? tokenUsage?.InputTokens
|
||||
?? 0;
|
||||
@@ -28,10 +30,10 @@ public class UsageStatisticsManager : DomainService
|
||||
?? tokenUsage?.OutputTokens
|
||||
?? 0;
|
||||
|
||||
await using (await DistributedLock.AcquireLockAsync($"UsageStatistics:{userId?.ToString()}"))
|
||||
await using (await DistributedLock.AcquireLockAsync($"UsageStatistics:{userId?.ToString()}:{actualTokenId}:{modelId}"))
|
||||
{
|
||||
var entity = await _repository._DbQueryable.FirstAsync(x => x.UserId == userId && x.ModelId == modelId);
|
||||
//存在数据,更细
|
||||
var entity = await _repository._DbQueryable.FirstAsync(x => x.UserId == userId && x.ModelId == modelId && x.TokenId == actualTokenId);
|
||||
//存在数据,更新
|
||||
if (entity is not null)
|
||||
{
|
||||
entity.AddOnceChat(inputTokenCount, outputTokenCount);
|
||||
@@ -40,7 +42,7 @@ public class UsageStatisticsManager : DomainService
|
||||
//不存在插入
|
||||
else
|
||||
{
|
||||
var usage = new UsageStatisticsAggregateRoot(userId, modelId);
|
||||
var usage = new UsageStatisticsAggregateRoot(userId, modelId, actualTokenId);
|
||||
usage.AddOnceChat(inputTokenCount, outputTokenCount);
|
||||
await _repository.InsertAsync(usage);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user