From d38fef08bb8fd114bffb8fa9fe94980c1998581f Mon Sep 17 00:00:00 2001 From: Houmgaor Date: Fri, 27 Feb 2026 11:45:20 +0100 Subject: [PATCH] refactor(discordbot): introduce Session interface for testability MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract a Session interface from *discordgo.Session so DiscordBot methods can be tested with a mock — no live Discord connection required. Add AddHandler, RegisterCommands, and UserID methods to DiscordBot so external callers (main.go, sys_channel_server.go) no longer reach through bot.Session directly. Rewrite tests with a mockSession, raising discordbot coverage from 12.5% to 66.7%. --- main.go | 3 +- server/channelserver/sys_channel_server.go | 4 +- server/discordbot/discord_bot.go | 42 +- server/discordbot/discord_bot_test.go | 639 +++++++++------------ 4 files changed, 315 insertions(+), 373 deletions(-) diff --git a/main.go b/main.go index 84c077887..edee6de24 100644 --- a/main.go +++ b/main.go @@ -63,8 +63,7 @@ func setupDiscordBot(config *cfg.Config, logger *zap.Logger) *discordbot.Discord preventClose(config, fmt.Sprintf("Discord: Failed to start, %s", err.Error())) } - _, err = bot.Session.ApplicationCommandBulkOverwrite(bot.Session.State.User.ID, "", discordbot.Commands) - if err != nil { + if err = bot.RegisterCommands(); err != nil { preventClose(config, fmt.Sprintf("Discord: Failed to start, %s", err.Error())) } diff --git a/server/channelserver/sys_channel_server.go b/server/channelserver/sys_channel_server.go index 97d86de52..1091c9c8c 100644 --- a/server/channelserver/sys_channel_server.go +++ b/server/channelserver/sys_channel_server.go @@ -208,8 +208,8 @@ func (s *Server) Start() error { // Start the discord bot for chat integration. if s.erupeConfig.Discord.Enabled && s.discordBot != nil { - s.discordBot.Session.AddHandler(s.onDiscordMessage) - s.discordBot.Session.AddHandler(s.onInteraction) + s.discordBot.AddHandler(s.onDiscordMessage) + s.discordBot.AddHandler(s.onInteraction) } return nil diff --git a/server/discordbot/discord_bot.go b/server/discordbot/discord_bot.go index 5dca4731c..4b4bb0010 100644 --- a/server/discordbot/discord_bot.go +++ b/server/discordbot/discord_bot.go @@ -37,12 +37,24 @@ var Commands = []*discordgo.ApplicationCommand{ }, } +// Session abstracts the discordgo.Session methods used by DiscordBot, +// allowing tests to inject a mock without a live Discord connection. +type Session interface { + Open() error + Channel(channelID string, options ...discordgo.RequestOption) (*discordgo.Channel, error) + User(userID string, options ...discordgo.RequestOption) (*discordgo.User, error) + ChannelMessageSend(channelID string, content string, options ...discordgo.RequestOption) (*discordgo.Message, error) + AddHandler(handler interface{}) func() + ApplicationCommandBulkOverwrite(appID string, guildID string, commands []*discordgo.ApplicationCommand, options ...discordgo.RequestOption) ([]*discordgo.ApplicationCommand, error) +} + // DiscordBot manages a Discord session and provides methods for relaying // messages between the game server and a configured Discord channel. type DiscordBot struct { - Session *discordgo.Session + Session Session config *cfg.Config logger *zap.Logger + userID string MainGuild *discordgo.Guild RelayChannel *discordgo.Channel } @@ -84,11 +96,31 @@ func NewDiscordBot(options Options) (discordBot *DiscordBot, err error) { return } -// Start opens the websocket connection to Discord. -func (bot *DiscordBot) Start() (err error) { - err = bot.Session.Open() +// Start opens the websocket connection to Discord and caches the bot's user ID. +func (bot *DiscordBot) Start() error { + if err := bot.Session.Open(); err != nil { + return err + } + if ds, ok := bot.Session.(*discordgo.Session); ok && ds.State != nil && ds.State.User != nil { + bot.userID = ds.State.User.ID + } + return nil +} - return +// UserID returns the bot's Discord user ID, populated after Start succeeds. +func (bot *DiscordBot) UserID() string { + return bot.userID +} + +// RegisterCommands bulk-overwrites the global slash commands for this bot. +func (bot *DiscordBot) RegisterCommands() error { + _, err := bot.Session.ApplicationCommandBulkOverwrite(bot.userID, "", Commands) + return err +} + +// AddHandler registers an event handler on the underlying Discord session. +func (bot *DiscordBot) AddHandler(handler interface{}) func() { + return bot.Session.AddHandler(handler) } // NormalizeDiscordMessage replaces all mentions to real name from the message. diff --git a/server/discordbot/discord_bot_test.go b/server/discordbot/discord_bot_test.go index 200e4f178..964e19fea 100644 --- a/server/discordbot/discord_bot_test.go +++ b/server/discordbot/discord_bot_test.go @@ -1,10 +1,283 @@ package discordbot import ( + "errors" + cfg "erupe-ce/config" "regexp" "testing" + + "github.com/bwmarrin/discordgo" + "go.uber.org/zap" ) +// mockSession implements the Session interface for testing. +type mockSession struct { + openErr error + channelResult *discordgo.Channel + channelErr error + userResults map[string]*discordgo.User + userErr error + messageSentTo string + messageSentContent string + messageErr error + addHandlerCalls int + bulkOverwriteAppID string + bulkOverwriteCommands []*discordgo.ApplicationCommand + bulkOverwriteErr error +} + +func (m *mockSession) Open() error { + return m.openErr +} + +func (m *mockSession) Channel(_ string, _ ...discordgo.RequestOption) (*discordgo.Channel, error) { + return m.channelResult, m.channelErr +} + +func (m *mockSession) User(userID string, _ ...discordgo.RequestOption) (*discordgo.User, error) { + if m.userResults != nil { + if u, ok := m.userResults[userID]; ok { + return u, nil + } + } + return nil, m.userErr +} + +func (m *mockSession) ChannelMessageSend(channelID string, content string, _ ...discordgo.RequestOption) (*discordgo.Message, error) { + m.messageSentTo = channelID + m.messageSentContent = content + return &discordgo.Message{}, m.messageErr +} + +func (m *mockSession) AddHandler(_ interface{}) func() { + m.addHandlerCalls++ + return func() {} +} + +func (m *mockSession) ApplicationCommandBulkOverwrite(appID string, _ string, commands []*discordgo.ApplicationCommand, _ ...discordgo.RequestOption) ([]*discordgo.ApplicationCommand, error) { + m.bulkOverwriteAppID = appID + m.bulkOverwriteCommands = commands + return commands, m.bulkOverwriteErr +} + +func newTestBot(session *mockSession) *DiscordBot { + return &DiscordBot{ + Session: session, + config: &cfg.Config{}, + logger: zap.NewNop(), + } +} + +func TestStart_Success(t *testing.T) { + ms := &mockSession{} + bot := newTestBot(ms) + + if err := bot.Start(); err != nil { + t.Fatalf("Start() unexpected error: %v", err) + } +} + +func TestStart_OpenError(t *testing.T) { + ms := &mockSession{openErr: errors.New("connection refused")} + bot := newTestBot(ms) + + err := bot.Start() + if err == nil { + t.Fatal("Start() expected error, got nil") + } + if err.Error() != "connection refused" { + t.Errorf("Start() error = %q, want %q", err.Error(), "connection refused") + } +} + +func TestNormalizeDiscordMessage(t *testing.T) { + tests := []struct { + name string + users map[string]*discordgo.User + userErr error + message string + expected string + }{ + { + name: "replace user mention with username", + users: map[string]*discordgo.User{ + "123456789012345678": {Username: "TestUser"}, + }, + message: "Hello <@123456789012345678>!", + expected: "Hello @TestUser!", + }, + { + name: "replace nickname mention", + users: map[string]*discordgo.User{ + "123456789012345678": {Username: "NickUser"}, + }, + message: "Hello <@!123456789012345678>!", + expected: "Hello @NickUser!", + }, + { + name: "unknown user fallback", + userErr: errors.New("not found"), + message: "Hello <@123456789012345678>!", + expected: "Hello @unknown!", + }, + { + name: "simple emoji preserved", + message: "Hello :smile:!", + expected: "Hello :smile:!", + }, + { + name: "custom emoji normalized", + message: "Nice <:custom:123456789012345678>", + expected: "Nice :custom:", + }, + { + name: "animated emoji normalized", + message: "Fun ", + expected: "Fun :dance:", + }, + { + name: "mixed mentions and emoji", + users: map[string]*discordgo.User{ + "111111111111111111": {Username: "Alice"}, + }, + message: "<@111111111111111111> says :wave:", + expected: "@Alice says :wave:", + }, + { + name: "plain text unchanged", + message: "Hello World", + expected: "Hello World", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ms := &mockSession{ + userResults: tt.users, + userErr: tt.userErr, + } + bot := newTestBot(ms) + result := bot.NormalizeDiscordMessage(tt.message) + if result != tt.expected { + t.Errorf("NormalizeDiscordMessage() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestRealtimeChannelSend_NilRelayChannel(t *testing.T) { + ms := &mockSession{} + bot := newTestBot(ms) + bot.RelayChannel = nil + + if err := bot.RealtimeChannelSend("test"); err != nil { + t.Fatalf("RealtimeChannelSend() unexpected error: %v", err) + } + if ms.messageSentTo != "" { + t.Error("RealtimeChannelSend() should not send when RelayChannel is nil") + } +} + +func TestRealtimeChannelSend_Success(t *testing.T) { + ms := &mockSession{} + bot := newTestBot(ms) + bot.RelayChannel = &discordgo.Channel{ID: "chan123"} + + if err := bot.RealtimeChannelSend("hello"); err != nil { + t.Fatalf("RealtimeChannelSend() unexpected error: %v", err) + } + if ms.messageSentTo != "chan123" { + t.Errorf("sent to channel %q, want %q", ms.messageSentTo, "chan123") + } + if ms.messageSentContent != "hello" { + t.Errorf("sent content %q, want %q", ms.messageSentContent, "hello") + } +} + +func TestRealtimeChannelSend_Error(t *testing.T) { + ms := &mockSession{messageErr: errors.New("send failed")} + bot := newTestBot(ms) + bot.RelayChannel = &discordgo.Channel{ID: "chan123"} + + err := bot.RealtimeChannelSend("hello") + if err == nil { + t.Fatal("RealtimeChannelSend() expected error, got nil") + } + if err.Error() != "send failed" { + t.Errorf("error = %q, want %q", err.Error(), "send failed") + } +} + +func TestRegisterCommands_Success(t *testing.T) { + ms := &mockSession{} + bot := newTestBot(ms) + bot.userID = "bot123" + + if err := bot.RegisterCommands(); err != nil { + t.Fatalf("RegisterCommands() unexpected error: %v", err) + } + if ms.bulkOverwriteAppID != "bot123" { + t.Errorf("appID = %q, want %q", ms.bulkOverwriteAppID, "bot123") + } + if len(ms.bulkOverwriteCommands) != len(Commands) { + t.Errorf("commands count = %d, want %d", len(ms.bulkOverwriteCommands), len(Commands)) + } +} + +func TestRegisterCommands_Error(t *testing.T) { + ms := &mockSession{bulkOverwriteErr: errors.New("forbidden")} + bot := newTestBot(ms) + + err := bot.RegisterCommands() + if err == nil { + t.Fatal("RegisterCommands() expected error, got nil") + } +} + +func TestAddHandler(t *testing.T) { + ms := &mockSession{} + bot := newTestBot(ms) + + cleanup := bot.AddHandler(func() {}) + if cleanup == nil { + t.Fatal("AddHandler() returned nil cleanup func") + } + if ms.addHandlerCalls != 1 { + t.Errorf("addHandlerCalls = %d, want 1", ms.addHandlerCalls) + } +} + +func TestUserID(t *testing.T) { + bot := &DiscordBot{userID: "abc123"} + if bot.UserID() != "abc123" { + t.Errorf("UserID() = %q, want %q", bot.UserID(), "abc123") + } +} + +func TestCommands_Structure(t *testing.T) { + if len(Commands) != 2 { + t.Fatalf("expected 2 commands, got %d", len(Commands)) + } + + expectedNames := []string{"link", "password"} + for i, name := range expectedNames { + if Commands[i].Name != name { + t.Errorf("Commands[%d].Name = %q, want %q", i, Commands[i].Name, name) + } + if Commands[i].Description == "" { + t.Errorf("Commands[%d] (%s) has empty description", i, name) + } + if len(Commands[i].Options) == 0 { + t.Errorf("Commands[%d] (%s) has no options", i, name) + } + for _, opt := range Commands[i].Options { + if !opt.Required { + t.Errorf("Commands[%d] (%s) option %q should be required", i, name, opt.Name) + } + } + } +} + func TestReplaceTextAll(t *testing.T) { tests := []struct { name string @@ -24,12 +297,12 @@ func TestReplaceTextAll(t *testing.T) { }, { name: "replace multiple matches", - text: "Users @111111111111111111 and @222222222222222222", + text: "Users @111 and @222", regex: regexp.MustCompile(`@(\d+)`), handler: func(id string) string { return "@user_" + id }, - expected: "Users @user_111111111111111111 and @user_222222222222222222", + expected: "Users @user_111 and @user_222", }, { name: "no matches", @@ -40,33 +313,6 @@ func TestReplaceTextAll(t *testing.T) { }, expected: "Hello World", }, - { - name: "replace with empty string", - text: "Remove @123456789012345678 this", - regex: regexp.MustCompile(`@(\d+)`), - handler: func(id string) string { - return "" - }, - expected: "Remove this", - }, - { - name: "replace emoji syntax", - text: "Hello :smile: and :wave:", - regex: regexp.MustCompile(`:(\w+):`), - handler: func(emoji string) string { - return "[" + emoji + "]" - }, - expected: "Hello [smile] and [wave]", - }, - { - name: "complex replacement", - text: "Text with <@!123456789012345678> mention", - regex: regexp.MustCompile(`<@!?(\d+)>`), - handler: func(id string) string { - return "@user_" + id - }, - expected: "Text with @user_123456789012345678 mention", - }, } for _, tt := range tests { @@ -78,338 +324,3 @@ func TestReplaceTextAll(t *testing.T) { }) } } - -func TestReplaceTextAll_UserMentionPattern(t *testing.T) { - // Test the actual user mention regex used in NormalizeDiscordMessage - userRegex := regexp.MustCompile(`<@!?(\d{17,19})>`) - - tests := []struct { - name string - text string - expected []string // Expected captured IDs - }{ - { - name: "standard mention", - text: "<@123456789012345678>", - expected: []string{"123456789012345678"}, - }, - { - name: "nickname mention", - text: "<@!123456789012345678>", - expected: []string{"123456789012345678"}, - }, - { - name: "multiple mentions", - text: "<@123456789012345678> and <@!987654321098765432>", - expected: []string{"123456789012345678", "987654321098765432"}, - }, - { - name: "17 digit ID", - text: "<@12345678901234567>", - expected: []string{"12345678901234567"}, - }, - { - name: "19 digit ID", - text: "<@1234567890123456789>", - expected: []string{"1234567890123456789"}, - }, - { - name: "invalid - too short", - text: "<@1234567890123456>", - expected: []string{}, - }, - { - name: "invalid - too long", - text: "<@12345678901234567890>", - expected: []string{}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - matches := userRegex.FindAllStringSubmatch(tt.text, -1) - if len(matches) != len(tt.expected) { - t.Fatalf("Expected %d matches, got %d", len(tt.expected), len(matches)) - } - for i, match := range matches { - if len(match) < 2 { - t.Fatalf("Match %d: expected capture group", i) - } - if match[1] != tt.expected[i] { - t.Errorf("Match %d: got ID %q, want %q", i, match[1], tt.expected[i]) - } - } - }) - } -} - -func TestReplaceTextAll_EmojiPattern(t *testing.T) { - // Test the actual emoji regex used in NormalizeDiscordMessage - emojiRegex := regexp.MustCompile(`(?:)?`) - - tests := []struct { - name string - text string - expectedName []string // Expected emoji names - }{ - { - name: "simple emoji", - text: ":smile:", - expectedName: []string{"smile"}, - }, - { - name: "custom emoji", - text: "<:customemoji:123456789012345678>", - expectedName: []string{"customemoji"}, - }, - { - name: "animated emoji", - text: "", - expectedName: []string{"animated"}, - }, - { - name: "multiple emojis", - text: ":wave: <:custom:123456789012345678> :smile:", - expectedName: []string{"wave", "custom", "smile"}, - }, - { - name: "emoji with underscores", - text: ":thumbs_up:", - expectedName: []string{"thumbs_up"}, - }, - { - name: "emoji with numbers", - text: ":emoji123:", - expectedName: []string{"emoji123"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - matches := emojiRegex.FindAllStringSubmatch(tt.text, -1) - if len(matches) != len(tt.expectedName) { - t.Fatalf("Expected %d matches, got %d", len(tt.expectedName), len(matches)) - } - for i, match := range matches { - if len(match) < 2 { - t.Fatalf("Match %d: expected capture group", i) - } - if match[1] != tt.expectedName[i] { - t.Errorf("Match %d: got name %q, want %q", i, match[1], tt.expectedName[i]) - } - } - }) - } -} - -func TestNormalizeDiscordMessage_Integration(t *testing.T) { - // Create a mock bot for testing the normalization logic - // Note: We can't fully test this without a real Discord session, - // but we can test the regex patterns and structure - tests := []struct { - name string - input string - contains []string // Strings that should be in the output - }{ - { - name: "plain text unchanged", - input: "Hello World", - contains: []string{"Hello World"}, - }, - { - name: "user mention format", - input: "Hello <@123456789012345678>", - // We can't test the actual replacement without a real Discord session - // but we can verify the pattern is matched - contains: []string{"Hello"}, - }, - { - name: "emoji format preserved", - input: "Hello :smile:", - contains: []string{"Hello", ":smile:"}, - }, - { - name: "mixed content", - input: "<@123456789012345678> sent :wave:", - contains: []string{"sent"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Test that the message contains expected parts - for _, expected := range tt.contains { - if len(expected) > 0 && !contains(tt.input, expected) { - t.Errorf("Input %q should contain %q", tt.input, expected) - } - } - }) - } -} - -func TestCommands_Structure(t *testing.T) { - // Test that the Commands slice is properly structured - if len(Commands) == 0 { - t.Error("Commands slice should not be empty") - } - - expectedCommands := map[string]bool{ - "link": false, - "password": false, - } - - for _, cmd := range Commands { - if cmd.Name == "" { - t.Error("Command should have a name") - } - if cmd.Description == "" { - t.Errorf("Command %q should have a description", cmd.Name) - } - - if _, exists := expectedCommands[cmd.Name]; exists { - expectedCommands[cmd.Name] = true - } - } - - // Verify expected commands exist - for name, found := range expectedCommands { - if !found { - t.Errorf("Expected command %q not found in Commands", name) - } - } -} - -func TestCommands_LinkCommand(t *testing.T) { - var linkCmd *struct { - Name string - Description string - Options []struct { - Type int - Name string - Description string - Required bool - } - } - - // Find the link command - for _, cmd := range Commands { - if cmd.Name == "link" { - // Verify structure - if cmd.Description == "" { - t.Error("Link command should have a description") - } - if len(cmd.Options) == 0 { - t.Error("Link command should have options") - } - - // Verify token option - for _, opt := range cmd.Options { - if opt.Name == "token" { - if !opt.Required { - t.Error("Token option should be required") - } - if opt.Description == "" { - t.Error("Token option should have a description") - } - return - } - } - t.Error("Link command should have a 'token' option") - } - } - - if linkCmd == nil { - t.Error("Link command not found") - } -} - -func TestCommands_PasswordCommand(t *testing.T) { - // Find the password command - for _, cmd := range Commands { - if cmd.Name == "password" { - // Verify structure - if cmd.Description == "" { - t.Error("Password command should have a description") - } - if len(cmd.Options) == 0 { - t.Error("Password command should have options") - } - - // Verify password option - for _, opt := range cmd.Options { - if opt.Name == "password" { - if !opt.Required { - t.Error("Password option should be required") - } - if opt.Description == "" { - t.Error("Password option should have a description") - } - return - } - } - t.Error("Password command should have a 'password' option") - } - } - - t.Error("Password command not found") -} - -func TestDiscordBotStruct(t *testing.T) { - // Test that the DiscordBot struct can be initialized - _ = &DiscordBot{ - Session: nil, // Can't create real session in tests - MainGuild: nil, - RelayChannel: nil, - } -} - -func TestOptionsStruct(t *testing.T) { - // Test that the Options struct can be initialized - opts := Options{ - Config: nil, - Logger: nil, - } - - // Just verify we can create the struct - _ = opts -} - -// Helper function -func contains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsHelper(s, substr)) -} - -func containsHelper(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} - -func BenchmarkReplaceTextAll(b *testing.B) { - text := "Message with <@123456789012345678> and <@!987654321098765432> mentions and :smile: :wave: emojis" - userRegex := regexp.MustCompile(`<@!?(\d{17,19})>`) - handler := func(id string) string { - return "@user_" + id - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ReplaceTextAll(text, userRegex, handler) - } -} - -func BenchmarkReplaceTextAll_NoMatches(b *testing.B) { - text := "Message with no mentions or special syntax at all, just plain text" - userRegex := regexp.MustCompile(`<@!?(\d{17,19})>`) - handler := func(id string) string { - return "@user_" + id - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ReplaceTextAll(text, userRegex, handler) - } -}