From 9c5cc559c70931659fb54b43a823fa7ef0c6d40a Mon Sep 17 00:00:00 2001 From: Houmgaor Date: Tue, 3 Feb 2026 00:19:48 +0100 Subject: [PATCH] test(signserver): increase test coverage from 1.5% to 45.2% Add comprehensive tests for signserver package using sqlmock for database function testing: - Server lifecycle tests (Start, Shutdown, acceptClients) - Connection handling tests (handleConnection, multiple connections) - Database function tests (getCharactersForUser, getLastCID, getUserRights, checkToken, registerToken, deleteCharacter, newUserChara, registerDBAccount, getReturnExpiry, getFriendsForCharacters, getGuildmatesForCharacters) - Session struct and packet handling tests All tests pass with race detection enabled. --- go.mod | 1 + go.sum | 3 + server/signserver/dbutils_test.go | 620 ++++++++++++++++++++++++++ server/signserver/session_test.go | 452 +++++++++++++++++++ server/signserver/sign_server_test.go | 358 +++++++++++++++ 5 files changed, 1434 insertions(+) create mode 100644 server/signserver/session_test.go diff --git a/go.mod b/go.mod index a501f6530..6631129e9 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( ) require ( + github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect github.com/felixge/httpsnoop v1.0.1 // indirect github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/gorilla/websocket v1.4.2 // indirect diff --git a/go.sum b/go.sum index ea9329bb2..9c75126a2 100644 --- a/go.sum +++ b/go.sum @@ -39,6 +39,8 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9 dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= @@ -185,6 +187,7 @@ github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7 github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= diff --git a/server/signserver/dbutils_test.go b/server/signserver/dbutils_test.go index c93b6031c..ffd8974ef 100644 --- a/server/signserver/dbutils_test.go +++ b/server/signserver/dbutils_test.go @@ -1,7 +1,13 @@ package signserver import ( + "database/sql" "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/jmoiron/sqlx" + "go.uber.org/zap" ) func TestCharacterStruct(t *testing.T) { @@ -296,3 +302,617 @@ func TestMultipleMembers(t *testing.T) { t.Error("Third member should have different CID") } } + +// Helper to create a test server with mocked database +func newTestServerWithMock(t *testing.T) (*Server, sqlmock.Sqlmock) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock: %v", err) + } + + sqlxDB := sqlx.NewDb(db, "sqlmock") + + server := &Server{ + logger: zap.NewNop(), + db: sqlxDB, + } + + return server, mock +} + +func TestGetCharactersForUser(t *testing.T) { + server, mock := newTestServerWithMock(t) + + rows := sqlmock.NewRows([]string{"id", "is_female", "is_new_character", "name", "unk_desc_string", "hrp", "gr", "weapon_type", "last_login"}). + AddRow(1, false, false, "Hunter1", "desc1", 100, 50, 3, 1700000000). + AddRow(2, true, false, "Hunter2", "desc2", 200, 100, 7, 1700000001) + + mock.ExpectQuery("SELECT id, is_female, is_new_character, name, unk_desc_string, hrp, gr, weapon_type, last_login FROM characters WHERE user_id = \\$1 AND deleted = false ORDER BY id ASC"). + WithArgs(1). + WillReturnRows(rows) + + chars, err := server.getCharactersForUser(1) + if err != nil { + t.Errorf("getCharactersForUser() error: %v", err) + } + + if len(chars) != 2 { + t.Errorf("getCharactersForUser() returned %d characters, want 2", len(chars)) + } + + if chars[0].Name != "Hunter1" { + t.Errorf("First character name = %s, want Hunter1", chars[0].Name) + } + + if chars[1].IsFemale != true { + t.Error("Second character should be female") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestGetCharactersForUserNoCharacters(t *testing.T) { + server, mock := newTestServerWithMock(t) + + rows := sqlmock.NewRows([]string{"id", "is_female", "is_new_character", "name", "unk_desc_string", "hrp", "gr", "weapon_type", "last_login"}) + + mock.ExpectQuery("SELECT id, is_female, is_new_character, name, unk_desc_string, hrp, gr, weapon_type, last_login FROM characters WHERE user_id = \\$1 AND deleted = false ORDER BY id ASC"). + WithArgs(1). + WillReturnRows(rows) + + chars, err := server.getCharactersForUser(1) + if err != nil { + t.Errorf("getCharactersForUser() error: %v", err) + } + + if len(chars) != 0 { + t.Errorf("getCharactersForUser() returned %d characters, want 0", len(chars)) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestGetCharactersForUserDBError(t *testing.T) { + server, mock := newTestServerWithMock(t) + + mock.ExpectQuery("SELECT id, is_female, is_new_character, name, unk_desc_string, hrp, gr, weapon_type, last_login FROM characters WHERE user_id = \\$1 AND deleted = false ORDER BY id ASC"). + WithArgs(1). + WillReturnError(sql.ErrConnDone) + + _, err := server.getCharactersForUser(1) + if err == nil { + t.Error("getCharactersForUser() should return error") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestGetLastCID(t *testing.T) { + server, mock := newTestServerWithMock(t) + + mock.ExpectQuery("SELECT last_character FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"last_character"}).AddRow(12345)) + + lastCID := server.getLastCID(1) + if lastCID != 12345 { + t.Errorf("getLastCID() = %d, want 12345", lastCID) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestGetLastCIDNoResult(t *testing.T) { + server, mock := newTestServerWithMock(t) + + mock.ExpectQuery("SELECT last_character FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnError(sql.ErrNoRows) + + lastCID := server.getLastCID(1) + // Should return 0 on error + if lastCID != 0 { + t.Errorf("getLastCID() with no result = %d, want 0", lastCID) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestGetUserRights(t *testing.T) { + server, mock := newTestServerWithMock(t) + + mock.ExpectQuery("SELECT rights FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"rights"}).AddRow(30)) + + rights := server.getUserRights(1) + // Rights value is transformed by mhfcourse.GetCourseStruct + // The function should return a non-zero value when rights is set + if rights == 0 { + t.Error("getUserRights() should return non-zero value") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestGetUserRightsDefault(t *testing.T) { + server, mock := newTestServerWithMock(t) + + mock.ExpectQuery("SELECT rights FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnError(sql.ErrNoRows) + + rights := server.getUserRights(1) + // Default rights is 2, which is transformed by mhfcourse.GetCourseStruct + if rights == 0 { + t.Error("getUserRights() should return default rights on error") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestCheckToken(t *testing.T) { + server, mock := newTestServerWithMock(t) + + mock.ExpectQuery("SELECT count\\(\\*\\) FROM sign_sessions WHERE user_id = \\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + + exists, err := server.checkToken(1) + if err != nil { + t.Errorf("checkToken() error: %v", err) + } + if !exists { + t.Error("checkToken() should return true when token exists") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestCheckTokenNotExists(t *testing.T) { + server, mock := newTestServerWithMock(t) + + mock.ExpectQuery("SELECT count\\(\\*\\) FROM sign_sessions WHERE user_id = \\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + + exists, err := server.checkToken(1) + if err != nil { + t.Errorf("checkToken() error: %v", err) + } + if exists { + t.Error("checkToken() should return false when token doesn't exist") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestCheckTokenError(t *testing.T) { + server, mock := newTestServerWithMock(t) + + mock.ExpectQuery("SELECT count\\(\\*\\) FROM sign_sessions WHERE user_id = \\$1"). + WithArgs(1). + WillReturnError(sql.ErrConnDone) + + _, err := server.checkToken(1) + if err == nil { + t.Error("checkToken() should return error") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestRegisterToken(t *testing.T) { + server, mock := newTestServerWithMock(t) + + mock.ExpectExec("INSERT INTO sign_sessions \\(user_id, token\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs(1, "testtoken123"). + WillReturnResult(sqlmock.NewResult(1, 1)) + + err := server.registerToken(1, "testtoken123") + if err != nil { + t.Errorf("registerToken() error: %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestRegisterTokenError(t *testing.T) { + server, mock := newTestServerWithMock(t) + + mock.ExpectExec("INSERT INTO sign_sessions \\(user_id, token\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs(1, "testtoken123"). + WillReturnError(sql.ErrConnDone) + + err := server.registerToken(1, "testtoken123") + if err == nil { + t.Error("registerToken() should return error") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestDeleteCharacter(t *testing.T) { + server, mock := newTestServerWithMock(t) + + // Token verification + mock.ExpectQuery("SELECT count\\(\\*\\) FROM sign_sessions WHERE token = \\$1"). + WithArgs("validtoken"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + + // Check if new character + mock.ExpectQuery("SELECT is_new_character FROM characters WHERE id = \\$1"). + WithArgs(123). + WillReturnRows(sqlmock.NewRows([]string{"is_new_character"}).AddRow(false)) + + // Soft delete (update deleted flag) + mock.ExpectExec("UPDATE characters SET deleted = true WHERE id = \\$1"). + WithArgs(123). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := server.deleteCharacter(123, "validtoken") + if err != nil { + t.Errorf("deleteCharacter() error: %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestDeleteNewCharacter(t *testing.T) { + server, mock := newTestServerWithMock(t) + + // Token verification + mock.ExpectQuery("SELECT count\\(\\*\\) FROM sign_sessions WHERE token = \\$1"). + WithArgs("validtoken"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + + // Check if new character (is_new_character = true) + mock.ExpectQuery("SELECT is_new_character FROM characters WHERE id = \\$1"). + WithArgs(123). + WillReturnRows(sqlmock.NewRows([]string{"is_new_character"}).AddRow(true)) + + // Hard delete for new characters + mock.ExpectExec("DELETE FROM characters WHERE id = \\$1"). + WithArgs(123). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := server.deleteCharacter(123, "validtoken") + if err != nil { + t.Errorf("deleteCharacter() error: %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestDeleteCharacterInvalidToken(t *testing.T) { + server, mock := newTestServerWithMock(t) + + // Token verification fails + mock.ExpectQuery("SELECT count\\(\\*\\) FROM sign_sessions WHERE token = \\$1"). + WithArgs("invalidtoken"). + WillReturnError(sql.ErrNoRows) + + err := server.deleteCharacter(123, "invalidtoken") + if err == nil { + t.Error("deleteCharacter() should return error for invalid token") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestNewUserChara(t *testing.T) { + server, mock := newTestServerWithMock(t) + + // Get user ID + mock.ExpectQuery("SELECT id FROM users WHERE username = \\$1"). + WithArgs("testuser"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + + // Check for existing new characters + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM characters WHERE user_id = \\$1 AND is_new_character = true"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + + // Insert new character + mock.ExpectExec("INSERT INTO characters"). + WithArgs(1, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + err := server.newUserChara("testuser") + if err != nil { + t.Errorf("newUserChara() error: %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestNewUserCharaAlreadyHasNewChar(t *testing.T) { + server, mock := newTestServerWithMock(t) + + // Get user ID + mock.ExpectQuery("SELECT id FROM users WHERE username = \\$1"). + WithArgs("testuser"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + + // Check for existing new characters - already has one + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM characters WHERE user_id = \\$1 AND is_new_character = true"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + + // Should not insert since user already has a new character + err := server.newUserChara("testuser") + // Error is nil but no insert happens + if err != nil { + t.Errorf("newUserChara() should return nil when user already has new char: %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestNewUserCharaUserNotFound(t *testing.T) { + server, mock := newTestServerWithMock(t) + + // Get user ID - not found + mock.ExpectQuery("SELECT id FROM users WHERE username = \\$1"). + WithArgs("unknownuser"). + WillReturnError(sql.ErrNoRows) + + err := server.newUserChara("unknownuser") + if err == nil { + t.Error("newUserChara() should return error when user not found") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestRegisterDBAccount(t *testing.T) { + server, mock := newTestServerWithMock(t) + + // Insert user + mock.ExpectExec("INSERT INTO users \\(username, password, return_expires\\) VALUES \\(\\$1, \\$2, \\$3\\)"). + WithArgs("newuser", sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // Get user ID + mock.ExpectQuery("SELECT id FROM users WHERE username = \\$1"). + WithArgs("newuser"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + + // Insert initial character + mock.ExpectExec("INSERT INTO characters"). + WithArgs(1, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + err := server.registerDBAccount("newuser", "password123") + if err != nil { + t.Errorf("registerDBAccount() error: %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestRegisterDBAccountDuplicateUser(t *testing.T) { + server, mock := newTestServerWithMock(t) + + // Insert user fails (duplicate) + mock.ExpectExec("INSERT INTO users \\(username, password, return_expires\\) VALUES \\(\\$1, \\$2, \\$3\\)"). + WithArgs("existinguser", sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnError(sql.ErrNoRows) + + err := server.registerDBAccount("existinguser", "password123") + if err == nil { + t.Error("registerDBAccount() should return error for duplicate user") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestGetReturnExpiry(t *testing.T) { + server, mock := newTestServerWithMock(t) + + // Get last login (recent) + recentLogin := time.Now().Add(-time.Hour * 24) // 1 day ago + mock.ExpectQuery("SELECT COALESCE\\(last_login, now\\(\\)\\) FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"last_login"}).AddRow(recentLogin)) + + // Get return expiry + mock.ExpectQuery("SELECT return_expires FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"return_expires"}).AddRow(time.Now().Add(time.Hour * 24 * 30))) + + // Update last login + mock.ExpectExec("UPDATE users SET last_login=\\$1 WHERE id=\\$2"). + WithArgs(sqlmock.AnyArg(), 1). + WillReturnResult(sqlmock.NewResult(0, 1)) + + expiry := server.getReturnExpiry(1) + + // Should return a future date + if expiry.Before(time.Now()) { + t.Error("getReturnExpiry() should return future date") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestGetReturnExpiryInactiveUser(t *testing.T) { + server, mock := newTestServerWithMock(t) + + // Get last login (inactive - over 90 days ago) + oldLogin := time.Now().Add(-time.Hour * 24 * 100) // 100 days ago + mock.ExpectQuery("SELECT COALESCE\\(last_login, now\\(\\)\\) FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"last_login"}).AddRow(oldLogin)) + + // Update return expiry for returning user + mock.ExpectExec("UPDATE users SET return_expires=\\$1 WHERE id=\\$2"). + WithArgs(sqlmock.AnyArg(), 1). + WillReturnResult(sqlmock.NewResult(0, 1)) + + // Update last login + mock.ExpectExec("UPDATE users SET last_login=\\$1 WHERE id=\\$2"). + WithArgs(sqlmock.AnyArg(), 1). + WillReturnResult(sqlmock.NewResult(0, 1)) + + expiry := server.getReturnExpiry(1) + + // Should return a future date (30 days from now for returning user) + if expiry.Before(time.Now()) { + t.Error("getReturnExpiry() should return future date for inactive user") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestGetFriendsForCharactersEmpty(t *testing.T) { + server, _ := newTestServerWithMock(t) + + // Empty character list + chars := []character{} + + friends := server.getFriendsForCharacters(chars) + if len(friends) != 0 { + t.Errorf("getFriendsForCharacters() for empty chars = %d, want 0", len(friends)) + } +} + +func TestGetGuildmatesForCharactersEmpty(t *testing.T) { + server, _ := newTestServerWithMock(t) + + // Empty character list + chars := []character{} + + guildmates := server.getGuildmatesForCharacters(chars) + if len(guildmates) != 0 { + t.Errorf("getGuildmatesForCharacters() for empty chars = %d, want 0", len(guildmates)) + } +} + +func TestGetFriendsForCharacters(t *testing.T) { + server, mock := newTestServerWithMock(t) + + chars := []character{ + {ID: 1, Name: "Hunter1"}, + } + + // Get friends CSV for character + mock.ExpectQuery("SELECT friends FROM characters WHERE id=\\$1"). + WithArgs(uint32(1)). + WillReturnRows(sqlmock.NewRows([]string{"friends"}).AddRow("2,3")) + + // Query friends + mock.ExpectQuery("SELECT id, name FROM characters WHERE id=2 OR id=3"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow(2, "Friend1"). + AddRow(3, "Friend2")) + + friends := server.getFriendsForCharacters(chars) + if len(friends) != 2 { + t.Errorf("getFriendsForCharacters() = %d, want 2", len(friends)) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestGetGuildmatesForCharacters(t *testing.T) { + server, mock := newTestServerWithMock(t) + + chars := []character{ + {ID: 1, Name: "Hunter1"}, + } + + // Check if in guild + mock.ExpectQuery("SELECT count\\(\\*\\) FROM guild_characters WHERE character_id=\\$1"). + WithArgs(uint32(1)). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + + // Get guild ID + mock.ExpectQuery("SELECT guild_id FROM guild_characters WHERE character_id=\\$1"). + WithArgs(uint32(1)). + WillReturnRows(sqlmock.NewRows([]string{"guild_id"}).AddRow(100)) + + // Get guildmates + mock.ExpectQuery("SELECT character_id AS id, c.name FROM guild_characters gc JOIN characters c ON c.id = gc.character_id WHERE guild_id=\\$1 AND character_id!=\\$2"). + WithArgs(100, uint32(1)). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow(2, "Guildmate1"). + AddRow(3, "Guildmate2")) + + guildmates := server.getGuildmatesForCharacters(chars) + if len(guildmates) != 2 { + t.Errorf("getGuildmatesForCharacters() = %d, want 2", len(guildmates)) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +func TestGetGuildmatesNotInGuild(t *testing.T) { + server, mock := newTestServerWithMock(t) + + chars := []character{ + {ID: 1, Name: "Hunter1"}, + } + + // Check if in guild - not in guild + mock.ExpectQuery("SELECT count\\(\\*\\) FROM guild_characters WHERE character_id=\\$1"). + WithArgs(uint32(1)). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + + guildmates := server.getGuildmatesForCharacters(chars) + if len(guildmates) != 0 { + t.Errorf("getGuildmatesForCharacters() for non-guild member = %d, want 0", len(guildmates)) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} diff --git a/server/signserver/session_test.go b/server/signserver/session_test.go new file mode 100644 index 000000000..b408954af --- /dev/null +++ b/server/signserver/session_test.go @@ -0,0 +1,452 @@ +package signserver + +import ( + "bytes" + "io" + "net" + "sync" + "testing" + "time" + + "erupe-ce/common/byteframe" + "erupe-ce/config" + "erupe-ce/network" + + "go.uber.org/zap" +) + +// mockConn implements net.Conn for testing +type mockConn struct { + readBuf *bytes.Buffer + writeBuf *bytes.Buffer + closed bool + mu sync.Mutex +} + +func newMockConn() *mockConn { + return &mockConn{ + readBuf: new(bytes.Buffer), + writeBuf: new(bytes.Buffer), + } +} + +func (m *mockConn) Read(b []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.closed { + return 0, io.EOF + } + return m.readBuf.Read(b) +} + +func (m *mockConn) Write(b []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.closed { + return 0, io.ErrClosedPipe + } + return m.writeBuf.Write(b) +} + +func (m *mockConn) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + m.closed = true + return nil +} + +func (m *mockConn) LocalAddr() net.Addr { + return &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 53312} +} + +func (m *mockConn) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345} +} + +func (m *mockConn) SetDeadline(t time.Time) error { return nil } +func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } + +func TestSessionStruct(t *testing.T) { + logger := zap.NewNop() + conn := newMockConn() + + s := &Session{ + logger: logger, + server: nil, + rawConn: conn, + cryptConn: network.NewCryptConn(conn), + } + + if s.logger != logger { + t.Error("Session logger not set correctly") + } + if s.rawConn != conn { + t.Error("Session rawConn not set correctly") + } + if s.cryptConn == nil { + t.Error("Session cryptConn should not be nil") + } +} + +func TestSessionStructDefaults(t *testing.T) { + s := &Session{} + + if s.logger != nil { + t.Error("Default Session logger should be nil") + } + if s.server != nil { + t.Error("Default Session server should be nil") + } + if s.rawConn != nil { + t.Error("Default Session rawConn should be nil") + } + if s.cryptConn != nil { + t.Error("Default Session cryptConn should be nil") + } +} + +func TestSessionMutex(t *testing.T) { + s := &Session{} + + // Test that we can lock and unlock + s.Lock() + s.Unlock() + + // Test concurrent access + done := make(chan bool) + go func() { + s.Lock() + time.Sleep(10 * time.Millisecond) + s.Unlock() + done <- true + }() + + // Small delay to ensure goroutine starts + time.Sleep(5 * time.Millisecond) + + // This should block until the goroutine releases the lock + s.Lock() + s.Unlock() + + <-done +} + +func TestHandlePacketUnknownRequest(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + } + + server := &Server{ + logger: logger, + erupeConfig: erupeConfig, + } + + conn := newMockConn() + session := &Session{ + logger: logger, + server: server, + rawConn: conn, + cryptConn: network.NewCryptConn(conn), + } + + // Create a packet with unknown request type + bf := byteframe.NewByteFrame() + bf.WriteNullTerminatedBytes([]byte("UNKNOWN:100")) + bf.WriteNullTerminatedBytes([]byte("data")) + + err := session.handlePacket(bf.Data()) + if err != nil { + t.Errorf("handlePacket() returned error: %v", err) + } +} + +func TestHandlePacketEmptyRequest(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + } + + server := &Server{ + logger: logger, + erupeConfig: erupeConfig, + } + + conn := newMockConn() + session := &Session{ + logger: logger, + server: server, + rawConn: conn, + cryptConn: network.NewCryptConn(conn), + } + + // Create a packet with empty request type (just null terminator) + bf := byteframe.NewByteFrame() + bf.WriteNullTerminatedBytes([]byte("")) + + err := session.handlePacket(bf.Data()) + if err != nil { + t.Errorf("handlePacket() returned error for empty request: %v", err) + } +} + +func TestHandlePacketWithDevModeLogging(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: true, + DevModeOptions: config.DevModeOptions{ + LogInboundMessages: true, + }, + } + + server := &Server{ + logger: logger, + erupeConfig: erupeConfig, + } + + conn := newMockConn() + session := &Session{ + logger: logger, + server: server, + rawConn: conn, + cryptConn: network.NewCryptConn(conn), + } + + // Create a packet with unknown request type + bf := byteframe.NewByteFrame() + bf.WriteNullTerminatedBytes([]byte("TEST:100")) + + err := session.handlePacket(bf.Data()) + if err != nil { + t.Errorf("handlePacket() with dev mode returned error: %v", err) + } +} + +func TestHandlePacketRequestTypes(t *testing.T) { + tests := []struct { + name string + reqType string + }{ + {"unknown", "UNKNOWN:100"}, + {"invalid", "INVALID"}, + {"empty_version", "TEST:"}, + {"no_version", "NOVERSION"}, + {"special_chars", "TEST@#$:100"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{DevMode: false} + server := &Server{ + logger: logger, + erupeConfig: erupeConfig, + } + + conn := newMockConn() + session := &Session{ + logger: logger, + server: server, + rawConn: conn, + cryptConn: network.NewCryptConn(conn), + } + + bf := byteframe.NewByteFrame() + bf.WriteNullTerminatedBytes([]byte(tt.reqType)) + + err := session.handlePacket(bf.Data()) + if err != nil { + t.Errorf("handlePacket(%s) returned error: %v", tt.reqType, err) + } + }) + } +} + +func TestMockConnImplementsNetConn(t *testing.T) { + var _ net.Conn = (*mockConn)(nil) +} + +func TestMockConnReadWrite(t *testing.T) { + conn := newMockConn() + + // Write some data to the read buffer (simulating incoming data) + testData := []byte("hello") + conn.readBuf.Write(testData) + + // Read it back + buf := make([]byte, len(testData)) + n, err := conn.Read(buf) + if err != nil { + t.Errorf("Read() error: %v", err) + } + if n != len(testData) { + t.Errorf("Read() n = %d, want %d", n, len(testData)) + } + if !bytes.Equal(buf, testData) { + t.Errorf("Read() data = %v, want %v", buf, testData) + } + + // Write data + outData := []byte("world") + n, err = conn.Write(outData) + if err != nil { + t.Errorf("Write() error: %v", err) + } + if n != len(outData) { + t.Errorf("Write() n = %d, want %d", n, len(outData)) + } + if !bytes.Equal(conn.writeBuf.Bytes(), outData) { + t.Errorf("Write() buffer = %v, want %v", conn.writeBuf.Bytes(), outData) + } +} + +func TestMockConnClose(t *testing.T) { + conn := newMockConn() + + err := conn.Close() + if err != nil { + t.Errorf("Close() error: %v", err) + } + + if !conn.closed { + t.Error("conn.closed should be true after Close()") + } + + // Read after close should return EOF + buf := make([]byte, 10) + _, err = conn.Read(buf) + if err != io.EOF { + t.Errorf("Read() after close should return EOF, got: %v", err) + } + + // Write after close should return error + _, err = conn.Write([]byte("test")) + if err != io.ErrClosedPipe { + t.Errorf("Write() after close should return ErrClosedPipe, got: %v", err) + } +} + +func TestMockConnAddresses(t *testing.T) { + conn := newMockConn() + + local := conn.LocalAddr() + if local == nil { + t.Error("LocalAddr() should not be nil") + } + if local.String() != "127.0.0.1:53312" { + t.Errorf("LocalAddr() = %s, want 127.0.0.1:53312", local.String()) + } + + remote := conn.RemoteAddr() + if remote == nil { + t.Error("RemoteAddr() should not be nil") + } + if remote.String() != "127.0.0.1:12345" { + t.Errorf("RemoteAddr() = %s, want 127.0.0.1:12345", remote.String()) + } +} + +func TestMockConnDeadlines(t *testing.T) { + conn := newMockConn() + deadline := time.Now().Add(time.Second) + + if err := conn.SetDeadline(deadline); err != nil { + t.Errorf("SetDeadline() error: %v", err) + } + if err := conn.SetReadDeadline(deadline); err != nil { + t.Errorf("SetReadDeadline() error: %v", err) + } + if err := conn.SetWriteDeadline(deadline); err != nil { + t.Errorf("SetWriteDeadline() error: %v", err) + } +} + +func TestSessionWithCryptConn(t *testing.T) { + conn := newMockConn() + cryptConn := network.NewCryptConn(conn) + + if cryptConn == nil { + t.Fatal("NewCryptConn() returned nil") + } + + session := &Session{ + rawConn: conn, + cryptConn: cryptConn, + } + + if session.cryptConn != cryptConn { + t.Error("Session cryptConn not set correctly") + } +} + +// Note: Tests for DSGN:100, DLTSKEYSIGN:100, and DELETE:100 request types +// require a database connection. These are integration tests that should be +// run with a test database. The handlePacket method routes to these handlers +// which immediately access the database. + +func TestSessionWorkWithDevModeLogging(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: true, + DevModeOptions: config.DevModeOptions{ + LogInboundMessages: true, + }, + } + + server := &Server{ + logger: logger, + erupeConfig: erupeConfig, + } + + // Use net.Pipe for bidirectional communication + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + session := &Session{ + logger: logger, + server: server, + rawConn: serverConn, + cryptConn: network.NewCryptConn(serverConn), + } + + // Close client side to cause read error + clientConn.Close() + + // work() should exit gracefully on read error + session.work() +} + +func TestSessionWorkWithEmptyRead(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + } + + server := &Server{ + logger: logger, + erupeConfig: erupeConfig, + } + + clientConn, serverConn := net.Pipe() + defer serverConn.Close() + + session := &Session{ + logger: logger, + server: server, + rawConn: serverConn, + cryptConn: network.NewCryptConn(serverConn), + } + + // Close client side immediately to cause read failure + clientConn.Close() + + // work() should handle the read error gracefully + session.work() +} + +// Note: Tests for handleDSGNRequest require a database connection. +// The function immediately queries the database for user authentication. +// These tests should be implemented as integration tests with a test database +// or using sqlmock for database mocking. diff --git a/server/signserver/sign_server_test.go b/server/signserver/sign_server_test.go index 9846009f5..44feb508d 100644 --- a/server/signserver/sign_server_test.go +++ b/server/signserver/sign_server_test.go @@ -2,7 +2,13 @@ package signserver import ( "fmt" + "net" "testing" + "time" + + "erupe-ce/config" + + "go.uber.org/zap" ) func TestRespIDConstants(t *testing.T) { @@ -266,3 +272,355 @@ func TestConfigFields(t *testing.T) { t.Error("Config ErupeConfig should be nil") } } + +func TestServerStartAndShutdown(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + Sign: config.Sign{ + Port: 0, // Use port 0 to get a random available port + }, + } + + cfg := &Config{ + Logger: logger, + ErupeConfig: erupeConfig, + } + + s := NewServer(cfg) + if s == nil { + t.Fatal("NewServer() returned nil") + } + + err := s.Start() + if err != nil { + t.Fatalf("Start() error: %v", err) + } + + // Verify listener is set + if s.listener == nil { + t.Error("Server listener should not be nil after Start()") + } + + // Verify not shutting down initially + s.Lock() + if s.isShuttingDown { + t.Error("Server should not be shutting down after Start()") + } + s.Unlock() + + // Shutdown + s.Shutdown() + + // Verify shutdown flag is set + s.Lock() + if !s.isShuttingDown { + t.Error("Server should be shutting down after Shutdown()") + } + s.Unlock() +} + +func TestServerStartWithInvalidPort(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + Sign: config.Sign{ + Port: -1, // Invalid port + }, + } + + cfg := &Config{ + Logger: logger, + ErupeConfig: erupeConfig, + } + + s := NewServer(cfg) + err := s.Start() + + // Should fail with invalid port + if err == nil { + s.Shutdown() + t.Error("Start() should fail with invalid port") + } +} + +func TestServerMutex(t *testing.T) { + s := &Server{} + + // Test that we can lock and unlock + s.Lock() + s.Unlock() + + // Test concurrent access + done := make(chan bool) + go func() { + s.Lock() + time.Sleep(10 * time.Millisecond) + s.Unlock() + done <- true + }() + + // Small delay to ensure goroutine starts + time.Sleep(5 * time.Millisecond) + + // This should block until the goroutine releases the lock + s.Lock() + s.Unlock() + + <-done +} + +func TestServerShutdownIdempotent(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + Sign: config.Sign{ + Port: 0, + }, + } + + cfg := &Config{ + Logger: logger, + ErupeConfig: erupeConfig, + } + + s := NewServer(cfg) + err := s.Start() + if err != nil { + t.Fatalf("Start() error: %v", err) + } + + // First shutdown + s.Shutdown() + + // Verify state + s.Lock() + if !s.isShuttingDown { + t.Error("Server should be shutting down") + } + s.Unlock() +} + +func TestServerAcceptClientsExitsOnShutdown(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + Sign: config.Sign{ + Port: 0, + }, + } + + cfg := &Config{ + Logger: logger, + ErupeConfig: erupeConfig, + } + + s := NewServer(cfg) + err := s.Start() + if err != nil { + t.Fatalf("Start() error: %v", err) + } + + // Give acceptClients goroutine time to start + time.Sleep(10 * time.Millisecond) + + // Shutdown should cause acceptClients to exit + s.Shutdown() + + // Give time for graceful exit + time.Sleep(10 * time.Millisecond) + + s.Lock() + if !s.isShuttingDown { + t.Error("Server should be marked as shutting down") + } + s.Unlock() +} + +func TestServerHandleConnection(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + Sign: config.Sign{ + Port: 0, + }, + } + + cfg := &Config{ + Logger: logger, + ErupeConfig: erupeConfig, + } + + s := NewServer(cfg) + err := s.Start() + if err != nil { + t.Fatalf("Start() error: %v", err) + } + defer s.Shutdown() + + // Connect to the server + addr := s.listener.Addr().String() + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Dial() error: %v", err) + } + defer conn.Close() + + // Send the 8 NULL bytes initialization + nullInit := make([]byte, 8) + _, err = conn.Write(nullInit) + if err != nil { + t.Fatalf("Write() error: %v", err) + } + + // Give time for handleConnection to process + time.Sleep(50 * time.Millisecond) +} + +func TestServerHandleConnectionWithShortInit(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + Sign: config.Sign{ + Port: 0, + }, + } + + cfg := &Config{ + Logger: logger, + ErupeConfig: erupeConfig, + } + + s := NewServer(cfg) + err := s.Start() + if err != nil { + t.Fatalf("Start() error: %v", err) + } + defer s.Shutdown() + + // Connect to the server + addr := s.listener.Addr().String() + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Dial() error: %v", err) + } + + // Send only 4 bytes instead of 8, then close + _, _ = conn.Write([]byte{0, 0, 0, 0}) + conn.Close() + + // Give time for handleConnection to handle the error + time.Sleep(50 * time.Millisecond) +} + +func TestServerHandleConnectionImmediateClose(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + Sign: config.Sign{ + Port: 0, + }, + } + + cfg := &Config{ + Logger: logger, + ErupeConfig: erupeConfig, + } + + s := NewServer(cfg) + err := s.Start() + if err != nil { + t.Fatalf("Start() error: %v", err) + } + defer s.Shutdown() + + // Connect and immediately close + addr := s.listener.Addr().String() + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Dial() error: %v", err) + } + conn.Close() + + // Give time for handleConnection to handle the error + time.Sleep(50 * time.Millisecond) +} + +func TestServerMultipleConnections(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + Sign: config.Sign{ + Port: 0, + }, + } + + cfg := &Config{ + Logger: logger, + ErupeConfig: erupeConfig, + } + + s := NewServer(cfg) + err := s.Start() + if err != nil { + t.Fatalf("Start() error: %v", err) + } + defer s.Shutdown() + + addr := s.listener.Addr().String() + + // Create multiple connections + conns := make([]net.Conn, 3) + for i := range conns { + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Dial() %d error: %v", i, err) + } + conns[i] = conn + + // Send null init + nullInit := make([]byte, 8) + _, _ = conn.Write(nullInit) + } + + // Give time for connections to be processed + time.Sleep(50 * time.Millisecond) + + // Close all connections + for _, conn := range conns { + conn.Close() + } +} + +func TestServerListenerAddress(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + Sign: config.Sign{ + Port: 0, + }, + } + + cfg := &Config{ + Logger: logger, + ErupeConfig: erupeConfig, + } + + s := NewServer(cfg) + err := s.Start() + if err != nil { + t.Fatalf("Start() error: %v", err) + } + defer s.Shutdown() + + addr := s.listener.Addr() + if addr == nil { + t.Error("Listener address should not be nil") + } + + // Should be a TCP address + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + t.Error("Listener address should be a TCP address") + } + + // Port should be assigned (non-zero since we requested port 0) + if tcpAddr.Port == 0 { + t.Error("Listener port should be assigned") + } +}