feat: 完成用量统计功能模块
This commit is contained in:
@@ -5,6 +5,6 @@ namespace Yi.Framework.AiHub.Domain.AiChat;
|
||||
|
||||
public interface IChatService
|
||||
{
|
||||
public IAsyncEnumerable<string> CompleteChatAsync(AiModelDescribe aiModelDescribe, List<ChatMessage> messages,
|
||||
public IAsyncEnumerable<CompleteChatResponse> CompleteChatAsync(AiModelDescribe aiModelDescribe, List<ChatMessage> messages,
|
||||
CancellationToken cancellationToken);
|
||||
}
|
||||
@@ -12,7 +12,8 @@ public class AzureChatService : IChatService
|
||||
{
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<string> CompleteChatAsync(AiModelDescribe aiModelDescribe, List<ChatMessage> messages,
|
||||
public async IAsyncEnumerable<CompleteChatResponse> CompleteChatAsync(AiModelDescribe aiModelDescribe,
|
||||
List<ChatMessage> messages,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken)
|
||||
{
|
||||
var endpoint = new Uri(aiModelDescribe.Endpoint);
|
||||
@@ -32,9 +33,28 @@ public class AzureChatService : IChatService
|
||||
|
||||
await foreach (StreamingChatCompletionUpdate update in response)
|
||||
{
|
||||
var result = new CompleteChatResponse();
|
||||
var isFinish = update.Usage?.OutputTokenCount is not null;
|
||||
if (isFinish)
|
||||
{
|
||||
result.IsFinish = true;
|
||||
result.TokenUsage = new TokenUsage
|
||||
{
|
||||
OutputTokenCount = update.Usage.OutputTokenCount,
|
||||
InputTokenCount = update.Usage.InputTokenCount,
|
||||
TotalTokenCount = update.Usage.TotalTokenCount
|
||||
};
|
||||
}
|
||||
|
||||
foreach (ChatMessageContentPart updatePart in update.ContentUpdate)
|
||||
{
|
||||
yield return updatePart.Text;
|
||||
result.Content = updatePart.Text;
|
||||
yield return result;
|
||||
}
|
||||
|
||||
if (isFinish)
|
||||
{
|
||||
yield return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,8 @@ public class AzureRestChatService : IChatService
|
||||
{
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<string> CompleteChatAsync(AiModelDescribe aiModelDescribe, List<ChatMessage> messages,
|
||||
public async IAsyncEnumerable<CompleteChatResponse> CompleteChatAsync(AiModelDescribe aiModelDescribe,
|
||||
List<ChatMessage> messages,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken)
|
||||
{
|
||||
// 设置API URL
|
||||
@@ -61,51 +62,65 @@ public class AzureRestChatService : IChatService
|
||||
var responseStream = await response.Content.ReadAsStreamAsync(cancellationToken);
|
||||
// 从流中读取数据并输出到控制台
|
||||
using var streamReader = new StreamReader(responseStream);
|
||||
string line;
|
||||
while ((line = await streamReader.ReadLineAsync(cancellationToken)) != null)
|
||||
while (await streamReader.ReadLineAsync(cancellationToken) is { } line)
|
||||
{
|
||||
var result = GetContent(line);
|
||||
if (result is not null)
|
||||
var result = new CompleteChatResponse();
|
||||
try
|
||||
{
|
||||
yield return result;
|
||||
var jsonObj = MapToJObject(line);
|
||||
var content = GetContent(jsonObj);
|
||||
var tokenUsage = GetTokenUsage(jsonObj);
|
||||
result= new CompleteChatResponse
|
||||
{
|
||||
TokenUsage = tokenUsage,
|
||||
IsFinish = tokenUsage is not null,
|
||||
Content = content
|
||||
};
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Console.WriteLine("解析失败");
|
||||
}
|
||||
|
||||
yield return result;
|
||||
}
|
||||
}
|
||||
|
||||
private string? GetContent(string line)
|
||||
private JObject? MapToJObject(string line)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(line))
|
||||
return null;
|
||||
string prefix = "data: ";
|
||||
line = line.Substring(prefix.Length);
|
||||
return JObject.Parse(line);
|
||||
}
|
||||
|
||||
private string? GetContent(JObject? jsonObj)
|
||||
{
|
||||
var contentToken = jsonObj.SelectToken("choices[0].delta.content");
|
||||
if (contentToken != null && contentToken.Type != JTokenType.Null)
|
||||
{
|
||||
return contentToken.ToString();
|
||||
}
|
||||
|
||||
try
|
||||
return null;
|
||||
}
|
||||
|
||||
private TokenUsage? GetTokenUsage(JObject? jsonObj)
|
||||
{
|
||||
var usage = jsonObj.SelectToken("usage");
|
||||
if (usage is not null && usage.Type != JTokenType.Null)
|
||||
{
|
||||
// 解析为JObject
|
||||
var jsonObj = JObject.Parse(line);
|
||||
var content = jsonObj["choices"][0]["delta"]["content"].ToString();
|
||||
return content;
|
||||
// // 判断choices是否存在且是数组,并且有元素
|
||||
// if (jsonObj.TryGetValue("choices", out var choicesToken) && choicesToken is JArray choicesArray &&
|
||||
// choicesArray.Count > 0)
|
||||
// {
|
||||
// var firstChoice = choicesArray[0] as JObject;
|
||||
// // 判断delta字段是否存在
|
||||
// if (firstChoice.TryGetValue("delta", out var deltaToken))
|
||||
// {
|
||||
// // 获取content字段
|
||||
// if (deltaToken.Type == JTokenType.Object && ((JObject)deltaToken).TryGetValue("content", out var contentToken))
|
||||
// {
|
||||
// return contentToken.ToString();
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
catch (Exception)
|
||||
{
|
||||
// 解析失败
|
||||
return null;
|
||||
var result = new TokenUsage()
|
||||
{
|
||||
OutputTokenCount = usage["completion_tokens"].ToObject<int>(),
|
||||
InputTokenCount = usage["prompt_tokens"].ToObject<int>(),
|
||||
TotalTokenCount = usage["total_tokens"].ToObject<int>()
|
||||
};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,10 @@
|
||||
using SqlSugar;
|
||||
using Mapster;
|
||||
using SqlSugar;
|
||||
using Volo.Abp.Domain.Entities.Auditing;
|
||||
using Yi.Framework.AiHub.Domain.Entities.ValueObjects;
|
||||
using Yi.Framework.AiHub.Domain.Shared.Dtos;
|
||||
|
||||
namespace Yi.Framework.AiHub.Domain.Entities;
|
||||
namespace Yi.Framework.AiHub.Domain.Entities.Chat;
|
||||
|
||||
[SugarTable("Ai_Message")]
|
||||
[SugarIndex($"index_{{table}}_{nameof(UserId)}_{nameof(SessionId)}",
|
||||
@@ -14,21 +17,32 @@ public class MessageAggregateRoot : FullAuditedAggregateRoot<Guid>
|
||||
{
|
||||
}
|
||||
|
||||
public MessageAggregateRoot(Guid userId, Guid sessionId, string content, string role, string modelId)
|
||||
public MessageAggregateRoot(Guid userId, Guid sessionId, string content, string role, string modelId,
|
||||
TokenUsage? tokenUsage)
|
||||
{
|
||||
UserId = userId;
|
||||
SessionId = sessionId;
|
||||
Content = content;
|
||||
Role = role;
|
||||
ModelId = modelId;
|
||||
if (tokenUsage is not null)
|
||||
{
|
||||
this.TokenUsage = tokenUsage.Adapt<TokenUsageValueObject>();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public Guid UserId { get; set; }
|
||||
public Guid SessionId { get; set; }
|
||||
|
||||
[SugarColumn(ColumnDataType = StaticConfig.CodeFirst_BigString)]
|
||||
public string Content { get; set; }
|
||||
|
||||
public string Role { get; set; }
|
||||
public decimal DeductCost { get; set; }
|
||||
public decimal TotalTokens { get; set; }
|
||||
public string ModelId { get; set; }
|
||||
public string Remark { get; set; }
|
||||
public string? Remark { get; set; }
|
||||
|
||||
[SugarColumn(IsOwnsOne = true)] public TokenUsageValueObject TokenUsage { get; set; }
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
using SqlSugar;
|
||||
using Volo.Abp.Domain.Entities.Auditing;
|
||||
|
||||
namespace Yi.Framework.AiHub.Domain.Entities;
|
||||
namespace Yi.Framework.AiHub.Domain.Entities.Chat;
|
||||
|
||||
[SugarTable("Ai_Session")]
|
||||
[SugarIndex($"index_{{table}}_{nameof(UserId)}",$"{nameof(UserId)}", OrderByType.Asc)]
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
using Volo.Abp.Domain.Entities.Auditing;
|
||||
using Yi.Framework.Core.Data;
|
||||
|
||||
namespace Yi.Framework.AiHub.Domain.Entities;
|
||||
namespace Yi.Framework.AiHub.Domain.Entities.Model;
|
||||
|
||||
/// <summary>
|
||||
/// ai应用
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
using SqlSugar;
|
||||
using Volo.Abp.Domain.Entities;
|
||||
using Volo.Abp.Domain.Entities.Auditing;
|
||||
using Yi.Framework.Core.Data;
|
||||
|
||||
namespace Yi.Framework.AiHub.Domain.Entities;
|
||||
namespace Yi.Framework.AiHub.Domain.Entities.Model;
|
||||
|
||||
/// <summary>
|
||||
/// ai模型定义
|
||||
|
||||
@@ -9,6 +9,16 @@ namespace Yi.Framework.AiHub.Domain.Entities;
|
||||
[SugarTable("Ai_UsageStatistics")]
|
||||
public class UsageStatisticsAggregateRoot : FullAuditedAggregateRoot<Guid>
|
||||
{
|
||||
public UsageStatisticsAggregateRoot()
|
||||
{
|
||||
}
|
||||
|
||||
public UsageStatisticsAggregateRoot(Guid userId, string modelId)
|
||||
{
|
||||
UserId = userId;
|
||||
ModelId = modelId;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 用户id
|
||||
/// </summary>
|
||||
@@ -19,18 +29,34 @@ public class UsageStatisticsAggregateRoot : FullAuditedAggregateRoot<Guid>
|
||||
/// </summary>
|
||||
public string ModelId { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 输入使用token使用
|
||||
/// </summary>
|
||||
public decimal InputTokens { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 输出使用token使用
|
||||
/// </summary>
|
||||
public decimal OutputTokens { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 对话次数
|
||||
/// </summary>
|
||||
public int Number { get; set; }
|
||||
public int UsageTotalNumber { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 使用输出token总数
|
||||
/// </summary>
|
||||
public int UsageOutputTokenCount { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 使用输入总数
|
||||
/// </summary>
|
||||
public int UsageInputTokenCount { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 总token使用数量
|
||||
/// </summary>
|
||||
public int TotalTokenCount { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// 新增一次聊天统计
|
||||
/// </summary>
|
||||
public void AddOnceChat(int inputTokenCount, int outputTokenCount)
|
||||
{
|
||||
UsageTotalNumber += 1;
|
||||
UsageOutputTokenCount += outputTokenCount;
|
||||
UsageInputTokenCount += inputTokenCount;
|
||||
TotalTokenCount += (outputTokenCount + inputTokenCount);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
namespace Yi.Framework.AiHub.Domain.Entities.ValueObjects;
|
||||
|
||||
public class TokenUsageValueObject
|
||||
{
|
||||
public int OutputTokenCount { get; set; }
|
||||
|
||||
public int InputTokenCount { get; set; }
|
||||
|
||||
public int TotalTokenCount { get; set; }
|
||||
}
|
||||
@@ -4,6 +4,7 @@ using OpenAI.Chat;
|
||||
using Volo.Abp.Domain.Services;
|
||||
using Yi.Framework.AiHub.Domain.AiChat;
|
||||
using Yi.Framework.AiHub.Domain.Entities;
|
||||
using Yi.Framework.AiHub.Domain.Entities.Model;
|
||||
using Yi.Framework.AiHub.Domain.Shared.Dtos;
|
||||
using Yi.Framework.SqlSugarCore.Abstractions;
|
||||
|
||||
@@ -57,7 +58,7 @@ public class AiGateWayManager : DomainService
|
||||
/// <param name="messages"></param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns></returns>
|
||||
public async IAsyncEnumerable<string> CompleteChatAsync(string modelId, List<ChatMessage> messages,
|
||||
public async IAsyncEnumerable<CompleteChatResponse> CompleteChatAsync(string modelId, List<ChatMessage> messages,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken)
|
||||
{
|
||||
var modelDescribe = await GetModelAsync(modelId);
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
using Volo.Abp.Users;
|
||||
using Yi.Framework.AiHub.Application.Contracts.Dtos;
|
||||
using Yi.Framework.AiHub.Domain.Entities;
|
||||
using Yi.Framework.AiHub.Domain.Entities.Chat;
|
||||
using Yi.Framework.SqlSugarCore.Abstractions;
|
||||
|
||||
namespace Yi.Framework.AiHub.Domain.Managers;
|
||||
@@ -16,15 +17,30 @@ public class AiMessageManager : DomainService
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 创建消息
|
||||
/// 创建系统消息
|
||||
/// </summary>
|
||||
/// <param name="sessionId"></param>
|
||||
/// <param name="userId"></param>
|
||||
/// <param name="input"></param>
|
||||
/// <returns></returns>
|
||||
public async Task CreateMessageAsync(Guid userId, Guid sessionId, MessageInputDto input)
|
||||
public async Task CreateSystemMessageAsync(Guid userId, Guid sessionId, MessageInputDto input)
|
||||
{
|
||||
var message = new MessageAggregateRoot(userId, sessionId, input.Content, input.Role, input.ModelId);
|
||||
input.Role = "system";
|
||||
var message = new MessageAggregateRoot(userId, sessionId, input.Content, input.Role, input.ModelId,input.TokenUsage);
|
||||
await _repository.InsertAsync(message);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// 创建系统消息
|
||||
/// </summary>
|
||||
/// <param name="sessionId"></param>
|
||||
/// <param name="userId"></param>
|
||||
/// <param name="input"></param>
|
||||
/// <returns></returns>
|
||||
public async Task CreateUserMessageAsync(Guid userId, Guid sessionId, MessageInputDto input)
|
||||
{
|
||||
input.Role = "user";
|
||||
var message = new MessageAggregateRoot(userId, sessionId, input.Content, input.Role, input.ModelId,input.TokenUsage);
|
||||
await _repository.InsertAsync(message);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
using Medallion.Threading;
|
||||
using Volo.Abp.Domain.Services;
|
||||
using Yi.Framework.AiHub.Domain.Entities;
|
||||
using Yi.Framework.SqlSugarCore.Abstractions;
|
||||
|
||||
public class UsageStatisticsManager : DomainService
|
||||
{
|
||||
private readonly ISqlSugarRepository<UsageStatisticsAggregateRoot> _repository;
|
||||
|
||||
private IDistributedLockProvider DistributedLock =>
|
||||
LazyServiceProvider.LazyGetRequiredService<IDistributedLockProvider>();
|
||||
|
||||
public async Task SetUsageAsync(Guid userId, string modelId, int inputTokenCount, int outputTokenCount)
|
||||
{
|
||||
await using (await DistributedLock.AcquireLockAsync($"UsageStatistics:{userId.ToString()}"))
|
||||
{
|
||||
var entity = await _repository._DbQueryable.FirstAsync(x => x.UserId == userId && x.ModelId == modelId);
|
||||
//存在数据,更细
|
||||
if (entity is not null)
|
||||
{
|
||||
entity.AddOnceChat(inputTokenCount, outputTokenCount);
|
||||
await _repository.UpdateAsync(entity);
|
||||
}
|
||||
//不存在插入
|
||||
else
|
||||
{
|
||||
var usage = new UsageStatisticsAggregateRoot(userId, modelId);
|
||||
usage.AddOnceChat(inputTokenCount, outputTokenCount);
|
||||
await _repository.InsertAsync(usage);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal class LazyServiceProvider
|
||||
{
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Azure.AI.OpenAI" Version="2.2.0-beta.4" />
|
||||
<PackageReference Include="Volo.Abp.Ddd.Domain" Version="$(AbpVersion)" />
|
||||
|
||||
<PackageReference Include="Volo.Abp.DistributedLocking" Version="$(AbpVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
||||
Reference in New Issue
Block a user