Working LLM chatbot

This commit is contained in:
Myx
2024-06-01 13:16:19 +02:00
parent a703f7f9a2
commit ddf5bfdd3d
11 changed files with 122 additions and 79 deletions

View File

@@ -1,64 +1,62 @@
using System.Text; using System.Text;
using Discord.WebSocket; using Discord.WebSocket;
using Lunaris2.Handler.MusicPlayer;
using Lunaris2.SlashCommand;
using MediatR; using MediatR;
using Microsoft.Extensions.Options;
using OllamaSharp; using OllamaSharp;
using OllamaSharp.Models;
namespace Lunaris2.Handler.ChatCommand namespace Lunaris2.Handler.ChatCommand
{ {
public record ChatCommand(SocketSlashCommand Message) : IRequest; public record ChatCommand(SocketMessage Message, string FilteredMessage) : IRequest;
public class ChatHandler : IRequestHandler<ChatCommand> public class ChatHandler : IRequestHandler<ChatCommand>
{ {
private readonly Uri _uri = new("http://192.168.50.54:11434");
private readonly OllamaApiClient _ollama; private readonly OllamaApiClient _ollama;
private SocketSlashCommand _context; private readonly Dictionary<ulong, Chat?> _chatContexts = new();
public ChatHandler() public ChatHandler(IOptions<ChatSettings> chatSettings)
{ {
_ollama = new OllamaApiClient(_uri) var uri = new Uri(chatSettings.Value.Url);
_ollama = new OllamaApiClient(uri)
{ {
SelectedModel = "lunaris" SelectedModel = chatSettings.Value.Model
}; };
} }
public async Task Handle(ChatCommand command, CancellationToken cancellationToken) public async Task Handle(ChatCommand command, CancellationToken cancellationToken)
{ {
_context = command.Message; var channelId = command.Message.Channel.Id;
_chatContexts.TryAdd(channelId, null);
var userMessage = _context.GetOptionValueByName(Option.Input);
var userMessage = command.FilteredMessage;
using var setTyping = command.Message.Channel.EnterTypingState(); using var setTyping = command.Message.Channel.EnterTypingState();
await command.Message.DeferAsync();
if (string.IsNullOrWhiteSpace(userMessage)) if (string.IsNullOrWhiteSpace(userMessage))
{ {
await command.Message.ModifyOriginalResponseAsync(properties => properties.Content = "Am I expected to read your mind?"); await command.Message.Channel.SendMessageAsync("Am I expected to read your mind?");
setTyping.Dispose(); setTyping.Dispose();
return; return;
} }
var response = await GenerateResponse(userMessage, cancellationToken); var response = await GenerateResponse(userMessage, channelId, cancellationToken);
await command.Message.ModifyOriginalResponseAsync(properties => properties.Content = response); await command.Message.Channel.SendMessageAsync(response);
setTyping.Dispose();
} }
private async Task<string> GenerateResponse(string userMessage, CancellationToken cancellationToken) private async Task<string> GenerateResponse(string userMessage, ulong channelId, CancellationToken cancellationToken)
{ {
var response = new StringBuilder(); var response = new StringBuilder();
ConversationContext? chatContext = null;
chatContext = await _ollama.StreamCompletion( if (_chatContexts[channelId] == null)
userMessage, {
chatContext, _chatContexts[channelId] = _ollama.Chat(stream => response.Append(stream.Message?.Content ?? ""));
Streamer, }
cancellationToken: cancellationToken);
await _chatContexts[channelId].Send(userMessage, cancellationToken);
return response.ToString(); return response.ToString();
void Streamer(GenerateCompletionResponseStream stream) =>
response.Append(stream.Response);
} }
} }
} }

View File

@@ -0,0 +1,7 @@
namespace Lunaris2.Handler.ChatCommand;
public class ChatSettings
{
public string Url { get; set; }
public string Model { get; set; }
}

View File

@@ -1,37 +1,37 @@
using Lunaris2.Handler.GoodByeCommand; using System.Text.RegularExpressions;
using Lunaris2.Handler.MusicPlayer.JoinCommand; using Discord.WebSocket;
using Lunaris2.Handler.MusicPlayer.PlayCommand;
using Lunaris2.Handler.MusicPlayer.SkipCommand;
using Lunaris2.Notification; using Lunaris2.Notification;
using Lunaris2.SlashCommand;
using MediatR; using MediatR;
namespace Lunaris2.Handler; namespace Lunaris2.Handler;
public class MessageReceivedHandler(ISender mediator) : INotificationHandler<MessageReceivedNotification> public class MessageReceivedHandler : INotificationHandler<MessageReceivedNotification>
{ {
private readonly DiscordSocketClient _client;
private readonly ISender _mediatir;
public MessageReceivedHandler(DiscordSocketClient client, ISender mediatir)
{
_client = client;
_mediatir = mediatir;
}
public async Task Handle(MessageReceivedNotification notification, CancellationToken cancellationToken) public async Task Handle(MessageReceivedNotification notification, CancellationToken cancellationToken)
{ {
switch (notification.Message.CommandName) await BotMentioned(notification, cancellationToken);
}
private async Task BotMentioned(MessageReceivedNotification notification, CancellationToken cancellationToken)
{
if (notification.Message.MentionedUsers.Any(user => user.Id == _client.CurrentUser.Id))
{ {
case Command.Hello.Name: // The bot was mentioned
await mediator.Send(new HelloCommand.HelloCommand(notification.Message), cancellationToken); const string pattern = "<.*?>";
break; const string replacement = "";
case Command.Goodbye.Name: var regex = new Regex(pattern);
await mediator.Send(new GoodbyeCommand(notification.Message), cancellationToken); var messageContent = regex.Replace(notification.Message.Content, replacement);
break;
case Command.Join.Name: await _mediatir.Send(new ChatCommand.ChatCommand(notification.Message, messageContent), cancellationToken);
await mediator.Send(new JoinCommand(notification.Message), cancellationToken);
break;
case Command.Play.Name:
await mediator.Send(new PlayCommand(notification.Message), cancellationToken);
break;
case Command.Skip.Name:
await mediator.Send(new SkipCommand(notification.Message), cancellationToken);
break;
case Command.Chat.Name:
await mediator.Send(new ChatCommand.ChatCommand(notification.Message), cancellationToken);
break;
} }
} }
} }

View File

@@ -0,0 +1,34 @@
using Lunaris2.Handler.GoodByeCommand;
using Lunaris2.Handler.MusicPlayer.JoinCommand;
using Lunaris2.Handler.MusicPlayer.PlayCommand;
using Lunaris2.Handler.MusicPlayer.SkipCommand;
using Lunaris2.Notification;
using Lunaris2.SlashCommand;
using MediatR;
namespace Lunaris2.Handler;
public class SlashCommandReceivedHandler(ISender mediator) : INotificationHandler<SlashCommandReceivedNotification>
{
public async Task Handle(SlashCommandReceivedNotification notification, CancellationToken cancellationToken)
{
switch (notification.Message.CommandName)
{
case Command.Hello.Name:
await mediator.Send(new HelloCommand.HelloCommand(notification.Message), cancellationToken);
break;
case Command.Goodbye.Name:
await mediator.Send(new GoodbyeCommand(notification.Message), cancellationToken);
break;
case Command.Join.Name:
await mediator.Send(new JoinCommand(notification.Message), cancellationToken);
break;
case Command.Play.Name:
await mediator.Send(new PlayCommand(notification.Message), cancellationToken);
break;
case Command.Skip.Name:
await mediator.Send(new SkipCommand(notification.Message), cancellationToken);
break;
}
}
}

View File

@@ -19,13 +19,19 @@ public class DiscordEventListener(DiscordSocketClient client, IServiceScopeFacto
public async Task StartAsync() public async Task StartAsync()
{ {
client.SlashCommandExecuted += OnMessageReceivedAsync; client.SlashCommandExecuted += OnSlashCommandRecievedAsync;
client.MessageReceived += OnMessageReceivedAsync;
await Task.CompletedTask; await Task.CompletedTask;
} }
private async Task OnMessageReceivedAsync(SocketSlashCommand arg) private async Task OnMessageReceivedAsync(SocketMessage arg)
{ {
await Mediator.Publish(new MessageReceivedNotification(arg), _cancellationToken); await Mediator.Publish(new MessageReceivedNotification(arg), _cancellationToken);
} }
private async Task OnSlashCommandRecievedAsync(SocketSlashCommand arg)
{
await Mediator.Publish(new SlashCommandReceivedNotification(arg), _cancellationToken);
}
} }

View File

@@ -3,7 +3,7 @@ using MediatR;
namespace Lunaris2.Notification; namespace Lunaris2.Notification;
public class MessageReceivedNotification(SocketSlashCommand message) : INotification public class MessageReceivedNotification(SocketMessage message) : INotification
{ {
public SocketSlashCommand Message { get; } = message ?? throw new ArgumentNullException(nameof(message)); public SocketMessage Message { get; } = message ?? throw new ArgumentNullException(nameof(message));
} }

View File

@@ -0,0 +1,9 @@
using Discord.WebSocket;
using MediatR;
namespace Lunaris2.Notification;
public class SlashCommandReceivedNotification(SocketSlashCommand message) : INotification
{
public SocketSlashCommand Message { get; } = message ?? throw new ArgumentNullException(nameof(message));
}

View File

@@ -3,6 +3,7 @@ using Discord;
using Discord.Commands; using Discord.Commands;
using Discord.Interactions; using Discord.Interactions;
using Discord.WebSocket; using Discord.WebSocket;
using Lunaris2.Handler.ChatCommand;
using Lunaris2.Handler.MusicPlayer; using Lunaris2.Handler.MusicPlayer;
using Lunaris2.Notification; using Lunaris2.Notification;
using Lunaris2.SlashCommand; using Lunaris2.SlashCommand;
@@ -54,7 +55,9 @@ public class Program
nodeConfiguration.Authorization = configuration["LavaLinkPassword"]; nodeConfiguration.Authorization = configuration["LavaLinkPassword"];
}) })
.AddSingleton<LavaNode>() .AddSingleton<LavaNode>()
.AddSingleton<MusicEmbed>(); .AddSingleton<MusicEmbed>()
.AddSingleton<ChatSettings>()
.Configure<ChatSettings>(configuration.GetSection("LLM"));
client.Ready += () => Client_Ready(client); client.Ready += () => Client_Ready(client);
client.Log += Log; client.Log += Log;

View File

@@ -21,23 +21,6 @@ public static class Command
public const string Description = "Say goodbye to the bot!"; public const string Description = "Say goodbye to the bot!";
} }
public static class Chat
{
public const string Name = "chat";
public const string Description = "Chat with the bot!";
public static readonly List<SlashCommandOptionBuilder>? Options = new()
{
new SlashCommandOptionBuilder
{
Name = "message",
Description = "Chat with Lunaris",
Type = ApplicationCommandOptionType.String,
IsRequired = true
}
};
}
public static class Join public static class Join
{ {
public const string Name = "join"; public const string Name = "join";

View File

@@ -12,8 +12,7 @@ public static class SlashCommandRegistration
RegisterCommand(client, Command.Join.Name, Command.Join.Description); RegisterCommand(client, Command.Join.Name, Command.Join.Description);
RegisterCommand(client, Command.Skip.Name, Command.Skip.Description); RegisterCommand(client, Command.Skip.Name, Command.Skip.Description);
RegisterCommand(client, Command.Play.Name, Command.Play.Description, Command.Play.Options); RegisterCommand(client, Command.Play.Name, Command.Play.Description, Command.Play.Options);
RegisterCommand(client, Command.Stop.Name, Command.Stop.Description); RegisterCommand(client, Command.Stop.Name, Command.Stop.Description);
RegisterCommand(client, Command.Chat.Name, Command.Chat.Description, Command.Chat.Options);
} }
private static void RegisterCommand( private static void RegisterCommand(

View File

@@ -10,4 +10,8 @@
"LavaLinkPassword": "youshallnotpass", "LavaLinkPassword": "youshallnotpass",
"LavaLinkHostname": "127.0.0.1", "LavaLinkHostname": "127.0.0.1",
"LavaLinkPort": 2333 "LavaLinkPort": 2333
"LLM": {
"Url": "http://192.168.50.54:11434",
"Model": "gemma"
}
} }